start reworking protobuf handling

Will allow us to cut down on boilerplate by automatically generating
code from RPC service definitions
This commit is contained in:
Damien Elmes 2020-05-22 20:02:10 +10:00
parent 58a243aa6c
commit 9c20d9a02b
9 changed files with 363 additions and 105 deletions

1
proto/.gitignore vendored
View file

@ -1 +1,2 @@
fluent.proto

View file

@ -14,6 +14,11 @@ message OptionalUInt32 {
uint32 val = 1;
}
service BackendService {
rpc RenderExistingCard (RenderExistingCardIn) returns (RenderCardOut);
rpc RenderUncommittedCard (RenderUncommittedCardIn) returns (RenderCardOut);
}
// Protobuf stored in .anki2 files
// These should be moved to a separate file in the future
///////////////////////////////////////////////////////////

View file

@ -39,12 +39,14 @@ all: check
python -m pip install -r requirements.dev
@touch $@
PROTODEPS := $(wildcard ../proto/*.proto)
PROTODEPS := ../proto/backend.proto ../proto/fluent.proto
.build/py-proto: .build/dev-deps $(PROTODEPS)
protoc --proto_path=../proto --python_out=anki --mypy_out=anki $(PROTODEPS)
perl -i'' -pe 's/from fluent_pb2/from anki.fluent_pb2/' anki/backend_pb2.pyi
perl -i'' -pe 's/import fluent_pb2/import anki.fluent_pb2/' anki/backend_pb2.py
python tools/genbackend.py
python -m black anki/rsbackend.py
@touch $@
.build/hooks: tools/genhooks.py tools/hookslib.py
@ -52,7 +54,7 @@ PROTODEPS := $(wildcard ../proto/*.proto)
python -m black anki/hooks.py
@touch $@
BUILD_STEPS := .build/vernum .build/run-deps .build/dev-deps .build/py-proto anki/buildinfo.py .build/hooks
BUILD_STEPS := .build/vernum .build/run-deps .build/dev-deps anki/buildinfo.py .build/py-proto .build/hooks
# Checking
######################

View file

@ -65,6 +65,10 @@ except:
loads = json.loads
to_json_bytes = orjson.dumps
from_json_bytes = orjson.loads
class Interrupted(Exception):
pass
@ -161,22 +165,6 @@ def av_tag_to_native(tag: pb.AVTag) -> AVTag:
)
@dataclass
class TemplateReplacement:
field_name: str
current_text: str
filters: List[str]
TemplateReplacementList = List[Union[str, TemplateReplacement]]
@dataclass
class PartiallyRenderedCard:
qnodes: TemplateReplacementList
anodes: TemplateReplacementList
MediaSyncProgress = pb.MediaSyncProgress
MediaCheckOutput = pb.MediaCheckOut
@ -207,24 +195,6 @@ class Progress:
val: Union[MediaSyncProgress, str]
def proto_replacement_list_to_native(
nodes: List[pb.RenderedTemplateNode],
) -> TemplateReplacementList:
results: TemplateReplacementList = []
for node in nodes:
if node.WhichOneof("value") == "text":
results.append(node.text)
else:
results.append(
TemplateReplacement(
field_name=node.replacement.field_name,
current_text=node.replacement.current_text,
filters=list(node.replacement.filters),
)
)
return results
def proto_progress_to_native(progress: pb.Progress) -> Progress:
kind = progress.WhichOneof("value")
if kind == "media_sync":
@ -302,40 +272,6 @@ class RustBackend:
pb.BackendInput(sched_timing_today=pb.Empty())
).sched_timing_today
def render_existing_card(self, cid: int, browser: bool) -> PartiallyRenderedCard:
out = self._run_command(
pb.BackendInput(
render_existing_card=pb.RenderExistingCardIn(
card_id=cid, browser=browser,
)
)
).render_existing_card
qnodes = proto_replacement_list_to_native(out.question_nodes) # type: ignore
anodes = proto_replacement_list_to_native(out.answer_nodes) # type: ignore
return PartiallyRenderedCard(qnodes, anodes)
def render_uncommitted_card(
self, note: BackendNote, card_ord: int, template: Dict, fill_empty: bool
) -> PartiallyRenderedCard:
template_json = orjson.dumps(template)
out = self._run_command(
pb.BackendInput(
render_uncommitted_card=pb.RenderUncommittedCardIn(
note=note,
template=template_json,
card_ord=card_ord,
fill_empty=fill_empty,
)
)
).render_uncommitted_card
qnodes = proto_replacement_list_to_native(out.question_nodes) # type: ignore
anodes = proto_replacement_list_to_native(out.answer_nodes) # type: ignore
return PartiallyRenderedCard(qnodes, anodes)
def local_minutes_west(self, stamp: int) -> int:
return self._run_command(
pb.BackendInput(local_minutes_west=stamp)
@ -830,6 +766,39 @@ class RustBackend:
).cloze_numbers_in_note.numbers
)
def _run_command2(self, method: int, input: Any) -> bytes:
input_bytes = input.SerializeToString()
try:
return self._backend.command2(method, input_bytes)
except Exception as e:
err_bytes = bytes(e.args[0])
err = pb.BackendError()
err.ParseFromString(err_bytes)
raise proto_exception_to_native(err)
# The code in this section is automatically generated - any edits you make
# will be lost.
# @@AUTOGEN@@
def render_existing_card(self, card_id: int, browser: bool) -> pb.RenderCardOut:
input = pb.RenderExistingCardIn(card_id=card_id, browser=browser)
output = pb.RenderCardOut()
output.ParseFromString(self._run_command2(1, input))
return output
def render_uncommitted_card(
self, note: pb.Note, card_ord: int, template: bytes, fill_empty: bool
) -> pb.RenderCardOut:
input = pb.RenderUncommittedCardIn(
note=note, card_ord=card_ord, template=template, fill_empty=fill_empty
)
output = pb.RenderCardOut()
output.ParseFromString(self._run_command2(2, input))
return output
# @@AUTOGEN@@
def translate_string_in(
key: TR, **kwargs: Union[str, int, float]

View file

@ -29,7 +29,7 @@ template_legacy.py file, using the legacy addHook() system.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import anki
from anki import hooks
@ -37,7 +37,7 @@ from anki.cards import Card
from anki.decks import DeckManager
from anki.models import NoteType
from anki.notes import Note
from anki.rsbackend import PartiallyRenderedCard, TemplateReplacementList
from anki.rsbackend import pb, to_json_bytes
from anki.sound import AVTag
CARD_BLANK_HELP = (
@ -45,6 +45,47 @@ CARD_BLANK_HELP = (
)
@dataclass
class TemplateReplacement:
field_name: str
current_text: str
filters: List[str]
TemplateReplacementList = List[Union[str, TemplateReplacement]]
@dataclass
class PartiallyRenderedCard:
qnodes: TemplateReplacementList
anodes: TemplateReplacementList
@classmethod
def from_proto(cls, out: pb.RenderCardOut) -> PartiallyRenderedCard:
qnodes = cls.nodes_from_proto(out.question_nodes)
anodes = cls.nodes_from_proto(out.answer_nodes)
return PartiallyRenderedCard(qnodes, anodes)
@staticmethod
def nodes_from_proto(
nodes: Sequence[pb.RenderedTemplateNode],
) -> TemplateReplacementList:
results: TemplateReplacementList = []
for node in nodes:
if node.WhichOneof("value") == "text":
results.append(node.text)
else:
results.append(
TemplateReplacement(
field_name=node.replacement.field_name,
current_text=node.replacement.current_text,
filters=list(node.replacement.filters),
)
)
return results
class TemplateRenderContext:
"""Holds information for the duration of one card render.
@ -177,15 +218,16 @@ class TemplateRenderContext:
def _partially_render(self) -> PartiallyRenderedCard:
if self._template:
# card layout screen
return self._col.backend.render_uncommitted_card(
out = self._col.backend.render_uncommitted_card(
self._note.to_backend_note(),
self._card.ord,
self._template,
to_json_bytes(self._template),
self._fill_empty,
)
else:
# existing card (eg study mode)
return self._col.backend.render_existing_card(self._card.id, self._browser)
out = self._col.backend.render_existing_card(self._card.id, self._browser)
return PartiallyRenderedCard.from_proto(out)
@dataclass

107
pylib/tools/genbackend.py Executable file
View file

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
import re
from anki import backend_pb2 as pb
import stringcase
TYPE_DOUBLE = 1
TYPE_FLOAT = 2
TYPE_INT64 = 3
TYPE_UINT64 = 4
TYPE_INT32 = 5
TYPE_FIXED64 = 6
TYPE_FIXED32 = 7
TYPE_BOOL = 8
TYPE_STRING = 9
TYPE_GROUP = 10
TYPE_MESSAGE = 11
TYPE_BYTES = 12
TYPE_UINT32 = 13
TYPE_ENUM = 14
TYPE_SFIXED32 = 15
TYPE_SFIXED64 = 16
TYPE_SINT32 = 17
TYPE_SINT64 = 18
LABEL_OPTIONAL = 1
LABEL_REQUIRED = 2
LABEL_REPEATED = 3
def python_type(field):
type = python_type_inner(field)
if field.label == LABEL_REPEATED:
type = f"List[{type}]"
return type
def python_type_inner(field):
type = field.type
if type == TYPE_BOOL:
return "bool"
elif type in (1, 2):
return "float"
elif type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18):
return "int"
elif type == TYPE_STRING:
return "str"
elif type == TYPE_BYTES:
return "bytes"
elif type == 11:
return "pb." + field.message_type.name
else:
raise Exception(f"unknown type: {type}")
def get_input_args(msg):
fields = sorted(msg.fields, key=lambda x: x.number)
return ", ".join(["self"] + [f"{f.name}: {python_type(f)}" for f in fields])
def get_input_assign(msg):
fields = sorted(msg.fields, key=lambda x: x.number)
return ", ".join(f"{f.name}={f.name}" for f in fields)
def render_method(method, idx):
input_args = get_input_args(method.input_type)
input_assign = get_input_assign(method.input_type)
name = stringcase.snakecase(method.name)
if len(method.output_type.fields) == 1:
# unwrap single return arg
f = method.output_type.fields[0]
single_field = f".{f.name}"
return_type = python_type(f)
else:
single_field = ""
return_type = f"pb.{method.output_type.name}"
return f"""\
def {name}({input_args}) -> {return_type}:
input = pb.{method.input_type.name}({input_assign})
output = pb.{method.output_type.name}()
output.ParseFromString(self._run_command2({idx+1}, input))
return output{single_field}
"""
out = []
for idx, method in enumerate(pb._BACKENDSERVICE.methods):
out.append(render_method(method, idx))
out = "\n".join(out)
path = "anki/rsbackend.py"
with open(path) as file:
orig = file.read()
new = re.sub(
"(?s)# @@AUTOGEN@@.*?# @@AUTOGEN@@\n",
f"# @@AUTOGEN@@\n\n{out}\n # @@AUTOGEN@@\n",
orig,
)
with open(path, "wb") as file:
file.write(new.encode("utf8"))

View file

@ -1,3 +1,4 @@
use std::fmt::Write;
use std::fs;
use std::path::Path;
@ -91,6 +92,82 @@ const FLUENT_KEYS: &[&str] = &[
}
}
struct CustomGenerator {}
fn write_method_enum(buf: &mut String, service: &prost_build::Service) {
buf.push_str(
r#"
use num_enum::TryFromPrimitive;
#[derive(PartialEq,TryFromPrimitive)]
#[repr(u32)]
pub enum BackendMethod {
"#,
);
for (idx, method) in service.methods.iter().enumerate() {
write!(buf, " {} = {},\n", method.proto_name, idx + 1).unwrap();
}
buf.push_str("}\n\n");
}
fn write_method_trait(buf: &mut String, service: &prost_build::Service) {
buf.push_str(
r#"
use prost::Message;
pub type BackendResult<T> = std::result::Result<T, crate::err::AnkiError>;
pub trait BackendService {
fn run_command_bytes2_inner(&mut self, method: u32, input: &[u8]) -> std::result::Result<Vec<u8>, crate::err::AnkiError> {
match method {
"#,
);
for (idx, method) in service.methods.iter().enumerate() {
write!(
buf,
concat!(" ",
"{idx} => {{ let input = {input_type}::decode(input)?;\n",
"let output = self.{rust_method}(input)?;\n",
"let mut out_bytes = Vec::new(); output.encode(&mut out_bytes)?; Ok(out_bytes) }}, "),
idx = idx + 1,
input_type = method.input_type,
rust_method = method.name
)
.unwrap();
}
buf.push_str(
r#"
_ => Err(crate::err::AnkiError::invalid_input("invalid command")),
}
}
"#,
);
for method in &service.methods {
write!(
buf,
concat!(
" fn {method_name}(&mut self, input: {input_type}) -> ",
"BackendResult<{output_type}>;\n"
),
method_name = method.name,
input_type = method.input_type,
output_type = method.output_type
)
.unwrap();
}
buf.push_str("}\n");
}
impl prost_build::ServiceGenerator for CustomGenerator {
fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
write_method_enum(buf, &service);
write_method_trait(buf, &service);
}
}
fn service_generator() -> Box<dyn prost_build::ServiceGenerator> {
Box::new(CustomGenerator {})
}
fn main() -> std::io::Result<()> {
// write template.ftl
let mut buf = String::new();
@ -126,7 +203,12 @@ fn main() -> std::io::Result<()> {
// we avoid default OUT_DIR for now, as it breaks code completion
std::env::set_var("OUT_DIR", "src");
println!("cargo:rerun-if-changed=../proto/backend.proto");
prost_build::compile_protos(&["../proto/backend.proto"], &["../proto"]).unwrap();
let mut config = prost_build::Config::new();
config.service_generator(service_generator());
config
.compile_protos(&["../proto/backend.proto"], &["../proto"])
.unwrap();
// write the other language ftl files
let mut ftl_lang_dirs = vec!["./ftl/repo/core".to_string()];

View file

@ -1,11 +1,14 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
pub use crate::backend_proto::BackendMethod;
use crate::{
backend::dbproxy::db_command_bytes,
backend_proto as pb,
backend_proto::builtin_search_order::BuiltinSortKind,
backend_proto::{AddOrUpdateDeckConfigIn, Empty, RenderedTemplateReplacement, SyncMediaIn},
backend_proto::{
AddOrUpdateDeckConfigIn, BackendResult, Empty, RenderedTemplateReplacement, SyncMediaIn,
},
card::{Card, CardID},
card::{CardQueue, CardType},
cloze::add_cloze_numbers_in_string,
@ -37,13 +40,16 @@ use crate::{
use fluent::FluentValue;
use futures::future::{AbortHandle, Abortable};
use log::error;
use pb::backend_input::Value;
use pb::{backend_input::Value, BackendService};
use prost::Message;
use serde_json::Value as JsonValue;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::{
result,
sync::{Arc, Mutex},
};
use tokio::runtime::Runtime;
mod dbproxy;
@ -141,6 +147,36 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
Ok(Backend::new(i18n, input.server))
}
impl BackendService for Backend {
fn render_existing_card(
&mut self,
input: pb::RenderExistingCardIn,
) -> BackendResult<pb::RenderCardOut> {
self.with_col(|col| {
col.render_existing_card(CardID(input.card_id), input.browser)
.map(Into::into)
})
}
fn render_uncommitted_card(
&mut self,
input: pb::RenderUncommittedCardIn,
) -> BackendResult<pb::RenderCardOut> {
let schema11: CardTemplateSchema11 = serde_json::from_slice(&input.template)?;
let template = schema11.into();
let mut note = input
.note
.ok_or_else(|| AnkiError::invalid_input("missing note"))?
.into();
let ord = input.card_ord as u16;
let fill_empty = input.fill_empty;
self.with_col(|col| {
col.render_uncommitted_card(&mut note, &template, ord, fill_empty)
.map(Into::into)
})
}
}
impl Backend {
pub fn new(i18n: I18n, server: bool) -> Backend {
Backend {
@ -179,6 +215,19 @@ impl Backend {
buf
}
pub fn run_command_bytes2(
&mut self,
method: u32,
input: &[u8],
) -> result::Result<Vec<u8>, Vec<u8>> {
self.run_command_bytes2_inner(method, input).map_err(|err| {
let backend_err = anki_error_to_proto_error(err, &self.i18n);
let mut bytes = Vec::new();
backend_err.encode(&mut bytes).unwrap();
bytes
})
}
/// If collection is open, run the provided closure while holding
/// the mutex.
/// If collection is not open, return an error.
@ -461,31 +510,6 @@ impl Backend {
self.with_col(|col| col.deck_tree(input.include_counts, lim))
}
fn render_existing_card(&self, input: pb::RenderExistingCardIn) -> Result<pb::RenderCardOut> {
self.with_col(|col| {
col.render_existing_card(CardID(input.card_id), input.browser)
.map(Into::into)
})
}
fn render_uncommitted_card(
&self,
input: pb::RenderUncommittedCardIn,
) -> Result<pb::RenderCardOut> {
let schema11: CardTemplateSchema11 = serde_json::from_slice(&input.template)?;
let template = schema11.into();
let mut note = input
.note
.ok_or_else(|| AnkiError::invalid_input("missing note"))?
.into();
let ord = input.card_ord as u16;
let fill_empty = input.fill_empty;
self.with_col(|col| {
col.render_uncommitted_card(&mut note, &template, ord, fill_empty)
.map(Into::into)
})
}
fn extract_av_tags(&self, input: pb::ExtractAvTagsIn) -> pb::ExtractAvTagsOut {
let (text, tags) = extract_av_tags(&input.text, input.question_side);
let pt_tags = tags

View file

@ -1,11 +1,12 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use anki::backend::{init_backend, Backend as RustBackend};
use anki::backend::{init_backend, Backend as RustBackend, BackendMethod};
use pyo3::exceptions::Exception;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::{create_exception, exceptions, wrap_pyfunction};
use std::convert::TryFrom;
// Regular backend
//////////////////////////////////
@ -16,6 +17,7 @@ struct Backend {
}
create_exception!(ankirspy, DBError, Exception);
create_exception!(ankirspy, BackendError, Exception);
#[pyfunction]
fn buildhash() -> &'static str {
@ -30,6 +32,16 @@ fn open_backend(init_msg: &PyBytes) -> PyResult<Backend> {
}
}
fn want_release_gil(method: u32) -> bool {
if let Ok(method) = BackendMethod::try_from(method) {
match method {
_ => false,
}
} else {
false
}
}
#[pymethods]
impl Backend {
fn command(&mut self, py: Python, input: &PyBytes, release_gil: bool) -> PyObject {
@ -43,6 +55,20 @@ impl Backend {
out_obj.into()
}
fn command2(&mut self, py: Python, method: u32, input: &PyBytes) -> PyResult<PyObject> {
let in_bytes = input.as_bytes();
if want_release_gil(method) {
py.allow_threads(move || self.backend.run_command_bytes2(method, in_bytes))
} else {
self.backend.run_command_bytes2(method, in_bytes)
}
.map(|out_bytes| {
let out_obj = PyBytes::new(py, &out_bytes);
out_obj.into()
})
.map_err(|err_bytes| BackendError::py_err(err_bytes))
}
fn set_progress_callback(&mut self, callback: PyObject) {
if callback.is_none() {
self.backend.set_progress_callback(None);