diff --git a/.cargo/config.toml b/.cargo/config.toml index 838b8d683..744ad63a3 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,5 +1,6 @@ [env] STRINGS_JSON = { value = "out/rslib/i18n/strings.json", relative = true } +DESCRIPTORS_BIN = { value = "out/rslib/proto/descriptors.bin", relative = true } # build script will append .exe if necessary PROTOC = { value = "out/extracted/protoc/bin/protoc", relative = true } PYO3_NO_PYTHON = "1" diff --git a/Cargo.lock b/Cargo.lock index 2bc1a892b..a7bc7596d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,8 @@ version = "0.0.0" dependencies = [ "ammonia", "anki_i18n", + "anki_proto", + "anyhow", "async-compression", "async-stream", "async-trait", @@ -106,6 +108,7 @@ dependencies = [ "htmlescape", "hyper", "id_tree", + "inflections", "itertools", "lazy_static", "nom", @@ -117,6 +120,8 @@ dependencies = [ "pin-project", "prost", "prost-build", + "prost-reflect", + "prost-types", "pulldown-cmark 0.9.2", "rand 0.8.5", "regex", @@ -180,6 +185,17 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "anki_proto" +version = "0.0.0" +dependencies = [ + "anyhow", + "inflections", + "prost-build", + "prost-reflect", + "prost-types", +] + [[package]] name = "anstream" version = "0.2.6" @@ -2956,9 +2972,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" dependencies = [ "bytes", "prost-derive", @@ -2966,9 +2982,9 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12" +checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270" dependencies = [ "bytes", "heck", @@ -2988,9 +3004,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" dependencies = [ "anyhow", "itertools", @@ -3000,10 +3016,21 @@ dependencies = [ ] [[package]] -name = "prost-types" -version = "0.11.8" +name = "prost-reflect" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88" +checksum = "000e1e05ebf7b26e1eba298e66fe4eee6eb19c567d0ffb35e0dd34231cdac4c8" +dependencies = [ + "once_cell", + "prost", + "prost-types", +] + +[[package]] +name = "prost-types" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" dependencies = [ "prost", ] diff --git a/Cargo.toml b/Cargo.toml index 806f47ff9..e0babae9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "rslib/i18n", "rslib/i18n_helpers", "rslib/linkchecker", + "rslib/proto", "pylib/rsbridge", "build/configure", "build/ninja_gen", diff --git a/build/configure/src/pylib.rs b/build/configure/src/pylib.rs index 81d5c09f3..8799349ce 100644 --- a/build/configure/src/pylib.rs +++ b/build/configure/src/pylib.rs @@ -26,20 +26,7 @@ pub fn build_pylib(build: &mut Build) -> Result<()> { proto_files: inputs![glob!["proto/anki/*.proto"]], }, )?; - build.add( - "pylib/anki:_backend_generated.py", - RunCommand { - command: ":pyenv:bin", - args: "$script $out", - inputs: hashmap! { - "script" => inputs!["pylib/tools/genbackend.py"], - "" => inputs!["pylib/anki/_vendor/stringcase.py", ":pylib/anki:proto"] - }, - outputs: hashmap! { - "out" => vec!["pylib/anki/_backend_generated.py"] - }, - }, - )?; + build.add( "pylib/anki:_fluent.py", RunCommand { diff --git a/build/configure/src/python.rs b/build/configure/src/python.rs index 9852828fd..a986f8935 100644 --- a/build/configure/src/python.rs +++ b/build/configure/src/python.rs @@ -108,6 +108,10 @@ impl BuildAction for GenPythonProto { build.add_inputs("protoc", inputs!["$protoc_binary"]); build.add_inputs("protoc-gen-mypy", inputs![":pyenv:protoc-gen-mypy"]); build.add_outputs("", python_outputs); + // not a direct dependency, but we include the output interface in our declared + // outputs + build.add_inputs("", inputs!["rslib/proto"]); + build.add_outputs("", vec!["pylib/anki/_backend_generated.py"]); } } diff --git a/build/configure/src/rust.rs b/build/configure/src/rust.rs index 4e0ea621a..92467cdb8 100644 --- a/build/configure/src/rust.rs +++ b/build/configure/src/rust.rs @@ -22,6 +22,7 @@ use crate::proto::setup_protoc; pub fn build_rust(build: &mut Build) -> Result<()> { prepare_translations(build)?; setup_protoc(build)?; + prepare_proto_descriptors(build)?; build_rsbridge(build) } @@ -81,6 +82,24 @@ fn prepare_translations(build: &mut Build) -> Result<()> { Ok(()) } +fn prepare_proto_descriptors(build: &mut Build) -> Result<()> { + // build anki_proto and spit out descriptors/Python interface + build.add( + "rslib/proto", + CargoBuild { + inputs: inputs![glob!["{proto,rslib/proto}/**"], "$protoc_binary",], + outputs: &[RustOutput::Data( + "descriptors.bin", + "$builddir/rslib/proto/descriptors.bin", + )], + target: None, + extra_args: "-p anki_proto", + release_override: None, + }, + )?; + Ok(()) +} + fn build_rsbridge(build: &mut Build) -> Result<()> { let features = if cfg!(target_os = "linux") { "rustls" @@ -91,12 +110,12 @@ fn build_rsbridge(build: &mut Build) -> Result<()> { "pylib/rsbridge", CargoBuild { inputs: inputs![ - glob!["{pylib/rsbridge/**,rslib/**,proto/**}"], - "$protoc_binary", - // declare a dependency on i18n so it gets built first, allowing + glob!["{pylib/rsbridge/**,rslib/**}"], + // declare a dependency on i18n/proto so it gets built first, allowing // things depending on strings.json to build faster, and ensuring // changes to the ftl files trigger a rebuild ":rslib/i18n", + ":rslib/proto", // when env vars change the build hash gets updated "$builddir/build.ninja", // building on Windows requires python3.lib diff --git a/proto/anki/backend.proto b/proto/anki/backend.proto index 8db0550a8..4de1f0d02 100644 --- a/proto/anki/backend.proto +++ b/proto/anki/backend.proto @@ -9,31 +9,6 @@ package anki.backend; import "anki/links.proto"; -/// while the protobuf descriptors expose the order services are defined in, -/// that information is not available in prost, so we define an enum to make -/// sure all clients agree on the service index -enum ServiceIndex { - SERVICE_INDEX_SCHEDULER = 0; - SERVICE_INDEX_DECKS = 1; - SERVICE_INDEX_NOTES = 2; - SERVICE_INDEX_SYNC = 3; - SERVICE_INDEX_NOTETYPES = 4; - SERVICE_INDEX_CONFIG = 5; - SERVICE_INDEX_CARD_RENDERING = 6; - SERVICE_INDEX_DECK_CONFIG = 7; - SERVICE_INDEX_TAGS = 8; - SERVICE_INDEX_SEARCH = 9; - SERVICE_INDEX_STATS = 10; - SERVICE_INDEX_MEDIA = 11; - SERVICE_INDEX_I18N = 12; - SERVICE_INDEX_COLLECTION = 13; - SERVICE_INDEX_CARDS = 14; - SERVICE_INDEX_LINKS = 15; - SERVICE_INDEX_IMPORT_EXPORT = 16; - SERVICE_INDEX_ANKIDROID = 17; - SERVICE_INDEX_IMAGE_OCCLUSION = 18; -} - message BackendInit { repeated string preferred_langs = 1; string locale_folder_path = 2; diff --git a/pylib/anki/_backend.py b/pylib/anki/_backend.py index e62dd5e36..7e5d68221 100644 --- a/pylib/anki/_backend.py +++ b/pylib/anki/_backend.py @@ -125,15 +125,10 @@ class RustBackend(RustBackendGenerated): for k, v in kwargs.items() } - input = i18n_pb2.TranslateStringRequest( - module_index=module_index, - message_index=message_index, - args=args, + return self.translate_string( + module_index=module_index, message_index=message_index, args=args ) - output_bytes = self.translate_string_raw(input.SerializeToString()) - return anki.generic_pb2.String.FromString(output_bytes).val - def format_time_span( self, seconds: Any, diff --git a/pylib/tools/genbackend.py b/pylib/tools/genbackend.py deleted file mode 100644 index 2b2d3d0ae..000000000 --- a/pylib/tools/genbackend.py +++ /dev/null @@ -1,251 +0,0 @@ -#!/usr/bin/env python3 -# Copyright: Ankitects Pty Ltd and contributors -# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html - -import re -import sys - -sys.path.append("out/pylib") -sys.path.append("pylib/anki/_vendor") - -import google.protobuf.descriptor -import stringcase - -import anki.backend_pb2 -import anki.card_rendering_pb2 -import anki.cards_pb2 -import anki.collection_pb2 -import anki.config_pb2 -import anki.deckconfig_pb2 -import anki.decks_pb2 -import anki.i18n_pb2 -import anki.image_occlusion_pb2 -import anki.import_export_pb2 -import anki.links_pb2 -import anki.media_pb2 -import anki.notes_pb2 -import anki.notetypes_pb2 -import anki.scheduler_pb2 -import anki.search_pb2 -import anki.stats_pb2 -import anki.sync_pb2 -import anki.tags_pb2 - -TYPE_DOUBLE = 1 -TYPE_FLOAT = 2 -TYPE_INT64 = 3 -TYPE_UINT64 = 4 -TYPE_INT32 = 5 -TYPE_FIXED64 = 6 -TYPE_FIXED32 = 7 -TYPE_BOOL = 8 -TYPE_STRING = 9 -TYPE_GROUP = 10 -TYPE_MESSAGE = 11 -TYPE_BYTES = 12 -TYPE_UINT32 = 13 -TYPE_ENUM = 14 -TYPE_SFIXED32 = 15 -TYPE_SFIXED64 = 16 -TYPE_SINT32 = 17 -TYPE_SINT64 = 18 - -LABEL_OPTIONAL = 1 -LABEL_REQUIRED = 2 -LABEL_REPEATED = 3 - -RAW_ONLY = {"TranslateString"} - - -def python_type(field): - type = python_type_inner(field) - if field.label == LABEL_REPEATED: - type = f"Sequence[{type}]" - return type - - -def python_type_inner(field): - type = field.type - if type == TYPE_BOOL: - return "bool" - elif type in (1, 2): - return "float" - elif type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): - return "int" - elif type == TYPE_STRING: - return "str" - elif type == TYPE_BYTES: - return "bytes" - elif type == TYPE_MESSAGE: - return fullname(field.message_type.full_name) - elif type == TYPE_ENUM: - return fullname(field.enum_type.full_name) + ".V" - else: - raise Exception(f"unknown type: {type}") - - -def fullname(fullname: str) -> str: - # eg anki.generic.Empty -> anki.generic_pb2.Empty - components = fullname.split(".") - components[1] += "_pb2" - return ".".join(components) - - -# get_deck_i_d -> get_deck_id etc -def fix_snakecase(name): - for fix in "a_v", "i_d": - name = re.sub( - rf"(\w)({fix})(\w)", - lambda m: m.group(1) + m.group(2).replace("_", "") + m.group(3), - name, - ) - return name - - -def get_input_args(input_type): - fields = sorted(input_type.fields, key=lambda x: x.number) - self_star = ["self"] - if len(fields) >= 2: - self_star.append("*") - return ", ".join(self_star + [f"{f.name}: {python_type(f)}" for f in fields]) - - -def get_input_assign(input_type): - fields = sorted(input_type.fields, key=lambda x: x.number) - return ", ".join(f"{f.name}={f.name}" for f in fields) - - -def render_method(service_idx, method_idx, method): - name = fix_snakecase(stringcase.snakecase(method.name)) - input_name = method.input_type.name - - if ( - input_name.endswith("Request") or len(method.input_type.fields) < 2 - ) and not method.input_type.oneofs: - input_params = get_input_args(method.input_type) - input_assign_full = f"message = {fullname(method.input_type.full_name)}({get_input_assign(method.input_type)})" - else: - input_params = f"self, message: {fullname(method.input_type.full_name)}" - input_assign_full = "" - - if ( - len(method.output_type.fields) == 1 - and method.output_type.fields[0].type != TYPE_ENUM - ): - # unwrap single return arg - f = method.output_type.fields[0] - return_type = python_type(f) - single_attribute = f".{f.name}" - else: - return_type = fullname(method.output_type.full_name) - single_attribute = "" - - buf = f"""\ - def {name}_raw(self, message: bytes) -> bytes: - return self._run_command({service_idx}, {method_idx}, message) - -""" - - if not method.name in RAW_ONLY: - buf += f"""\ - def {name}({input_params}) -> {return_type}: - {input_assign_full} - raw_bytes = self._run_command({service_idx}, {method_idx}, message.SerializeToString()) - output = {fullname(method.output_type.full_name)}() - output.ParseFromString(raw_bytes) - return output{single_attribute} - -""" - - return buf - - -out: list[str] = [] - - -def render_service( - service: google.protobuf.descriptor.ServiceDescriptor, service_index: int -) -> None: - for method_index, method in enumerate(service.methods): - out.append(render_method(service_index, method_index, method)) - - -service_modules = dict( - I18N=anki.i18n_pb2, - COLLECTION=anki.collection_pb2, - CARDS=anki.cards_pb2, - NOTES=anki.notes_pb2, - DECKS=anki.decks_pb2, - DECK_CONFIG=anki.deckconfig_pb2, - NOTETYPES=anki.notetypes_pb2, - SCHEDULER=anki.scheduler_pb2, - SYNC=anki.sync_pb2, - CONFIG=anki.config_pb2, - SEARCH=anki.search_pb2, - STATS=anki.stats_pb2, - CARD_RENDERING=anki.card_rendering_pb2, - TAGS=anki.tags_pb2, - MEDIA=anki.media_pb2, - LINKS=anki.links_pb2, - IMPORT_EXPORT=anki.import_export_pb2, - IMAGE_OCCLUSION=anki.image_occlusion_pb2, -) - -for service in anki.backend_pb2.ServiceIndex.DESCRIPTOR.values: - # SERVICE_INDEX_TEST -> _TESTSERVICE - base = service.name.replace("SERVICE_INDEX_", "") - service_pkg = service_modules.get(base) - service_var = "_" + base.replace("_", "") + "SERVICE" - if service_var == "_ANKIDROIDSERVICE": - continue - service_obj = getattr(service_pkg, service_var) - service_index = service.number - render_service(service_obj, service_index) - -with open(sys.argv[1], "w", encoding="utf8") as f: - f.write( - '''# Copyright: Ankitects Pty Ltd and contributors -# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -# pylint: skip-file - -from __future__ import annotations - -""" -THIS FILE IS AUTOMATICALLY GENERATED - DO NOT EDIT. - -Please do not access methods on the backend directly - they may be changed -or removed at any time. Instead, please use the methods on the collection -instead. Eg, don't use col.backend.all_deck_config(), instead use -col.decks.all_config() -""" - -from typing import * - -import anki -import anki.backend_pb2 -import anki.i18n_pb2 -import anki.cards_pb2 -import anki.collection_pb2 -import anki.decks_pb2 -import anki.deckconfig_pb2 -import anki.links_pb2 -import anki.notes_pb2 -import anki.notetypes_pb2 -import anki.scheduler_pb2 -import anki.sync_pb2 -import anki.config_pb2 -import anki.search_pb2 -import anki.stats_pb2 -import anki.card_rendering_pb2 -import anki.tags_pb2 -import anki.media_pb2 -import anki.import_export_pb2 -import anki.image_occlusion_pb2 - -class RustBackendGenerated: - def _run_command(self, service: int, method: int, input: Any) -> bytes: - raise Exception("not implemented") - -''' - + "\n".join(out) - ) diff --git a/rslib/Cargo.toml b/rslib/Cargo.toml index 15e359dc4..4c7fc20f7 100644 --- a/rslib/Cargo.toml +++ b/rslib/Cargo.toml @@ -1,6 +1,5 @@ [package] name = "anki" -build = "build/main.rs" publish = false description = "Anki's Rust library code" @@ -10,10 +9,6 @@ edition.workspace = true license.workspace = true rust-version.workspace = true -[lib] -name = "anki" -path = "src/lib.rs" - [features] bench = ["criterion"] rustls = ["reqwest/rustls-tls", "reqwest/rustls-tls-native-roots"] @@ -27,7 +22,13 @@ required-features = ["bench"] # After updating anything below, run ../cargo/update_licenses.sh [build-dependencies] +anyhow = "1.0.71" +inflections = "1.1.1" +prost = "0.11.8" prost-build = "0.11.8" +prost-reflect = "0.11.4" +prost-types = "0.11.9" +regex = "1.7.3" which = "4.4.0" [dev-dependencies] @@ -42,6 +43,7 @@ features = ["json", "socks", "stream", "multipart"] [dependencies] anki_i18n = { path = "i18n" } +anki_proto = { path = "proto" } csv = { git = "https://github.com/ankitects/rust-csv.git", rev = "1c9d3aab6f79a7d815c69f925a46a4590c115f90" } percent-encoding-iri = { git = "https://github.com/ankitects/rust-url.git", rev = "bb930b8d089f4d30d7d19c12e54e66191de47b88" } diff --git a/rslib/build/main.rs b/rslib/build.rs similarity index 85% rename from rslib/build/main.rs rename to rslib/build.rs index 4ebf8519b..508c85277 100644 --- a/rslib/build/main.rs +++ b/rslib/build.rs @@ -3,11 +3,7 @@ use std::fs; -pub mod protobuf; - fn main() { - protobuf::write_backend_proto_rs(); - println!("cargo:rerun-if-changed=../out/buildhash"); let buildhash = fs::read_to_string("../out/buildhash").unwrap_or_default(); println!("cargo:rustc-env=BUILDHASH={buildhash}") diff --git a/rslib/build/protobuf.rs b/rslib/build/protobuf.rs deleted file mode 100644 index 18f67880e..000000000 --- a/rslib/build/protobuf.rs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright: Ankitects Pty Ltd and contributors -// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html - -use std::env; -use std::fmt::Write; -use std::path::PathBuf; - -struct CustomGenerator {} - -fn write_method_trait(buf: &mut String, service: &prost_build::Service) { - buf.push_str( - r#" -pub trait Service { - fn run_method(&self, method: u32, input: &[u8]) -> Result> { - match method { -"#, - ); - for (idx, method) in service.methods.iter().enumerate() { - write!( - buf, - concat!(" ", - "{idx} => {{ let input = super::{input_type}::decode(input)?;\n", - "let output = self.{rust_method}(input)?;\n", - "let mut out_bytes = Vec::new(); output.encode(&mut out_bytes)?; Ok(out_bytes) }}, "), - idx = idx, - input_type = method.input_type, - rust_method = method.name - ) - .unwrap(); - } - buf.push_str( - r#" - _ => crate::invalid_input!("invalid command"), - } - } -"#, - ); - - for method in &service.methods { - write!( - buf, - concat!( - " fn {method_name}(&self, input: super::{input_type}) -> ", - "Result;\n" - ), - method_name = method.name, - input_type = method.input_type, - output_type = method.output_type - ) - .unwrap(); - } - buf.push_str("}\n"); -} - -impl prost_build::ServiceGenerator for CustomGenerator { - fn generate(&mut self, service: prost_build::Service, buf: &mut String) { - write!( - buf, - "pub mod {name}_service {{ - use prost::Message; - use crate::error::Result; - ", - name = service.name.replace("Service", "").to_ascii_lowercase() - ) - .unwrap(); - write_method_trait(buf, &service); - buf.push('}'); - } -} - -fn service_generator() -> Box { - Box::new(CustomGenerator {}) -} - -pub fn write_backend_proto_rs() { - set_protoc_path(); - let proto_dir = PathBuf::from("../proto"); - - let subfolders = &["anki"]; - let mut paths = vec![]; - for subfolder in subfolders { - for entry in proto_dir.join(subfolder).read_dir().unwrap() { - let entry = entry.unwrap(); - let path = entry.path(); - if path - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".proto") - { - println!("cargo:rerun-if-changed={}", path.to_str().unwrap()); - paths.push(path); - } - } - } - - let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); - let mut config = prost_build::Config::new(); - config - .out_dir(&out_dir) - .service_generator(service_generator()) - .type_attribute( - "Deck.Filtered.SearchTerm.Order", - "#[derive(strum::EnumIter)]", - ) - .type_attribute( - "Deck.Normal.DayLimit", - "#[derive(Copy, Eq, serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .type_attribute("HelpPageLinkRequest.HelpPage", "#[derive(strum::EnumIter)]") - .type_attribute("CsvMetadata.Delimiter", "#[derive(strum::EnumIter)]") - .type_attribute( - "Preferences.BackupLimits", - "#[derive(Copy, serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .type_attribute( - "CsvMetadata.DupeResolution", - "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .type_attribute( - "CsvMetadata.MatchScope", - "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .compile_protos(paths.as_slice(), &[proto_dir]) - .unwrap(); -} - -/// Set PROTOC to the custom path provided by PROTOC_BINARY, or add .exe to -/// the standard path if on Windows. -fn set_protoc_path() { - if let Ok(custom_protoc) = env::var("PROTOC_BINARY") { - env::set_var("PROTOC", custom_protoc); - } else if let Ok(bundled_protoc) = env::var("PROTOC") { - if cfg!(windows) && !bundled_protoc.ends_with(".exe") { - env::set_var("PROTOC", format!("{bundled_protoc}.exe")); - } - } -} diff --git a/rslib/proto/Cargo.toml b/rslib/proto/Cargo.toml new file mode 100644 index 000000000..c34d4c191 --- /dev/null +++ b/rslib/proto/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "anki_proto" +publish = false +description = "Anki's Rust library protobuf code" + +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[build-dependencies] +anyhow = "1.0.71" +inflections = "1.1.1" +prost-build = "0.11.9" +prost-reflect = "0.11.4" +prost-types = "0.11.9" diff --git a/rslib/proto/build.rs b/rslib/proto/build.rs new file mode 100644 index 000000000..c5cb66abe --- /dev/null +++ b/rslib/proto/build.rs @@ -0,0 +1,19 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +pub mod python; +pub mod rust; + +use std::env; +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; + +fn main() -> Result<()> { + let descriptors_path = PathBuf::from(env::var("DESCRIPTORS_BIN").context("DESCRIPTORS_BIN")?); + + let pool = rust::write_backend_proto_rs(&descriptors_path)?; + python::write_python_interface(&pool)?; + Ok(()) +} diff --git a/rslib/proto/python.rs b/rslib/proto/python.rs new file mode 100644 index 000000000..5b9ab966d --- /dev/null +++ b/rslib/proto/python.rs @@ -0,0 +1,239 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::fs::File; +use std::io::BufWriter; +use std::io::Write; +use std::path::Path; + +use anyhow::Context; +use anyhow::Result; +use inflections::Inflect; +use prost_reflect::DescriptorPool; +use prost_reflect::FieldDescriptor; +use prost_reflect::Kind; +use prost_reflect::MessageDescriptor; +use prost_reflect::MethodDescriptor; +use prost_reflect::ServiceDescriptor; + +pub(crate) fn write_python_interface(pool: &DescriptorPool) -> Result<()> { + let output_path = Path::new("../../out/pylib/anki/_backend_generated.py"); + let mut out = BufWriter::new( + File::create(output_path).with_context(|| format!("opening {output_path:?}"))?, + ); + write_header(&mut out)?; + + for service in pool.services() { + if service.name() == "AnkidroidService" { + continue; + } + for method in service.methods() { + render_method(&service, &method, &mut out); + } + } + + Ok(()) +} + +/// Generates text like the following: +/// +/// def get_field_names_raw(self, message: bytes) -> bytes: +/// return self._run_command(7, 16, message) +/// +/// def get_field_names(self, ntid: int) -> Sequence[str]: +/// message = anki.notetypes_pb2.NotetypeId(ntid=ntid) +/// raw_bytes = self._run_command(7, 16, message.SerializeToString()) +/// output = anki.generic_pb2.StringList() +/// output.ParseFromString(raw_bytes) +/// return output.vals +fn render_method(service: &ServiceDescriptor, method: &MethodDescriptor, out: &mut impl Write) { + let method_name = method.name().to_snake_case(); + let input = method.input(); + let output = method.output(); + let service_idx = service.index(); + let method_idx = method.index(); + + // raw bytes + write!( + out, + r#" def {method_name}_raw(self, message: bytes) -> bytes: + return self._run_command({service_idx}, {method_idx}, message) + +"# + ) + .unwrap(); + + // (possibly destructured) message + let (input_params, input_assign) = maybe_destructured_input(&input); + let output_constructor = full_name_to_python(output.full_name()); + let (output_msg_or_single_field, output_type) = maybe_destructured_output(&output); + write!( + out, + r#" def {method_name}({input_params}) -> {output_type}: + {input_assign} + raw_bytes = self._run_command({service_idx}, {method_idx}, message.SerializeToString()) + output = {output_constructor}() + output.ParseFromString(raw_bytes) + return {output_msg_or_single_field} + +"# + ) + .unwrap(); +} + +/// If any of the following apply to the input type: +/// - it has a single field +/// - its name ends in Request +/// - it has any optional fields +/// ...then destructuring will be skipped, and the method will take the input +/// message directly. Returns (params_line, assignment_lines) +fn maybe_destructured_input(input: &MessageDescriptor) -> (String, String) { + if (input.name().ends_with("Request") || input.fields().len() < 2) + && input.oneofs().next().is_none() + { + // destructure + let method_args = build_method_arguments(input); + let input_type = full_name_to_python(input.full_name()); + let input_message_args = build_input_message_arguments(input); + let assignment = format!("message = {input_type}({input_message_args})",); + (method_args, assignment) + } else { + // no destructure + let params = format!("self, message: {}", full_name_to_python(input.full_name())); + let assignment = String::new(); + (params, assignment) + } +} + +/// e.g. "self, *, note_ids: Sequence[int], new_fields: Sequence[int]" +fn build_method_arguments(input: &MessageDescriptor) -> String { + let fields = input.fields(); + let mut args = vec!["self".to_string()]; + if fields.len() >= 2 { + args.push("*".to_string()); + } + for field in fields { + let arg = format!("{}: {}", field.name(), python_type(&field)); + args.push(arg); + } + args.join(", ") +} + +/// e.g. "note_ids=note_ids, new_fields=new_fields" +fn build_input_message_arguments(input: &MessageDescriptor) -> String { + input + .fields() + .map(|field| { + let name = field.name(); + format!("{name}={name}") + }) + .collect::>() + .join(", ") +} + +// If output type has a single field and is not an enum, we return its single +// field value directly. Returns (expr, type), where expr is 'output' or +// 'output.'. +fn maybe_destructured_output(output: &MessageDescriptor) -> (String, String) { + let first_field = output.fields().next(); + if output.fields().len() == 1 && !matches!(first_field.as_ref().unwrap().kind(), Kind::Enum(_)) + { + let field = first_field.unwrap(); + (format!("output.{}", field.name()), python_type(&field)) + } else { + ("output".into(), full_name_to_python(output.full_name())) + } +} + +/// e.g. uint32 -> int; repeated bool -> Sequence[bool] +fn python_type(field: &FieldDescriptor) -> String { + let kind = match field.kind() { + Kind::Int32 + | Kind::Int64 + | Kind::Uint32 + | Kind::Uint64 + | Kind::Sint32 + | Kind::Sint64 + | Kind::Fixed32 + | Kind::Fixed64 + | Kind::Sfixed32 + | Kind::Sfixed64 => "int".into(), + Kind::Float | Kind::Double => "float".into(), + Kind::Bool => "bool".into(), + Kind::String => "str".into(), + Kind::Bytes => "bytes".into(), + Kind::Message(msg) => full_name_to_python(msg.full_name()), + Kind::Enum(en) => format!("{}.V", full_name_to_python(en.full_name())), + }; + if field.is_list() { + format!("Sequence[{}]", kind) + } else if field.is_map() { + let map_kind = field.kind(); + let map_kind = map_kind.as_message().unwrap(); + let map_kv: Vec<_> = map_kind.fields().map(|f| python_type(&f)).collect(); + format!("Mapping[{}, {}]", map_kv[0], map_kv[1]) + } else { + kind + } +} + +// e.g. anki.import_export.ImportResponse -> +// anki.import_export_pb2.ImportResponse +fn full_name_to_python(name: &str) -> String { + let mut name = name.splitn(3, '.'); + format!( + "{}.{}_pb2.{}", + name.next().unwrap(), + name.next().unwrap(), + name.next().unwrap() + ) +} + +fn write_header(out: &mut impl Write) -> Result<()> { + out.write_all( + br#"# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; https://www.gnu.org/licenses/agpl.html +# pylint: skip-file + +from __future__ import annotations + +""" +THIS FILE IS AUTOMATICALLY GENERATED - DO NOT EDIT. + +Please do not access methods on the backend directly - they may be changed +or removed at any time. Instead, please use the methods on the collection +instead. Eg, don't use col.backend.all_deck_config(), instead use +col.decks.all_config() +""" + +from typing import * + +import anki +import anki.backend_pb2 +import anki.card_rendering_pb2 +import anki.cards_pb2 +import anki.collection_pb2 +import anki.config_pb2 +import anki.deckconfig_pb2 +import anki.decks_pb2 +import anki.i18n_pb2 +import anki.image_occlusion_pb2 +import anki.import_export_pb2 +import anki.links_pb2 +import anki.media_pb2 +import anki.notes_pb2 +import anki.notetypes_pb2 +import anki.scheduler_pb2 +import anki.search_pb2 +import anki.stats_pb2 +import anki.sync_pb2 +import anki.tags_pb2 + +class RustBackendGenerated: + def _run_command(self, service: int, method: int, input: Any) -> bytes: + raise Exception("not implemented") + +"#, + )?; + Ok(()) +} diff --git a/rslib/proto/rust.rs b/rslib/proto/rust.rs new file mode 100644 index 000000000..bdd3b4c8c --- /dev/null +++ b/rslib/proto/rust.rs @@ -0,0 +1,186 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::env; +use std::fmt::Write; +use std::fs; +use std::path::Path; +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; +use prost_build::ServiceGenerator; +use prost_reflect::DescriptorPool; + +pub fn write_backend_proto_rs(descriptors_path: &Path) -> Result { + set_protoc_path(); + let proto_dir = PathBuf::from("../../proto"); + let paths = gather_proto_paths(&proto_dir)?; + let out_dir = Path::new("../../out/rslib/proto"); + fs::create_dir_all(out_dir).with_context(|| format!("{:?}", out_dir))?; + + prost_build::Config::new() + .out_dir(out_dir) + .file_descriptor_set_path(descriptors_path) + .service_generator(RustCodeGenerator::boxed()) + .type_attribute( + "Deck.Filtered.SearchTerm.Order", + "#[derive(strum::EnumIter)]", + ) + .type_attribute( + "Deck.Normal.DayLimit", + "#[derive(Copy, Eq, serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .type_attribute("HelpPageLinkRequest.HelpPage", "#[derive(strum::EnumIter)]") + .type_attribute("CsvMetadata.Delimiter", "#[derive(strum::EnumIter)]") + .type_attribute( + "Preferences.BackupLimits", + "#[derive(Copy, serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .type_attribute( + "CsvMetadata.DupeResolution", + "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .type_attribute( + "CsvMetadata.MatchScope", + "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .compile_protos(paths.as_slice(), &[proto_dir]) + .context("prost build")?; + + write_service_index(out_dir, descriptors_path) +} + +fn write_service_index(out_dir: &Path, descriptors_path: &Path) -> Result { + let descriptors = fs::read(descriptors_path) + .with_context(|| format!("failed to read {descriptors_path:?}"))?; + let pool = + DescriptorPool::decode(descriptors.as_ref()).context("unable to decode descriptors")?; + let mut buf = String::new(); + + writeln!( + buf, + "#[derive(num_enum::TryFromPrimitive)] +#[repr(u32)] +pub enum ServiceIndex {{" + ) + .unwrap(); + for service in pool.services() { + writeln!( + buf, + " {} = {},", + service.name().replace("Service", ""), + service.index() + ) + .unwrap(); + } + writeln!(buf, "}}").unwrap(); + + fs::write(out_dir.join("service_index.rs"), buf).context("failed to write service index")?; + + Ok(pool) +} + +fn gather_proto_paths(proto_dir: &Path) -> Result> { + let subfolders = &["anki"]; + let mut paths = vec![]; + for subfolder in subfolders { + for entry in proto_dir.join(subfolder).read_dir().unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".proto") + { + println!("cargo:rerun-if-changed={}", path.to_str().unwrap()); + paths.push(path); + } + } + } + paths.sort(); + Ok(paths) +} + +struct RustCodeGenerator {} + +impl RustCodeGenerator { + fn boxed() -> Box { + Box::new(Self {}) + } + + fn write_method_trait(&mut self, buf: &mut String, service: &prost_build::Service) { + buf.push_str( + r#" +pub trait Service { + fn run_method(&self, method: u32, input: &[u8]) -> Result> { + match method { +"#, + ); + for (idx, method) in service.methods.iter().enumerate() { + write!( + buf, + concat!(" ", + "{idx} => {{ let input = super::{input_type}::decode(input)?;\n", + "let output = self.{rust_method}(input)?;\n", + "let mut out_bytes = Vec::new(); output.encode(&mut out_bytes)?; Ok(out_bytes) }}, "), + idx = idx, + input_type = method.input_type, + rust_method = method.name + ) + .unwrap(); + } + buf.push_str( + r#" + _ => crate::invalid_input!("invalid command"), + } + } +"#, + ); + + for method in &service.methods { + write!( + buf, + concat!( + " fn {method_name}(&self, input: super::{input_type}) -> ", + "Result;\n" + ), + method_name = method.name, + input_type = method.input_type, + output_type = method.output_type + ) + .unwrap(); + } + buf.push_str("}\n"); + } +} + +impl ServiceGenerator for RustCodeGenerator { + fn generate(&mut self, service: prost_build::Service, buf: &mut String) { + write!( + buf, + "pub mod {name}_service {{ + use prost::Message; + use crate::error::Result; + ", + name = service.name.replace("Service", "").to_ascii_lowercase() + ) + .unwrap(); + self.write_method_trait(buf, &service); + buf.push('}'); + } +} + +/// Set PROTOC to the custom path provided by PROTOC_BINARY, or add .exe to +/// the standard path if on Windows. +fn set_protoc_path() { + if let Ok(custom_protoc) = env::var("PROTOC_BINARY") { + env::set_var("PROTOC", custom_protoc); + } else if let Ok(bundled_protoc) = env::var("PROTOC") { + if cfg!(windows) && !bundled_protoc.ends_with(".exe") { + env::set_var("PROTOC", format!("{bundled_protoc}.exe")); + } + } +} diff --git a/rslib/proto/src/lib.rs b/rslib/proto/src/lib.rs new file mode 100644 index 000000000..cf499d047 --- /dev/null +++ b/rslib/proto/src/lib.rs @@ -0,0 +1,2 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html diff --git a/rslib/src/backend/mod.rs b/rslib/src/backend/mod.rs index ba7387d4e..d715862bd 100644 --- a/rslib/src/backend/mod.rs +++ b/rslib/src/backend/mod.rs @@ -64,7 +64,7 @@ use self::sync::SyncState; use self::tags::TagsService; use crate::backend::dbproxy::db_command_bytes; use crate::pb; -use crate::pb::backend::ServiceIndex; +use crate::pb::ServiceIndex; use crate::prelude::*; pub struct Backend { @@ -121,7 +121,7 @@ impl Backend { method: u32, input: &[u8], ) -> result::Result, Vec> { - ServiceIndex::from_i32(service as i32) + ServiceIndex::try_from(service) .or_invalid("invalid service") .and_then(|service| match service { ServiceIndex::Ankidroid => AnkidroidService::run_method(self, method, input), diff --git a/rslib/src/pb.rs b/rslib/src/pb.rs index 017e689ca..ce84b5e9d 100644 --- a/rslib/src/pb.rs +++ b/rslib/src/pb.rs @@ -4,11 +4,13 @@ macro_rules! protobuf { ($ident:ident, $name:literal) => { pub mod $ident { - include!(concat!(env!("OUT_DIR"), "/anki.", $name, ".rs")); + include!(concat!("../../out/rslib/proto/anki.", $name, ".rs")); } }; } +include!("../../out/rslib/proto/service_index.rs"); + protobuf!(ankidroid, "ankidroid"); protobuf!(backend, "backend"); protobuf!(card_rendering, "card_rendering");