# Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html # pylint: skip-file """ Python bindings for Anki's Rust libraries. Please do not access methods on the backend directly - they may be changed or removed at any time. Instead, please use the methods on the collection instead. Eg, don't use col.backend.all_deck_config(), instead use col.decks.all_config() """ import enum import json import os from dataclasses import dataclass from typing import ( Any, Callable, Dict, Iterable, List, NewType, NoReturn, Optional, Sequence, Tuple, Union, ) import ankirspy # pytype: disable=import-error import anki.backend_pb2 as pb import anki.buildinfo from anki import hooks from anki.dbproxy import Row as DBRow from anki.dbproxy import ValueForDB from anki.fluent_pb2 import FluentString as TR from anki.sound import AVTag, SoundOrVideoTag, TTSTag from anki.types import assert_impossible_literal from anki.utils import intTime assert ankirspy.buildhash() == anki.buildinfo.buildhash SchedTimingToday = pb.SchedTimingTodayOut BuiltinSortKind = pb.BuiltinSearchOrder.BuiltinSortKind BackendCard = pb.Card BackendNote = pb.Note TagUsnTuple = pb.TagUsnTuple NoteType = pb.NoteType DeckTreeNode = pb.DeckTreeNode try: import orjson except: # add compat layer for 32 bit builds that can't use orjson print("reverting to stock json") class orjson: # type: ignore def dumps(obj: Any) -> bytes: return json.dumps(obj).encode("utf8") loads = json.loads class Interrupted(Exception): pass class StringError(Exception): def __str__(self) -> str: return self.args[0] # pylint: disable=unsubscriptable-object NetworkErrorKind = pb.NetworkError.NetworkErrorKind SyncErrorKind = pb.SyncError.SyncErrorKind class NetworkError(StringError): def kind(self) -> NetworkErrorKind: return self.args[1] class SyncError(StringError): def kind(self) -> SyncErrorKind: return self.args[1] class IOError(StringError): pass class DBError(StringError): pass class TemplateError(StringError): pass class NotFoundError(Exception): pass class ExistsError(Exception): pass class DeckIsFilteredError(Exception): pass class InvalidInput(StringError): pass def proto_exception_to_native(err: pb.BackendError) -> Exception: val = err.WhichOneof("value") if val == "interrupted": return Interrupted() elif val == "network_error": return NetworkError(err.localized, err.network_error.kind) elif val == "sync_error": return SyncError(err.localized, err.sync_error.kind) elif val == "io_error": return IOError(err.localized) elif val == "db_error": return DBError(err.localized) elif val == "template_parse": return TemplateError(err.localized) elif val == "invalid_input": return InvalidInput(err.localized) elif val == "json_error": return StringError(err.localized) elif val == "not_found_error": return NotFoundError() elif val == "exists": return ExistsError() elif val == "deck_is_filtered": return DeckIsFilteredError() else: assert_impossible_literal(val) def av_tag_to_native(tag: pb.AVTag) -> AVTag: val = tag.WhichOneof("value") if val == "sound_or_video": return SoundOrVideoTag(filename=tag.sound_or_video) else: return TTSTag( field_text=tag.tts.field_text, lang=tag.tts.lang, voices=list(tag.tts.voices), other_args=list(tag.tts.other_args), speed=tag.tts.speed, ) @dataclass class TemplateReplacement: field_name: str current_text: str filters: List[str] TemplateReplacementList = List[Union[str, TemplateReplacement]] MediaSyncProgress = pb.MediaSyncProgress MediaCheckOutput = pb.MediaCheckOut FormatTimeSpanContext = pb.FormatTimeSpanIn.Context @dataclass class ExtractedLatex: filename: str latex_body: str @dataclass class ExtractedLatexOutput: html: str latex: List[ExtractedLatex] class ProgressKind(enum.Enum): MediaSync = 0 MediaCheck = 1 @dataclass class Progress: kind: ProgressKind val: Union[MediaSyncProgress, str] def proto_replacement_list_to_native( nodes: List[pb.RenderedTemplateNode], ) -> TemplateReplacementList: results: TemplateReplacementList = [] for node in nodes: if node.WhichOneof("value") == "text": results.append(node.text) else: results.append( TemplateReplacement( field_name=node.replacement.field_name, current_text=node.replacement.current_text, filters=list(node.replacement.filters), ) ) return results def proto_progress_to_native(progress: pb.Progress) -> Progress: kind = progress.WhichOneof("value") if kind == "media_sync": return Progress(kind=ProgressKind.MediaSync, val=progress.media_sync) elif kind == "media_check": return Progress(kind=ProgressKind.MediaCheck, val=progress.media_check) else: assert_impossible_literal(kind) def _on_progress(progress_bytes: bytes) -> bool: progress = pb.Progress() progress.ParseFromString(progress_bytes) native_progress = proto_progress_to_native(progress) return hooks.bg_thread_progress_callback(True, native_progress) class RustBackend: def __init__( self, ftl_folder: Optional[str] = None, langs: Optional[List[str]] = None, server: bool = False, ) -> None: # pick up global defaults if not provided if ftl_folder is None: ftl_folder = os.path.join(anki.lang.locale_folder, "fluent") if langs is None: langs = [anki.lang.currentLang] init_msg = pb.BackendInit( locale_folder_path=ftl_folder, preferred_langs=langs, server=server, ) self._backend = ankirspy.open_backend(init_msg.SerializeToString()) self._backend.set_progress_callback(_on_progress) def _run_command( self, input: pb.BackendInput, release_gil: bool = False ) -> pb.BackendOutput: input_bytes = input.SerializeToString() output_bytes = self._backend.command(input_bytes, release_gil) output = pb.BackendOutput() output.ParseFromString(output_bytes) kind = output.WhichOneof("value") if kind == "error": raise proto_exception_to_native(output.error) else: return output def open_collection( self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str ): self._run_command( pb.BackendInput( open_collection=pb.OpenCollectionIn( collection_path=col_path, media_folder_path=media_folder_path, media_db_path=media_db_path, log_path=log_path, ) ), release_gil=True, ) def close_collection(self, downgrade=True): self._run_command( pb.BackendInput( close_collection=pb.CloseCollectionIn(downgrade_to_schema11=downgrade) ), release_gil=True, ) def sched_timing_today( self, created_secs: int, created_mins_west: Optional[int], now_mins_west: Optional[int], rollover: Optional[int], ) -> SchedTimingToday: if created_mins_west is not None: crt_west = pb.OptionalInt32(val=created_mins_west) else: crt_west = None if now_mins_west is not None: now_west = pb.OptionalInt32(val=now_mins_west) else: now_west = None if rollover is not None: roll = pb.OptionalInt32(val=rollover) else: roll = None return self._run_command( pb.BackendInput( sched_timing_today=pb.SchedTimingTodayIn( created_secs=created_secs, now_secs=intTime(), created_mins_west=crt_west, now_mins_west=now_west, rollover_hour=roll, ) ) ).sched_timing_today def render_card( self, qfmt: str, afmt: str, fields: Dict[str, str], card_ord: int ) -> Tuple[TemplateReplacementList, TemplateReplacementList]: out = self._run_command( pb.BackendInput( render_card=pb.RenderCardIn( question_template=qfmt, answer_template=afmt, fields=fields, card_ordinal=card_ord, ) ) ).render_card qnodes = proto_replacement_list_to_native(out.question_nodes) # type: ignore anodes = proto_replacement_list_to_native(out.answer_nodes) # type: ignore return (qnodes, anodes) def local_minutes_west(self, stamp: int) -> int: return self._run_command( pb.BackendInput(local_minutes_west=stamp) ).local_minutes_west def strip_av_tags(self, text: str) -> str: return self._run_command(pb.BackendInput(strip_av_tags=text)).strip_av_tags def extract_av_tags( self, text: str, question_side: bool ) -> Tuple[str, List[AVTag]]: out = self._run_command( pb.BackendInput( extract_av_tags=pb.ExtractAVTagsIn( text=text, question_side=question_side ) ) ).extract_av_tags native_tags = list(map(av_tag_to_native, out.av_tags)) return out.text, native_tags def extract_latex( self, text: str, svg: bool, expand_clozes: bool ) -> ExtractedLatexOutput: out = self._run_command( pb.BackendInput( extract_latex=pb.ExtractLatexIn( text=text, svg=svg, expand_clozes=expand_clozes ) ) ).extract_latex return ExtractedLatexOutput( html=out.text, latex=[ ExtractedLatex(filename=l.filename, latex_body=l.latex_body) for l in out.latex ], ) def add_file_to_media_folder(self, desired_name: str, data: bytes) -> str: return self._run_command( pb.BackendInput( add_media_file=pb.AddMediaFileIn(desired_name=desired_name, data=data) ) ).add_media_file def sync_media(self, hkey: str, endpoint: str) -> None: self._run_command( pb.BackendInput(sync_media=pb.SyncMediaIn(hkey=hkey, endpoint=endpoint,)), release_gil=True, ) def check_media(self) -> MediaCheckOutput: return self._run_command( pb.BackendInput(check_media=pb.Empty()), release_gil=True, ).check_media def trash_media_files(self, fnames: List[str]) -> None: self._run_command( pb.BackendInput(trash_media_files=pb.TrashMediaFilesIn(fnames=fnames)) ) def translate(self, key: TR, **kwargs: Union[str, int, float]) -> str: return self._run_command( pb.BackendInput(translate_string=translate_string_in(key, **kwargs)) ).translate_string def format_time_span( self, seconds: float, context: FormatTimeSpanContext = FormatTimeSpanContext.INTERVALS, ) -> str: return self._run_command( pb.BackendInput( format_time_span=pb.FormatTimeSpanIn(seconds=seconds, context=context) ) ).format_time_span def studied_today(self, cards: int, seconds: float) -> str: return self._run_command( pb.BackendInput( studied_today=pb.StudiedTodayIn(cards=cards, seconds=seconds) ) ).studied_today def learning_congrats_msg(self, next_due: float, remaining: int) -> str: return self._run_command( pb.BackendInput( congrats_learn_msg=pb.CongratsLearnMsgIn( next_due=next_due, remaining=remaining ) ) ).congrats_learn_msg def empty_trash(self): self._run_command(pb.BackendInput(empty_trash=pb.Empty())) def restore_trash(self): self._run_command(pb.BackendInput(restore_trash=pb.Empty())) def db_query( self, sql: str, args: Sequence[ValueForDB], first_row_only: bool ) -> List[DBRow]: return self._db_command( dict(kind="query", sql=sql, args=args, first_row_only=first_row_only) ) def db_execute_many(self, sql: str, args: List[List[ValueForDB]]) -> List[DBRow]: return self._db_command(dict(kind="executemany", sql=sql, args=args)) def db_begin(self) -> None: return self._db_command(dict(kind="begin")) def db_commit(self) -> None: return self._db_command(dict(kind="commit")) def db_rollback(self) -> None: return self._db_command(dict(kind="rollback")) def _db_command(self, input: Dict[str, Any]) -> Any: return orjson.loads(self._backend.db_command(orjson.dumps(input))) def search_cards( self, search: str, order: Union[bool, str, int], reverse: bool = False ) -> Sequence[int]: if isinstance(order, str): mode = pb.SortOrder(custom=order) elif order is True: mode = pb.SortOrder(from_config=pb.Empty()) elif order is False: mode = pb.SortOrder(none=pb.Empty()) else: # sadly we can't use the protobuf type in a Union, so we # have to accept an int and convert it kind = BuiltinSortKind.Value(BuiltinSortKind.Name(order)) mode = pb.SortOrder( builtin=pb.BuiltinSearchOrder(kind=kind, reverse=reverse) ) return self._run_command( pb.BackendInput(search_cards=pb.SearchCardsIn(search=search, order=mode)) ).search_cards.card_ids def search_notes(self, search: str) -> Sequence[int]: return self._run_command( pb.BackendInput(search_notes=pb.SearchNotesIn(search=search)) ).search_notes.note_ids def get_card(self, cid: int) -> Optional[pb.Card]: output = self._run_command(pb.BackendInput(get_card=cid)).get_card if output.HasField("card"): return output.card else: return None def update_card(self, card: BackendCard) -> None: self._run_command(pb.BackendInput(update_card=card)) # returns the new card id def add_card(self, card: BackendCard) -> int: return self._run_command(pb.BackendInput(add_card=card)).add_card def get_deck_config(self, dcid: int) -> Dict[str, Any]: jstr = self._run_command(pb.BackendInput(get_deck_config=dcid)).get_deck_config return orjson.loads(jstr) def add_or_update_deck_config(self, conf: Dict[str, Any], preserve_usn) -> None: conf_json = orjson.dumps(conf) id = self._run_command( pb.BackendInput( add_or_update_deck_config=pb.AddOrUpdateDeckConfigIn( config=conf_json, preserve_usn_and_mtime=preserve_usn ) ) ).add_or_update_deck_config conf["id"] = id def all_deck_config(self) -> Sequence[Dict[str, Any]]: jstr = self._run_command( pb.BackendInput(all_deck_config=pb.Empty()) ).all_deck_config return orjson.loads(jstr) def new_deck_config(self) -> Dict[str, Any]: jstr = self._run_command( pb.BackendInput(new_deck_config=pb.Empty()) ).new_deck_config return orjson.loads(jstr) def remove_deck_config(self, dcid: int) -> None: self._run_command(pb.BackendInput(remove_deck_config=dcid)) def abort_media_sync(self): self._run_command(pb.BackendInput(abort_media_sync=pb.Empty())) def all_tags(self) -> Iterable[TagUsnTuple]: return self._run_command(pb.BackendInput(all_tags=pb.Empty())).all_tags.tags def register_tags(self, tags: str, usn: Optional[int], clear_first: bool) -> bool: if usn is None: preserve_usn = False usn_ = 0 else: usn_ = usn preserve_usn = True return self._run_command( pb.BackendInput( register_tags=pb.RegisterTagsIn( tags=tags, usn=usn_, preserve_usn=preserve_usn, clear_first=clear_first, ) ) ).register_tags def before_upload(self): self._run_command(pb.BackendInput(before_upload=pb.Empty())) def get_changed_tags(self, usn: int) -> List[str]: return list( self._run_command( pb.BackendInput(get_changed_tags=usn) ).get_changed_tags.tags ) def get_config_json(self, key: str) -> Any: b = self._run_command(pb.BackendInput(get_config_json=key)).get_config_json if b == b"": raise KeyError return orjson.loads(b) def set_config_json(self, key: str, val: Any): self._run_command( pb.BackendInput( set_config_json=pb.SetConfigJson(key=key, val=orjson.dumps(val)) ) ) def remove_config(self, key: str): self._run_command( pb.BackendInput( set_config_json=pb.SetConfigJson(key=key, remove=pb.Empty()) ) ) def get_all_config(self) -> Dict[str, Any]: jstr = self._run_command( pb.BackendInput(get_all_config=pb.Empty()) ).get_all_config return orjson.loads(jstr) def set_all_config(self, conf: Dict[str, Any]): self._run_command(pb.BackendInput(set_all_config=orjson.dumps(conf))) def get_changed_notetypes(self, usn: int) -> Dict[str, Dict[str, Any]]: jstr = self._run_command( pb.BackendInput(get_changed_notetypes=usn) ).get_changed_notetypes return orjson.loads(jstr) def get_all_decks(self) -> Dict[str, Dict[str, Any]]: jstr = self._run_command( pb.BackendInput(get_all_decks=pb.Empty()) ).get_all_decks return orjson.loads(jstr) def all_stock_notetypes(self) -> List[NoteType]: return list( self._run_command( pb.BackendInput(all_stock_notetypes=pb.Empty()) ).all_stock_notetypes.notetypes ) def get_notetype_names_and_ids(self) -> List[pb.NoteTypeNameID]: return list( self._run_command( pb.BackendInput(get_notetype_names=pb.Empty()) ).get_notetype_names.entries ) def get_notetype_use_counts(self) -> List[pb.NoteTypeNameIDUseCount]: return list( self._run_command( pb.BackendInput(get_notetype_names_and_counts=pb.Empty()) ).get_notetype_names_and_counts.entries ) def get_notetype_legacy(self, ntid: int) -> Optional[Dict]: try: bytes = self._run_command( pb.BackendInput(get_notetype_legacy=ntid) ).get_notetype_legacy except NotFoundError: return None return orjson.loads(bytes) def get_notetype_id_by_name(self, name: str) -> Optional[int]: return ( self._run_command( pb.BackendInput(get_notetype_id_by_name=name) ).get_notetype_id_by_name or None ) def add_or_update_notetype(self, nt: Dict[str, Any], preserve_usn: bool) -> None: bjson = orjson.dumps(nt) id = self._run_command( pb.BackendInput( add_or_update_notetype=pb.AddOrUpdateNotetypeIn( json=bjson, preserve_usn_and_mtime=preserve_usn ) ), release_gil=True, ).add_or_update_notetype nt["id"] = id def remove_notetype(self, ntid: int) -> None: self._run_command(pb.BackendInput(remove_notetype=ntid), release_gil=True) def new_note(self, ntid: int) -> BackendNote: return self._run_command(pb.BackendInput(new_note=ntid)).new_note def add_note(self, note: BackendNote, deck_id: int) -> int: return self._run_command( pb.BackendInput(add_note=pb.AddNoteIn(note=note, deck_id=deck_id)) ).add_note def update_note(self, note: BackendNote) -> None: self._run_command(pb.BackendInput(update_note=note)) def get_note(self, nid) -> Optional[BackendNote]: try: return self._run_command(pb.BackendInput(get_note=nid)).get_note except NotFoundError: return None def empty_cards_report(self) -> pb.EmptyCardsReport: return self._run_command( pb.BackendInput(get_empty_cards=pb.Empty()), release_gil=True ).get_empty_cards def get_deck_legacy(self, did: int) -> Optional[Dict]: try: bytes = self._run_command( pb.BackendInput(get_deck_legacy=did) ).get_deck_legacy return orjson.loads(bytes) except NotFoundError: return None def get_deck_names_and_ids(self) -> List[pb.DeckNameID]: return list( self._run_command( pb.BackendInput(get_deck_names=pb.Empty()) ).get_deck_names.entries ) def add_or_update_deck_legacy( self, deck: Dict[str, Any], preserve_usn: bool ) -> None: deck_json = orjson.dumps(deck) id = self._run_command( pb.BackendInput( add_or_update_deck_legacy=pb.AddOrUpdateDeckLegacyIn( deck=deck_json, preserve_usn_and_mtime=preserve_usn ) ) ).add_or_update_deck_legacy deck["id"] = id def new_deck_legacy(self, filtered: bool) -> Dict[str, Any]: jstr = self._run_command( pb.BackendInput(new_deck_legacy=filtered) ).new_deck_legacy return orjson.loads(jstr) def get_deck_id_by_name(self, name: str) -> Optional[int]: return ( self._run_command( pb.BackendInput(get_deck_id_by_name=name) ).get_deck_id_by_name or None ) def remove_deck(self, did: int) -> None: self._run_command(pb.BackendInput(remove_deck=did)) def deck_tree(self, include_counts: bool) -> DeckTreeNode: return self._run_command( pb.BackendInput(deck_tree=pb.DeckTreeIn(include_counts=include_counts)) ).deck_tree def check_database(self) -> List[str]: return list( self._run_command( pb.BackendInput(check_database=pb.Empty()), release_gil=True ).check_database.problems ) def legacy_deck_tree(self) -> Sequence: bytes = self._run_command( pb.BackendInput(deck_tree_legacy=pb.Empty()) ).deck_tree_legacy return orjson.loads(bytes)[5] def field_names_for_note_ids(self, nids: List[int]) -> Sequence[str]: return self._run_command( pb.BackendInput(field_names_for_notes=pb.FieldNamesForNotesIn(nids=nids)), release_gil=True, ).field_names_for_notes.fields def find_and_replace( self, nids: List[int], search: str, repl: str, re: bool, nocase: bool, field_name: Optional[str], ) -> int: return self._run_command( pb.BackendInput( find_and_replace=pb.FindAndReplaceIn( nids=nids, search=search, replacement=repl, regex=re, match_case=not nocase, field_name=field_name, ) ), release_gil=True, ).find_and_replace def after_note_updates( self, nids: List[int], generate_cards: bool, mark_notes_modified: bool ) -> None: self._run_command( pb.BackendInput( after_note_updates=pb.AfterNoteUpdatesIn( nids=nids, generate_cards=generate_cards, mark_notes_modified=mark_notes_modified, ) ), release_gil=True, ) def add_note_tags(self, nids: List[int], tags: str) -> int: return self._run_command( pb.BackendInput(add_note_tags=pb.AddNoteTagsIn(nids=nids, tags=tags)) ).add_note_tags def update_note_tags( self, nids: List[int], tags: str, replacement: str, regex: bool ) -> int: return self._run_command( pb.BackendInput( update_note_tags=pb.UpdateNoteTagsIn( nids=nids, tags=tags, replacement=replacement, regex=regex ) ) ).update_note_tags def set_local_minutes_west(self, mins: int) -> None: self._run_command(pb.BackendInput(set_local_minutes_west=mins)) def translate_string_in( key: TR, **kwargs: Union[str, int, float] ) -> pb.TranslateStringIn: args = {} for (k, v) in kwargs.items(): if isinstance(v, str): args[k] = pb.TranslateArgValue(str=v) else: args[k] = pb.TranslateArgValue(number=v) return pb.TranslateStringIn(key=key, args=args) # temporarily force logging of media handling if "RUST_LOG" not in os.environ: os.environ["RUST_LOG"] = "warn,anki::media=debug,anki::dbcheck=debug"