From 6ffe82ac54341a0e284989f94c01dcca99127c9a Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Thu, 19 Dec 2019 14:20:57 +1000 Subject: [PATCH] db/hooks/utils --- anki/db.py | 12 ++++++------ anki/hooks.py | 7 ++++--- anki/utils.py | 46 ++++++++++++++++++++++++---------------------- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/anki/db.py b/anki/db.py index a1c009b57..1fc3ef2be 100644 --- a/anki/db.py +++ b/anki/db.py @@ -11,14 +11,14 @@ from typing import Any, List DBError = sqlite.Error class DB: - def __init__(self, path, timeout=0) -> None: + def __init__(self, path: str, timeout: int = 0) -> None: self._db = sqlite.connect(path, timeout=timeout) self._db.text_factory = self._textFactory self._path = path self.echo = os.environ.get("DBECHO") self.mod = False - def execute(self, sql, *a, **ka) -> Cursor: + def execute(self, sql: str, *a, **ka) -> Cursor: s = sql.strip().lower() # mark modified? for stmt in "insert", "update", "delete": @@ -38,7 +38,7 @@ class DB: print(a, ka) return res - def executemany(self, sql, l) -> None: + def executemany(self, sql: str, l: Any) -> None: self.mod = True t = time.time() self._db.executemany(sql, l) @@ -53,7 +53,7 @@ class DB: if self.echo: print("commit %0.3fms" % ((time.time() - t)*1000)) - def executescript(self, sql) -> None: + def executescript(self, sql: str) -> None: self.mod = True if self.echo: print(sql) @@ -100,14 +100,14 @@ class DB: def interrupt(self) -> None: self._db.interrupt() - def setAutocommit(self, autocommit) -> None: + def setAutocommit(self, autocommit: bool) -> None: if autocommit: self._db.isolation_level = None else: self._db.isolation_level = '' # strip out invalid utf-8 when reading from db - def _textFactory(self, data) -> str: + def _textFactory(self, data: bytes) -> str: return str(data, errors="ignore") def cursor(self, factory=Cursor) -> Cursor: diff --git a/anki/hooks.py b/anki/hooks.py index 98c811a51..e8d283b21 100644 --- a/anki/hooks.py +++ b/anki/hooks.py @@ -19,9 +19,10 @@ from typing import Dict, List, Callable, Any # Hooks ############################################################################## +from typing import Callable, Dict, Union _hooks: Dict[str, List[Callable[..., Any]]] = {} -def runHook(hook, *args) -> None: +def runHook(hook: str, *args) -> None: "Run all functions on hook." hook = _hooks.get(hook, None) if hook: @@ -32,7 +33,7 @@ def runHook(hook, *args) -> None: hook.remove(func) raise -def runFilter(hook, arg, *args) -> Any: +def runFilter(hook: str, arg: Any, *args) -> Any: hook = _hooks.get(hook, None) if hook: for func in hook: @@ -43,7 +44,7 @@ def runFilter(hook, arg, *args) -> Any: raise return arg -def addHook(hook, func) -> None: +def addHook(hook: str, func: Callable) -> None: "Add a function to hook. Ignore if already on hook." if not _hooks.get(hook, None): _hooks[hook] = [] diff --git a/anki/utils.py b/anki/utils.py index cec62d6e4..61b1e1540 100644 --- a/anki/utils.py +++ b/anki/utils.py @@ -24,12 +24,14 @@ from anki.lang import _, ngettext import json # pylint: disable=unused-import from typing import Any, Optional, Tuple +from anki.db import DB +from typing import Any, Iterator, List, Union _tmpdir: Optional[str] # Time handling ############################################################################## -def intTime(scale=1) -> int: +def intTime(scale: int = 1) -> int: "The time in integer seconds. Pass scale=1000 to get milliseconds." return int(time.time()*scale) @@ -51,7 +53,7 @@ inTimeTable = { "seconds": lambda n: ngettext("in %s second", "in %s seconds", n), } -def shortTimeFmt(type) -> Any: +def shortTimeFmt(type: str) -> Any: return { #T: year is an abbreviation for year. %s is a number of years "years": _("%sy"), @@ -67,7 +69,7 @@ def shortTimeFmt(type) -> Any: "seconds": _("%ss"), }[type] -def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str: +def fmtTimeSpan(time: Union[int, float], pad: int = 0, point: int = 0, short: bool = False, inTime: bool = False, unit: int = 99) -> str: "Return a string representing a time span (eg '2 days')." (type, point) = optimalPeriod(time, point, unit) time = convertSecondsTo(time, type) @@ -83,7 +85,7 @@ def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str timestr = "%%%(a)d.%(b)df" % {'a': pad, 'b': point} return locale.format_string(fmt % timestr, time) -def optimalPeriod(time, point, unit) -> Tuple[str, Any]: +def optimalPeriod(time: Union[int, float], point: int, unit: int) -> Tuple[str, Any]: if abs(time) < 60 or unit < 1: type = "seconds" point -= 1 @@ -101,7 +103,7 @@ def optimalPeriod(time, point, unit) -> Tuple[str, Any]: point += 1 return (type, max(point, 0)) -def convertSecondsTo(seconds, type) -> Any: +def convertSecondsTo(seconds: Union[int, float], type: str) -> Any: if type == "seconds": return seconds elif type == "minutes": @@ -116,7 +118,7 @@ def convertSecondsTo(seconds, type) -> Any: return seconds / 31536000 assert False -def _pluralCount(time, point) -> int: +def _pluralCount(time: Union[int, float], point: int) -> int: if point: return 2 return math.floor(time) @@ -143,7 +145,7 @@ reTag = re.compile("(?s)<.*?>") reEnts = re.compile(r"&#?\w+;") reMedia = re.compile("(?i)]+src=[\"']?([^\"'>]+)[\"']?[^>]*>") -def stripHTML(s) -> str: +def stripHTML(s: str) -> str: s = reComment.sub("", s) s = reStyle.sub("", s) s = reScript.sub("", s) @@ -151,7 +153,7 @@ def stripHTML(s) -> str: s = entsToTxt(s) return s -def stripHTMLMedia(s) -> Any: +def stripHTMLMedia(s: str) -> Any: "Strip HTML but keep media filenames" s = reMedia.sub(" \\1 ", s) return stripHTML(s) @@ -177,7 +179,7 @@ def htmlToTextLine(s) -> Any: s = s.strip() return s -def entsToTxt(html) -> str: +def entsToTxt(html: str) -> str: # entitydefs defines nbsp as \xa0 instead of a standard space, so we # replace it first html = html.replace(" ", " ") @@ -216,11 +218,11 @@ def hexifyID(id) -> str: def dehexifyID(id) -> int: return int(id, 16) -def ids2str(ids) -> str: +def ids2str(ids: Any) -> str: """Given a list of integers, return a string '(int1,int2,...)'.""" return "(%s)" % ",".join(str(i) for i in ids) -def timestampID(db, table) -> int: +def timestampID(db: DB, table: str) -> int: "Return a non-conflicting timestamp for table." # be careful not to create multiple objects without flushing them, or they # may share an ID. @@ -229,7 +231,7 @@ def timestampID(db, table) -> int: t += 1 return t -def maxID(db) -> Any: +def maxID(db: DB) -> Any: "Return the first safe ID to use." now = intTime(1000) for tbl in "cards", "notes": @@ -237,7 +239,7 @@ def maxID(db) -> Any: return now + 1 # used in ankiweb -def base62(num, extra="") -> str: +def base62(num: int, extra: str = "") -> str: s = string; table = s.ascii_letters + s.digits + extra buf = "" while num: @@ -246,7 +248,7 @@ def base62(num, extra="") -> str: return buf _base91_extra_chars = "!#$%&()*+,-./:;<=>?@[]^_`{|}~" -def base91(num) -> str: +def base91(num: int) -> str: # all printable characters minus quotes, backslash and separators return base62(num, _base91_extra_chars) @@ -271,21 +273,21 @@ def _incGuid(guid) -> str: # Fields ############################################################################## -def joinFields(list) -> str: +def joinFields(list: List[str]) -> str: return "\x1f".join(list) -def splitFields(string) -> Any: +def splitFields(string: str) -> Any: return string.split("\x1f") # Checksums ############################################################################## -def checksum(data) -> str: +def checksum(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") return sha1(data).hexdigest() -def fieldChecksum(data) -> int: +def fieldChecksum(data: str) -> int: # 32 bit unsigned number from first 8 digits of sha1 hash return int(checksum(stripHTMLMedia(data).encode("utf-8"))[:8], 16) @@ -308,12 +310,12 @@ def tmpdir() -> Any: os.mkdir(_tmpdir) return _tmpdir -def tmpfile(prefix="", suffix="") -> Any: +def tmpfile(prefix: str = "", suffix: str = "") -> Any: (fd, name) = tempfile.mkstemp(dir=tmpdir(), prefix=prefix, suffix=suffix) os.close(fd) return name -def namedtmp(name, rm=True) -> Any: +def namedtmp(name: str, rm: bool = True) -> Any: "Return tmpdir+name. Deletes any existing file." path = os.path.join(tmpdir(), name) if rm: @@ -327,13 +329,13 @@ def namedtmp(name, rm=True) -> Any: ############################################################################## @contextmanager -def noBundledLibs(): +def noBundledLibs() -> Iterator[None]: oldlpath = os.environ.pop("LD_LIBRARY_PATH", None) yield if oldlpath is not None: os.environ["LD_LIBRARY_PATH"] = oldlpath -def call(argv, wait=True, **kwargs) -> int: +def call(argv: List[str], wait: bool = True, **kwargs) -> int: "Execute a command. If WAIT, return exit code." # ensure we don't open a separate window for forking process on windows if isWin: