mirror of
https://github.com/ankitects/anki.git
synced 2025-11-11 15:17:12 -05:00
db/hooks/utils
This commit is contained in:
parent
f69ef52845
commit
6ffe82ac54
3 changed files with 34 additions and 31 deletions
12
anki/db.py
12
anki/db.py
|
|
@ -11,14 +11,14 @@ from typing import Any, List
|
||||||
DBError = sqlite.Error
|
DBError = sqlite.Error
|
||||||
|
|
||||||
class DB:
|
class DB:
|
||||||
def __init__(self, path, timeout=0) -> None:
|
def __init__(self, path: str, timeout: int = 0) -> None:
|
||||||
self._db = sqlite.connect(path, timeout=timeout)
|
self._db = sqlite.connect(path, timeout=timeout)
|
||||||
self._db.text_factory = self._textFactory
|
self._db.text_factory = self._textFactory
|
||||||
self._path = path
|
self._path = path
|
||||||
self.echo = os.environ.get("DBECHO")
|
self.echo = os.environ.get("DBECHO")
|
||||||
self.mod = False
|
self.mod = False
|
||||||
|
|
||||||
def execute(self, sql, *a, **ka) -> Cursor:
|
def execute(self, sql: str, *a, **ka) -> Cursor:
|
||||||
s = sql.strip().lower()
|
s = sql.strip().lower()
|
||||||
# mark modified?
|
# mark modified?
|
||||||
for stmt in "insert", "update", "delete":
|
for stmt in "insert", "update", "delete":
|
||||||
|
|
@ -38,7 +38,7 @@ class DB:
|
||||||
print(a, ka)
|
print(a, ka)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def executemany(self, sql, l) -> None:
|
def executemany(self, sql: str, l: Any) -> None:
|
||||||
self.mod = True
|
self.mod = True
|
||||||
t = time.time()
|
t = time.time()
|
||||||
self._db.executemany(sql, l)
|
self._db.executemany(sql, l)
|
||||||
|
|
@ -53,7 +53,7 @@ class DB:
|
||||||
if self.echo:
|
if self.echo:
|
||||||
print("commit %0.3fms" % ((time.time() - t)*1000))
|
print("commit %0.3fms" % ((time.time() - t)*1000))
|
||||||
|
|
||||||
def executescript(self, sql) -> None:
|
def executescript(self, sql: str) -> None:
|
||||||
self.mod = True
|
self.mod = True
|
||||||
if self.echo:
|
if self.echo:
|
||||||
print(sql)
|
print(sql)
|
||||||
|
|
@ -100,14 +100,14 @@ class DB:
|
||||||
def interrupt(self) -> None:
|
def interrupt(self) -> None:
|
||||||
self._db.interrupt()
|
self._db.interrupt()
|
||||||
|
|
||||||
def setAutocommit(self, autocommit) -> None:
|
def setAutocommit(self, autocommit: bool) -> None:
|
||||||
if autocommit:
|
if autocommit:
|
||||||
self._db.isolation_level = None
|
self._db.isolation_level = None
|
||||||
else:
|
else:
|
||||||
self._db.isolation_level = ''
|
self._db.isolation_level = ''
|
||||||
|
|
||||||
# strip out invalid utf-8 when reading from db
|
# strip out invalid utf-8 when reading from db
|
||||||
def _textFactory(self, data) -> str:
|
def _textFactory(self, data: bytes) -> str:
|
||||||
return str(data, errors="ignore")
|
return str(data, errors="ignore")
|
||||||
|
|
||||||
def cursor(self, factory=Cursor) -> Cursor:
|
def cursor(self, factory=Cursor) -> Cursor:
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@ from typing import Dict, List, Callable, Any
|
||||||
# Hooks
|
# Hooks
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
|
from typing import Callable, Dict, Union
|
||||||
_hooks: Dict[str, List[Callable[..., Any]]] = {}
|
_hooks: Dict[str, List[Callable[..., Any]]] = {}
|
||||||
|
|
||||||
def runHook(hook, *args) -> None:
|
def runHook(hook: str, *args) -> None:
|
||||||
"Run all functions on hook."
|
"Run all functions on hook."
|
||||||
hook = _hooks.get(hook, None)
|
hook = _hooks.get(hook, None)
|
||||||
if hook:
|
if hook:
|
||||||
|
|
@ -32,7 +33,7 @@ def runHook(hook, *args) -> None:
|
||||||
hook.remove(func)
|
hook.remove(func)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def runFilter(hook, arg, *args) -> Any:
|
def runFilter(hook: str, arg: Any, *args) -> Any:
|
||||||
hook = _hooks.get(hook, None)
|
hook = _hooks.get(hook, None)
|
||||||
if hook:
|
if hook:
|
||||||
for func in hook:
|
for func in hook:
|
||||||
|
|
@ -43,7 +44,7 @@ def runFilter(hook, arg, *args) -> Any:
|
||||||
raise
|
raise
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
def addHook(hook, func) -> None:
|
def addHook(hook: str, func: Callable) -> None:
|
||||||
"Add a function to hook. Ignore if already on hook."
|
"Add a function to hook. Ignore if already on hook."
|
||||||
if not _hooks.get(hook, None):
|
if not _hooks.get(hook, None):
|
||||||
_hooks[hook] = []
|
_hooks[hook] = []
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,14 @@ from anki.lang import _, ngettext
|
||||||
import json # pylint: disable=unused-import
|
import json # pylint: disable=unused-import
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
from anki.db import DB
|
||||||
|
from typing import Any, Iterator, List, Union
|
||||||
_tmpdir: Optional[str]
|
_tmpdir: Optional[str]
|
||||||
|
|
||||||
# Time handling
|
# Time handling
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
def intTime(scale=1) -> int:
|
def intTime(scale: int = 1) -> int:
|
||||||
"The time in integer seconds. Pass scale=1000 to get milliseconds."
|
"The time in integer seconds. Pass scale=1000 to get milliseconds."
|
||||||
return int(time.time()*scale)
|
return int(time.time()*scale)
|
||||||
|
|
||||||
|
|
@ -51,7 +53,7 @@ inTimeTable = {
|
||||||
"seconds": lambda n: ngettext("in %s second", "in %s seconds", n),
|
"seconds": lambda n: ngettext("in %s second", "in %s seconds", n),
|
||||||
}
|
}
|
||||||
|
|
||||||
def shortTimeFmt(type) -> Any:
|
def shortTimeFmt(type: str) -> Any:
|
||||||
return {
|
return {
|
||||||
#T: year is an abbreviation for year. %s is a number of years
|
#T: year is an abbreviation for year. %s is a number of years
|
||||||
"years": _("%sy"),
|
"years": _("%sy"),
|
||||||
|
|
@ -67,7 +69,7 @@ def shortTimeFmt(type) -> Any:
|
||||||
"seconds": _("%ss"),
|
"seconds": _("%ss"),
|
||||||
}[type]
|
}[type]
|
||||||
|
|
||||||
def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str:
|
def fmtTimeSpan(time: Union[int, float], pad: int = 0, point: int = 0, short: bool = False, inTime: bool = False, unit: int = 99) -> str:
|
||||||
"Return a string representing a time span (eg '2 days')."
|
"Return a string representing a time span (eg '2 days')."
|
||||||
(type, point) = optimalPeriod(time, point, unit)
|
(type, point) = optimalPeriod(time, point, unit)
|
||||||
time = convertSecondsTo(time, type)
|
time = convertSecondsTo(time, type)
|
||||||
|
|
@ -83,7 +85,7 @@ def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str
|
||||||
timestr = "%%%(a)d.%(b)df" % {'a': pad, 'b': point}
|
timestr = "%%%(a)d.%(b)df" % {'a': pad, 'b': point}
|
||||||
return locale.format_string(fmt % timestr, time)
|
return locale.format_string(fmt % timestr, time)
|
||||||
|
|
||||||
def optimalPeriod(time, point, unit) -> Tuple[str, Any]:
|
def optimalPeriod(time: Union[int, float], point: int, unit: int) -> Tuple[str, Any]:
|
||||||
if abs(time) < 60 or unit < 1:
|
if abs(time) < 60 or unit < 1:
|
||||||
type = "seconds"
|
type = "seconds"
|
||||||
point -= 1
|
point -= 1
|
||||||
|
|
@ -101,7 +103,7 @@ def optimalPeriod(time, point, unit) -> Tuple[str, Any]:
|
||||||
point += 1
|
point += 1
|
||||||
return (type, max(point, 0))
|
return (type, max(point, 0))
|
||||||
|
|
||||||
def convertSecondsTo(seconds, type) -> Any:
|
def convertSecondsTo(seconds: Union[int, float], type: str) -> Any:
|
||||||
if type == "seconds":
|
if type == "seconds":
|
||||||
return seconds
|
return seconds
|
||||||
elif type == "minutes":
|
elif type == "minutes":
|
||||||
|
|
@ -116,7 +118,7 @@ def convertSecondsTo(seconds, type) -> Any:
|
||||||
return seconds / 31536000
|
return seconds / 31536000
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
def _pluralCount(time, point) -> int:
|
def _pluralCount(time: Union[int, float], point: int) -> int:
|
||||||
if point:
|
if point:
|
||||||
return 2
|
return 2
|
||||||
return math.floor(time)
|
return math.floor(time)
|
||||||
|
|
@ -143,7 +145,7 @@ reTag = re.compile("(?s)<.*?>")
|
||||||
reEnts = re.compile(r"&#?\w+;")
|
reEnts = re.compile(r"&#?\w+;")
|
||||||
reMedia = re.compile("(?i)<img[^>]+src=[\"']?([^\"'>]+)[\"']?[^>]*>")
|
reMedia = re.compile("(?i)<img[^>]+src=[\"']?([^\"'>]+)[\"']?[^>]*>")
|
||||||
|
|
||||||
def stripHTML(s) -> str:
|
def stripHTML(s: str) -> str:
|
||||||
s = reComment.sub("", s)
|
s = reComment.sub("", s)
|
||||||
s = reStyle.sub("", s)
|
s = reStyle.sub("", s)
|
||||||
s = reScript.sub("", s)
|
s = reScript.sub("", s)
|
||||||
|
|
@ -151,7 +153,7 @@ def stripHTML(s) -> str:
|
||||||
s = entsToTxt(s)
|
s = entsToTxt(s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def stripHTMLMedia(s) -> Any:
|
def stripHTMLMedia(s: str) -> Any:
|
||||||
"Strip HTML but keep media filenames"
|
"Strip HTML but keep media filenames"
|
||||||
s = reMedia.sub(" \\1 ", s)
|
s = reMedia.sub(" \\1 ", s)
|
||||||
return stripHTML(s)
|
return stripHTML(s)
|
||||||
|
|
@ -177,7 +179,7 @@ def htmlToTextLine(s) -> Any:
|
||||||
s = s.strip()
|
s = s.strip()
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def entsToTxt(html) -> str:
|
def entsToTxt(html: str) -> str:
|
||||||
# entitydefs defines nbsp as \xa0 instead of a standard space, so we
|
# entitydefs defines nbsp as \xa0 instead of a standard space, so we
|
||||||
# replace it first
|
# replace it first
|
||||||
html = html.replace(" ", " ")
|
html = html.replace(" ", " ")
|
||||||
|
|
@ -216,11 +218,11 @@ def hexifyID(id) -> str:
|
||||||
def dehexifyID(id) -> int:
|
def dehexifyID(id) -> int:
|
||||||
return int(id, 16)
|
return int(id, 16)
|
||||||
|
|
||||||
def ids2str(ids) -> str:
|
def ids2str(ids: Any) -> str:
|
||||||
"""Given a list of integers, return a string '(int1,int2,...)'."""
|
"""Given a list of integers, return a string '(int1,int2,...)'."""
|
||||||
return "(%s)" % ",".join(str(i) for i in ids)
|
return "(%s)" % ",".join(str(i) for i in ids)
|
||||||
|
|
||||||
def timestampID(db, table) -> int:
|
def timestampID(db: DB, table: str) -> int:
|
||||||
"Return a non-conflicting timestamp for table."
|
"Return a non-conflicting timestamp for table."
|
||||||
# be careful not to create multiple objects without flushing them, or they
|
# be careful not to create multiple objects without flushing them, or they
|
||||||
# may share an ID.
|
# may share an ID.
|
||||||
|
|
@ -229,7 +231,7 @@ def timestampID(db, table) -> int:
|
||||||
t += 1
|
t += 1
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def maxID(db) -> Any:
|
def maxID(db: DB) -> Any:
|
||||||
"Return the first safe ID to use."
|
"Return the first safe ID to use."
|
||||||
now = intTime(1000)
|
now = intTime(1000)
|
||||||
for tbl in "cards", "notes":
|
for tbl in "cards", "notes":
|
||||||
|
|
@ -237,7 +239,7 @@ def maxID(db) -> Any:
|
||||||
return now + 1
|
return now + 1
|
||||||
|
|
||||||
# used in ankiweb
|
# used in ankiweb
|
||||||
def base62(num, extra="") -> str:
|
def base62(num: int, extra: str = "") -> str:
|
||||||
s = string; table = s.ascii_letters + s.digits + extra
|
s = string; table = s.ascii_letters + s.digits + extra
|
||||||
buf = ""
|
buf = ""
|
||||||
while num:
|
while num:
|
||||||
|
|
@ -246,7 +248,7 @@ def base62(num, extra="") -> str:
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
_base91_extra_chars = "!#$%&()*+,-./:;<=>?@[]^_`{|}~"
|
_base91_extra_chars = "!#$%&()*+,-./:;<=>?@[]^_`{|}~"
|
||||||
def base91(num) -> str:
|
def base91(num: int) -> str:
|
||||||
# all printable characters minus quotes, backslash and separators
|
# all printable characters minus quotes, backslash and separators
|
||||||
return base62(num, _base91_extra_chars)
|
return base62(num, _base91_extra_chars)
|
||||||
|
|
||||||
|
|
@ -271,21 +273,21 @@ def _incGuid(guid) -> str:
|
||||||
# Fields
|
# Fields
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
def joinFields(list) -> str:
|
def joinFields(list: List[str]) -> str:
|
||||||
return "\x1f".join(list)
|
return "\x1f".join(list)
|
||||||
|
|
||||||
def splitFields(string) -> Any:
|
def splitFields(string: str) -> Any:
|
||||||
return string.split("\x1f")
|
return string.split("\x1f")
|
||||||
|
|
||||||
# Checksums
|
# Checksums
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
def checksum(data) -> str:
|
def checksum(data: Union[bytes, str]) -> str:
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
data = data.encode("utf-8")
|
data = data.encode("utf-8")
|
||||||
return sha1(data).hexdigest()
|
return sha1(data).hexdigest()
|
||||||
|
|
||||||
def fieldChecksum(data) -> int:
|
def fieldChecksum(data: str) -> int:
|
||||||
# 32 bit unsigned number from first 8 digits of sha1 hash
|
# 32 bit unsigned number from first 8 digits of sha1 hash
|
||||||
return int(checksum(stripHTMLMedia(data).encode("utf-8"))[:8], 16)
|
return int(checksum(stripHTMLMedia(data).encode("utf-8"))[:8], 16)
|
||||||
|
|
||||||
|
|
@ -308,12 +310,12 @@ def tmpdir() -> Any:
|
||||||
os.mkdir(_tmpdir)
|
os.mkdir(_tmpdir)
|
||||||
return _tmpdir
|
return _tmpdir
|
||||||
|
|
||||||
def tmpfile(prefix="", suffix="") -> Any:
|
def tmpfile(prefix: str = "", suffix: str = "") -> Any:
|
||||||
(fd, name) = tempfile.mkstemp(dir=tmpdir(), prefix=prefix, suffix=suffix)
|
(fd, name) = tempfile.mkstemp(dir=tmpdir(), prefix=prefix, suffix=suffix)
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def namedtmp(name, rm=True) -> Any:
|
def namedtmp(name: str, rm: bool = True) -> Any:
|
||||||
"Return tmpdir+name. Deletes any existing file."
|
"Return tmpdir+name. Deletes any existing file."
|
||||||
path = os.path.join(tmpdir(), name)
|
path = os.path.join(tmpdir(), name)
|
||||||
if rm:
|
if rm:
|
||||||
|
|
@ -327,13 +329,13 @@ def namedtmp(name, rm=True) -> Any:
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noBundledLibs():
|
def noBundledLibs() -> Iterator[None]:
|
||||||
oldlpath = os.environ.pop("LD_LIBRARY_PATH", None)
|
oldlpath = os.environ.pop("LD_LIBRARY_PATH", None)
|
||||||
yield
|
yield
|
||||||
if oldlpath is not None:
|
if oldlpath is not None:
|
||||||
os.environ["LD_LIBRARY_PATH"] = oldlpath
|
os.environ["LD_LIBRARY_PATH"] = oldlpath
|
||||||
|
|
||||||
def call(argv, wait=True, **kwargs) -> int:
|
def call(argv: List[str], wait: bool = True, **kwargs) -> int:
|
||||||
"Execute a command. If WAIT, return exit code."
|
"Execute a command. If WAIT, return exit code."
|
||||||
# ensure we don't open a separate window for forking process on windows
|
# ensure we don't open a separate window for forking process on windows
|
||||||
if isWin:
|
if isWin:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue