mirror of
https://github.com/ankitects/anki.git
synced 2025-09-18 22:12:21 -04:00

The progress handling code needs a rethink, as we now have two separate ways to flag that the media sync should abort. In the future, it may make sense to switch to polling the backend for progress, instead of passing a callback in.
540 lines
16 KiB
Python
540 lines
16 KiB
Python
# Copyright: Ankitects Pty Ltd and contributors
|
|
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
|
# pylint: skip-file
|
|
|
|
import enum
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
NewType,
|
|
NoReturn,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import ankirspy # pytype: disable=import-error
|
|
|
|
import anki.backend_pb2 as pb
|
|
import anki.buildinfo
|
|
from anki import hooks
|
|
from anki.dbproxy import Row as DBRow
|
|
from anki.dbproxy import ValueForDB
|
|
from anki.fluent_pb2 import FluentString as TR
|
|
from anki.models import AllTemplateReqs
|
|
from anki.sound import AVTag, SoundOrVideoTag, TTSTag
|
|
from anki.types import assert_impossible_literal
|
|
from anki.utils import intTime
|
|
|
|
assert ankirspy.buildhash() == anki.buildinfo.buildhash
|
|
|
|
SchedTimingToday = pb.SchedTimingTodayOut
|
|
BuiltinSortKind = pb.BuiltinSortKind
|
|
BackendCard = pb.Card
|
|
|
|
try:
|
|
import orjson
|
|
except:
|
|
# add compat layer for 32 bit builds that can't use orjson
|
|
print("reverting to stock json")
|
|
|
|
class orjson: # type: ignore
|
|
def dumps(obj: Any) -> bytes:
|
|
return json.dumps(obj).encode("utf8")
|
|
|
|
loads = json.loads
|
|
|
|
|
|
class Interrupted(Exception):
|
|
pass
|
|
|
|
|
|
class StringError(Exception):
|
|
def __str__(self) -> str:
|
|
return self.args[0] # pylint: disable=unsubscriptable-object
|
|
|
|
|
|
NetworkErrorKind = pb.NetworkError.NetworkErrorKind
|
|
SyncErrorKind = pb.SyncError.SyncErrorKind
|
|
|
|
|
|
class NetworkError(StringError):
|
|
def kind(self) -> NetworkErrorKind:
|
|
return self.args[1]
|
|
|
|
|
|
class SyncError(StringError):
|
|
def kind(self) -> SyncErrorKind:
|
|
return self.args[1]
|
|
|
|
|
|
class IOError(StringError):
|
|
pass
|
|
|
|
|
|
class DBError(StringError):
|
|
pass
|
|
|
|
|
|
class TemplateError(StringError):
|
|
pass
|
|
|
|
|
|
def proto_exception_to_native(err: pb.BackendError) -> Exception:
|
|
val = err.WhichOneof("value")
|
|
if val == "interrupted":
|
|
return Interrupted()
|
|
elif val == "network_error":
|
|
return NetworkError(err.localized, err.network_error.kind)
|
|
elif val == "sync_error":
|
|
return SyncError(err.localized, err.sync_error.kind)
|
|
elif val == "io_error":
|
|
return IOError(err.localized)
|
|
elif val == "db_error":
|
|
return DBError(err.localized)
|
|
elif val == "template_parse":
|
|
return TemplateError(err.localized)
|
|
elif val == "invalid_input":
|
|
return StringError(err.localized)
|
|
else:
|
|
assert_impossible_literal(val)
|
|
|
|
|
|
def proto_template_reqs_to_legacy(
|
|
reqs: List[pb.TemplateRequirement],
|
|
) -> AllTemplateReqs:
|
|
legacy_reqs = []
|
|
for (idx, req) in enumerate(reqs):
|
|
kind = req.WhichOneof("value")
|
|
# fixme: sorting is for the unit tests - should check if any
|
|
# code depends on the order
|
|
if kind == "any":
|
|
legacy_reqs.append((idx, "any", sorted(req.any.ords)))
|
|
elif kind == "all":
|
|
legacy_reqs.append((idx, "all", sorted(req.all.ords)))
|
|
else:
|
|
l: List[int] = []
|
|
legacy_reqs.append((idx, "none", l))
|
|
return legacy_reqs
|
|
|
|
|
|
def av_tag_to_native(tag: pb.AVTag) -> AVTag:
|
|
val = tag.WhichOneof("value")
|
|
if val == "sound_or_video":
|
|
return SoundOrVideoTag(filename=tag.sound_or_video)
|
|
else:
|
|
return TTSTag(
|
|
field_text=tag.tts.field_text,
|
|
lang=tag.tts.lang,
|
|
voices=list(tag.tts.voices),
|
|
other_args=list(tag.tts.other_args),
|
|
speed=tag.tts.speed,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TemplateReplacement:
|
|
field_name: str
|
|
current_text: str
|
|
filters: List[str]
|
|
|
|
|
|
TemplateReplacementList = List[Union[str, TemplateReplacement]]
|
|
|
|
|
|
MediaSyncProgress = pb.MediaSyncProgress
|
|
|
|
MediaCheckOutput = pb.MediaCheckOut
|
|
|
|
FormatTimeSpanContext = pb.FormatTimeSpanIn.Context
|
|
|
|
|
|
@dataclass
|
|
class ExtractedLatex:
|
|
filename: str
|
|
latex_body: str
|
|
|
|
|
|
@dataclass
|
|
class ExtractedLatexOutput:
|
|
html: str
|
|
latex: List[ExtractedLatex]
|
|
|
|
|
|
class ProgressKind(enum.Enum):
|
|
MediaSync = 0
|
|
MediaCheck = 1
|
|
|
|
|
|
@dataclass
|
|
class Progress:
|
|
kind: ProgressKind
|
|
val: Union[MediaSyncProgress, str]
|
|
|
|
|
|
def proto_replacement_list_to_native(
|
|
nodes: List[pb.RenderedTemplateNode],
|
|
) -> TemplateReplacementList:
|
|
results: TemplateReplacementList = []
|
|
for node in nodes:
|
|
if node.WhichOneof("value") == "text":
|
|
results.append(node.text)
|
|
else:
|
|
results.append(
|
|
TemplateReplacement(
|
|
field_name=node.replacement.field_name,
|
|
current_text=node.replacement.current_text,
|
|
filters=list(node.replacement.filters),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
def proto_progress_to_native(progress: pb.Progress) -> Progress:
|
|
kind = progress.WhichOneof("value")
|
|
if kind == "media_sync":
|
|
return Progress(kind=ProgressKind.MediaSync, val=progress.media_sync)
|
|
elif kind == "media_check":
|
|
return Progress(kind=ProgressKind.MediaCheck, val=progress.media_check)
|
|
else:
|
|
assert_impossible_literal(kind)
|
|
|
|
|
|
def _on_progress(progress_bytes: bytes) -> bool:
|
|
progress = pb.Progress()
|
|
progress.ParseFromString(progress_bytes)
|
|
native_progress = proto_progress_to_native(progress)
|
|
return hooks.bg_thread_progress_callback(True, native_progress)
|
|
|
|
|
|
class RustBackend:
|
|
def __init__(
|
|
self,
|
|
ftl_folder: Optional[str] = None,
|
|
langs: Optional[List[str]] = None,
|
|
server: bool = False,
|
|
) -> None:
|
|
# pick up global defaults if not provided
|
|
if ftl_folder is None:
|
|
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent")
|
|
if langs is None:
|
|
langs = [anki.lang.currentLang]
|
|
|
|
init_msg = pb.BackendInit(
|
|
locale_folder_path=ftl_folder, preferred_langs=langs, server=server,
|
|
)
|
|
self._backend = ankirspy.open_backend(init_msg.SerializeToString())
|
|
self._backend.set_progress_callback(_on_progress)
|
|
|
|
def _run_command(
|
|
self, input: pb.BackendInput, release_gil: bool = False
|
|
) -> pb.BackendOutput:
|
|
input_bytes = input.SerializeToString()
|
|
output_bytes = self._backend.command(input_bytes, release_gil)
|
|
output = pb.BackendOutput()
|
|
output.ParseFromString(output_bytes)
|
|
kind = output.WhichOneof("value")
|
|
if kind == "error":
|
|
raise proto_exception_to_native(output.error)
|
|
else:
|
|
return output
|
|
|
|
def open_collection(
|
|
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str
|
|
):
|
|
self._run_command(
|
|
pb.BackendInput(
|
|
open_collection=pb.OpenCollectionIn(
|
|
collection_path=col_path,
|
|
media_folder_path=media_folder_path,
|
|
media_db_path=media_db_path,
|
|
log_path=log_path,
|
|
)
|
|
),
|
|
release_gil=True,
|
|
)
|
|
|
|
def close_collection(self, downgrade=True):
|
|
self._run_command(
|
|
pb.BackendInput(
|
|
close_collection=pb.CloseCollectionIn(downgrade_to_schema11=downgrade)
|
|
),
|
|
release_gil=True,
|
|
)
|
|
|
|
def template_requirements(
|
|
self, template_fronts: List[str], field_map: Dict[str, int]
|
|
) -> AllTemplateReqs:
|
|
input = pb.BackendInput(
|
|
template_requirements=pb.TemplateRequirementsIn(
|
|
template_front=template_fronts, field_names_to_ordinals=field_map
|
|
)
|
|
)
|
|
output = self._run_command(input).template_requirements
|
|
reqs: List[pb.TemplateRequirement] = output.requirements # type: ignore
|
|
return proto_template_reqs_to_legacy(reqs)
|
|
|
|
def sched_timing_today(
|
|
self,
|
|
created_secs: int,
|
|
created_mins_west: Optional[int],
|
|
now_mins_west: Optional[int],
|
|
rollover: Optional[int],
|
|
) -> SchedTimingToday:
|
|
if created_mins_west is not None:
|
|
crt_west = pb.OptionalInt32(val=created_mins_west)
|
|
else:
|
|
crt_west = None
|
|
|
|
if now_mins_west is not None:
|
|
now_west = pb.OptionalInt32(val=now_mins_west)
|
|
else:
|
|
now_west = None
|
|
|
|
if rollover is not None:
|
|
roll = pb.OptionalInt32(val=rollover)
|
|
else:
|
|
roll = None
|
|
|
|
return self._run_command(
|
|
pb.BackendInput(
|
|
sched_timing_today=pb.SchedTimingTodayIn(
|
|
created_secs=created_secs,
|
|
now_secs=intTime(),
|
|
created_mins_west=crt_west,
|
|
now_mins_west=now_west,
|
|
rollover_hour=roll,
|
|
)
|
|
)
|
|
).sched_timing_today
|
|
|
|
def render_card(
|
|
self, qfmt: str, afmt: str, fields: Dict[str, str], card_ord: int
|
|
) -> Tuple[TemplateReplacementList, TemplateReplacementList]:
|
|
out = self._run_command(
|
|
pb.BackendInput(
|
|
render_card=pb.RenderCardIn(
|
|
question_template=qfmt,
|
|
answer_template=afmt,
|
|
fields=fields,
|
|
card_ordinal=card_ord,
|
|
)
|
|
)
|
|
).render_card
|
|
|
|
qnodes = proto_replacement_list_to_native(out.question_nodes) # type: ignore
|
|
anodes = proto_replacement_list_to_native(out.answer_nodes) # type: ignore
|
|
|
|
return (qnodes, anodes)
|
|
|
|
def local_minutes_west(self, stamp: int) -> int:
|
|
return self._run_command(
|
|
pb.BackendInput(local_minutes_west=stamp)
|
|
).local_minutes_west
|
|
|
|
def strip_av_tags(self, text: str) -> str:
|
|
return self._run_command(pb.BackendInput(strip_av_tags=text)).strip_av_tags
|
|
|
|
def extract_av_tags(
|
|
self, text: str, question_side: bool
|
|
) -> Tuple[str, List[AVTag]]:
|
|
out = self._run_command(
|
|
pb.BackendInput(
|
|
extract_av_tags=pb.ExtractAVTagsIn(
|
|
text=text, question_side=question_side
|
|
)
|
|
)
|
|
).extract_av_tags
|
|
native_tags = list(map(av_tag_to_native, out.av_tags))
|
|
|
|
return out.text, native_tags
|
|
|
|
def extract_latex(
|
|
self, text: str, svg: bool, expand_clozes: bool
|
|
) -> ExtractedLatexOutput:
|
|
out = self._run_command(
|
|
pb.BackendInput(
|
|
extract_latex=pb.ExtractLatexIn(
|
|
text=text, svg=svg, expand_clozes=expand_clozes
|
|
)
|
|
)
|
|
).extract_latex
|
|
|
|
return ExtractedLatexOutput(
|
|
html=out.text,
|
|
latex=[
|
|
ExtractedLatex(filename=l.filename, latex_body=l.latex_body)
|
|
for l in out.latex
|
|
],
|
|
)
|
|
|
|
def add_file_to_media_folder(self, desired_name: str, data: bytes) -> str:
|
|
return self._run_command(
|
|
pb.BackendInput(
|
|
add_media_file=pb.AddMediaFileIn(desired_name=desired_name, data=data)
|
|
)
|
|
).add_media_file
|
|
|
|
def sync_media(self, hkey: str, endpoint: str) -> None:
|
|
self._run_command(
|
|
pb.BackendInput(sync_media=pb.SyncMediaIn(hkey=hkey, endpoint=endpoint,)),
|
|
release_gil=True,
|
|
)
|
|
|
|
def check_media(self) -> MediaCheckOutput:
|
|
return self._run_command(
|
|
pb.BackendInput(check_media=pb.Empty()), release_gil=True,
|
|
).check_media
|
|
|
|
def trash_media_files(self, fnames: List[str]) -> None:
|
|
self._run_command(
|
|
pb.BackendInput(trash_media_files=pb.TrashMediaFilesIn(fnames=fnames))
|
|
)
|
|
|
|
def translate(self, key: TR, **kwargs: Union[str, int, float]) -> str:
|
|
return self._run_command(
|
|
pb.BackendInput(translate_string=translate_string_in(key, **kwargs))
|
|
).translate_string
|
|
|
|
def format_time_span(
|
|
self,
|
|
seconds: float,
|
|
context: FormatTimeSpanContext = FormatTimeSpanContext.INTERVALS,
|
|
) -> str:
|
|
return self._run_command(
|
|
pb.BackendInput(
|
|
format_time_span=pb.FormatTimeSpanIn(seconds=seconds, context=context)
|
|
)
|
|
).format_time_span
|
|
|
|
def studied_today(self, cards: int, seconds: float) -> str:
|
|
return self._run_command(
|
|
pb.BackendInput(
|
|
studied_today=pb.StudiedTodayIn(cards=cards, seconds=seconds)
|
|
)
|
|
).studied_today
|
|
|
|
def learning_congrats_msg(self, next_due: float, remaining: int) -> str:
|
|
return self._run_command(
|
|
pb.BackendInput(
|
|
congrats_learn_msg=pb.CongratsLearnMsgIn(
|
|
next_due=next_due, remaining=remaining
|
|
)
|
|
)
|
|
).congrats_learn_msg
|
|
|
|
def empty_trash(self):
|
|
self._run_command(pb.BackendInput(empty_trash=pb.Empty()))
|
|
|
|
def restore_trash(self):
|
|
self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
|
|
|
|
def db_query(
|
|
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool
|
|
) -> List[DBRow]:
|
|
return self._db_command(
|
|
dict(kind="query", sql=sql, args=args, first_row_only=first_row_only)
|
|
)
|
|
|
|
def db_execute_many(self, sql: str, args: List[List[ValueForDB]]) -> List[DBRow]:
|
|
return self._db_command(dict(kind="executemany", sql=sql, args=args))
|
|
|
|
def db_begin(self) -> None:
|
|
return self._db_command(dict(kind="begin"))
|
|
|
|
def db_commit(self) -> None:
|
|
return self._db_command(dict(kind="commit"))
|
|
|
|
def db_rollback(self) -> None:
|
|
return self._db_command(dict(kind="rollback"))
|
|
|
|
def _db_command(self, input: Dict[str, Any]) -> Any:
|
|
return orjson.loads(self._backend.db_command(orjson.dumps(input)))
|
|
|
|
def search_cards(
|
|
self, search: str, order: Union[bool, str, int], reverse: bool = False
|
|
) -> Sequence[int]:
|
|
if isinstance(order, str):
|
|
mode = pb.SortOrder(custom=order)
|
|
elif order is True:
|
|
mode = pb.SortOrder(from_config=pb.Empty())
|
|
elif order is False:
|
|
mode = pb.SortOrder(none=pb.Empty())
|
|
else:
|
|
# sadly we can't use the protobuf type in a Union, so we
|
|
# have to accept an int and convert it
|
|
kind = BuiltinSortKind.Value(BuiltinSortKind.Name(order))
|
|
mode = pb.SortOrder(
|
|
builtin=pb.BuiltinSearchOrder(kind=kind, reverse=reverse)
|
|
)
|
|
|
|
return self._run_command(
|
|
pb.BackendInput(search_cards=pb.SearchCardsIn(search=search, order=mode))
|
|
).search_cards.card_ids
|
|
|
|
def search_notes(self, search: str) -> Sequence[int]:
|
|
return self._run_command(
|
|
pb.BackendInput(search_notes=pb.SearchNotesIn(search=search))
|
|
).search_notes.note_ids
|
|
|
|
def get_card(self, cid: int) -> Optional[pb.Card]:
|
|
return self._run_command(pb.BackendInput(get_card=cid)).get_card.card
|
|
|
|
def update_card(self, card: BackendCard) -> None:
|
|
self._run_command(pb.BackendInput(update_card=card))
|
|
|
|
def add_card(self, card: BackendCard) -> None:
|
|
card.id = self._run_command(pb.BackendInput(add_card=card)).add_card
|
|
|
|
def get_deck_config(self, dcid: int) -> Dict[str, Any]:
|
|
jstr = self._run_command(pb.BackendInput(get_deck_config=dcid)).get_deck_config
|
|
return json.loads(jstr)
|
|
|
|
def add_or_update_deck_config(self, conf: Dict[str, Any]) -> None:
|
|
conf_json = json.dumps(conf)
|
|
id = self._run_command(
|
|
pb.BackendInput(add_or_update_deck_config=conf_json)
|
|
).add_or_update_deck_config
|
|
conf["id"] = id
|
|
|
|
def all_deck_config(self) -> Sequence[Dict[str, Any]]:
|
|
jstr = self._run_command(
|
|
pb.BackendInput(all_deck_config=pb.Empty())
|
|
).all_deck_config
|
|
return json.loads(jstr)
|
|
|
|
def new_deck_config(self) -> Dict[str, Any]:
|
|
jstr = self._run_command(
|
|
pb.BackendInput(new_deck_config=pb.Empty())
|
|
).new_deck_config
|
|
return json.loads(jstr)
|
|
|
|
def remove_deck_config(self, dcid: int) -> None:
|
|
self._run_command(pb.BackendInput(remove_deck_config=dcid))
|
|
|
|
def abort_media_sync(self):
|
|
self._run_command(pb.BackendInput(abort_media_sync=pb.Empty()))
|
|
|
|
|
|
def translate_string_in(
|
|
key: TR, **kwargs: Union[str, int, float]
|
|
) -> pb.TranslateStringIn:
|
|
args = {}
|
|
for (k, v) in kwargs.items():
|
|
if isinstance(v, str):
|
|
args[k] = pb.TranslateArgValue(str=v)
|
|
else:
|
|
args[k] = pb.TranslateArgValue(number=v)
|
|
return pb.TranslateStringIn(key=key, args=args)
|
|
|
|
|
|
# temporarily force logging of media handling
|
|
if "RUST_LOG" not in os.environ:
|
|
os.environ["RUST_LOG"] = "warn,anki::media=debug"
|