diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 01f96fa8c..b21a78e26 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -23,7 +23,7 @@ import anki.template from anki import hooks from anki.cards import Card from anki.consts import * -from anki.db import DB +from anki.dbproxy import DBProxy from anki.decks import DeckManager from anki.errors import AnkiError from anki.lang import _, ngettext @@ -67,7 +67,7 @@ defaultConf = { # this is initialized by storage.Collection class _Collection: - db: Optional[DB] + db: Optional[DBProxy] sched: Union[V1Scheduler, V2Scheduler] crt: int mod: int @@ -80,7 +80,7 @@ class _Collection: def __init__( self, - db: DB, + db: DBProxy, backend: RustBackend, server: Optional["anki.storage.ServerData"] = None, log: bool = False, @@ -267,7 +267,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", def reopen(self) -> None: "Reconnect to DB (after changing threads, etc)." if not self.db: - self.db = DB(self.path) + self.db = DBProxy(self.path) self.media.connect() self._openLog() diff --git a/pylib/anki/dbproxy.py b/pylib/anki/dbproxy.py new file mode 100644 index 000000000..1d11e2a03 --- /dev/null +++ b/pylib/anki/dbproxy.py @@ -0,0 +1,112 @@ +# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +import os +import time +from sqlite3 import Cursor +from sqlite3 import dbapi2 as sqlite +from typing import Any, List, Type + + +class DBProxy: + def __init__(self, path: str, timeout: int = 0) -> None: + self._db = sqlite.connect(path, timeout=timeout) + self._db.text_factory = self._textFactory + self._path = path + self.echo = os.environ.get("DBECHO") + self.mod = False + + def execute(self, sql: str, *a, **ka) -> Cursor: + s = sql.strip().lower() + # mark modified? + for stmt in "insert", "update", "delete": + if s.startswith(stmt): + self.mod = True + t = time.time() + if ka: + # execute("...where id = :id", id=5) + res = self._db.execute(sql, ka) + else: + # execute("...where id = ?", 5) + res = self._db.execute(sql, a) + if self.echo: + # print a, ka + print(sql, "%0.3fms" % ((time.time() - t) * 1000)) + if self.echo == "2": + print(a, ka) + return res + + def executemany(self, sql: str, l: Any) -> None: + self.mod = True + t = time.time() + self._db.executemany(sql, l) + if self.echo: + print(sql, "%0.3fms" % ((time.time() - t) * 1000)) + if self.echo == "2": + print(l) + + def commit(self) -> None: + t = time.time() + self._db.commit() + if self.echo: + print("commit %0.3fms" % ((time.time() - t) * 1000)) + + def executescript(self, sql: str) -> None: + self.mod = True + if self.echo: + print(sql) + self._db.executescript(sql) + + def rollback(self) -> None: + self._db.rollback() + + def scalar(self, *a, **kw) -> Any: + res = self.execute(*a, **kw).fetchone() + if res: + return res[0] + return None + + def all(self, *a, **kw) -> List: + return self.execute(*a, **kw).fetchall() + + def first(self, *a, **kw) -> Any: + c = self.execute(*a, **kw) + res = c.fetchone() + c.close() + return res + + def list(self, *a, **kw) -> List: + return [x[0] for x in self.execute(*a, **kw)] + + def close(self) -> None: + self._db.text_factory = None + self._db.close() + + def set_progress_handler(self, *args) -> None: + self._db.set_progress_handler(*args) + + def __enter__(self) -> "DBProxy": + self._db.execute("begin") + return self + + def __exit__(self, exc_type, *args) -> None: + self._db.close() + + def totalChanges(self) -> Any: + return self._db.total_changes + + def interrupt(self) -> None: + self._db.interrupt() + + def setAutocommit(self, autocommit: bool) -> None: + if autocommit: + self._db.isolation_level = None + else: + self._db.isolation_level = "" + + # strip out invalid utf-8 when reading from db + def _textFactory(self, data: bytes) -> str: + return str(data, errors="ignore") + + def cursor(self, factory: Type[Cursor] = Cursor) -> Cursor: + return self._db.cursor(factory) diff --git a/pylib/anki/storage.py b/pylib/anki/storage.py index 665291cfd..0a1001466 100644 --- a/pylib/anki/storage.py +++ b/pylib/anki/storage.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Optional, Tuple from anki.collection import _Collection from anki.consts import * -from anki.db import DB +from anki.dbproxy import DBProxy from anki.lang import _ from anki.media import media_paths_from_col_path from anki.rsbackend import RustBackend @@ -44,7 +44,7 @@ def Collection( for c in ("/", ":", "\\"): assert c not in base # connect - db = DB(path) + db = DBProxy(path) db.setAutocommit(True) if create: ver = _createDB(db) @@ -78,7 +78,7 @@ def Collection( return col -def _upgradeSchema(db: DB) -> Any: +def _upgradeSchema(db: DBProxy) -> Any: ver = db.scalar("select ver from col") if ver == SCHEMA_VERSION: return ver @@ -238,7 +238,7 @@ def _upgradeClozeModel(col, m) -> None: ###################################################################### -def _createDB(db: DB) -> int: +def _createDB(db: DBProxy) -> int: db.execute("pragma page_size = 4096") db.execute("pragma legacy_file_format = 0") db.execute("vacuum") @@ -248,7 +248,7 @@ def _createDB(db: DB) -> int: return SCHEMA_VERSION -def _addSchema(db: DB, setColConf: bool = True) -> None: +def _addSchema(db: DBProxy, setColConf: bool = True) -> None: db.executescript( """ create table if not exists col ( @@ -329,7 +329,7 @@ values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}'); _addColVars(db, *_getColVars(db)) -def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]: +def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]: import anki.collection import anki.decks @@ -344,7 +344,7 @@ def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]: def _addColVars( - db: DB, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any] + db: DBProxy, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any] ) -> None: db.execute( """ @@ -355,7 +355,7 @@ update col set conf = ?, decks = ?, dconf = ?""", ) -def _updateIndices(db: DB) -> None: +def _updateIndices(db: DBProxy) -> None: "Add indices to the DB." db.executescript( """ diff --git a/pylib/anki/utils.py b/pylib/anki/utils.py index d99437e69..33c4ebe08 100644 --- a/pylib/anki/utils.py +++ b/pylib/anki/utils.py @@ -22,7 +22,7 @@ from hashlib import sha1 from html.entities import name2codepoint from typing import Iterable, Iterator, List, Optional, Union -from anki.db import DB +from anki.dbproxy import DBProxy _tmpdir: Optional[str] @@ -142,7 +142,7 @@ def ids2str(ids: Iterable[Union[int, str]]) -> str: return "(%s)" % ",".join(str(i) for i in ids) -def timestampID(db: DB, table: str) -> int: +def timestampID(db: DBProxy, table: str) -> int: "Return a non-conflicting timestamp for table." # be careful not to create multiple objects without flushing them, or they # may share an ID. @@ -152,7 +152,7 @@ def timestampID(db: DB, table: str) -> int: return t -def maxID(db: DB) -> int: +def maxID(db: DBProxy) -> int: "Return the first safe ID to use." now = intTime(1000) for tbl in "cards", "notes":