compute template requirements in Rust

on a 100 field template, what took ~75 seconds now takes ~3 seconds.
This commit is contained in:
Damien Elmes 2019-12-24 14:05:15 +10:00
parent ecfce51dbd
commit 3ce4d5fd3d
10 changed files with 574 additions and 4 deletions

View file

@ -556,6 +556,9 @@ select id from notes where mid = ?)"""
##########################################################################
def _updateRequired(self, m: NoteType) -> None:
self._updateRequiredNew(m)
def _updateRequiredLegacy(self, m: NoteType) -> None:
if m["type"] == MODEL_CLOZE:
# nothing to do
return
@ -566,6 +569,14 @@ select id from notes where mid = ?)"""
req.append([t["ord"], ret[0], ret[1]])
m["req"] = req
def _updateRequiredNew(self, m: NoteType) -> None:
fronts = [t["qfmt"] for t in m["tmpls"]]
field_map = {}
for (idx, fld) in enumerate(m["flds"]):
field_map[fld["name"]] = idx
reqs = self.col.rust.template_requirements(fronts, field_map)
m["req"] = [list(l) for l in reqs]
def _reqForTemplate(
self, m: NoteType, flds: List[str], t: Template
) -> Tuple[Union[str, List[int]], ...]:

View file

@ -1,8 +1,12 @@
from typing import Dict, List
import _ankirs # pytype: disable=import-error
import betterproto
from anki.proto import proto as pb
from .types import AllTemplateReqs
class BridgeException(Exception):
def __str__(self) -> str:
@ -10,10 +14,30 @@ class BridgeException(Exception):
(kind, obj) = betterproto.which_one_of(err, "value")
if kind == "invalid_input":
return f"invalid input: {obj.info}"
elif kind == "template_parse":
return f"template parse: {obj.info}"
else:
return f"unhandled error: {err} {obj}"
def proto_template_reqs_to_legacy(
reqs: List[pb.TemplateRequirement],
) -> AllTemplateReqs:
legacy_reqs = []
for (idx, req) in enumerate(reqs):
(kind, val) = betterproto.which_one_of(req, "value")
# fixme: sorting is for the unit tests - should check if any
# code depends on the order
if kind == "any":
legacy_reqs.append((idx, "any", sorted(req.any.ords)))
elif kind == "all":
legacy_reqs.append((idx, "all", sorted(req.all.ords)))
else:
l: List[int] = []
legacy_reqs.append((idx, "none", l))
return legacy_reqs
class RSBridge:
def __init__(self):
self._bridge = _ankirs.Bridge()
@ -33,5 +57,13 @@ class RSBridge:
output = self._run_command(input)
return output.plus_one.num
bridge = RSBridge()
def template_requirements(
self, template_fronts: List[str], field_map: Dict[str, int]
) -> AllTemplateReqs:
input = pb.BridgeInput(
template_requirements=pb.TemplateRequirementsIn(
template_front=template_fronts, field_names_to_ordinals=field_map
)
)
output = self._run_command(input).template_requirements
return proto_template_reqs_to_legacy(output.requirements)

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
# Model attributes are stored in a dict keyed by strings. This type alias
# provides more descriptive function signatures than just 'Dict[str, Any]'
@ -31,3 +31,8 @@ QAData = Tuple[
# Corresponds to 'cardFlags' column. TODO: document
int,
]
TemplateRequirementType = str # Union["all", "any", "none"]
# template ordinal, type, list of field ordinals
TemplateRequiredFieldOrds = Tuple[int, TemplateRequirementType, List[int]]
AllTemplateReqs = List[TemplateRequiredFieldOrds]

View file

@ -7,6 +7,7 @@ message Empty {}
message BridgeInput {
oneof value {
PlusOneIn plus_one = 2;
TemplateRequirementsIn template_requirements = 3;
}
}
@ -14,12 +15,14 @@ message BridgeOutput {
oneof value {
BridgeError error = 1;
PlusOneOut plus_one = 2;
TemplateRequirementsOut template_requirements = 3;
}
}
message BridgeError {
oneof value {
InvalidInputError invalid_input = 1;
TemplateParseError template_parse = 2;
}
}
@ -34,3 +37,32 @@ message PlusOneIn {
message PlusOneOut {
int32 num = 1;
}
message TemplateParseError {
string info = 1;
}
message TemplateRequirementsIn {
repeated string template_front = 1;
map<string, uint32> field_names_to_ordinals = 2;
}
message TemplateRequirementsOut {
repeated TemplateRequirement requirements = 1;
}
message TemplateRequirement {
oneof value {
TemplateRequirementAll all = 1;
TemplateRequirementAny any = 2;
Empty none = 3;
}
}
message TemplateRequirementAll {
repeated uint32 ords = 1;
}
message TemplateRequirementAny {
repeated uint32 ords = 1;
}

78
rs/Cargo.lock generated
View file

@ -15,10 +15,20 @@ version = "0.1.0"
dependencies = [
"bytes",
"failure",
"nom",
"prost",
"prost-build",
]
[[package]]
name = "arrayvec"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd9fd44efafa8690358b7408d253adf110036b88f55672a933f01d616ad9b1b9"
dependencies = [
"nodrop",
]
[[package]]
name = "autocfg"
version = "0.1.7"
@ -234,6 +244,19 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lexical-core"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304bccb228c4b020f3a4835d247df0a02a7c4686098d4167762cfbbe4c5cb14"
dependencies = [
"arrayvec",
"cfg-if",
"rustc_version",
"ryu",
"static_assertions",
]
[[package]]
name = "libc"
version = "0.2.66"
@ -261,6 +284,23 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04b9f127583ed176e163fb9ec6f3e793b87e21deedd5734a69386a18a0151"
[[package]]
name = "nodrop"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "nom"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c618b63422da4401283884e6668d39f819a106ef51f5f59b81add00075da35ca"
dependencies = [
"lexical-core",
"memchr",
"version_check 0.1.5",
]
[[package]]
name = "num-traits"
version = "0.2.10"
@ -414,7 +454,7 @@ dependencies = [
"serde_json",
"spin",
"unindent",
"version_check",
"version_check 0.9.1",
]
[[package]]
@ -538,12 +578,36 @@ version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
[[package]]
name = "rustc_version"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
dependencies = [
"semver",
]
[[package]]
name = "ryu"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8"
[[package]]
name = "semver"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
dependencies = [
"semver-parser",
]
[[package]]
name = "semver-parser"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
[[package]]
name = "serde"
version = "1.0.104"
@ -581,6 +645,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "static_assertions"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f3eb36b47e512f8f1c9e3d10c2c1965bc992bd9cdb024fa581e2194501c83d3"
[[package]]
name = "syn"
version = "0.15.44"
@ -662,6 +732,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63f18aa3b0e35fed5a0048f029558b1518095ffe2a0a31fb87c93dece93a4993"
[[package]]
name = "version_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
[[package]]
name = "version_check"
version = "0.9.1"

View file

@ -5,6 +5,7 @@ edition = "2018"
authors = ["Ankitects Pty Ltd and contributors"]
[dependencies]
nom = "5.0.1"
failure = "0.1.6"
prost = "0.5.0"
bytes = "0.4"

View file

@ -1,7 +1,9 @@
use crate::err::{AnkiError, Result};
use crate::proto as pt;
use crate::proto::bridge_input::Value;
use crate::template::{FieldMap, FieldRequirements, ParsedTemplate};
use prost::Message;
use std::collections::HashSet;
pub struct Bridge {}
@ -17,6 +19,9 @@ impl std::convert::From<AnkiError> for pt::BridgeError {
use pt::bridge_error::Value as V;
let value = match err {
AnkiError::InvalidInput { info } => V::InvalidInput(pt::InvalidInputError { info }),
AnkiError::TemplateParseError { info } => {
V::TemplateParse(pt::TemplateParseError { info })
}
};
pt::BridgeError { value: Some(value) }
@ -73,6 +78,9 @@ impl Bridge {
fn run_command_inner(&self, ival: pt::bridge_input::Value) -> Result<pt::bridge_output::Value> {
use pt::bridge_output::Value as OValue;
Ok(match ival {
Value::TemplateRequirements(input) => {
OValue::TemplateRequirements(self.template_requirements(input)?)
}
Value::PlusOne(input) => OValue::PlusOne(self.plus_one(input)?),
})
}
@ -81,4 +89,48 @@ impl Bridge {
let num = input.num + 1;
Ok(pt::PlusOneOut { num })
}
fn template_requirements(
&self,
input: pt::TemplateRequirementsIn,
) -> Result<pt::TemplateRequirementsOut> {
let map: FieldMap = input
.field_names_to_ordinals
.iter()
.map(|(name, ord)| (name.as_str(), *ord as u16))
.collect();
// map each provided template into a requirements list
use crate::proto::template_requirement::Value;
let all_reqs = input
.template_front
.into_iter()
.map(|template| {
if let Ok(tmpl) = ParsedTemplate::from_text(&template) {
// convert the rust structure into a protobuf one
let val = match tmpl.requirements(&map) {
FieldRequirements::Any(ords) => Value::Any(pt::TemplateRequirementAny {
ords: ords_hash_to_set(ords),
}),
FieldRequirements::All(ords) => Value::All(pt::TemplateRequirementAll {
ords: ords_hash_to_set(ords),
}),
FieldRequirements::None => Value::None(pt::Empty {}),
};
Ok(pt::TemplateRequirement { value: Some(val) })
} else {
// template parsing failures make card unsatisfiable
Ok(pt::TemplateRequirement {
value: Some(Value::None(pt::Empty {})),
})
}
})
.collect::<Result<Vec<_>>>()?;
Ok(pt::TemplateRequirementsOut {
requirements: all_reqs,
})
}
}
fn ords_hash_to_set(ords: HashSet<u16>) -> Vec<u32> {
ords.iter().map(|ord| *ord as u32).collect()
}

View file

@ -6,10 +6,17 @@ pub type Result<T> = std::result::Result<T, AnkiError>;
pub enum AnkiError {
#[fail(display = "invalid input: {}", info)]
InvalidInput { info: String },
#[fail(display = "invalid card template: {}", info)]
TemplateParseError { info: String },
}
// error helpers
impl AnkiError {
pub(crate) fn parse<S: Into<String>>(s: S) -> AnkiError {
AnkiError::TemplateParseError { info: s.into() }
}
pub(crate) fn invalid_input<S: Into<String>>(s: S) -> AnkiError {
AnkiError::InvalidInput { info: s.into() }
}

View file

@ -2,3 +2,4 @@ mod proto;
pub mod bridge;
pub mod err;
pub mod template;

353
rs/ankirs/src/template.rs Normal file
View file

@ -0,0 +1,353 @@
use crate::err::{AnkiError, Result};
use nom;
use nom::branch::alt;
use nom::bytes::complete::tag;
use nom::error::ErrorKind;
use nom::sequence::delimited;
use std::collections::{HashMap, HashSet};
pub type FieldMap<'a> = HashMap<&'a str, u16>;
// Lexing
//----------------------------------------
#[derive(Debug)]
pub enum Token<'a> {
Text(&'a str),
Replacement(&'a str),
OpenConditional(&'a str),
OpenNegated(&'a str),
CloseConditional(&'a str),
}
/// a span of text, terminated by {{, }} or end of string
pub(crate) fn text_until_handlebars(s: &str) -> nom::IResult<&str, &str> {
let end = s.len();
let limited_end = end
.min(s.find("{{").unwrap_or(end))
.min(s.find("}}").unwrap_or(end));
let (output, input) = s.split_at(limited_end);
if output.is_empty() {
Err(nom::Err::Error((input, ErrorKind::TakeUntil)))
} else {
Ok((input, output))
}
}
/// text outside handlebars
fn text_token(s: &str) -> nom::IResult<&str, Token> {
text_until_handlebars(s).map(|(input, output)| (input, Token::Text(output)))
}
/// text wrapped in handlebars
fn handle_token(s: &str) -> nom::IResult<&str, Token> {
delimited(tag("{{"), text_until_handlebars, tag("}}"))(s)
.map(|(input, output)| (input, classify_handle(output)))
}
/// classify handle based on leading character
fn classify_handle(s: &str) -> Token {
let start = s.trim();
if start.len() < 2 {
return Token::Replacement(start);
}
if start.starts_with('#') {
Token::OpenConditional(&start[1..].trim_start())
} else if start.starts_with('/') {
Token::CloseConditional(&start[1..].trim_start())
} else if start.starts_with('^') {
Token::OpenNegated(&start[1..].trim_start())
} else {
Token::Replacement(start)
}
}
fn next_token(input: &str) -> nom::IResult<&str, Token> {
alt((handle_token, text_token))(input)
}
fn tokens(template: &str) -> impl Iterator<Item = Result<Token>> {
let mut data = template;
std::iter::from_fn(move || {
if data.is_empty() {
return None;
}
match next_token(data) {
Ok((i, o)) => {
data = i;
Some(Ok(o))
}
Err(e) => Some(Err(AnkiError::parse(format!("{:?}", e)))),
}
})
}
// Parsing
//----------------------------------------
#[derive(Debug, PartialEq)]
enum ParsedNode<'a> {
Text(&'a str),
Replacement {
key: &'a str,
filters: Vec<&'a str>,
},
Conditional {
key: &'a str,
children: Vec<ParsedNode<'a>>,
},
NegatedConditional {
key: &'a str,
children: Vec<ParsedNode<'a>>,
},
}
#[derive(Debug)]
pub struct ParsedTemplate<'a>(Vec<ParsedNode<'a>>);
impl ParsedTemplate<'_> {
pub fn from_text(template: &str) -> Result<ParsedTemplate> {
let mut iter = tokens(template);
Ok(Self(parse_inner(&mut iter, None)?))
}
}
fn parse_inner<'a, I: Iterator<Item = Result<Token<'a>>>>(
iter: &mut I,
open_tag: Option<&'a str>,
) -> Result<Vec<ParsedNode<'a>>> {
let mut nodes = vec![];
while let Some(token) = iter.next() {
use Token::*;
nodes.push(match token? {
Text(t) => ParsedNode::Text(t),
Replacement(t) => {
let mut it = t.rsplit(':');
ParsedNode::Replacement {
key: it.next().unwrap(),
filters: it.collect(),
}
}
OpenConditional(t) => ParsedNode::Conditional {
key: t,
children: parse_inner(iter, Some(t))?,
},
OpenNegated(t) => ParsedNode::NegatedConditional {
key: t,
children: parse_inner(iter, Some(t))?,
},
CloseConditional(t) => {
if let Some(open) = open_tag {
if open == t {
// matching closing tag, move back to parent
return Ok(nodes);
}
}
return Err(AnkiError::parse(format!(
"unbalanced closing tag: {:?} / {}",
open_tag, t
)));
}
});
}
if let Some(open) = open_tag {
Err(AnkiError::parse(format!("unclosed conditional {}", open)))
} else {
Ok(nodes)
}
}
// Checking if template is empty
//----------------------------------------
impl ParsedTemplate<'_> {
/// true if provided fields are sufficient to render the template
pub fn renders_with_fields(&self, nonempty_fields: &HashSet<&str>) -> bool {
!template_is_empty(nonempty_fields, &self.0)
}
}
fn template_is_empty<'a>(nonempty_fields: &HashSet<&str>, nodes: &[ParsedNode<'a>]) -> bool {
use ParsedNode::*;
for node in nodes {
match node {
// ignore normal text
Text(_) => (),
Replacement { key, .. } => {
if nonempty_fields.contains(*key) {
// a single replacement is enough
return false;
}
}
Conditional { key, children } => {
if !nonempty_fields.contains(*key) {
continue;
}
if !template_is_empty(nonempty_fields, children) {
return false;
}
}
NegatedConditional { .. } => {
// negated conditionals ignored when determining card generation
continue;
}
}
}
true
}
// Compatibility with old Anki versions
//----------------------------------------
#[derive(Debug, Clone, PartialEq)]
pub enum FieldRequirements {
Any(HashSet<u16>),
All(HashSet<u16>),
None,
}
impl ParsedTemplate<'_> {
/// Return fields required by template.
///
/// This is not able to represent negated expressions or combinations of
/// Any and All, and is provided only for the sake of backwards
/// compatibility.
pub fn requirements(&self, field_map: &FieldMap) -> FieldRequirements {
let mut nonempty: HashSet<_> = Default::default();
let mut ords = HashSet::new();
for (name, ord) in field_map {
nonempty.clear();
nonempty.insert(*name);
if self.renders_with_fields(&nonempty) {
ords.insert(*ord);
}
}
if !ords.is_empty() {
return FieldRequirements::Any(ords);
}
nonempty.extend(field_map.keys());
ords.extend(field_map.values().copied());
for (name, ord) in field_map {
// can we remove this field and still render?
nonempty.remove(name);
if self.renders_with_fields(&nonempty) {
ords.remove(ord);
}
nonempty.insert(*name);
}
if !ords.is_empty() && self.renders_with_fields(&nonempty) {
FieldRequirements::All(ords)
} else {
FieldRequirements::None
}
}
}
// Tests
//---------------------------------------
#[cfg(test)]
mod test {
use super::{FieldMap, ParsedNode::*, ParsedTemplate as PT};
use crate::template::FieldRequirements;
use std::collections::HashSet;
use std::iter::FromIterator;
#[test]
fn test_parsing() {
let tmpl = PT::from_text("foo {{bar}} {{#baz}} quux {{/baz}}").unwrap();
assert_eq!(
tmpl.0,
vec![
Text("foo "),
Replacement {
key: "bar",
filters: vec![]
},
Text(" "),
Conditional {
key: "baz",
children: vec![Text(" quux ")]
}
]
);
let tmpl = PT::from_text("{{^baz}}{{/baz}}").unwrap();
assert_eq!(
tmpl.0,
vec![NegatedConditional {
key: "baz",
children: vec![]
}]
);
PT::from_text("{{#mis}}{{/matched}}").unwrap_err();
PT::from_text("{{/matched}}").unwrap_err();
PT::from_text("{{#mis}}").unwrap_err();
// whitespace
assert_eq!(
PT::from_text("{{ tag }}").unwrap().0,
vec![Replacement {
key: "tag",
filters: vec![]
}]
);
}
#[test]
fn test_nonempty() {
let fields = HashSet::from_iter(vec!["1", "3"].into_iter());
let mut tmpl = PT::from_text("{{2}}{{1}}").unwrap();
assert_eq!(tmpl.renders_with_fields(&fields), true);
tmpl = PT::from_text("{{2}}{{type:cloze:1}}").unwrap();
assert_eq!(tmpl.renders_with_fields(&fields), true);
tmpl = PT::from_text("{{2}}{{4}}").unwrap();
assert_eq!(tmpl.renders_with_fields(&fields), false);
tmpl = PT::from_text("{{#3}}{{^2}}{{1}}{{/2}}{{/3}}").unwrap();
assert_eq!(tmpl.renders_with_fields(&fields), false);
}
#[test]
fn test_requirements() {
let field_map: FieldMap = vec!["a", "b"]
.iter()
.enumerate()
.map(|(a, b)| (*b, a as u16))
.collect();
let mut tmpl = PT::from_text("{{a}}{{b}}").unwrap();
assert_eq!(
tmpl.requirements(&field_map),
FieldRequirements::Any(HashSet::from_iter(vec![0, 1].into_iter()))
);
tmpl = PT::from_text("{{#a}}{{b}}{{/a}}").unwrap();
assert_eq!(
tmpl.requirements(&field_map),
FieldRequirements::All(HashSet::from_iter(vec![0, 1].into_iter()))
);
tmpl = PT::from_text("{{c}}").unwrap();
assert_eq!(tmpl.requirements(&field_map), FieldRequirements::None);
tmpl = PT::from_text("{{^a}}{{b}}{{/a}}").unwrap();
assert_eq!(tmpl.requirements(&field_map), FieldRequirements::None);
tmpl = PT::from_text("{{#a}}{{#b}}{{a}}{{/b}}{{/a}}").unwrap();
assert_eq!(
tmpl.requirements(&field_map),
FieldRequirements::All(HashSet::from_iter(vec![0, 1].into_iter()))
);
// fixme: handling of type in answer card reqs doesn't match desktop,
// which only requires first field
//
}
}