diff --git a/pylib/anki/find.py b/pylib/anki/find.py index 847d4e05f..3faa6fe2a 100644 --- a/pylib/anki/find.py +++ b/pylib/anki/find.py @@ -1,10 +1,11 @@ # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +from __future__ import annotations import re import sre_constants import unicodedata -from typing import Any, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union, cast from anki import hooks from anki.consts import * @@ -18,12 +19,15 @@ from anki.utils import ( stripHTMLMedia, ) +if TYPE_CHECKING: + from anki.collection import _Collection + # Find ########################################################################## class Finder: - def __init__(self, col) -> None: + def __init__(self, col: Optional[_Collection]) -> None: self.col = col self.search = dict( added=self._findAdded, @@ -42,7 +46,7 @@ class Finder: self.search["is"] = self._findCardState hooks.search_terms_prepared(self.search) - def findCards(self, query, order=False) -> Any: + def findCards(self, query: str, order: Union[bool, str] = False) -> List[Any]: "Return a list of card ids for QUERY." tokens = self._tokenize(query) preds, args = self._where(tokens) @@ -59,7 +63,7 @@ class Finder: res.reverse() return res - def findNotes(self, query) -> Any: + def findNotes(self, query: str) -> List[Any]: tokens = self._tokenize(query) preds, args = self._where(tokens) if preds is None: @@ -83,8 +87,8 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ # Tokenizing ###################################################################### - def _tokenize(self, query) -> List: - inQuote = False + def _tokenize(self, query: str) -> List[str]: + inQuote: Union[bool, str] = False tokens = [] token = "" for c in query: @@ -137,7 +141,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ # Query building ###################################################################### - def _where(self, tokens) -> Tuple[Any, Optional[List[str]]]: + def _where(self, tokens: List[str]) -> Tuple[str, Optional[List[str]]]: # state and query s: Dict[str, Any] = dict(isnot=False, isor=False, join=False, q="", bad=False) args: List[Any] = [] @@ -197,7 +201,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ return None, None return s["q"], args - def _query(self, preds, order) -> str: + def _query(self, preds: str, order: str) -> str: # can we skip the note table? if "n." not in preds and "n." not in order: sql = "select c.id from cards c where " @@ -216,12 +220,12 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ # Ordering ###################################################################### - def _order(self, order) -> Tuple[Any, Any]: + def _order(self, order: Union[bool, str]) -> Tuple[str, bool]: if not order: return "", False elif order is not True: # custom order string provided - return " order by " + order, False + return " order by " + cast(str, order), False # use deck default type = self.col.conf["sortType"] sort = None @@ -253,8 +257,8 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ # Commands ###################################################################### - def _findTag(self, args) -> str: - (val, args) = args + def _findTag(self, args: Tuple[str, List[Any]]) -> str: + (val, list_args) = args if val == "none": return 'n.tags = ""' val = val.replace("*", "%") @@ -262,11 +266,11 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ val = "% " + val if not val.endswith("%") or val.endswith("\\%"): val += " %" - args.append(val) + list_args.append(val) return "n.tags like ? escape '\\'" - def _findCardState(self, args) -> Optional[str]: - (val, args) = args + def _findCardState(self, args: Tuple[str, List[Any]]) -> Optional[str]: + (val, __) = args if val in ("review", "new", "learn"): if val == "review": n = 2 @@ -290,17 +294,16 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ # unknown return None - def _findFlag(self, args) -> Optional[str]: - (val, args) = args + def _findFlag(self, args: Tuple[str, List[Any]]) -> Optional[str]: + (val, __) = args if not val or len(val) != 1 or val not in "01234": return None - val = int(val) mask = 2 ** 3 - 1 - return "(c.flags & %d) == %d" % (mask, val) + return "(c.flags & %d) == %d" % (mask, int(val)) - def _findRated(self, args) -> Optional[str]: + def _findRated(self, args: Tuple[str, List[Any]]) -> Optional[str]: # days(:optional_ease) - (val, args) = args + (val, __) = args r = val.split(":") try: days = int(r[0]) @@ -316,8 +319,8 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000 return "c.id in (select cid from revlog where id>%d %s)" % (cutoff, ease) - def _findAdded(self, args) -> Optional[str]: - (val, args) = args + def _findAdded(self, args: Tuple[str, List[Any]]) -> Optional[str]: + (val, __) = args try: days = int(val) except ValueError: @@ -325,20 +328,20 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000 return "c.id > %d" % cutoff - def _findProp(self, args) -> Optional[str]: + def _findProp(self, args: Tuple[str, List[Any]]) -> Optional[str]: # extract - (val, args) = args - m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", val) + (strval, __) = args + m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", strval) if not m: return None - prop, cmp, val = m.groups() + prop, cmp, strval = m.groups() prop = prop.lower() # pytype: disable=attribute-error # is val valid? try: if prop == "ease": - val = float(val) + val = float(strval) else: - val = int(val) + val = int(strval) except ValueError: return None # is prop valid? @@ -356,32 +359,32 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ q.append("(%s %s %s)" % (prop, cmp, val)) return " and ".join(q) - def _findText(self, val, args) -> str: + def _findText(self, val: str, args: List[str]) -> str: val = val.replace("*", "%") args.append("%" + val + "%") args.append("%" + val + "%") return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')" - def _findNids(self, args) -> Optional[str]: - (val, args) = args + def _findNids(self, args: Tuple[str, List[Any]]) -> Optional[str]: + (val, __) = args if re.search("[^0-9,]", val): return None return "n.id in (%s)" % val def _findCids(self, args) -> Optional[str]: - (val, args) = args + (val, __) = args if re.search("[^0-9,]", val): return None return "c.id in (%s)" % val def _findMid(self, args) -> Optional[str]: - (val, args) = args + (val, __) = args if re.search("[^0-9]", val): return None return "n.mid = %s" % val - def _findModel(self, args) -> str: - (val, args) = args + def _findModel(self, args: Tuple[str, List[Any]]) -> str: + (val, __) = args ids = [] val = val.lower() for m in self.col.models.all(): @@ -389,9 +392,9 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ ids.append(m["id"]) return "n.mid in %s" % ids2str(ids) - def _findDeck(self, args) -> Optional[str]: + def _findDeck(self, args: Tuple[str, List[Any]]) -> Optional[str]: # if searching for all decks, skip - (val, args) = args + (val, __) = args if val == "*": return "skip" # deck types @@ -422,9 +425,9 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ sids = ids2str(ids) return "c.did in %s or c.odid in %s" % (sids, sids) - def _findTemplate(self, args) -> str: + def _findTemplate(self, args: Tuple[str, List[Any]]) -> str: # were we given an ordinal number? - (val, args) = args + (val, __) = args try: num = int(val) - 1 except: @@ -445,7 +448,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """ lims.append("(n.mid = %s and c.ord = %s)" % (m["id"], t["ord"])) return " or ".join(lims) - def _findField(self, field, val) -> Optional[str]: + def _findField(self, field: str, val: str) -> Optional[str]: field = field.lower() val = val.replace("*", "%") # find models that have that field @@ -481,7 +484,7 @@ where mid in %s and flds like ? escape '\\'""" def _findDupes(self, args) -> Optional[str]: # caller must call stripHTMLMedia on passed val - (val, args) = args + (val, __) = args try: mid, val = val.split(",", 1) except OSError: @@ -500,9 +503,17 @@ where mid in %s and flds like ? escape '\\'""" ########################################################################## -def findReplace(col, nids, src, dst, regex=False, field=None, fold=True) -> int: +def findReplace( + col: _Collection, + nids: List[int], + src: str, + dst: str, + regex: bool = False, + field: Optional[str] = None, + fold: bool = True, +) -> int: "Find and replace fields in a note." - mmap = {} + mmap: Dict[str, Any] = {} if field: for m in col.models.all(): for f in m["flds"]: @@ -516,10 +527,10 @@ def findReplace(col, nids, src, dst, regex=False, field=None, fold=True) -> int: dst = dst.replace("\\", "\\\\") if fold: src = "(?i)" + src - regex = re.compile(src) + compiled_re = re.compile(src) - def repl(str): - return re.sub(regex, dst, str) + def repl(s: str): + return compiled_re.sub(dst, s) d = [] snids = ids2str(nids) @@ -577,7 +588,9 @@ def fieldNamesForNotes(col, nids) -> List: # Find duplicates ########################################################################## # returns array of ("dupestr", [nids]) -def findDupes(col, fieldName, search="") -> List[Tuple[Any, List]]: +def findDupes( + col: _Collection, fieldName: str, search: str = "" +) -> List[Tuple[Any, List]]: # limit search to notes with applicable field name if search: search = "(" + search + ") "