db/hooks/utils

This commit is contained in:
Damien Elmes 2019-12-19 14:20:57 +10:00
parent f69ef52845
commit 6ffe82ac54
3 changed files with 34 additions and 31 deletions

View file

@ -11,14 +11,14 @@ from typing import Any, List
DBError = sqlite.Error DBError = sqlite.Error
class DB: class DB:
def __init__(self, path, timeout=0) -> None: def __init__(self, path: str, timeout: int = 0) -> None:
self._db = sqlite.connect(path, timeout=timeout) self._db = sqlite.connect(path, timeout=timeout)
self._db.text_factory = self._textFactory self._db.text_factory = self._textFactory
self._path = path self._path = path
self.echo = os.environ.get("DBECHO") self.echo = os.environ.get("DBECHO")
self.mod = False self.mod = False
def execute(self, sql, *a, **ka) -> Cursor: def execute(self, sql: str, *a, **ka) -> Cursor:
s = sql.strip().lower() s = sql.strip().lower()
# mark modified? # mark modified?
for stmt in "insert", "update", "delete": for stmt in "insert", "update", "delete":
@ -38,7 +38,7 @@ class DB:
print(a, ka) print(a, ka)
return res return res
def executemany(self, sql, l) -> None: def executemany(self, sql: str, l: Any) -> None:
self.mod = True self.mod = True
t = time.time() t = time.time()
self._db.executemany(sql, l) self._db.executemany(sql, l)
@ -53,7 +53,7 @@ class DB:
if self.echo: if self.echo:
print("commit %0.3fms" % ((time.time() - t)*1000)) print("commit %0.3fms" % ((time.time() - t)*1000))
def executescript(self, sql) -> None: def executescript(self, sql: str) -> None:
self.mod = True self.mod = True
if self.echo: if self.echo:
print(sql) print(sql)
@ -100,14 +100,14 @@ class DB:
def interrupt(self) -> None: def interrupt(self) -> None:
self._db.interrupt() self._db.interrupt()
def setAutocommit(self, autocommit) -> None: def setAutocommit(self, autocommit: bool) -> None:
if autocommit: if autocommit:
self._db.isolation_level = None self._db.isolation_level = None
else: else:
self._db.isolation_level = '' self._db.isolation_level = ''
# strip out invalid utf-8 when reading from db # strip out invalid utf-8 when reading from db
def _textFactory(self, data) -> str: def _textFactory(self, data: bytes) -> str:
return str(data, errors="ignore") return str(data, errors="ignore")
def cursor(self, factory=Cursor) -> Cursor: def cursor(self, factory=Cursor) -> Cursor:

View file

@ -19,9 +19,10 @@ from typing import Dict, List, Callable, Any
# Hooks # Hooks
############################################################################## ##############################################################################
from typing import Callable, Dict, Union
_hooks: Dict[str, List[Callable[..., Any]]] = {} _hooks: Dict[str, List[Callable[..., Any]]] = {}
def runHook(hook, *args) -> None: def runHook(hook: str, *args) -> None:
"Run all functions on hook." "Run all functions on hook."
hook = _hooks.get(hook, None) hook = _hooks.get(hook, None)
if hook: if hook:
@ -32,7 +33,7 @@ def runHook(hook, *args) -> None:
hook.remove(func) hook.remove(func)
raise raise
def runFilter(hook, arg, *args) -> Any: def runFilter(hook: str, arg: Any, *args) -> Any:
hook = _hooks.get(hook, None) hook = _hooks.get(hook, None)
if hook: if hook:
for func in hook: for func in hook:
@ -43,7 +44,7 @@ def runFilter(hook, arg, *args) -> Any:
raise raise
return arg return arg
def addHook(hook, func) -> None: def addHook(hook: str, func: Callable) -> None:
"Add a function to hook. Ignore if already on hook." "Add a function to hook. Ignore if already on hook."
if not _hooks.get(hook, None): if not _hooks.get(hook, None):
_hooks[hook] = [] _hooks[hook] = []

View file

@ -24,12 +24,14 @@ from anki.lang import _, ngettext
import json # pylint: disable=unused-import import json # pylint: disable=unused-import
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from anki.db import DB
from typing import Any, Iterator, List, Union
_tmpdir: Optional[str] _tmpdir: Optional[str]
# Time handling # Time handling
############################################################################## ##############################################################################
def intTime(scale=1) -> int: def intTime(scale: int = 1) -> int:
"The time in integer seconds. Pass scale=1000 to get milliseconds." "The time in integer seconds. Pass scale=1000 to get milliseconds."
return int(time.time()*scale) return int(time.time()*scale)
@ -51,7 +53,7 @@ inTimeTable = {
"seconds": lambda n: ngettext("in %s second", "in %s seconds", n), "seconds": lambda n: ngettext("in %s second", "in %s seconds", n),
} }
def shortTimeFmt(type) -> Any: def shortTimeFmt(type: str) -> Any:
return { return {
#T: year is an abbreviation for year. %s is a number of years #T: year is an abbreviation for year. %s is a number of years
"years": _("%sy"), "years": _("%sy"),
@ -67,7 +69,7 @@ def shortTimeFmt(type) -> Any:
"seconds": _("%ss"), "seconds": _("%ss"),
}[type] }[type]
def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str: def fmtTimeSpan(time: Union[int, float], pad: int = 0, point: int = 0, short: bool = False, inTime: bool = False, unit: int = 99) -> str:
"Return a string representing a time span (eg '2 days')." "Return a string representing a time span (eg '2 days')."
(type, point) = optimalPeriod(time, point, unit) (type, point) = optimalPeriod(time, point, unit)
time = convertSecondsTo(time, type) time = convertSecondsTo(time, type)
@ -83,7 +85,7 @@ def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str
timestr = "%%%(a)d.%(b)df" % {'a': pad, 'b': point} timestr = "%%%(a)d.%(b)df" % {'a': pad, 'b': point}
return locale.format_string(fmt % timestr, time) return locale.format_string(fmt % timestr, time)
def optimalPeriod(time, point, unit) -> Tuple[str, Any]: def optimalPeriod(time: Union[int, float], point: int, unit: int) -> Tuple[str, Any]:
if abs(time) < 60 or unit < 1: if abs(time) < 60 or unit < 1:
type = "seconds" type = "seconds"
point -= 1 point -= 1
@ -101,7 +103,7 @@ def optimalPeriod(time, point, unit) -> Tuple[str, Any]:
point += 1 point += 1
return (type, max(point, 0)) return (type, max(point, 0))
def convertSecondsTo(seconds, type) -> Any: def convertSecondsTo(seconds: Union[int, float], type: str) -> Any:
if type == "seconds": if type == "seconds":
return seconds return seconds
elif type == "minutes": elif type == "minutes":
@ -116,7 +118,7 @@ def convertSecondsTo(seconds, type) -> Any:
return seconds / 31536000 return seconds / 31536000
assert False assert False
def _pluralCount(time, point) -> int: def _pluralCount(time: Union[int, float], point: int) -> int:
if point: if point:
return 2 return 2
return math.floor(time) return math.floor(time)
@ -143,7 +145,7 @@ reTag = re.compile("(?s)<.*?>")
reEnts = re.compile(r"&#?\w+;") reEnts = re.compile(r"&#?\w+;")
reMedia = re.compile("(?i)<img[^>]+src=[\"']?([^\"'>]+)[\"']?[^>]*>") reMedia = re.compile("(?i)<img[^>]+src=[\"']?([^\"'>]+)[\"']?[^>]*>")
def stripHTML(s) -> str: def stripHTML(s: str) -> str:
s = reComment.sub("", s) s = reComment.sub("", s)
s = reStyle.sub("", s) s = reStyle.sub("", s)
s = reScript.sub("", s) s = reScript.sub("", s)
@ -151,7 +153,7 @@ def stripHTML(s) -> str:
s = entsToTxt(s) s = entsToTxt(s)
return s return s
def stripHTMLMedia(s) -> Any: def stripHTMLMedia(s: str) -> Any:
"Strip HTML but keep media filenames" "Strip HTML but keep media filenames"
s = reMedia.sub(" \\1 ", s) s = reMedia.sub(" \\1 ", s)
return stripHTML(s) return stripHTML(s)
@ -177,7 +179,7 @@ def htmlToTextLine(s) -> Any:
s = s.strip() s = s.strip()
return s return s
def entsToTxt(html) -> str: def entsToTxt(html: str) -> str:
# entitydefs defines nbsp as \xa0 instead of a standard space, so we # entitydefs defines nbsp as \xa0 instead of a standard space, so we
# replace it first # replace it first
html = html.replace("&nbsp;", " ") html = html.replace("&nbsp;", " ")
@ -216,11 +218,11 @@ def hexifyID(id) -> str:
def dehexifyID(id) -> int: def dehexifyID(id) -> int:
return int(id, 16) return int(id, 16)
def ids2str(ids) -> str: def ids2str(ids: Any) -> str:
"""Given a list of integers, return a string '(int1,int2,...)'.""" """Given a list of integers, return a string '(int1,int2,...)'."""
return "(%s)" % ",".join(str(i) for i in ids) return "(%s)" % ",".join(str(i) for i in ids)
def timestampID(db, table) -> int: def timestampID(db: DB, table: str) -> int:
"Return a non-conflicting timestamp for table." "Return a non-conflicting timestamp for table."
# be careful not to create multiple objects without flushing them, or they # be careful not to create multiple objects without flushing them, or they
# may share an ID. # may share an ID.
@ -229,7 +231,7 @@ def timestampID(db, table) -> int:
t += 1 t += 1
return t return t
def maxID(db) -> Any: def maxID(db: DB) -> Any:
"Return the first safe ID to use." "Return the first safe ID to use."
now = intTime(1000) now = intTime(1000)
for tbl in "cards", "notes": for tbl in "cards", "notes":
@ -237,7 +239,7 @@ def maxID(db) -> Any:
return now + 1 return now + 1
# used in ankiweb # used in ankiweb
def base62(num, extra="") -> str: def base62(num: int, extra: str = "") -> str:
s = string; table = s.ascii_letters + s.digits + extra s = string; table = s.ascii_letters + s.digits + extra
buf = "" buf = ""
while num: while num:
@ -246,7 +248,7 @@ def base62(num, extra="") -> str:
return buf return buf
_base91_extra_chars = "!#$%&()*+,-./:;<=>?@[]^_`{|}~" _base91_extra_chars = "!#$%&()*+,-./:;<=>?@[]^_`{|}~"
def base91(num) -> str: def base91(num: int) -> str:
# all printable characters minus quotes, backslash and separators # all printable characters minus quotes, backslash and separators
return base62(num, _base91_extra_chars) return base62(num, _base91_extra_chars)
@ -271,21 +273,21 @@ def _incGuid(guid) -> str:
# Fields # Fields
############################################################################## ##############################################################################
def joinFields(list) -> str: def joinFields(list: List[str]) -> str:
return "\x1f".join(list) return "\x1f".join(list)
def splitFields(string) -> Any: def splitFields(string: str) -> Any:
return string.split("\x1f") return string.split("\x1f")
# Checksums # Checksums
############################################################################## ##############################################################################
def checksum(data) -> str: def checksum(data: Union[bytes, str]) -> str:
if isinstance(data, str): if isinstance(data, str):
data = data.encode("utf-8") data = data.encode("utf-8")
return sha1(data).hexdigest() return sha1(data).hexdigest()
def fieldChecksum(data) -> int: def fieldChecksum(data: str) -> int:
# 32 bit unsigned number from first 8 digits of sha1 hash # 32 bit unsigned number from first 8 digits of sha1 hash
return int(checksum(stripHTMLMedia(data).encode("utf-8"))[:8], 16) return int(checksum(stripHTMLMedia(data).encode("utf-8"))[:8], 16)
@ -308,12 +310,12 @@ def tmpdir() -> Any:
os.mkdir(_tmpdir) os.mkdir(_tmpdir)
return _tmpdir return _tmpdir
def tmpfile(prefix="", suffix="") -> Any: def tmpfile(prefix: str = "", suffix: str = "") -> Any:
(fd, name) = tempfile.mkstemp(dir=tmpdir(), prefix=prefix, suffix=suffix) (fd, name) = tempfile.mkstemp(dir=tmpdir(), prefix=prefix, suffix=suffix)
os.close(fd) os.close(fd)
return name return name
def namedtmp(name, rm=True) -> Any: def namedtmp(name: str, rm: bool = True) -> Any:
"Return tmpdir+name. Deletes any existing file." "Return tmpdir+name. Deletes any existing file."
path = os.path.join(tmpdir(), name) path = os.path.join(tmpdir(), name)
if rm: if rm:
@ -327,13 +329,13 @@ def namedtmp(name, rm=True) -> Any:
############################################################################## ##############################################################################
@contextmanager @contextmanager
def noBundledLibs(): def noBundledLibs() -> Iterator[None]:
oldlpath = os.environ.pop("LD_LIBRARY_PATH", None) oldlpath = os.environ.pop("LD_LIBRARY_PATH", None)
yield yield
if oldlpath is not None: if oldlpath is not None:
os.environ["LD_LIBRARY_PATH"] = oldlpath os.environ["LD_LIBRARY_PATH"] = oldlpath
def call(argv, wait=True, **kwargs) -> int: def call(argv: List[str], wait: bool = True, **kwargs) -> int:
"Execute a command. If WAIT, return exit code." "Execute a command. If WAIT, return exit code."
# ensure we don't open a separate window for forking process on windows # ensure we don't open a separate window for forking process on windows
if isWin: if isWin: