From bac05039a724a9db84d97732badbb744d96be1c3 Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Sun, 11 Jun 2023 22:17:41 +1000 Subject: [PATCH] Move protobuf generation into a separate crate; write .py interface in Rust A couple of motivations for this: - genbackend.py was somewhat messy, and difficult to change with the lack of types. The mobile clients used it as a base for their generation, so improving it will make life easier for them too, once they're ported. - It will make it easier to write a .ts generator in the future - We currently implement a bunch of helper methods on protobuf types which don't allow us to compile the protobuf types until we compile the Anki crate. If we change this in the future, we will be able to do more of the compilation up-front. We no longer need to record the services in the proto file, as we can extract the service order from the compiled protos. Support for map types has also been added. --- .cargo/config.toml | 1 + Cargo.lock | 45 ++++-- Cargo.toml | 1 + build/configure/src/pylib.rs | 15 +- build/configure/src/python.rs | 4 + build/configure/src/rust.rs | 25 ++- proto/anki/backend.proto | 25 --- pylib/anki/_backend.py | 9 +- pylib/tools/genbackend.py | 251 ------------------------------ rslib/Cargo.toml | 12 +- rslib/{build/main.rs => build.rs} | 4 - rslib/build/protobuf.rs | 139 ----------------- rslib/proto/Cargo.toml | 17 ++ rslib/proto/build.rs | 19 +++ rslib/proto/python.rs | 239 ++++++++++++++++++++++++++++ rslib/proto/rust.rs | 186 ++++++++++++++++++++++ rslib/proto/src/lib.rs | 2 + rslib/src/backend/mod.rs | 4 +- rslib/src/pb.rs | 4 +- 19 files changed, 542 insertions(+), 460 deletions(-) delete mode 100644 pylib/tools/genbackend.py rename rslib/{build/main.rs => build.rs} (85%) delete mode 100644 rslib/build/protobuf.rs create mode 100644 rslib/proto/Cargo.toml create mode 100644 rslib/proto/build.rs create mode 100644 rslib/proto/python.rs create mode 100644 rslib/proto/rust.rs create mode 100644 rslib/proto/src/lib.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index 838b8d683..744ad63a3 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,5 +1,6 @@ [env] STRINGS_JSON = { value = "out/rslib/i18n/strings.json", relative = true } +DESCRIPTORS_BIN = { value = "out/rslib/proto/descriptors.bin", relative = true } # build script will append .exe if necessary PROTOC = { value = "out/extracted/protoc/bin/protoc", relative = true } PYO3_NO_PYTHON = "1" diff --git a/Cargo.lock b/Cargo.lock index 2bc1a892b..a7bc7596d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,8 @@ version = "0.0.0" dependencies = [ "ammonia", "anki_i18n", + "anki_proto", + "anyhow", "async-compression", "async-stream", "async-trait", @@ -106,6 +108,7 @@ dependencies = [ "htmlescape", "hyper", "id_tree", + "inflections", "itertools", "lazy_static", "nom", @@ -117,6 +120,8 @@ dependencies = [ "pin-project", "prost", "prost-build", + "prost-reflect", + "prost-types", "pulldown-cmark 0.9.2", "rand 0.8.5", "regex", @@ -180,6 +185,17 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "anki_proto" +version = "0.0.0" +dependencies = [ + "anyhow", + "inflections", + "prost-build", + "prost-reflect", + "prost-types", +] + [[package]] name = "anstream" version = "0.2.6" @@ -2956,9 +2972,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" dependencies = [ "bytes", "prost-derive", @@ -2966,9 +2982,9 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12" +checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270" dependencies = [ "bytes", "heck", @@ -2988,9 +3004,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" dependencies = [ "anyhow", "itertools", @@ -3000,10 +3016,21 @@ dependencies = [ ] [[package]] -name = "prost-types" -version = "0.11.8" +name = "prost-reflect" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88" +checksum = "000e1e05ebf7b26e1eba298e66fe4eee6eb19c567d0ffb35e0dd34231cdac4c8" +dependencies = [ + "once_cell", + "prost", + "prost-types", +] + +[[package]] +name = "prost-types" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" dependencies = [ "prost", ] diff --git a/Cargo.toml b/Cargo.toml index 806f47ff9..e0babae9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "rslib/i18n", "rslib/i18n_helpers", "rslib/linkchecker", + "rslib/proto", "pylib/rsbridge", "build/configure", "build/ninja_gen", diff --git a/build/configure/src/pylib.rs b/build/configure/src/pylib.rs index 81d5c09f3..8799349ce 100644 --- a/build/configure/src/pylib.rs +++ b/build/configure/src/pylib.rs @@ -26,20 +26,7 @@ pub fn build_pylib(build: &mut Build) -> Result<()> { proto_files: inputs![glob!["proto/anki/*.proto"]], }, )?; - build.add( - "pylib/anki:_backend_generated.py", - RunCommand { - command: ":pyenv:bin", - args: "$script $out", - inputs: hashmap! { - "script" => inputs!["pylib/tools/genbackend.py"], - "" => inputs!["pylib/anki/_vendor/stringcase.py", ":pylib/anki:proto"] - }, - outputs: hashmap! { - "out" => vec!["pylib/anki/_backend_generated.py"] - }, - }, - )?; + build.add( "pylib/anki:_fluent.py", RunCommand { diff --git a/build/configure/src/python.rs b/build/configure/src/python.rs index 9852828fd..a986f8935 100644 --- a/build/configure/src/python.rs +++ b/build/configure/src/python.rs @@ -108,6 +108,10 @@ impl BuildAction for GenPythonProto { build.add_inputs("protoc", inputs!["$protoc_binary"]); build.add_inputs("protoc-gen-mypy", inputs![":pyenv:protoc-gen-mypy"]); build.add_outputs("", python_outputs); + // not a direct dependency, but we include the output interface in our declared + // outputs + build.add_inputs("", inputs!["rslib/proto"]); + build.add_outputs("", vec!["pylib/anki/_backend_generated.py"]); } } diff --git a/build/configure/src/rust.rs b/build/configure/src/rust.rs index 4e0ea621a..92467cdb8 100644 --- a/build/configure/src/rust.rs +++ b/build/configure/src/rust.rs @@ -22,6 +22,7 @@ use crate::proto::setup_protoc; pub fn build_rust(build: &mut Build) -> Result<()> { prepare_translations(build)?; setup_protoc(build)?; + prepare_proto_descriptors(build)?; build_rsbridge(build) } @@ -81,6 +82,24 @@ fn prepare_translations(build: &mut Build) -> Result<()> { Ok(()) } +fn prepare_proto_descriptors(build: &mut Build) -> Result<()> { + // build anki_proto and spit out descriptors/Python interface + build.add( + "rslib/proto", + CargoBuild { + inputs: inputs![glob!["{proto,rslib/proto}/**"], "$protoc_binary",], + outputs: &[RustOutput::Data( + "descriptors.bin", + "$builddir/rslib/proto/descriptors.bin", + )], + target: None, + extra_args: "-p anki_proto", + release_override: None, + }, + )?; + Ok(()) +} + fn build_rsbridge(build: &mut Build) -> Result<()> { let features = if cfg!(target_os = "linux") { "rustls" @@ -91,12 +110,12 @@ fn build_rsbridge(build: &mut Build) -> Result<()> { "pylib/rsbridge", CargoBuild { inputs: inputs![ - glob!["{pylib/rsbridge/**,rslib/**,proto/**}"], - "$protoc_binary", - // declare a dependency on i18n so it gets built first, allowing + glob!["{pylib/rsbridge/**,rslib/**}"], + // declare a dependency on i18n/proto so it gets built first, allowing // things depending on strings.json to build faster, and ensuring // changes to the ftl files trigger a rebuild ":rslib/i18n", + ":rslib/proto", // when env vars change the build hash gets updated "$builddir/build.ninja", // building on Windows requires python3.lib diff --git a/proto/anki/backend.proto b/proto/anki/backend.proto index 8db0550a8..4de1f0d02 100644 --- a/proto/anki/backend.proto +++ b/proto/anki/backend.proto @@ -9,31 +9,6 @@ package anki.backend; import "anki/links.proto"; -/// while the protobuf descriptors expose the order services are defined in, -/// that information is not available in prost, so we define an enum to make -/// sure all clients agree on the service index -enum ServiceIndex { - SERVICE_INDEX_SCHEDULER = 0; - SERVICE_INDEX_DECKS = 1; - SERVICE_INDEX_NOTES = 2; - SERVICE_INDEX_SYNC = 3; - SERVICE_INDEX_NOTETYPES = 4; - SERVICE_INDEX_CONFIG = 5; - SERVICE_INDEX_CARD_RENDERING = 6; - SERVICE_INDEX_DECK_CONFIG = 7; - SERVICE_INDEX_TAGS = 8; - SERVICE_INDEX_SEARCH = 9; - SERVICE_INDEX_STATS = 10; - SERVICE_INDEX_MEDIA = 11; - SERVICE_INDEX_I18N = 12; - SERVICE_INDEX_COLLECTION = 13; - SERVICE_INDEX_CARDS = 14; - SERVICE_INDEX_LINKS = 15; - SERVICE_INDEX_IMPORT_EXPORT = 16; - SERVICE_INDEX_ANKIDROID = 17; - SERVICE_INDEX_IMAGE_OCCLUSION = 18; -} - message BackendInit { repeated string preferred_langs = 1; string locale_folder_path = 2; diff --git a/pylib/anki/_backend.py b/pylib/anki/_backend.py index e62dd5e36..7e5d68221 100644 --- a/pylib/anki/_backend.py +++ b/pylib/anki/_backend.py @@ -125,15 +125,10 @@ class RustBackend(RustBackendGenerated): for k, v in kwargs.items() } - input = i18n_pb2.TranslateStringRequest( - module_index=module_index, - message_index=message_index, - args=args, + return self.translate_string( + module_index=module_index, message_index=message_index, args=args ) - output_bytes = self.translate_string_raw(input.SerializeToString()) - return anki.generic_pb2.String.FromString(output_bytes).val - def format_time_span( self, seconds: Any, diff --git a/pylib/tools/genbackend.py b/pylib/tools/genbackend.py deleted file mode 100644 index 2b2d3d0ae..000000000 --- a/pylib/tools/genbackend.py +++ /dev/null @@ -1,251 +0,0 @@ -#!/usr/bin/env python3 -# Copyright: Ankitects Pty Ltd and contributors -# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html - -import re -import sys - -sys.path.append("out/pylib") -sys.path.append("pylib/anki/_vendor") - -import google.protobuf.descriptor -import stringcase - -import anki.backend_pb2 -import anki.card_rendering_pb2 -import anki.cards_pb2 -import anki.collection_pb2 -import anki.config_pb2 -import anki.deckconfig_pb2 -import anki.decks_pb2 -import anki.i18n_pb2 -import anki.image_occlusion_pb2 -import anki.import_export_pb2 -import anki.links_pb2 -import anki.media_pb2 -import anki.notes_pb2 -import anki.notetypes_pb2 -import anki.scheduler_pb2 -import anki.search_pb2 -import anki.stats_pb2 -import anki.sync_pb2 -import anki.tags_pb2 - -TYPE_DOUBLE = 1 -TYPE_FLOAT = 2 -TYPE_INT64 = 3 -TYPE_UINT64 = 4 -TYPE_INT32 = 5 -TYPE_FIXED64 = 6 -TYPE_FIXED32 = 7 -TYPE_BOOL = 8 -TYPE_STRING = 9 -TYPE_GROUP = 10 -TYPE_MESSAGE = 11 -TYPE_BYTES = 12 -TYPE_UINT32 = 13 -TYPE_ENUM = 14 -TYPE_SFIXED32 = 15 -TYPE_SFIXED64 = 16 -TYPE_SINT32 = 17 -TYPE_SINT64 = 18 - -LABEL_OPTIONAL = 1 -LABEL_REQUIRED = 2 -LABEL_REPEATED = 3 - -RAW_ONLY = {"TranslateString"} - - -def python_type(field): - type = python_type_inner(field) - if field.label == LABEL_REPEATED: - type = f"Sequence[{type}]" - return type - - -def python_type_inner(field): - type = field.type - if type == TYPE_BOOL: - return "bool" - elif type in (1, 2): - return "float" - elif type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): - return "int" - elif type == TYPE_STRING: - return "str" - elif type == TYPE_BYTES: - return "bytes" - elif type == TYPE_MESSAGE: - return fullname(field.message_type.full_name) - elif type == TYPE_ENUM: - return fullname(field.enum_type.full_name) + ".V" - else: - raise Exception(f"unknown type: {type}") - - -def fullname(fullname: str) -> str: - # eg anki.generic.Empty -> anki.generic_pb2.Empty - components = fullname.split(".") - components[1] += "_pb2" - return ".".join(components) - - -# get_deck_i_d -> get_deck_id etc -def fix_snakecase(name): - for fix in "a_v", "i_d": - name = re.sub( - rf"(\w)({fix})(\w)", - lambda m: m.group(1) + m.group(2).replace("_", "") + m.group(3), - name, - ) - return name - - -def get_input_args(input_type): - fields = sorted(input_type.fields, key=lambda x: x.number) - self_star = ["self"] - if len(fields) >= 2: - self_star.append("*") - return ", ".join(self_star + [f"{f.name}: {python_type(f)}" for f in fields]) - - -def get_input_assign(input_type): - fields = sorted(input_type.fields, key=lambda x: x.number) - return ", ".join(f"{f.name}={f.name}" for f in fields) - - -def render_method(service_idx, method_idx, method): - name = fix_snakecase(stringcase.snakecase(method.name)) - input_name = method.input_type.name - - if ( - input_name.endswith("Request") or len(method.input_type.fields) < 2 - ) and not method.input_type.oneofs: - input_params = get_input_args(method.input_type) - input_assign_full = f"message = {fullname(method.input_type.full_name)}({get_input_assign(method.input_type)})" - else: - input_params = f"self, message: {fullname(method.input_type.full_name)}" - input_assign_full = "" - - if ( - len(method.output_type.fields) == 1 - and method.output_type.fields[0].type != TYPE_ENUM - ): - # unwrap single return arg - f = method.output_type.fields[0] - return_type = python_type(f) - single_attribute = f".{f.name}" - else: - return_type = fullname(method.output_type.full_name) - single_attribute = "" - - buf = f"""\ - def {name}_raw(self, message: bytes) -> bytes: - return self._run_command({service_idx}, {method_idx}, message) - -""" - - if not method.name in RAW_ONLY: - buf += f"""\ - def {name}({input_params}) -> {return_type}: - {input_assign_full} - raw_bytes = self._run_command({service_idx}, {method_idx}, message.SerializeToString()) - output = {fullname(method.output_type.full_name)}() - output.ParseFromString(raw_bytes) - return output{single_attribute} - -""" - - return buf - - -out: list[str] = [] - - -def render_service( - service: google.protobuf.descriptor.ServiceDescriptor, service_index: int -) -> None: - for method_index, method in enumerate(service.methods): - out.append(render_method(service_index, method_index, method)) - - -service_modules = dict( - I18N=anki.i18n_pb2, - COLLECTION=anki.collection_pb2, - CARDS=anki.cards_pb2, - NOTES=anki.notes_pb2, - DECKS=anki.decks_pb2, - DECK_CONFIG=anki.deckconfig_pb2, - NOTETYPES=anki.notetypes_pb2, - SCHEDULER=anki.scheduler_pb2, - SYNC=anki.sync_pb2, - CONFIG=anki.config_pb2, - SEARCH=anki.search_pb2, - STATS=anki.stats_pb2, - CARD_RENDERING=anki.card_rendering_pb2, - TAGS=anki.tags_pb2, - MEDIA=anki.media_pb2, - LINKS=anki.links_pb2, - IMPORT_EXPORT=anki.import_export_pb2, - IMAGE_OCCLUSION=anki.image_occlusion_pb2, -) - -for service in anki.backend_pb2.ServiceIndex.DESCRIPTOR.values: - # SERVICE_INDEX_TEST -> _TESTSERVICE - base = service.name.replace("SERVICE_INDEX_", "") - service_pkg = service_modules.get(base) - service_var = "_" + base.replace("_", "") + "SERVICE" - if service_var == "_ANKIDROIDSERVICE": - continue - service_obj = getattr(service_pkg, service_var) - service_index = service.number - render_service(service_obj, service_index) - -with open(sys.argv[1], "w", encoding="utf8") as f: - f.write( - '''# Copyright: Ankitects Pty Ltd and contributors -# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -# pylint: skip-file - -from __future__ import annotations - -""" -THIS FILE IS AUTOMATICALLY GENERATED - DO NOT EDIT. - -Please do not access methods on the backend directly - they may be changed -or removed at any time. Instead, please use the methods on the collection -instead. Eg, don't use col.backend.all_deck_config(), instead use -col.decks.all_config() -""" - -from typing import * - -import anki -import anki.backend_pb2 -import anki.i18n_pb2 -import anki.cards_pb2 -import anki.collection_pb2 -import anki.decks_pb2 -import anki.deckconfig_pb2 -import anki.links_pb2 -import anki.notes_pb2 -import anki.notetypes_pb2 -import anki.scheduler_pb2 -import anki.sync_pb2 -import anki.config_pb2 -import anki.search_pb2 -import anki.stats_pb2 -import anki.card_rendering_pb2 -import anki.tags_pb2 -import anki.media_pb2 -import anki.import_export_pb2 -import anki.image_occlusion_pb2 - -class RustBackendGenerated: - def _run_command(self, service: int, method: int, input: Any) -> bytes: - raise Exception("not implemented") - -''' - + "\n".join(out) - ) diff --git a/rslib/Cargo.toml b/rslib/Cargo.toml index 15e359dc4..4c7fc20f7 100644 --- a/rslib/Cargo.toml +++ b/rslib/Cargo.toml @@ -1,6 +1,5 @@ [package] name = "anki" -build = "build/main.rs" publish = false description = "Anki's Rust library code" @@ -10,10 +9,6 @@ edition.workspace = true license.workspace = true rust-version.workspace = true -[lib] -name = "anki" -path = "src/lib.rs" - [features] bench = ["criterion"] rustls = ["reqwest/rustls-tls", "reqwest/rustls-tls-native-roots"] @@ -27,7 +22,13 @@ required-features = ["bench"] # After updating anything below, run ../cargo/update_licenses.sh [build-dependencies] +anyhow = "1.0.71" +inflections = "1.1.1" +prost = "0.11.8" prost-build = "0.11.8" +prost-reflect = "0.11.4" +prost-types = "0.11.9" +regex = "1.7.3" which = "4.4.0" [dev-dependencies] @@ -42,6 +43,7 @@ features = ["json", "socks", "stream", "multipart"] [dependencies] anki_i18n = { path = "i18n" } +anki_proto = { path = "proto" } csv = { git = "https://github.com/ankitects/rust-csv.git", rev = "1c9d3aab6f79a7d815c69f925a46a4590c115f90" } percent-encoding-iri = { git = "https://github.com/ankitects/rust-url.git", rev = "bb930b8d089f4d30d7d19c12e54e66191de47b88" } diff --git a/rslib/build/main.rs b/rslib/build.rs similarity index 85% rename from rslib/build/main.rs rename to rslib/build.rs index 4ebf8519b..508c85277 100644 --- a/rslib/build/main.rs +++ b/rslib/build.rs @@ -3,11 +3,7 @@ use std::fs; -pub mod protobuf; - fn main() { - protobuf::write_backend_proto_rs(); - println!("cargo:rerun-if-changed=../out/buildhash"); let buildhash = fs::read_to_string("../out/buildhash").unwrap_or_default(); println!("cargo:rustc-env=BUILDHASH={buildhash}") diff --git a/rslib/build/protobuf.rs b/rslib/build/protobuf.rs deleted file mode 100644 index 18f67880e..000000000 --- a/rslib/build/protobuf.rs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright: Ankitects Pty Ltd and contributors -// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html - -use std::env; -use std::fmt::Write; -use std::path::PathBuf; - -struct CustomGenerator {} - -fn write_method_trait(buf: &mut String, service: &prost_build::Service) { - buf.push_str( - r#" -pub trait Service { - fn run_method(&self, method: u32, input: &[u8]) -> Result> { - match method { -"#, - ); - for (idx, method) in service.methods.iter().enumerate() { - write!( - buf, - concat!(" ", - "{idx} => {{ let input = super::{input_type}::decode(input)?;\n", - "let output = self.{rust_method}(input)?;\n", - "let mut out_bytes = Vec::new(); output.encode(&mut out_bytes)?; Ok(out_bytes) }}, "), - idx = idx, - input_type = method.input_type, - rust_method = method.name - ) - .unwrap(); - } - buf.push_str( - r#" - _ => crate::invalid_input!("invalid command"), - } - } -"#, - ); - - for method in &service.methods { - write!( - buf, - concat!( - " fn {method_name}(&self, input: super::{input_type}) -> ", - "Result;\n" - ), - method_name = method.name, - input_type = method.input_type, - output_type = method.output_type - ) - .unwrap(); - } - buf.push_str("}\n"); -} - -impl prost_build::ServiceGenerator for CustomGenerator { - fn generate(&mut self, service: prost_build::Service, buf: &mut String) { - write!( - buf, - "pub mod {name}_service {{ - use prost::Message; - use crate::error::Result; - ", - name = service.name.replace("Service", "").to_ascii_lowercase() - ) - .unwrap(); - write_method_trait(buf, &service); - buf.push('}'); - } -} - -fn service_generator() -> Box { - Box::new(CustomGenerator {}) -} - -pub fn write_backend_proto_rs() { - set_protoc_path(); - let proto_dir = PathBuf::from("../proto"); - - let subfolders = &["anki"]; - let mut paths = vec![]; - for subfolder in subfolders { - for entry in proto_dir.join(subfolder).read_dir().unwrap() { - let entry = entry.unwrap(); - let path = entry.path(); - if path - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".proto") - { - println!("cargo:rerun-if-changed={}", path.to_str().unwrap()); - paths.push(path); - } - } - } - - let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); - let mut config = prost_build::Config::new(); - config - .out_dir(&out_dir) - .service_generator(service_generator()) - .type_attribute( - "Deck.Filtered.SearchTerm.Order", - "#[derive(strum::EnumIter)]", - ) - .type_attribute( - "Deck.Normal.DayLimit", - "#[derive(Copy, Eq, serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .type_attribute("HelpPageLinkRequest.HelpPage", "#[derive(strum::EnumIter)]") - .type_attribute("CsvMetadata.Delimiter", "#[derive(strum::EnumIter)]") - .type_attribute( - "Preferences.BackupLimits", - "#[derive(Copy, serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .type_attribute( - "CsvMetadata.DupeResolution", - "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .type_attribute( - "CsvMetadata.MatchScope", - "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", - ) - .compile_protos(paths.as_slice(), &[proto_dir]) - .unwrap(); -} - -/// Set PROTOC to the custom path provided by PROTOC_BINARY, or add .exe to -/// the standard path if on Windows. -fn set_protoc_path() { - if let Ok(custom_protoc) = env::var("PROTOC_BINARY") { - env::set_var("PROTOC", custom_protoc); - } else if let Ok(bundled_protoc) = env::var("PROTOC") { - if cfg!(windows) && !bundled_protoc.ends_with(".exe") { - env::set_var("PROTOC", format!("{bundled_protoc}.exe")); - } - } -} diff --git a/rslib/proto/Cargo.toml b/rslib/proto/Cargo.toml new file mode 100644 index 000000000..c34d4c191 --- /dev/null +++ b/rslib/proto/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "anki_proto" +publish = false +description = "Anki's Rust library protobuf code" + +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[build-dependencies] +anyhow = "1.0.71" +inflections = "1.1.1" +prost-build = "0.11.9" +prost-reflect = "0.11.4" +prost-types = "0.11.9" diff --git a/rslib/proto/build.rs b/rslib/proto/build.rs new file mode 100644 index 000000000..c5cb66abe --- /dev/null +++ b/rslib/proto/build.rs @@ -0,0 +1,19 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +pub mod python; +pub mod rust; + +use std::env; +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; + +fn main() -> Result<()> { + let descriptors_path = PathBuf::from(env::var("DESCRIPTORS_BIN").context("DESCRIPTORS_BIN")?); + + let pool = rust::write_backend_proto_rs(&descriptors_path)?; + python::write_python_interface(&pool)?; + Ok(()) +} diff --git a/rslib/proto/python.rs b/rslib/proto/python.rs new file mode 100644 index 000000000..5b9ab966d --- /dev/null +++ b/rslib/proto/python.rs @@ -0,0 +1,239 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::fs::File; +use std::io::BufWriter; +use std::io::Write; +use std::path::Path; + +use anyhow::Context; +use anyhow::Result; +use inflections::Inflect; +use prost_reflect::DescriptorPool; +use prost_reflect::FieldDescriptor; +use prost_reflect::Kind; +use prost_reflect::MessageDescriptor; +use prost_reflect::MethodDescriptor; +use prost_reflect::ServiceDescriptor; + +pub(crate) fn write_python_interface(pool: &DescriptorPool) -> Result<()> { + let output_path = Path::new("../../out/pylib/anki/_backend_generated.py"); + let mut out = BufWriter::new( + File::create(output_path).with_context(|| format!("opening {output_path:?}"))?, + ); + write_header(&mut out)?; + + for service in pool.services() { + if service.name() == "AnkidroidService" { + continue; + } + for method in service.methods() { + render_method(&service, &method, &mut out); + } + } + + Ok(()) +} + +/// Generates text like the following: +/// +/// def get_field_names_raw(self, message: bytes) -> bytes: +/// return self._run_command(7, 16, message) +/// +/// def get_field_names(self, ntid: int) -> Sequence[str]: +/// message = anki.notetypes_pb2.NotetypeId(ntid=ntid) +/// raw_bytes = self._run_command(7, 16, message.SerializeToString()) +/// output = anki.generic_pb2.StringList() +/// output.ParseFromString(raw_bytes) +/// return output.vals +fn render_method(service: &ServiceDescriptor, method: &MethodDescriptor, out: &mut impl Write) { + let method_name = method.name().to_snake_case(); + let input = method.input(); + let output = method.output(); + let service_idx = service.index(); + let method_idx = method.index(); + + // raw bytes + write!( + out, + r#" def {method_name}_raw(self, message: bytes) -> bytes: + return self._run_command({service_idx}, {method_idx}, message) + +"# + ) + .unwrap(); + + // (possibly destructured) message + let (input_params, input_assign) = maybe_destructured_input(&input); + let output_constructor = full_name_to_python(output.full_name()); + let (output_msg_or_single_field, output_type) = maybe_destructured_output(&output); + write!( + out, + r#" def {method_name}({input_params}) -> {output_type}: + {input_assign} + raw_bytes = self._run_command({service_idx}, {method_idx}, message.SerializeToString()) + output = {output_constructor}() + output.ParseFromString(raw_bytes) + return {output_msg_or_single_field} + +"# + ) + .unwrap(); +} + +/// If any of the following apply to the input type: +/// - it has a single field +/// - its name ends in Request +/// - it has any optional fields +/// ...then destructuring will be skipped, and the method will take the input +/// message directly. Returns (params_line, assignment_lines) +fn maybe_destructured_input(input: &MessageDescriptor) -> (String, String) { + if (input.name().ends_with("Request") || input.fields().len() < 2) + && input.oneofs().next().is_none() + { + // destructure + let method_args = build_method_arguments(input); + let input_type = full_name_to_python(input.full_name()); + let input_message_args = build_input_message_arguments(input); + let assignment = format!("message = {input_type}({input_message_args})",); + (method_args, assignment) + } else { + // no destructure + let params = format!("self, message: {}", full_name_to_python(input.full_name())); + let assignment = String::new(); + (params, assignment) + } +} + +/// e.g. "self, *, note_ids: Sequence[int], new_fields: Sequence[int]" +fn build_method_arguments(input: &MessageDescriptor) -> String { + let fields = input.fields(); + let mut args = vec!["self".to_string()]; + if fields.len() >= 2 { + args.push("*".to_string()); + } + for field in fields { + let arg = format!("{}: {}", field.name(), python_type(&field)); + args.push(arg); + } + args.join(", ") +} + +/// e.g. "note_ids=note_ids, new_fields=new_fields" +fn build_input_message_arguments(input: &MessageDescriptor) -> String { + input + .fields() + .map(|field| { + let name = field.name(); + format!("{name}={name}") + }) + .collect::>() + .join(", ") +} + +// If output type has a single field and is not an enum, we return its single +// field value directly. Returns (expr, type), where expr is 'output' or +// 'output.'. +fn maybe_destructured_output(output: &MessageDescriptor) -> (String, String) { + let first_field = output.fields().next(); + if output.fields().len() == 1 && !matches!(first_field.as_ref().unwrap().kind(), Kind::Enum(_)) + { + let field = first_field.unwrap(); + (format!("output.{}", field.name()), python_type(&field)) + } else { + ("output".into(), full_name_to_python(output.full_name())) + } +} + +/// e.g. uint32 -> int; repeated bool -> Sequence[bool] +fn python_type(field: &FieldDescriptor) -> String { + let kind = match field.kind() { + Kind::Int32 + | Kind::Int64 + | Kind::Uint32 + | Kind::Uint64 + | Kind::Sint32 + | Kind::Sint64 + | Kind::Fixed32 + | Kind::Fixed64 + | Kind::Sfixed32 + | Kind::Sfixed64 => "int".into(), + Kind::Float | Kind::Double => "float".into(), + Kind::Bool => "bool".into(), + Kind::String => "str".into(), + Kind::Bytes => "bytes".into(), + Kind::Message(msg) => full_name_to_python(msg.full_name()), + Kind::Enum(en) => format!("{}.V", full_name_to_python(en.full_name())), + }; + if field.is_list() { + format!("Sequence[{}]", kind) + } else if field.is_map() { + let map_kind = field.kind(); + let map_kind = map_kind.as_message().unwrap(); + let map_kv: Vec<_> = map_kind.fields().map(|f| python_type(&f)).collect(); + format!("Mapping[{}, {}]", map_kv[0], map_kv[1]) + } else { + kind + } +} + +// e.g. anki.import_export.ImportResponse -> +// anki.import_export_pb2.ImportResponse +fn full_name_to_python(name: &str) -> String { + let mut name = name.splitn(3, '.'); + format!( + "{}.{}_pb2.{}", + name.next().unwrap(), + name.next().unwrap(), + name.next().unwrap() + ) +} + +fn write_header(out: &mut impl Write) -> Result<()> { + out.write_all( + br#"# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; https://www.gnu.org/licenses/agpl.html +# pylint: skip-file + +from __future__ import annotations + +""" +THIS FILE IS AUTOMATICALLY GENERATED - DO NOT EDIT. + +Please do not access methods on the backend directly - they may be changed +or removed at any time. Instead, please use the methods on the collection +instead. Eg, don't use col.backend.all_deck_config(), instead use +col.decks.all_config() +""" + +from typing import * + +import anki +import anki.backend_pb2 +import anki.card_rendering_pb2 +import anki.cards_pb2 +import anki.collection_pb2 +import anki.config_pb2 +import anki.deckconfig_pb2 +import anki.decks_pb2 +import anki.i18n_pb2 +import anki.image_occlusion_pb2 +import anki.import_export_pb2 +import anki.links_pb2 +import anki.media_pb2 +import anki.notes_pb2 +import anki.notetypes_pb2 +import anki.scheduler_pb2 +import anki.search_pb2 +import anki.stats_pb2 +import anki.sync_pb2 +import anki.tags_pb2 + +class RustBackendGenerated: + def _run_command(self, service: int, method: int, input: Any) -> bytes: + raise Exception("not implemented") + +"#, + )?; + Ok(()) +} diff --git a/rslib/proto/rust.rs b/rslib/proto/rust.rs new file mode 100644 index 000000000..bdd3b4c8c --- /dev/null +++ b/rslib/proto/rust.rs @@ -0,0 +1,186 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::env; +use std::fmt::Write; +use std::fs; +use std::path::Path; +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; +use prost_build::ServiceGenerator; +use prost_reflect::DescriptorPool; + +pub fn write_backend_proto_rs(descriptors_path: &Path) -> Result { + set_protoc_path(); + let proto_dir = PathBuf::from("../../proto"); + let paths = gather_proto_paths(&proto_dir)?; + let out_dir = Path::new("../../out/rslib/proto"); + fs::create_dir_all(out_dir).with_context(|| format!("{:?}", out_dir))?; + + prost_build::Config::new() + .out_dir(out_dir) + .file_descriptor_set_path(descriptors_path) + .service_generator(RustCodeGenerator::boxed()) + .type_attribute( + "Deck.Filtered.SearchTerm.Order", + "#[derive(strum::EnumIter)]", + ) + .type_attribute( + "Deck.Normal.DayLimit", + "#[derive(Copy, Eq, serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .type_attribute("HelpPageLinkRequest.HelpPage", "#[derive(strum::EnumIter)]") + .type_attribute("CsvMetadata.Delimiter", "#[derive(strum::EnumIter)]") + .type_attribute( + "Preferences.BackupLimits", + "#[derive(Copy, serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .type_attribute( + "CsvMetadata.DupeResolution", + "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .type_attribute( + "CsvMetadata.MatchScope", + "#[derive(serde_derive::Deserialize, serde_derive::Serialize)]", + ) + .compile_protos(paths.as_slice(), &[proto_dir]) + .context("prost build")?; + + write_service_index(out_dir, descriptors_path) +} + +fn write_service_index(out_dir: &Path, descriptors_path: &Path) -> Result { + let descriptors = fs::read(descriptors_path) + .with_context(|| format!("failed to read {descriptors_path:?}"))?; + let pool = + DescriptorPool::decode(descriptors.as_ref()).context("unable to decode descriptors")?; + let mut buf = String::new(); + + writeln!( + buf, + "#[derive(num_enum::TryFromPrimitive)] +#[repr(u32)] +pub enum ServiceIndex {{" + ) + .unwrap(); + for service in pool.services() { + writeln!( + buf, + " {} = {},", + service.name().replace("Service", ""), + service.index() + ) + .unwrap(); + } + writeln!(buf, "}}").unwrap(); + + fs::write(out_dir.join("service_index.rs"), buf).context("failed to write service index")?; + + Ok(pool) +} + +fn gather_proto_paths(proto_dir: &Path) -> Result> { + let subfolders = &["anki"]; + let mut paths = vec![]; + for subfolder in subfolders { + for entry in proto_dir.join(subfolder).read_dir().unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".proto") + { + println!("cargo:rerun-if-changed={}", path.to_str().unwrap()); + paths.push(path); + } + } + } + paths.sort(); + Ok(paths) +} + +struct RustCodeGenerator {} + +impl RustCodeGenerator { + fn boxed() -> Box { + Box::new(Self {}) + } + + fn write_method_trait(&mut self, buf: &mut String, service: &prost_build::Service) { + buf.push_str( + r#" +pub trait Service { + fn run_method(&self, method: u32, input: &[u8]) -> Result> { + match method { +"#, + ); + for (idx, method) in service.methods.iter().enumerate() { + write!( + buf, + concat!(" ", + "{idx} => {{ let input = super::{input_type}::decode(input)?;\n", + "let output = self.{rust_method}(input)?;\n", + "let mut out_bytes = Vec::new(); output.encode(&mut out_bytes)?; Ok(out_bytes) }}, "), + idx = idx, + input_type = method.input_type, + rust_method = method.name + ) + .unwrap(); + } + buf.push_str( + r#" + _ => crate::invalid_input!("invalid command"), + } + } +"#, + ); + + for method in &service.methods { + write!( + buf, + concat!( + " fn {method_name}(&self, input: super::{input_type}) -> ", + "Result;\n" + ), + method_name = method.name, + input_type = method.input_type, + output_type = method.output_type + ) + .unwrap(); + } + buf.push_str("}\n"); + } +} + +impl ServiceGenerator for RustCodeGenerator { + fn generate(&mut self, service: prost_build::Service, buf: &mut String) { + write!( + buf, + "pub mod {name}_service {{ + use prost::Message; + use crate::error::Result; + ", + name = service.name.replace("Service", "").to_ascii_lowercase() + ) + .unwrap(); + self.write_method_trait(buf, &service); + buf.push('}'); + } +} + +/// Set PROTOC to the custom path provided by PROTOC_BINARY, or add .exe to +/// the standard path if on Windows. +fn set_protoc_path() { + if let Ok(custom_protoc) = env::var("PROTOC_BINARY") { + env::set_var("PROTOC", custom_protoc); + } else if let Ok(bundled_protoc) = env::var("PROTOC") { + if cfg!(windows) && !bundled_protoc.ends_with(".exe") { + env::set_var("PROTOC", format!("{bundled_protoc}.exe")); + } + } +} diff --git a/rslib/proto/src/lib.rs b/rslib/proto/src/lib.rs new file mode 100644 index 000000000..cf499d047 --- /dev/null +++ b/rslib/proto/src/lib.rs @@ -0,0 +1,2 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html diff --git a/rslib/src/backend/mod.rs b/rslib/src/backend/mod.rs index ba7387d4e..d715862bd 100644 --- a/rslib/src/backend/mod.rs +++ b/rslib/src/backend/mod.rs @@ -64,7 +64,7 @@ use self::sync::SyncState; use self::tags::TagsService; use crate::backend::dbproxy::db_command_bytes; use crate::pb; -use crate::pb::backend::ServiceIndex; +use crate::pb::ServiceIndex; use crate::prelude::*; pub struct Backend { @@ -121,7 +121,7 @@ impl Backend { method: u32, input: &[u8], ) -> result::Result, Vec> { - ServiceIndex::from_i32(service as i32) + ServiceIndex::try_from(service) .or_invalid("invalid service") .and_then(|service| match service { ServiceIndex::Ankidroid => AnkidroidService::run_method(self, method, input), diff --git a/rslib/src/pb.rs b/rslib/src/pb.rs index 017e689ca..ce84b5e9d 100644 --- a/rslib/src/pb.rs +++ b/rslib/src/pb.rs @@ -4,11 +4,13 @@ macro_rules! protobuf { ($ident:ident, $name:literal) => { pub mod $ident { - include!(concat!(env!("OUT_DIR"), "/anki.", $name, ".rs")); + include!(concat!("../../out/rslib/proto/anki.", $name, ".rs")); } }; } +include!("../../out/rslib/proto/service_index.rs"); + protobuf!(ankidroid, "ankidroid"); protobuf!(backend, "backend"); protobuf!(card_rendering, "card_rendering");