diff --git a/pylib/anki/dbproxy.py b/pylib/anki/dbproxy.py index de91bff60..d93c0ee50 100644 --- a/pylib/anki/dbproxy.py +++ b/pylib/anki/dbproxy.py @@ -4,48 +4,40 @@ # fixme: lossy utf8 handling # fixme: progress -import time from sqlite3 import Cursor from sqlite3 import dbapi2 as sqlite -from typing import Any, List +from typing import Any, Iterable, List class DBProxy: + # Lifecycle + ############### + def __init__(self, path: str) -> None: self._db = sqlite.connect(path, timeout=0) self._path = path self.mod = False - def execute(self, sql: str, *args) -> Cursor: - s = sql.strip().lower() - # mark modified? - for stmt in "insert", "update", "delete": - if s.startswith(stmt): - self.mod = True - res = self._db.execute(sql, args) - return res + def close(self) -> None: + self._db.close() - def executemany(self, sql: str, l: Any) -> None: - self.mod = True - t = time.time() - self._db.executemany(sql, l) + # Transactions + ############### def commit(self) -> None: - t = time.time() self._db.commit() - def executescript(self, sql: str) -> None: - self.mod = True - self._db.executescript(sql) - def rollback(self) -> None: self._db.rollback() - def scalar(self, sql: str, *args) -> Any: - res = self.execute(sql, *args).fetchone() - if res: - return res[0] - return None + def setAutocommit(self, autocommit: bool) -> None: + if autocommit: + self._db.isolation_level = None + else: + self._db.isolation_level = "" + + # Querying + ################ def all(self, sql: str, *args) -> List: return self.execute(sql, *args).fetchall() @@ -59,11 +51,31 @@ class DBProxy: def list(self, sql: str, *args) -> List: return [x[0] for x in self.execute(sql, *args)] - def close(self) -> None: - self._db.close() + def scalar(self, sql: str, *args) -> Any: + res = self.execute(sql, *args).fetchone() + if res: + return res[0] + return None - def setAutocommit(self, autocommit: bool) -> None: - if autocommit: - self._db.isolation_level = None - else: - self._db.isolation_level = "" + # Updates + ################ + + def executemany(self, sql: str, args: Iterable) -> None: + self.mod = True + self._db.executemany(sql, args) + + def executescript(self, sql: str) -> None: + self.mod = True + self._db.executescript(sql) + + # Cursor API + ############### + + def execute(self, sql: str, *args) -> Cursor: + s = sql.strip().lower() + # mark modified? + for stmt in "insert", "update", "delete": + if s.startswith(stmt): + self.mod = True + res = self._db.execute(sql, args) + return res