add some typehints, and remove some unused code

This commit is contained in:
Damien Elmes 2021-01-31 20:56:21 +10:00
parent f0f2da0f56
commit 7fda601aef
9 changed files with 58 additions and 111 deletions

View file

@ -1,10 +1,7 @@
from typing import Any
def buildhash(*args, **kwargs) -> Any: ...
def open_backend(*args, **kwargs) -> Any: ...
def buildhash() -> str: ...
def open_backend(data: bytes) -> Backend: ...
class Backend:
@classmethod
def __init__(self, *args, **kwargs) -> None: ...
def command(self, *args, **kwargs) -> Any: ...
def db_command(self, *args, **kwargs) -> Any: ...
def command(self, method: int, data: bytes) -> bytes: ...
def db_command(self, data: bytes) -> bytes: ...

View file

@ -92,7 +92,7 @@ REVLOG_RESCHED = 4
##########################################################################
def _tr(col: Optional[anki.collection.Collection]):
def _tr(col: Optional[anki.collection.Collection]) -> Any:
if col:
return col.tr
else:

View file

@ -32,7 +32,7 @@ class DB:
del d["_db"]
return f"{super().__repr__()} {pprint.pformat(d, width=300)}"
def execute(self, sql: str, *a, **ka) -> Cursor:
def execute(self, sql: str, *a: Any, **ka: Any) -> Cursor:
s = sql.strip().lower()
# mark modified?
for stmt in "insert", "update", "delete":
@ -76,36 +76,36 @@ class DB:
def rollback(self) -> None:
self._db.rollback()
def scalar(self, *a, **kw) -> Any:
def scalar(self, *a: Any, **kw: Any) -> Any:
res = self.execute(*a, **kw).fetchone()
if res:
return res[0]
return None
def all(self, *a, **kw) -> List:
def all(self, *a: Any, **kw: Any) -> List:
return self.execute(*a, **kw).fetchall()
def first(self, *a, **kw) -> Any:
def first(self, *a: Any, **kw: Any) -> Any:
c = self.execute(*a, **kw)
res = c.fetchone()
c.close()
return res
def list(self, *a, **kw) -> List:
def list(self, *a: Any, **kw: Any) -> List:
return [x[0] for x in self.execute(*a, **kw)]
def close(self) -> None:
self._db.text_factory = None
self._db.close()
def set_progress_handler(self, *args) -> None:
def set_progress_handler(self, *args: Any) -> None:
self._db.set_progress_handler(*args)
def __enter__(self) -> "DB":
self._db.execute("begin")
return self
def __exit__(self, exc_type, *args) -> None:
def __exit__(self, *args: Any) -> None:
self._db.close()
def totalChanges(self) -> Any:

View file

@ -4,6 +4,7 @@
from __future__ import annotations
import re
from re import Match
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import anki
@ -43,7 +44,11 @@ class DBProxy:
################
def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False, **kwargs
self,
sql: str,
*args: ValueForDB,
first_row_only: bool = False,
**kwargs: ValueForDB,
) -> List[Row]:
# mark modified?
s = sql.strip().lower()
@ -57,20 +62,22 @@ class DBProxy:
# Query shortcuts
###################
def all(self, sql: str, *args: ValueForDB, **kwargs) -> List[Row]:
return self._query(sql, *args, **kwargs)
def all(self, sql: str, *args: ValueForDB, **kwargs: ValueForDB) -> List[Row]:
return self._query(sql, *args, first_row_only=False, **kwargs)
def list(self, sql: str, *args: ValueForDB, **kwargs) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args, **kwargs)]
def list(
self, sql: str, *args: ValueForDB, **kwargs: ValueForDB
) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args, first_row_only=False, **kwargs)]
def first(self, sql: str, *args: ValueForDB, **kwargs) -> Optional[Row]:
def first(self, sql: str, *args: ValueForDB, **kwargs: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True, **kwargs)
if rows:
return rows[0]
else:
return None
def scalar(self, sql: str, *args: ValueForDB, **kwargs) -> ValueFromDB:
def scalar(self, sql: str, *args: ValueForDB, **kwargs: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True, **kwargs)
if rows:
return rows[0][0]
@ -109,7 +116,7 @@ def emulate_named_args(
n = len(args2)
arg_num[key] = n
# update refs
def repl(m):
def repl(m: Match) -> str:
arg = m.group(1)
return f"?{arg_num[arg]}"

View file

@ -3,8 +3,6 @@
from __future__ import annotations
from typing import Any
import anki._backend.backend_pb2 as _pb
# fixme: notfounderror etc need to be in rsbackend.py
@ -88,17 +86,15 @@ def backend_exception_to_pylib(err: _pb.BackendError) -> Exception:
return StringError(err.localized)
# FIXME: this is only used with "abortSchemaMod", but currently some
# add-ons depend on it
class AnkiError(Exception):
def __init__(self, type, **data) -> None:
def __init__(self, type: str) -> None:
super().__init__()
self.type = type
self.data = data
def __str__(self) -> Any:
m = self.type
if self.data:
m += ": %s" % repr(self.data)
return m
def __str__(self) -> str:
return self.type
class DeckRenameError(Exception):
@ -106,5 +102,5 @@ class DeckRenameError(Exception):
super().__init__()
self.description = description
def __str__(self):
def __str__(self) -> str:
return "Couldn't rename deck: " + self.description

View file

@ -5,6 +5,8 @@
Wrapper for requests that adds a callback for tracking upload/download progress.
"""
from __future__ import annotations
import io
import os
from typing import Any, Callable, Dict, Optional
@ -28,24 +30,23 @@ class HttpClient:
self.progress_hook = progress_hook
self.session = requests.Session()
def __enter__(self):
def __enter__(self) -> HttpClient:
return self
def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
self.close()
def close(self):
def close(self) -> None:
if self.session:
self.session.close()
self.session = None
def __del__(self):
def __del__(self) -> None:
self.close()
def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response:
data = _MonitoringFile(
data, hook=self.progress_hook
) # pytype: disable=wrong-arg-types
def post(
self, url: str, data: bytes, headers: Optional[Dict[str, str]]
) -> Response:
headers["User-Agent"] = self._agentName()
return self.session.post(
url,
@ -56,7 +57,7 @@ class HttpClient:
verify=self.verify,
) # pytype: disable=wrong-arg-types
def get(self, url, headers=None) -> Response:
def get(self, url: str, headers: Dict[str, str] = None) -> Response:
if headers is None:
headers = {}
headers["User-Agent"] = self._agentName()
@ -64,7 +65,7 @@ class HttpClient:
url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify
)
def streamContent(self, resp) -> bytes:
def streamContent(self, resp: Response) -> bytes:
resp.raise_for_status()
buf = io.BytesIO()
@ -87,15 +88,3 @@ if os.environ.get("ANKI_NOVERIFYSSL"):
import warnings
warnings.filterwarnings("ignore")
class _MonitoringFile(io.BufferedReader):
def __init__(self, raw: io.RawIOBase, hook: Optional[ProgressCallback]):
io.BufferedReader.__init__(self, raw)
self.hook = hook
def read(self, size=-1) -> bytes:
data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
if self.hook:
self.hook(len(data), 0)
return data

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import locale
import re
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
import anki
import anki._backend.backend_pb2 as _pb
@ -169,7 +169,7 @@ def ngettext(single: str, plural: str, n: int) -> str:
return plural
def tr_legacyglobal(*args, **kwargs) -> str:
def tr_legacyglobal(*args: Any, **kwargs: Any) -> str:
"Should use col.tr() instead."
if current_i18n:
return current_i18n.translate(*args, **kwargs)

View file

@ -48,7 +48,7 @@ class TagManager:
#############################################################
def register(
self, tags: Collection[str], usn: Optional[int] = None, clear=False
self, tags: Collection[str], usn: Optional[int] = None, clear: bool = False
) -> None:
print("tags.register() is deprecated and no longer works")
@ -56,10 +56,10 @@ class TagManager:
"Clear unused tags and add any missing tags from notes to the tag list."
self.clear_unused_tags()
def clear_unused_tags(self):
def clear_unused_tags(self) -> None:
self.col._backend.clear_unused_tags()
def byDeck(self, did, children=False) -> List[str]:
def byDeck(self, did: int, children: bool = False) -> List[str]:
basequery = "select n.tags from cards c, notes n WHERE c.nid = n.id"
if not children:
query = basequery + " AND c.did=?"
@ -72,7 +72,7 @@ class TagManager:
res = self.col.db.list(query)
return list(set(self.split(" ".join(res))))
def set_collapsed(self, tag: str, collapsed: bool):
def set_collapsed(self, tag: str, collapsed: bool) -> None:
"Set browser collapse state for tag, registering the tag if missing."
self.col._backend.set_tag_collapsed(name=tag, collapsed=collapsed)
@ -139,9 +139,9 @@ class TagManager:
def remFromStr(self, deltags: str, tags: str) -> str:
"Delete tags if they exist."
def wildcard(pat, str):
def wildcard(pat: str, repl: str):
pat = re.escape(pat).replace("\\*", ".*")
return re.match("^" + pat + "$", str, re.IGNORECASE)
return re.match("^" + pat + "$", repl, re.IGNORECASE)
currentTags = self.split(tags)
for tag in self.split(deltags):

View file

@ -5,7 +5,6 @@ from __future__ import annotations
# some add-ons expect json to be in the utils module
import json # pylint: disable=unused-import
import locale
import os
import platform
import random
@ -20,7 +19,7 @@ import traceback
from contextlib import contextmanager
from hashlib import sha1
from html.entities import name2codepoint
from typing import Iterable, Iterator, List, Optional, Union
from typing import Any, Iterable, Iterator, List, Match, Optional, Union
from anki.dbproxy import DBProxy
@ -46,22 +45,6 @@ def intTime(scale: int = 1) -> int:
return int(time.time() * scale)
# Locale
##############################################################################
def fmtPercentage(float_value, point=1) -> str:
"Return float with percentage sign"
fmt = "%" + "0.%(b)df" % {"b": point}
return locale.format_string(fmt, float_value) + "%"
def fmtFloat(float_value, point=1) -> str:
"Return a string with decimal separator according to current locale"
fmt = "%" + "0.%(b)df" % {"b": point}
return locale.format_string(fmt, float_value)
# HTML
##############################################################################
reComment = re.compile("(?s)<!--.*?-->")
@ -114,7 +97,7 @@ def entsToTxt(html: str) -> str:
# replace it first
html = html.replace("&nbsp;", " ")
def fixup(m):
def fixup(m: Match) -> str:
text = m.group(0)
if text[:2] == "&#":
# character reference
@ -140,14 +123,6 @@ def entsToTxt(html: str) -> str:
##############################################################################
def hexifyID(id) -> str:
return "%x" % int(id)
def dehexifyID(id) -> int:
return int(id, 16)
def ids2str(ids: Iterable[Union[int, str]]) -> str:
"""Given a list of integers, return a string '(int1,int2,...)'."""
return "(%s)" % ",".join(str(i) for i in ids)
@ -195,23 +170,6 @@ def guid64() -> str:
return base91(random.randint(0, 2 ** 64 - 1))
# increment a guid by one, for note type conflicts
def incGuid(guid) -> str:
return _incGuid(guid[::-1])[::-1]
def _incGuid(guid) -> str:
s = string
table = s.ascii_letters + s.digits + _base91_extra_chars
idx = table.index(guid[0])
if idx + 1 == len(table):
# overflow
guid = table[0] + _incGuid(guid[1:])
else:
guid = table[idx + 1] + guid[1:]
return guid
# Fields
##############################################################################
@ -250,7 +208,7 @@ def tmpdir() -> str:
global _tmpdir
if not _tmpdir:
def cleanup():
def cleanup() -> None:
if os.path.exists(_tmpdir):
shutil.rmtree(_tmpdir)
@ -294,7 +252,7 @@ def noBundledLibs() -> Iterator[None]:
os.environ["LD_LIBRARY_PATH"] = oldlpath
def call(argv: List[str], wait: bool = True, **kwargs) -> int:
def call(argv: List[str], wait: bool = True, **kwargs: Any) -> int:
"Execute a command. If WAIT, return exit code."
# ensure we don't open a separate window for forking process on windows
if isWin:
@ -338,7 +296,7 @@ devMode = os.getenv("ANKIDEV", "")
invalidFilenameChars = ':*?"<>|'
def invalidFilename(str, dirsep=True) -> Optional[str]:
def invalidFilename(str: str, dirsep: bool = True) -> Optional[str]:
for c in invalidFilenameChars:
if c in str:
return c
@ -384,7 +342,7 @@ class TimedLog:
def __init__(self) -> None:
self._last = time.time()
def log(self, s) -> None:
def log(self, s: str) -> None:
path, num, fn, y = traceback.extract_stack(limit=2)[0]
sys.stderr.write(
"%5dms: %s(): %s\n" % ((time.time() - self._last) * 1000, fn, s)