mirror of
https://github.com/ankitects/anki.git
synced 2025-09-18 14:02:21 -04:00
288 lines
8.2 KiB
Python
288 lines
8.2 KiB
Python
# Copyright: Ankitects Pty Ltd and contributors
|
|
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
|
# pylint: skip-file
|
|
|
|
"""
|
|
Python bindings for Anki's Rust libraries.
|
|
|
|
Please do not access methods on the backend directly - they may be changed
|
|
or removed at any time. Instead, please use the methods on the collection
|
|
instead. Eg, don't use col.backend.all_deck_config(), instead use
|
|
col.decks.all_config()
|
|
|
|
If you need to access a backend method that is not currently accessible
|
|
via the collection, please send through a pull request that adds a method.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
|
|
import anki._rsbridge
|
|
import anki.backend_pb2 as pb
|
|
import anki.buildinfo
|
|
from anki import hooks
|
|
from anki.dbproxy import Row as DBRow
|
|
from anki.dbproxy import ValueForDB
|
|
from anki.fluent_pb2 import FluentString as TR
|
|
from anki.types import assert_impossible_literal
|
|
|
|
try:
|
|
from anki.rsbackend_gen import RustBackendGenerated
|
|
except ImportError:
|
|
# will fail during initial setup
|
|
class RustBackendGenerated: # type: ignore
|
|
pass
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from anki.fluent_pb2 import FluentStringValue as TRValue
|
|
|
|
FormatTimeSpanContextValue = pb.FormatTimespanIn.ContextValue
|
|
|
|
assert anki._rsbridge.buildhash() == anki.buildinfo.buildhash
|
|
|
|
SchedTimingToday = pb.SchedTimingTodayOut
|
|
BuiltinSortKind = pb.BuiltinSearchOrder.BuiltinSortKind
|
|
BackendCard = pb.Card
|
|
BackendNote = pb.Note
|
|
TagUsnTuple = pb.TagUsnTuple
|
|
NoteType = pb.NoteType
|
|
DeckTreeNode = pb.DeckTreeNode
|
|
StockNoteType = pb.StockNoteType
|
|
SyncAuth = pb.SyncAuth
|
|
SyncOutput = pb.SyncCollectionOut
|
|
SyncStatus = pb.SyncStatusOut
|
|
CountsForDeckToday = pb.CountsForDeckTodayOut
|
|
|
|
try:
|
|
import orjson
|
|
|
|
to_json_bytes = orjson.dumps
|
|
from_json_bytes = orjson.loads
|
|
except:
|
|
# add compat layer for 32 bit builds that can't use orjson
|
|
to_json_bytes = lambda obj: json.dumps(obj).encode("utf8") # type: ignore
|
|
from_json_bytes = json.loads
|
|
|
|
|
|
class Interrupted(Exception):
|
|
pass
|
|
|
|
|
|
class StringError(Exception):
|
|
def __str__(self) -> str:
|
|
return self.args[0] # pylint: disable=unsubscriptable-object
|
|
|
|
|
|
NetworkErrorKind = pb.NetworkError.NetworkErrorKind
|
|
SyncErrorKind = pb.SyncError.SyncErrorKind
|
|
|
|
|
|
class NetworkError(StringError):
|
|
def kind(self) -> pb.NetworkError.NetworkErrorKindValue:
|
|
return self.args[1]
|
|
|
|
|
|
class SyncError(StringError):
|
|
def kind(self) -> pb.SyncError.SyncErrorKindValue:
|
|
return self.args[1]
|
|
|
|
|
|
class IOError(StringError):
|
|
pass
|
|
|
|
|
|
class DBError(StringError):
|
|
pass
|
|
|
|
|
|
class TemplateError(StringError):
|
|
pass
|
|
|
|
|
|
class NotFoundError(Exception):
|
|
pass
|
|
|
|
|
|
class ExistsError(Exception):
|
|
pass
|
|
|
|
|
|
class DeckIsFilteredError(Exception):
|
|
pass
|
|
|
|
|
|
class InvalidInput(StringError):
|
|
pass
|
|
|
|
|
|
def proto_exception_to_native(err: pb.BackendError) -> Exception:
|
|
val = err.WhichOneof("value")
|
|
if val == "interrupted":
|
|
return Interrupted()
|
|
elif val == "network_error":
|
|
return NetworkError(err.localized, err.network_error.kind)
|
|
elif val == "sync_error":
|
|
return SyncError(err.localized, err.sync_error.kind)
|
|
elif val == "io_error":
|
|
return IOError(err.localized)
|
|
elif val == "db_error":
|
|
return DBError(err.localized)
|
|
elif val == "template_parse":
|
|
return TemplateError(err.localized)
|
|
elif val == "invalid_input":
|
|
return InvalidInput(err.localized)
|
|
elif val == "json_error":
|
|
return StringError(err.localized)
|
|
elif val == "not_found_error":
|
|
return NotFoundError()
|
|
elif val == "exists":
|
|
return ExistsError()
|
|
elif val == "deck_is_filtered":
|
|
return DeckIsFilteredError()
|
|
elif val == "proto_error":
|
|
return StringError(err.localized)
|
|
else:
|
|
print("unhandled error type:", val)
|
|
return StringError(err.localized)
|
|
|
|
|
|
MediaSyncProgress = pb.MediaSyncProgress
|
|
FullSyncProgress = pb.FullSyncProgress
|
|
NormalSyncProgress = pb.NormalSyncProgress
|
|
DatabaseCheckProgress = pb.DatabaseCheckProgress
|
|
|
|
FormatTimeSpanContext = pb.FormatTimespanIn.Context
|
|
|
|
|
|
class ProgressKind(enum.Enum):
|
|
NoProgress = 0
|
|
MediaSync = 1
|
|
MediaCheck = 2
|
|
FullSync = 3
|
|
NormalSync = 4
|
|
DatabaseCheck = 5
|
|
|
|
|
|
@dataclass
|
|
class Progress:
|
|
kind: ProgressKind
|
|
val: Union[
|
|
MediaSyncProgress,
|
|
pb.FullSyncProgress,
|
|
NormalSyncProgress,
|
|
DatabaseCheckProgress,
|
|
str,
|
|
]
|
|
|
|
@staticmethod
|
|
def from_proto(proto: pb.Progress) -> Progress:
|
|
kind = proto.WhichOneof("value")
|
|
if kind == "media_sync":
|
|
return Progress(kind=ProgressKind.MediaSync, val=proto.media_sync)
|
|
elif kind == "media_check":
|
|
return Progress(kind=ProgressKind.MediaCheck, val=proto.media_check)
|
|
elif kind == "full_sync":
|
|
return Progress(kind=ProgressKind.FullSync, val=proto.full_sync)
|
|
elif kind == "normal_sync":
|
|
return Progress(kind=ProgressKind.NormalSync, val=proto.normal_sync)
|
|
elif kind == "database_check":
|
|
return Progress(kind=ProgressKind.DatabaseCheck, val=proto.database_check)
|
|
else:
|
|
return Progress(kind=ProgressKind.NoProgress, val="")
|
|
|
|
|
|
class RustBackend(RustBackendGenerated):
|
|
def __init__(
|
|
self,
|
|
ftl_folder: Optional[str] = None,
|
|
langs: Optional[List[str]] = None,
|
|
server: bool = False,
|
|
) -> None:
|
|
# pick up global defaults if not provided
|
|
if ftl_folder is None:
|
|
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent")
|
|
if langs is None:
|
|
langs = [anki.lang.currentLang]
|
|
|
|
init_msg = pb.BackendInit(
|
|
locale_folder_path=ftl_folder,
|
|
preferred_langs=langs,
|
|
server=server,
|
|
)
|
|
self._backend = anki._rsbridge.open_backend(init_msg.SerializeToString())
|
|
|
|
def db_query(
|
|
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool
|
|
) -> List[DBRow]:
|
|
return self._db_command(
|
|
dict(kind="query", sql=sql, args=args, first_row_only=first_row_only)
|
|
)
|
|
|
|
def db_execute_many(self, sql: str, args: List[List[ValueForDB]]) -> List[DBRow]:
|
|
return self._db_command(dict(kind="executemany", sql=sql, args=args))
|
|
|
|
def db_begin(self) -> None:
|
|
return self._db_command(dict(kind="begin"))
|
|
|
|
def db_commit(self) -> None:
|
|
return self._db_command(dict(kind="commit"))
|
|
|
|
def db_rollback(self) -> None:
|
|
return self._db_command(dict(kind="rollback"))
|
|
|
|
def _db_command(self, input: Dict[str, Any]) -> Any:
|
|
try:
|
|
return from_json_bytes(self._backend.db_command(to_json_bytes(input)))
|
|
except Exception as e:
|
|
err_bytes = bytes(e.args[0])
|
|
err = pb.BackendError()
|
|
err.ParseFromString(err_bytes)
|
|
raise proto_exception_to_native(err)
|
|
|
|
def translate(self, key: TRValue, **kwargs: Union[str, int, float]) -> str:
|
|
return self.translate_string(translate_string_in(key, **kwargs))
|
|
|
|
def format_time_span(
|
|
self,
|
|
seconds: float,
|
|
context: FormatTimeSpanContextValue = FormatTimeSpanContext.INTERVALS,
|
|
) -> str:
|
|
print(
|
|
"please use col.format_timespan() instead of col.backend.format_time_span()"
|
|
)
|
|
return self.format_timespan(seconds=seconds, context=context)
|
|
|
|
def _run_command(self, method: int, input: Any) -> bytes:
|
|
input_bytes = input.SerializeToString()
|
|
try:
|
|
return self._backend.command(method, input_bytes)
|
|
except Exception as e:
|
|
err_bytes = bytes(e.args[0])
|
|
err = pb.BackendError()
|
|
err.ParseFromString(err_bytes)
|
|
raise proto_exception_to_native(err)
|
|
|
|
|
|
def translate_string_in(
|
|
key: TRValue, **kwargs: Union[str, int, float]
|
|
) -> pb.TranslateStringIn:
|
|
args = {}
|
|
for (k, v) in kwargs.items():
|
|
if isinstance(v, str):
|
|
args[k] = pb.TranslateArgValue(str=v)
|
|
else:
|
|
args[k] = pb.TranslateArgValue(number=v)
|
|
return pb.TranslateStringIn(key=key, args=args)
|
|
|
|
|
|
# temporarily force logging of media handling
|
|
if "RUST_LOG" not in os.environ:
|
|
os.environ[
|
|
"RUST_LOG"
|
|
] = "warn,anki::media=debug,anki::sync=debug,anki::dbcheck=debug"
|