add some typehints, and remove some unused code

This commit is contained in:
Damien Elmes 2021-01-31 20:56:21 +10:00
parent f0f2da0f56
commit 7fda601aef
9 changed files with 58 additions and 111 deletions

View file

@ -1,10 +1,7 @@
from typing import Any def buildhash() -> str: ...
def open_backend(data: bytes) -> Backend: ...
def buildhash(*args, **kwargs) -> Any: ...
def open_backend(*args, **kwargs) -> Any: ...
class Backend: class Backend:
@classmethod @classmethod
def __init__(self, *args, **kwargs) -> None: ... def command(self, method: int, data: bytes) -> bytes: ...
def command(self, *args, **kwargs) -> Any: ... def db_command(self, data: bytes) -> bytes: ...
def db_command(self, *args, **kwargs) -> Any: ...

View file

@ -92,7 +92,7 @@ REVLOG_RESCHED = 4
########################################################################## ##########################################################################
def _tr(col: Optional[anki.collection.Collection]): def _tr(col: Optional[anki.collection.Collection]) -> Any:
if col: if col:
return col.tr return col.tr
else: else:

View file

@ -32,7 +32,7 @@ class DB:
del d["_db"] del d["_db"]
return f"{super().__repr__()} {pprint.pformat(d, width=300)}" return f"{super().__repr__()} {pprint.pformat(d, width=300)}"
def execute(self, sql: str, *a, **ka) -> Cursor: def execute(self, sql: str, *a: Any, **ka: Any) -> Cursor:
s = sql.strip().lower() s = sql.strip().lower()
# mark modified? # mark modified?
for stmt in "insert", "update", "delete": for stmt in "insert", "update", "delete":
@ -76,36 +76,36 @@ class DB:
def rollback(self) -> None: def rollback(self) -> None:
self._db.rollback() self._db.rollback()
def scalar(self, *a, **kw) -> Any: def scalar(self, *a: Any, **kw: Any) -> Any:
res = self.execute(*a, **kw).fetchone() res = self.execute(*a, **kw).fetchone()
if res: if res:
return res[0] return res[0]
return None return None
def all(self, *a, **kw) -> List: def all(self, *a: Any, **kw: Any) -> List:
return self.execute(*a, **kw).fetchall() return self.execute(*a, **kw).fetchall()
def first(self, *a, **kw) -> Any: def first(self, *a: Any, **kw: Any) -> Any:
c = self.execute(*a, **kw) c = self.execute(*a, **kw)
res = c.fetchone() res = c.fetchone()
c.close() c.close()
return res return res
def list(self, *a, **kw) -> List: def list(self, *a: Any, **kw: Any) -> List:
return [x[0] for x in self.execute(*a, **kw)] return [x[0] for x in self.execute(*a, **kw)]
def close(self) -> None: def close(self) -> None:
self._db.text_factory = None self._db.text_factory = None
self._db.close() self._db.close()
def set_progress_handler(self, *args) -> None: def set_progress_handler(self, *args: Any) -> None:
self._db.set_progress_handler(*args) self._db.set_progress_handler(*args)
def __enter__(self) -> "DB": def __enter__(self) -> "DB":
self._db.execute("begin") self._db.execute("begin")
return self return self
def __exit__(self, exc_type, *args) -> None: def __exit__(self, *args: Any) -> None:
self._db.close() self._db.close()
def totalChanges(self) -> Any: def totalChanges(self) -> Any:

View file

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from re import Match
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import anki import anki
@ -43,7 +44,11 @@ class DBProxy:
################ ################
def _query( def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False, **kwargs self,
sql: str,
*args: ValueForDB,
first_row_only: bool = False,
**kwargs: ValueForDB,
) -> List[Row]: ) -> List[Row]:
# mark modified? # mark modified?
s = sql.strip().lower() s = sql.strip().lower()
@ -57,20 +62,22 @@ class DBProxy:
# Query shortcuts # Query shortcuts
################### ###################
def all(self, sql: str, *args: ValueForDB, **kwargs) -> List[Row]: def all(self, sql: str, *args: ValueForDB, **kwargs: ValueForDB) -> List[Row]:
return self._query(sql, *args, **kwargs) return self._query(sql, *args, first_row_only=False, **kwargs)
def list(self, sql: str, *args: ValueForDB, **kwargs) -> List[ValueFromDB]: def list(
return [x[0] for x in self._query(sql, *args, **kwargs)] self, sql: str, *args: ValueForDB, **kwargs: ValueForDB
) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args, first_row_only=False, **kwargs)]
def first(self, sql: str, *args: ValueForDB, **kwargs) -> Optional[Row]: def first(self, sql: str, *args: ValueForDB, **kwargs: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True, **kwargs) rows = self._query(sql, *args, first_row_only=True, **kwargs)
if rows: if rows:
return rows[0] return rows[0]
else: else:
return None return None
def scalar(self, sql: str, *args: ValueForDB, **kwargs) -> ValueFromDB: def scalar(self, sql: str, *args: ValueForDB, **kwargs: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True, **kwargs) rows = self._query(sql, *args, first_row_only=True, **kwargs)
if rows: if rows:
return rows[0][0] return rows[0][0]
@ -109,7 +116,7 @@ def emulate_named_args(
n = len(args2) n = len(args2)
arg_num[key] = n arg_num[key] = n
# update refs # update refs
def repl(m): def repl(m: Match) -> str:
arg = m.group(1) arg = m.group(1)
return f"?{arg_num[arg]}" return f"?{arg_num[arg]}"

View file

@ -3,8 +3,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any
import anki._backend.backend_pb2 as _pb import anki._backend.backend_pb2 as _pb
# fixme: notfounderror etc need to be in rsbackend.py # fixme: notfounderror etc need to be in rsbackend.py
@ -88,17 +86,15 @@ def backend_exception_to_pylib(err: _pb.BackendError) -> Exception:
return StringError(err.localized) return StringError(err.localized)
# FIXME: this is only used with "abortSchemaMod", but currently some
# add-ons depend on it
class AnkiError(Exception): class AnkiError(Exception):
def __init__(self, type, **data) -> None: def __init__(self, type: str) -> None:
super().__init__() super().__init__()
self.type = type self.type = type
self.data = data
def __str__(self) -> Any: def __str__(self) -> str:
m = self.type return self.type
if self.data:
m += ": %s" % repr(self.data)
return m
class DeckRenameError(Exception): class DeckRenameError(Exception):
@ -106,5 +102,5 @@ class DeckRenameError(Exception):
super().__init__() super().__init__()
self.description = description self.description = description
def __str__(self): def __str__(self) -> str:
return "Couldn't rename deck: " + self.description return "Couldn't rename deck: " + self.description

View file

@ -5,6 +5,8 @@
Wrapper for requests that adds a callback for tracking upload/download progress. Wrapper for requests that adds a callback for tracking upload/download progress.
""" """
from __future__ import annotations
import io import io
import os import os
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
@ -28,24 +30,23 @@ class HttpClient:
self.progress_hook = progress_hook self.progress_hook = progress_hook
self.session = requests.Session() self.session = requests.Session()
def __enter__(self): def __enter__(self) -> HttpClient:
return self return self
def __exit__(self, *args): def __exit__(self, *args: Any) -> None:
self.close() self.close()
def close(self): def close(self) -> None:
if self.session: if self.session:
self.session.close() self.session.close()
self.session = None self.session = None
def __del__(self): def __del__(self) -> None:
self.close() self.close()
def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response: def post(
data = _MonitoringFile( self, url: str, data: bytes, headers: Optional[Dict[str, str]]
data, hook=self.progress_hook ) -> Response:
) # pytype: disable=wrong-arg-types
headers["User-Agent"] = self._agentName() headers["User-Agent"] = self._agentName()
return self.session.post( return self.session.post(
url, url,
@ -56,7 +57,7 @@ class HttpClient:
verify=self.verify, verify=self.verify,
) # pytype: disable=wrong-arg-types ) # pytype: disable=wrong-arg-types
def get(self, url, headers=None) -> Response: def get(self, url: str, headers: Dict[str, str] = None) -> Response:
if headers is None: if headers is None:
headers = {} headers = {}
headers["User-Agent"] = self._agentName() headers["User-Agent"] = self._agentName()
@ -64,7 +65,7 @@ class HttpClient:
url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify
) )
def streamContent(self, resp) -> bytes: def streamContent(self, resp: Response) -> bytes:
resp.raise_for_status() resp.raise_for_status()
buf = io.BytesIO() buf = io.BytesIO()
@ -87,15 +88,3 @@ if os.environ.get("ANKI_NOVERIFYSSL"):
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
class _MonitoringFile(io.BufferedReader):
def __init__(self, raw: io.RawIOBase, hook: Optional[ProgressCallback]):
io.BufferedReader.__init__(self, raw)
self.hook = hook
def read(self, size=-1) -> bytes:
data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
if self.hook:
self.hook(len(data), 0)
return data

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import locale import locale
import re import re
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Any, Optional, Tuple
import anki import anki
import anki._backend.backend_pb2 as _pb import anki._backend.backend_pb2 as _pb
@ -169,7 +169,7 @@ def ngettext(single: str, plural: str, n: int) -> str:
return plural return plural
def tr_legacyglobal(*args, **kwargs) -> str: def tr_legacyglobal(*args: Any, **kwargs: Any) -> str:
"Should use col.tr() instead." "Should use col.tr() instead."
if current_i18n: if current_i18n:
return current_i18n.translate(*args, **kwargs) return current_i18n.translate(*args, **kwargs)

View file

@ -48,7 +48,7 @@ class TagManager:
############################################################# #############################################################
def register( def register(
self, tags: Collection[str], usn: Optional[int] = None, clear=False self, tags: Collection[str], usn: Optional[int] = None, clear: bool = False
) -> None: ) -> None:
print("tags.register() is deprecated and no longer works") print("tags.register() is deprecated and no longer works")
@ -56,10 +56,10 @@ class TagManager:
"Clear unused tags and add any missing tags from notes to the tag list." "Clear unused tags and add any missing tags from notes to the tag list."
self.clear_unused_tags() self.clear_unused_tags()
def clear_unused_tags(self): def clear_unused_tags(self) -> None:
self.col._backend.clear_unused_tags() self.col._backend.clear_unused_tags()
def byDeck(self, did, children=False) -> List[str]: def byDeck(self, did: int, children: bool = False) -> List[str]:
basequery = "select n.tags from cards c, notes n WHERE c.nid = n.id" basequery = "select n.tags from cards c, notes n WHERE c.nid = n.id"
if not children: if not children:
query = basequery + " AND c.did=?" query = basequery + " AND c.did=?"
@ -72,7 +72,7 @@ class TagManager:
res = self.col.db.list(query) res = self.col.db.list(query)
return list(set(self.split(" ".join(res)))) return list(set(self.split(" ".join(res))))
def set_collapsed(self, tag: str, collapsed: bool): def set_collapsed(self, tag: str, collapsed: bool) -> None:
"Set browser collapse state for tag, registering the tag if missing." "Set browser collapse state for tag, registering the tag if missing."
self.col._backend.set_tag_collapsed(name=tag, collapsed=collapsed) self.col._backend.set_tag_collapsed(name=tag, collapsed=collapsed)
@ -139,9 +139,9 @@ class TagManager:
def remFromStr(self, deltags: str, tags: str) -> str: def remFromStr(self, deltags: str, tags: str) -> str:
"Delete tags if they exist." "Delete tags if they exist."
def wildcard(pat, str): def wildcard(pat: str, repl: str):
pat = re.escape(pat).replace("\\*", ".*") pat = re.escape(pat).replace("\\*", ".*")
return re.match("^" + pat + "$", str, re.IGNORECASE) return re.match("^" + pat + "$", repl, re.IGNORECASE)
currentTags = self.split(tags) currentTags = self.split(tags)
for tag in self.split(deltags): for tag in self.split(deltags):

View file

@ -5,7 +5,6 @@ from __future__ import annotations
# some add-ons expect json to be in the utils module # some add-ons expect json to be in the utils module
import json # pylint: disable=unused-import import json # pylint: disable=unused-import
import locale
import os import os
import platform import platform
import random import random
@ -20,7 +19,7 @@ import traceback
from contextlib import contextmanager from contextlib import contextmanager
from hashlib import sha1 from hashlib import sha1
from html.entities import name2codepoint from html.entities import name2codepoint
from typing import Iterable, Iterator, List, Optional, Union from typing import Any, Iterable, Iterator, List, Match, Optional, Union
from anki.dbproxy import DBProxy from anki.dbproxy import DBProxy
@ -46,22 +45,6 @@ def intTime(scale: int = 1) -> int:
return int(time.time() * scale) return int(time.time() * scale)
# Locale
##############################################################################
def fmtPercentage(float_value, point=1) -> str:
"Return float with percentage sign"
fmt = "%" + "0.%(b)df" % {"b": point}
return locale.format_string(fmt, float_value) + "%"
def fmtFloat(float_value, point=1) -> str:
"Return a string with decimal separator according to current locale"
fmt = "%" + "0.%(b)df" % {"b": point}
return locale.format_string(fmt, float_value)
# HTML # HTML
############################################################################## ##############################################################################
reComment = re.compile("(?s)<!--.*?-->") reComment = re.compile("(?s)<!--.*?-->")
@ -114,7 +97,7 @@ def entsToTxt(html: str) -> str:
# replace it first # replace it first
html = html.replace("&nbsp;", " ") html = html.replace("&nbsp;", " ")
def fixup(m): def fixup(m: Match) -> str:
text = m.group(0) text = m.group(0)
if text[:2] == "&#": if text[:2] == "&#":
# character reference # character reference
@ -140,14 +123,6 @@ def entsToTxt(html: str) -> str:
############################################################################## ##############################################################################
def hexifyID(id) -> str:
return "%x" % int(id)
def dehexifyID(id) -> int:
return int(id, 16)
def ids2str(ids: Iterable[Union[int, str]]) -> str: def ids2str(ids: Iterable[Union[int, str]]) -> str:
"""Given a list of integers, return a string '(int1,int2,...)'.""" """Given a list of integers, return a string '(int1,int2,...)'."""
return "(%s)" % ",".join(str(i) for i in ids) return "(%s)" % ",".join(str(i) for i in ids)
@ -195,23 +170,6 @@ def guid64() -> str:
return base91(random.randint(0, 2 ** 64 - 1)) return base91(random.randint(0, 2 ** 64 - 1))
# increment a guid by one, for note type conflicts
def incGuid(guid) -> str:
return _incGuid(guid[::-1])[::-1]
def _incGuid(guid) -> str:
s = string
table = s.ascii_letters + s.digits + _base91_extra_chars
idx = table.index(guid[0])
if idx + 1 == len(table):
# overflow
guid = table[0] + _incGuid(guid[1:])
else:
guid = table[idx + 1] + guid[1:]
return guid
# Fields # Fields
############################################################################## ##############################################################################
@ -250,7 +208,7 @@ def tmpdir() -> str:
global _tmpdir global _tmpdir
if not _tmpdir: if not _tmpdir:
def cleanup(): def cleanup() -> None:
if os.path.exists(_tmpdir): if os.path.exists(_tmpdir):
shutil.rmtree(_tmpdir) shutil.rmtree(_tmpdir)
@ -294,7 +252,7 @@ def noBundledLibs() -> Iterator[None]:
os.environ["LD_LIBRARY_PATH"] = oldlpath os.environ["LD_LIBRARY_PATH"] = oldlpath
def call(argv: List[str], wait: bool = True, **kwargs) -> int: def call(argv: List[str], wait: bool = True, **kwargs: Any) -> int:
"Execute a command. If WAIT, return exit code." "Execute a command. If WAIT, return exit code."
# ensure we don't open a separate window for forking process on windows # ensure we don't open a separate window for forking process on windows
if isWin: if isWin:
@ -338,7 +296,7 @@ devMode = os.getenv("ANKIDEV", "")
invalidFilenameChars = ':*?"<>|' invalidFilenameChars = ':*?"<>|'
def invalidFilename(str, dirsep=True) -> Optional[str]: def invalidFilename(str: str, dirsep: bool = True) -> Optional[str]:
for c in invalidFilenameChars: for c in invalidFilenameChars:
if c in str: if c in str:
return c return c
@ -384,7 +342,7 @@ class TimedLog:
def __init__(self) -> None: def __init__(self) -> None:
self._last = time.time() self._last = time.time()
def log(self, s) -> None: def log(self, s: str) -> None:
path, num, fn, y = traceback.extract_stack(limit=2)[0] path, num, fn, y = traceback.extract_stack(limit=2)[0]
sys.stderr.write( sys.stderr.write(
"%5dms: %s(): %s\n" % ((time.time() - self._last) * 1000, fn, s) "%5dms: %s(): %s\n" % ((time.time() - self._last) * 1000, fn, s)