clone db.py into dbproxy.py

This commit is contained in:
Damien Elmes 2020-03-02 14:13:57 +10:00
parent 0d43e9dca3
commit c1252d68f0
4 changed files with 127 additions and 15 deletions

View file

@ -23,7 +23,7 @@ import anki.template
from anki import hooks from anki import hooks
from anki.cards import Card from anki.cards import Card
from anki.consts import * from anki.consts import *
from anki.db import DB from anki.dbproxy import DBProxy
from anki.decks import DeckManager from anki.decks import DeckManager
from anki.errors import AnkiError from anki.errors import AnkiError
from anki.lang import _, ngettext from anki.lang import _, ngettext
@ -67,7 +67,7 @@ defaultConf = {
# this is initialized by storage.Collection # this is initialized by storage.Collection
class _Collection: class _Collection:
db: Optional[DB] db: Optional[DBProxy]
sched: Union[V1Scheduler, V2Scheduler] sched: Union[V1Scheduler, V2Scheduler]
crt: int crt: int
mod: int mod: int
@ -80,7 +80,7 @@ class _Collection:
def __init__( def __init__(
self, self,
db: DB, db: DBProxy,
backend: RustBackend, backend: RustBackend,
server: Optional["anki.storage.ServerData"] = None, server: Optional["anki.storage.ServerData"] = None,
log: bool = False, log: bool = False,
@ -267,7 +267,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
def reopen(self) -> None: def reopen(self) -> None:
"Reconnect to DB (after changing threads, etc)." "Reconnect to DB (after changing threads, etc)."
if not self.db: if not self.db:
self.db = DB(self.path) self.db = DBProxy(self.path)
self.media.connect() self.media.connect()
self._openLog() self._openLog()

112
pylib/anki/dbproxy.py Normal file
View file

@ -0,0 +1,112 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
import os
import time
from sqlite3 import Cursor
from sqlite3 import dbapi2 as sqlite
from typing import Any, List, Type
class DBProxy:
def __init__(self, path: str, timeout: int = 0) -> None:
self._db = sqlite.connect(path, timeout=timeout)
self._db.text_factory = self._textFactory
self._path = path
self.echo = os.environ.get("DBECHO")
self.mod = False
def execute(self, sql: str, *a, **ka) -> Cursor:
s = sql.strip().lower()
# mark modified?
for stmt in "insert", "update", "delete":
if s.startswith(stmt):
self.mod = True
t = time.time()
if ka:
# execute("...where id = :id", id=5)
res = self._db.execute(sql, ka)
else:
# execute("...where id = ?", 5)
res = self._db.execute(sql, a)
if self.echo:
# print a, ka
print(sql, "%0.3fms" % ((time.time() - t) * 1000))
if self.echo == "2":
print(a, ka)
return res
def executemany(self, sql: str, l: Any) -> None:
self.mod = True
t = time.time()
self._db.executemany(sql, l)
if self.echo:
print(sql, "%0.3fms" % ((time.time() - t) * 1000))
if self.echo == "2":
print(l)
def commit(self) -> None:
t = time.time()
self._db.commit()
if self.echo:
print("commit %0.3fms" % ((time.time() - t) * 1000))
def executescript(self, sql: str) -> None:
self.mod = True
if self.echo:
print(sql)
self._db.executescript(sql)
def rollback(self) -> None:
self._db.rollback()
def scalar(self, *a, **kw) -> Any:
res = self.execute(*a, **kw).fetchone()
if res:
return res[0]
return None
def all(self, *a, **kw) -> List:
return self.execute(*a, **kw).fetchall()
def first(self, *a, **kw) -> Any:
c = self.execute(*a, **kw)
res = c.fetchone()
c.close()
return res
def list(self, *a, **kw) -> List:
return [x[0] for x in self.execute(*a, **kw)]
def close(self) -> None:
self._db.text_factory = None
self._db.close()
def set_progress_handler(self, *args) -> None:
self._db.set_progress_handler(*args)
def __enter__(self) -> "DBProxy":
self._db.execute("begin")
return self
def __exit__(self, exc_type, *args) -> None:
self._db.close()
def totalChanges(self) -> Any:
return self._db.total_changes
def interrupt(self) -> None:
self._db.interrupt()
def setAutocommit(self, autocommit: bool) -> None:
if autocommit:
self._db.isolation_level = None
else:
self._db.isolation_level = ""
# strip out invalid utf-8 when reading from db
def _textFactory(self, data: bytes) -> str:
return str(data, errors="ignore")
def cursor(self, factory: Type[Cursor] = Cursor) -> Cursor:
return self._db.cursor(factory)

View file

@ -9,7 +9,7 @@ from typing import Any, Dict, Optional, Tuple
from anki.collection import _Collection from anki.collection import _Collection
from anki.consts import * from anki.consts import *
from anki.db import DB from anki.dbproxy import DBProxy
from anki.lang import _ from anki.lang import _
from anki.media import media_paths_from_col_path from anki.media import media_paths_from_col_path
from anki.rsbackend import RustBackend from anki.rsbackend import RustBackend
@ -44,7 +44,7 @@ def Collection(
for c in ("/", ":", "\\"): for c in ("/", ":", "\\"):
assert c not in base assert c not in base
# connect # connect
db = DB(path) db = DBProxy(path)
db.setAutocommit(True) db.setAutocommit(True)
if create: if create:
ver = _createDB(db) ver = _createDB(db)
@ -78,7 +78,7 @@ def Collection(
return col return col
def _upgradeSchema(db: DB) -> Any: def _upgradeSchema(db: DBProxy) -> Any:
ver = db.scalar("select ver from col") ver = db.scalar("select ver from col")
if ver == SCHEMA_VERSION: if ver == SCHEMA_VERSION:
return ver return ver
@ -238,7 +238,7 @@ def _upgradeClozeModel(col, m) -> None:
###################################################################### ######################################################################
def _createDB(db: DB) -> int: def _createDB(db: DBProxy) -> int:
db.execute("pragma page_size = 4096") db.execute("pragma page_size = 4096")
db.execute("pragma legacy_file_format = 0") db.execute("pragma legacy_file_format = 0")
db.execute("vacuum") db.execute("vacuum")
@ -248,7 +248,7 @@ def _createDB(db: DB) -> int:
return SCHEMA_VERSION return SCHEMA_VERSION
def _addSchema(db: DB, setColConf: bool = True) -> None: def _addSchema(db: DBProxy, setColConf: bool = True) -> None:
db.executescript( db.executescript(
""" """
create table if not exists col ( create table if not exists col (
@ -329,7 +329,7 @@ values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
_addColVars(db, *_getColVars(db)) _addColVars(db, *_getColVars(db))
def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]: def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]:
import anki.collection import anki.collection
import anki.decks import anki.decks
@ -344,7 +344,7 @@ def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
def _addColVars( def _addColVars(
db: DB, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any] db: DBProxy, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
) -> None: ) -> None:
db.execute( db.execute(
""" """
@ -355,7 +355,7 @@ update col set conf = ?, decks = ?, dconf = ?""",
) )
def _updateIndices(db: DB) -> None: def _updateIndices(db: DBProxy) -> None:
"Add indices to the DB." "Add indices to the DB."
db.executescript( db.executescript(
""" """

View file

@ -22,7 +22,7 @@ from hashlib import sha1
from html.entities import name2codepoint from html.entities import name2codepoint
from typing import Iterable, Iterator, List, Optional, Union from typing import Iterable, Iterator, List, Optional, Union
from anki.db import DB from anki.dbproxy import DBProxy
_tmpdir: Optional[str] _tmpdir: Optional[str]
@ -142,7 +142,7 @@ def ids2str(ids: Iterable[Union[int, str]]) -> str:
return "(%s)" % ",".join(str(i) for i in ids) return "(%s)" % ",".join(str(i) for i in ids)
def timestampID(db: DB, table: str) -> int: def timestampID(db: DBProxy, table: str) -> int:
"Return a non-conflicting timestamp for table." "Return a non-conflicting timestamp for table."
# be careful not to create multiple objects without flushing them, or they # be careful not to create multiple objects without flushing them, or they
# may share an ID. # may share an ID.
@ -152,7 +152,7 @@ def timestampID(db: DB, table: str) -> int:
return t return t
def maxID(db: DB) -> int: def maxID(db: DBProxy) -> int:
"Return the first safe ID to use." "Return the first safe ID to use."
now = intTime(1000) now = intTime(1000)
for tbl in "cards", "notes": for tbl in "cards", "notes":