From 3ce4d5fd3d2aaa8cf222656942f42dcec7bc1913 Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Tue, 24 Dec 2019 14:05:15 +1000 Subject: [PATCH] compute template requirements in Rust on a 100 field template, what took ~75 seconds now takes ~3 seconds. --- anki/models.py | 11 ++ anki/rsbridge.py | 36 +++- anki/types.py | 7 +- proto/bridge.proto | 32 ++++ rs/Cargo.lock | 78 ++++++++- rs/ankirs/Cargo.toml | 1 + rs/ankirs/src/bridge.rs | 52 ++++++ rs/ankirs/src/err.rs | 7 + rs/ankirs/src/lib.rs | 1 + rs/ankirs/src/template.rs | 353 ++++++++++++++++++++++++++++++++++++++ 10 files changed, 574 insertions(+), 4 deletions(-) create mode 100644 rs/ankirs/src/template.rs diff --git a/anki/models.py b/anki/models.py index 89f091d54..ed00cdcae 100644 --- a/anki/models.py +++ b/anki/models.py @@ -556,6 +556,9 @@ select id from notes where mid = ?)""" ########################################################################## def _updateRequired(self, m: NoteType) -> None: + self._updateRequiredNew(m) + + def _updateRequiredLegacy(self, m: NoteType) -> None: if m["type"] == MODEL_CLOZE: # nothing to do return @@ -566,6 +569,14 @@ select id from notes where mid = ?)""" req.append([t["ord"], ret[0], ret[1]]) m["req"] = req + def _updateRequiredNew(self, m: NoteType) -> None: + fronts = [t["qfmt"] for t in m["tmpls"]] + field_map = {} + for (idx, fld) in enumerate(m["flds"]): + field_map[fld["name"]] = idx + reqs = self.col.rust.template_requirements(fronts, field_map) + m["req"] = [list(l) for l in reqs] + def _reqForTemplate( self, m: NoteType, flds: List[str], t: Template ) -> Tuple[Union[str, List[int]], ...]: diff --git a/anki/rsbridge.py b/anki/rsbridge.py index 17132d528..e7e4574c7 100644 --- a/anki/rsbridge.py +++ b/anki/rsbridge.py @@ -1,8 +1,12 @@ +from typing import Dict, List + import _ankirs # pytype: disable=import-error import betterproto from anki.proto import proto as pb +from .types import AllTemplateReqs + class BridgeException(Exception): def __str__(self) -> str: @@ -10,10 +14,30 @@ class BridgeException(Exception): (kind, obj) = betterproto.which_one_of(err, "value") if kind == "invalid_input": return f"invalid input: {obj.info}" + elif kind == "template_parse": + return f"template parse: {obj.info}" else: return f"unhandled error: {err} {obj}" +def proto_template_reqs_to_legacy( + reqs: List[pb.TemplateRequirement], +) -> AllTemplateReqs: + legacy_reqs = [] + for (idx, req) in enumerate(reqs): + (kind, val) = betterproto.which_one_of(req, "value") + # fixme: sorting is for the unit tests - should check if any + # code depends on the order + if kind == "any": + legacy_reqs.append((idx, "any", sorted(req.any.ords))) + elif kind == "all": + legacy_reqs.append((idx, "all", sorted(req.all.ords))) + else: + l: List[int] = [] + legacy_reqs.append((idx, "none", l)) + return legacy_reqs + + class RSBridge: def __init__(self): self._bridge = _ankirs.Bridge() @@ -33,5 +57,13 @@ class RSBridge: output = self._run_command(input) return output.plus_one.num - -bridge = RSBridge() + def template_requirements( + self, template_fronts: List[str], field_map: Dict[str, int] + ) -> AllTemplateReqs: + input = pb.BridgeInput( + template_requirements=pb.TemplateRequirementsIn( + template_front=template_fronts, field_names_to_ordinals=field_map + ) + ) + output = self._run_command(input).template_requirements + return proto_template_reqs_to_legacy(output.requirements) diff --git a/anki/types.py b/anki/types.py index cdf742296..19395eaa2 100644 --- a/anki/types.py +++ b/anki/types.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, List, Tuple, Union # Model attributes are stored in a dict keyed by strings. This type alias # provides more descriptive function signatures than just 'Dict[str, Any]' @@ -31,3 +31,8 @@ QAData = Tuple[ # Corresponds to 'cardFlags' column. TODO: document int, ] + +TemplateRequirementType = str # Union["all", "any", "none"] +# template ordinal, type, list of field ordinals +TemplateRequiredFieldOrds = Tuple[int, TemplateRequirementType, List[int]] +AllTemplateReqs = List[TemplateRequiredFieldOrds] diff --git a/proto/bridge.proto b/proto/bridge.proto index a132e308b..3fed1978d 100644 --- a/proto/bridge.proto +++ b/proto/bridge.proto @@ -7,6 +7,7 @@ message Empty {} message BridgeInput { oneof value { PlusOneIn plus_one = 2; + TemplateRequirementsIn template_requirements = 3; } } @@ -14,12 +15,14 @@ message BridgeOutput { oneof value { BridgeError error = 1; PlusOneOut plus_one = 2; + TemplateRequirementsOut template_requirements = 3; } } message BridgeError { oneof value { InvalidInputError invalid_input = 1; + TemplateParseError template_parse = 2; } } @@ -34,3 +37,32 @@ message PlusOneIn { message PlusOneOut { int32 num = 1; } + +message TemplateParseError { + string info = 1; +} + +message TemplateRequirementsIn { + repeated string template_front = 1; + map field_names_to_ordinals = 2; +} + +message TemplateRequirementsOut { + repeated TemplateRequirement requirements = 1; +} + +message TemplateRequirement { + oneof value { + TemplateRequirementAll all = 1; + TemplateRequirementAny any = 2; + Empty none = 3; + } +} + +message TemplateRequirementAll { + repeated uint32 ords = 1; +} + +message TemplateRequirementAny { + repeated uint32 ords = 1; +} diff --git a/rs/Cargo.lock b/rs/Cargo.lock index d63b0c82b..d72b38d5f 100644 --- a/rs/Cargo.lock +++ b/rs/Cargo.lock @@ -15,10 +15,20 @@ version = "0.1.0" dependencies = [ "bytes", "failure", + "nom", "prost", "prost-build", ] +[[package]] +name = "arrayvec" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd9fd44efafa8690358b7408d253adf110036b88f55672a933f01d616ad9b1b9" +dependencies = [ + "nodrop", +] + [[package]] name = "autocfg" version = "0.1.7" @@ -234,6 +244,19 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lexical-core" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304bccb228c4b020f3a4835d247df0a02a7c4686098d4167762cfbbe4c5cb14" +dependencies = [ + "arrayvec", + "cfg-if", + "rustc_version", + "ryu", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.66" @@ -261,6 +284,23 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2eb04b9f127583ed176e163fb9ec6f3e793b87e21deedd5734a69386a18a0151" +[[package]] +name = "nodrop" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" + +[[package]] +name = "nom" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c618b63422da4401283884e6668d39f819a106ef51f5f59b81add00075da35ca" +dependencies = [ + "lexical-core", + "memchr", + "version_check 0.1.5", +] + [[package]] name = "num-traits" version = "0.2.10" @@ -414,7 +454,7 @@ dependencies = [ "serde_json", "spin", "unindent", - "version_check", + "version_check 0.9.1", ] [[package]] @@ -538,12 +578,36 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" +[[package]] +name = "rustc_version" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +dependencies = [ + "semver", +] + [[package]] name = "ryu" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8" +[[package]] +name = "semver" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" +dependencies = [ + "semver-parser", +] + +[[package]] +name = "semver-parser" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" + [[package]] name = "serde" version = "1.0.104" @@ -581,6 +645,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "static_assertions" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f3eb36b47e512f8f1c9e3d10c2c1965bc992bd9cdb024fa581e2194501c83d3" + [[package]] name = "syn" version = "0.15.44" @@ -662,6 +732,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63f18aa3b0e35fed5a0048f029558b1518095ffe2a0a31fb87c93dece93a4993" +[[package]] +name = "version_check" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" + [[package]] name = "version_check" version = "0.9.1" diff --git a/rs/ankirs/Cargo.toml b/rs/ankirs/Cargo.toml index a9c34c962..eb11872f3 100644 --- a/rs/ankirs/Cargo.toml +++ b/rs/ankirs/Cargo.toml @@ -5,6 +5,7 @@ edition = "2018" authors = ["Ankitects Pty Ltd and contributors"] [dependencies] +nom = "5.0.1" failure = "0.1.6" prost = "0.5.0" bytes = "0.4" diff --git a/rs/ankirs/src/bridge.rs b/rs/ankirs/src/bridge.rs index df5a8c5a3..549cad884 100644 --- a/rs/ankirs/src/bridge.rs +++ b/rs/ankirs/src/bridge.rs @@ -1,7 +1,9 @@ use crate::err::{AnkiError, Result}; use crate::proto as pt; use crate::proto::bridge_input::Value; +use crate::template::{FieldMap, FieldRequirements, ParsedTemplate}; use prost::Message; +use std::collections::HashSet; pub struct Bridge {} @@ -17,6 +19,9 @@ impl std::convert::From for pt::BridgeError { use pt::bridge_error::Value as V; let value = match err { AnkiError::InvalidInput { info } => V::InvalidInput(pt::InvalidInputError { info }), + AnkiError::TemplateParseError { info } => { + V::TemplateParse(pt::TemplateParseError { info }) + } }; pt::BridgeError { value: Some(value) } @@ -73,6 +78,9 @@ impl Bridge { fn run_command_inner(&self, ival: pt::bridge_input::Value) -> Result { use pt::bridge_output::Value as OValue; Ok(match ival { + Value::TemplateRequirements(input) => { + OValue::TemplateRequirements(self.template_requirements(input)?) + } Value::PlusOne(input) => OValue::PlusOne(self.plus_one(input)?), }) } @@ -81,4 +89,48 @@ impl Bridge { let num = input.num + 1; Ok(pt::PlusOneOut { num }) } + + fn template_requirements( + &self, + input: pt::TemplateRequirementsIn, + ) -> Result { + let map: FieldMap = input + .field_names_to_ordinals + .iter() + .map(|(name, ord)| (name.as_str(), *ord as u16)) + .collect(); + // map each provided template into a requirements list + use crate::proto::template_requirement::Value; + let all_reqs = input + .template_front + .into_iter() + .map(|template| { + if let Ok(tmpl) = ParsedTemplate::from_text(&template) { + // convert the rust structure into a protobuf one + let val = match tmpl.requirements(&map) { + FieldRequirements::Any(ords) => Value::Any(pt::TemplateRequirementAny { + ords: ords_hash_to_set(ords), + }), + FieldRequirements::All(ords) => Value::All(pt::TemplateRequirementAll { + ords: ords_hash_to_set(ords), + }), + FieldRequirements::None => Value::None(pt::Empty {}), + }; + Ok(pt::TemplateRequirement { value: Some(val) }) + } else { + // template parsing failures make card unsatisfiable + Ok(pt::TemplateRequirement { + value: Some(Value::None(pt::Empty {})), + }) + } + }) + .collect::>>()?; + Ok(pt::TemplateRequirementsOut { + requirements: all_reqs, + }) + } +} + +fn ords_hash_to_set(ords: HashSet) -> Vec { + ords.iter().map(|ord| *ord as u32).collect() } diff --git a/rs/ankirs/src/err.rs b/rs/ankirs/src/err.rs index 48b16d39e..acf3a2062 100644 --- a/rs/ankirs/src/err.rs +++ b/rs/ankirs/src/err.rs @@ -6,10 +6,17 @@ pub type Result = std::result::Result; pub enum AnkiError { #[fail(display = "invalid input: {}", info)] InvalidInput { info: String }, + + #[fail(display = "invalid card template: {}", info)] + TemplateParseError { info: String }, } // error helpers impl AnkiError { + pub(crate) fn parse>(s: S) -> AnkiError { + AnkiError::TemplateParseError { info: s.into() } + } + pub(crate) fn invalid_input>(s: S) -> AnkiError { AnkiError::InvalidInput { info: s.into() } } diff --git a/rs/ankirs/src/lib.rs b/rs/ankirs/src/lib.rs index eb1b6bf97..c05971249 100644 --- a/rs/ankirs/src/lib.rs +++ b/rs/ankirs/src/lib.rs @@ -2,3 +2,4 @@ mod proto; pub mod bridge; pub mod err; +pub mod template; diff --git a/rs/ankirs/src/template.rs b/rs/ankirs/src/template.rs new file mode 100644 index 000000000..89f7589e3 --- /dev/null +++ b/rs/ankirs/src/template.rs @@ -0,0 +1,353 @@ +use crate::err::{AnkiError, Result}; +use nom; +use nom::branch::alt; +use nom::bytes::complete::tag; +use nom::error::ErrorKind; +use nom::sequence::delimited; +use std::collections::{HashMap, HashSet}; + +pub type FieldMap<'a> = HashMap<&'a str, u16>; + +// Lexing +//---------------------------------------- + +#[derive(Debug)] +pub enum Token<'a> { + Text(&'a str), + Replacement(&'a str), + OpenConditional(&'a str), + OpenNegated(&'a str), + CloseConditional(&'a str), +} + +/// a span of text, terminated by {{, }} or end of string +pub(crate) fn text_until_handlebars(s: &str) -> nom::IResult<&str, &str> { + let end = s.len(); + + let limited_end = end + .min(s.find("{{").unwrap_or(end)) + .min(s.find("}}").unwrap_or(end)); + let (output, input) = s.split_at(limited_end); + if output.is_empty() { + Err(nom::Err::Error((input, ErrorKind::TakeUntil))) + } else { + Ok((input, output)) + } +} + +/// text outside handlebars +fn text_token(s: &str) -> nom::IResult<&str, Token> { + text_until_handlebars(s).map(|(input, output)| (input, Token::Text(output))) +} + +/// text wrapped in handlebars +fn handle_token(s: &str) -> nom::IResult<&str, Token> { + delimited(tag("{{"), text_until_handlebars, tag("}}"))(s) + .map(|(input, output)| (input, classify_handle(output))) +} + +/// classify handle based on leading character +fn classify_handle(s: &str) -> Token { + let start = s.trim(); + if start.len() < 2 { + return Token::Replacement(start); + } + if start.starts_with('#') { + Token::OpenConditional(&start[1..].trim_start()) + } else if start.starts_with('/') { + Token::CloseConditional(&start[1..].trim_start()) + } else if start.starts_with('^') { + Token::OpenNegated(&start[1..].trim_start()) + } else { + Token::Replacement(start) + } +} + +fn next_token(input: &str) -> nom::IResult<&str, Token> { + alt((handle_token, text_token))(input) +} + +fn tokens(template: &str) -> impl Iterator> { + let mut data = template; + + std::iter::from_fn(move || { + if data.is_empty() { + return None; + } + match next_token(data) { + Ok((i, o)) => { + data = i; + Some(Ok(o)) + } + Err(e) => Some(Err(AnkiError::parse(format!("{:?}", e)))), + } + }) +} + +// Parsing +//---------------------------------------- + +#[derive(Debug, PartialEq)] +enum ParsedNode<'a> { + Text(&'a str), + Replacement { + key: &'a str, + filters: Vec<&'a str>, + }, + Conditional { + key: &'a str, + children: Vec>, + }, + NegatedConditional { + key: &'a str, + children: Vec>, + }, +} + +#[derive(Debug)] +pub struct ParsedTemplate<'a>(Vec>); + +impl ParsedTemplate<'_> { + pub fn from_text(template: &str) -> Result { + let mut iter = tokens(template); + Ok(Self(parse_inner(&mut iter, None)?)) + } +} + +fn parse_inner<'a, I: Iterator>>>( + iter: &mut I, + open_tag: Option<&'a str>, +) -> Result>> { + let mut nodes = vec![]; + + while let Some(token) = iter.next() { + use Token::*; + nodes.push(match token? { + Text(t) => ParsedNode::Text(t), + Replacement(t) => { + let mut it = t.rsplit(':'); + ParsedNode::Replacement { + key: it.next().unwrap(), + filters: it.collect(), + } + } + OpenConditional(t) => ParsedNode::Conditional { + key: t, + children: parse_inner(iter, Some(t))?, + }, + OpenNegated(t) => ParsedNode::NegatedConditional { + key: t, + children: parse_inner(iter, Some(t))?, + }, + CloseConditional(t) => { + if let Some(open) = open_tag { + if open == t { + // matching closing tag, move back to parent + return Ok(nodes); + } + } + return Err(AnkiError::parse(format!( + "unbalanced closing tag: {:?} / {}", + open_tag, t + ))); + } + }); + } + + if let Some(open) = open_tag { + Err(AnkiError::parse(format!("unclosed conditional {}", open))) + } else { + Ok(nodes) + } +} + +// Checking if template is empty +//---------------------------------------- + +impl ParsedTemplate<'_> { + /// true if provided fields are sufficient to render the template + pub fn renders_with_fields(&self, nonempty_fields: &HashSet<&str>) -> bool { + !template_is_empty(nonempty_fields, &self.0) + } +} + +fn template_is_empty<'a>(nonempty_fields: &HashSet<&str>, nodes: &[ParsedNode<'a>]) -> bool { + use ParsedNode::*; + for node in nodes { + match node { + // ignore normal text + Text(_) => (), + Replacement { key, .. } => { + if nonempty_fields.contains(*key) { + // a single replacement is enough + return false; + } + } + Conditional { key, children } => { + if !nonempty_fields.contains(*key) { + continue; + } + if !template_is_empty(nonempty_fields, children) { + return false; + } + } + NegatedConditional { .. } => { + // negated conditionals ignored when determining card generation + continue; + } + } + } + + true +} + +// Compatibility with old Anki versions +//---------------------------------------- + +#[derive(Debug, Clone, PartialEq)] +pub enum FieldRequirements { + Any(HashSet), + All(HashSet), + None, +} + +impl ParsedTemplate<'_> { + /// Return fields required by template. + /// + /// This is not able to represent negated expressions or combinations of + /// Any and All, and is provided only for the sake of backwards + /// compatibility. + pub fn requirements(&self, field_map: &FieldMap) -> FieldRequirements { + let mut nonempty: HashSet<_> = Default::default(); + let mut ords = HashSet::new(); + for (name, ord) in field_map { + nonempty.clear(); + nonempty.insert(*name); + if self.renders_with_fields(&nonempty) { + ords.insert(*ord); + } + } + if !ords.is_empty() { + return FieldRequirements::Any(ords); + } + + nonempty.extend(field_map.keys()); + ords.extend(field_map.values().copied()); + for (name, ord) in field_map { + // can we remove this field and still render? + nonempty.remove(name); + if self.renders_with_fields(&nonempty) { + ords.remove(ord); + } + nonempty.insert(*name); + } + if !ords.is_empty() && self.renders_with_fields(&nonempty) { + FieldRequirements::All(ords) + } else { + FieldRequirements::None + } + } +} + +// Tests +//--------------------------------------- + +#[cfg(test)] +mod test { + use super::{FieldMap, ParsedNode::*, ParsedTemplate as PT}; + use crate::template::FieldRequirements; + use std::collections::HashSet; + use std::iter::FromIterator; + + #[test] + fn test_parsing() { + let tmpl = PT::from_text("foo {{bar}} {{#baz}} quux {{/baz}}").unwrap(); + assert_eq!( + tmpl.0, + vec![ + Text("foo "), + Replacement { + key: "bar", + filters: vec![] + }, + Text(" "), + Conditional { + key: "baz", + children: vec![Text(" quux ")] + } + ] + ); + + let tmpl = PT::from_text("{{^baz}}{{/baz}}").unwrap(); + assert_eq!( + tmpl.0, + vec![NegatedConditional { + key: "baz", + children: vec![] + }] + ); + + PT::from_text("{{#mis}}{{/matched}}").unwrap_err(); + PT::from_text("{{/matched}}").unwrap_err(); + PT::from_text("{{#mis}}").unwrap_err(); + + // whitespace + assert_eq!( + PT::from_text("{{ tag }}").unwrap().0, + vec![Replacement { + key: "tag", + filters: vec![] + }] + ); + } + + #[test] + fn test_nonempty() { + let fields = HashSet::from_iter(vec!["1", "3"].into_iter()); + let mut tmpl = PT::from_text("{{2}}{{1}}").unwrap(); + assert_eq!(tmpl.renders_with_fields(&fields), true); + tmpl = PT::from_text("{{2}}{{type:cloze:1}}").unwrap(); + assert_eq!(tmpl.renders_with_fields(&fields), true); + tmpl = PT::from_text("{{2}}{{4}}").unwrap(); + assert_eq!(tmpl.renders_with_fields(&fields), false); + tmpl = PT::from_text("{{#3}}{{^2}}{{1}}{{/2}}{{/3}}").unwrap(); + assert_eq!(tmpl.renders_with_fields(&fields), false); + } + + #[test] + fn test_requirements() { + let field_map: FieldMap = vec!["a", "b"] + .iter() + .enumerate() + .map(|(a, b)| (*b, a as u16)) + .collect(); + + let mut tmpl = PT::from_text("{{a}}{{b}}").unwrap(); + assert_eq!( + tmpl.requirements(&field_map), + FieldRequirements::Any(HashSet::from_iter(vec![0, 1].into_iter())) + ); + + tmpl = PT::from_text("{{#a}}{{b}}{{/a}}").unwrap(); + assert_eq!( + tmpl.requirements(&field_map), + FieldRequirements::All(HashSet::from_iter(vec![0, 1].into_iter())) + ); + + tmpl = PT::from_text("{{c}}").unwrap(); + assert_eq!(tmpl.requirements(&field_map), FieldRequirements::None); + + tmpl = PT::from_text("{{^a}}{{b}}{{/a}}").unwrap(); + assert_eq!(tmpl.requirements(&field_map), FieldRequirements::None); + + tmpl = PT::from_text("{{#a}}{{#b}}{{a}}{{/b}}{{/a}}").unwrap(); + assert_eq!( + tmpl.requirements(&field_map), + FieldRequirements::All(HashSet::from_iter(vec![0, 1].into_iter())) + ); + + // fixme: handling of type in answer card reqs doesn't match desktop, + // which only requires first field + // + } +}