Monkeytype pylib/anki/find.py

This commit is contained in:
Alan Du 2020-02-26 00:24:32 -05:00
parent cb71cbad54
commit b157ee7570

View file

@ -1,10 +1,11 @@
# Copyright: Ankitects Pty Ltd and contributors # Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from __future__ import annotations
import re import re
import sre_constants import sre_constants
import unicodedata import unicodedata
from typing import Any, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union, cast
from anki import hooks from anki import hooks
from anki.consts import * from anki.consts import *
@ -18,12 +19,15 @@ from anki.utils import (
stripHTMLMedia, stripHTMLMedia,
) )
if TYPE_CHECKING:
from anki.collection import _Collection
# Find # Find
########################################################################## ##########################################################################
class Finder: class Finder:
def __init__(self, col) -> None: def __init__(self, col: Optional[_Collection]) -> None:
self.col = col self.col = col
self.search = dict( self.search = dict(
added=self._findAdded, added=self._findAdded,
@ -42,7 +46,7 @@ class Finder:
self.search["is"] = self._findCardState self.search["is"] = self._findCardState
hooks.search_terms_prepared(self.search) hooks.search_terms_prepared(self.search)
def findCards(self, query, order=False) -> Any: def findCards(self, query: str, order: Union[bool, str] = False) -> List[Any]:
"Return a list of card ids for QUERY." "Return a list of card ids for QUERY."
tokens = self._tokenize(query) tokens = self._tokenize(query)
preds, args = self._where(tokens) preds, args = self._where(tokens)
@ -59,7 +63,7 @@ class Finder:
res.reverse() res.reverse()
return res return res
def findNotes(self, query) -> Any: def findNotes(self, query: str) -> List[Any]:
tokens = self._tokenize(query) tokens = self._tokenize(query)
preds, args = self._where(tokens) preds, args = self._where(tokens)
if preds is None: if preds is None:
@ -83,8 +87,8 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
# Tokenizing # Tokenizing
###################################################################### ######################################################################
def _tokenize(self, query) -> List: def _tokenize(self, query: str) -> List[str]:
inQuote = False inQuote: Union[bool, str] = False
tokens = [] tokens = []
token = "" token = ""
for c in query: for c in query:
@ -137,7 +141,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
# Query building # Query building
###################################################################### ######################################################################
def _where(self, tokens) -> Tuple[Any, Optional[List[str]]]: def _where(self, tokens: List[str]) -> Tuple[str, Optional[List[str]]]:
# state and query # state and query
s: Dict[str, Any] = dict(isnot=False, isor=False, join=False, q="", bad=False) s: Dict[str, Any] = dict(isnot=False, isor=False, join=False, q="", bad=False)
args: List[Any] = [] args: List[Any] = []
@ -197,7 +201,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
return None, None return None, None
return s["q"], args return s["q"], args
def _query(self, preds, order) -> str: def _query(self, preds: str, order: str) -> str:
# can we skip the note table? # can we skip the note table?
if "n." not in preds and "n." not in order: if "n." not in preds and "n." not in order:
sql = "select c.id from cards c where " sql = "select c.id from cards c where "
@ -216,12 +220,12 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
# Ordering # Ordering
###################################################################### ######################################################################
def _order(self, order) -> Tuple[Any, Any]: def _order(self, order: Union[bool, str]) -> Tuple[str, bool]:
if not order: if not order:
return "", False return "", False
elif order is not True: elif order is not True:
# custom order string provided # custom order string provided
return " order by " + order, False return " order by " + cast(str, order), False
# use deck default # use deck default
type = self.col.conf["sortType"] type = self.col.conf["sortType"]
sort = None sort = None
@ -253,8 +257,8 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
# Commands # Commands
###################################################################### ######################################################################
def _findTag(self, args) -> str: def _findTag(self, args: Tuple[str, List[Any]]) -> str:
(val, args) = args (val, list_args) = args
if val == "none": if val == "none":
return 'n.tags = ""' return 'n.tags = ""'
val = val.replace("*", "%") val = val.replace("*", "%")
@ -262,11 +266,11 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
val = "% " + val val = "% " + val
if not val.endswith("%") or val.endswith("\\%"): if not val.endswith("%") or val.endswith("\\%"):
val += " %" val += " %"
args.append(val) list_args.append(val)
return "n.tags like ? escape '\\'" return "n.tags like ? escape '\\'"
def _findCardState(self, args) -> Optional[str]: def _findCardState(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, args) = args (val, __) = args
if val in ("review", "new", "learn"): if val in ("review", "new", "learn"):
if val == "review": if val == "review":
n = 2 n = 2
@ -290,17 +294,16 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
# unknown # unknown
return None return None
def _findFlag(self, args) -> Optional[str]: def _findFlag(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, args) = args (val, __) = args
if not val or len(val) != 1 or val not in "01234": if not val or len(val) != 1 or val not in "01234":
return None return None
val = int(val)
mask = 2 ** 3 - 1 mask = 2 ** 3 - 1
return "(c.flags & %d) == %d" % (mask, val) return "(c.flags & %d) == %d" % (mask, int(val))
def _findRated(self, args) -> Optional[str]: def _findRated(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# days(:optional_ease) # days(:optional_ease)
(val, args) = args (val, __) = args
r = val.split(":") r = val.split(":")
try: try:
days = int(r[0]) days = int(r[0])
@ -316,8 +319,8 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000 cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000
return "c.id in (select cid from revlog where id>%d %s)" % (cutoff, ease) return "c.id in (select cid from revlog where id>%d %s)" % (cutoff, ease)
def _findAdded(self, args) -> Optional[str]: def _findAdded(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, args) = args (val, __) = args
try: try:
days = int(val) days = int(val)
except ValueError: except ValueError:
@ -325,20 +328,20 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000 cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000
return "c.id > %d" % cutoff return "c.id > %d" % cutoff
def _findProp(self, args) -> Optional[str]: def _findProp(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# extract # extract
(val, args) = args (strval, __) = args
m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", val) m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", strval)
if not m: if not m:
return None return None
prop, cmp, val = m.groups() prop, cmp, strval = m.groups()
prop = prop.lower() # pytype: disable=attribute-error prop = prop.lower() # pytype: disable=attribute-error
# is val valid? # is val valid?
try: try:
if prop == "ease": if prop == "ease":
val = float(val) val = float(strval)
else: else:
val = int(val) val = int(strval)
except ValueError: except ValueError:
return None return None
# is prop valid? # is prop valid?
@ -356,32 +359,32 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
q.append("(%s %s %s)" % (prop, cmp, val)) q.append("(%s %s %s)" % (prop, cmp, val))
return " and ".join(q) return " and ".join(q)
def _findText(self, val, args) -> str: def _findText(self, val: str, args: List[str]) -> str:
val = val.replace("*", "%") val = val.replace("*", "%")
args.append("%" + val + "%") args.append("%" + val + "%")
args.append("%" + val + "%") args.append("%" + val + "%")
return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')" return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')"
def _findNids(self, args) -> Optional[str]: def _findNids(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, args) = args (val, __) = args
if re.search("[^0-9,]", val): if re.search("[^0-9,]", val):
return None return None
return "n.id in (%s)" % val return "n.id in (%s)" % val
def _findCids(self, args) -> Optional[str]: def _findCids(self, args) -> Optional[str]:
(val, args) = args (val, __) = args
if re.search("[^0-9,]", val): if re.search("[^0-9,]", val):
return None return None
return "c.id in (%s)" % val return "c.id in (%s)" % val
def _findMid(self, args) -> Optional[str]: def _findMid(self, args) -> Optional[str]:
(val, args) = args (val, __) = args
if re.search("[^0-9]", val): if re.search("[^0-9]", val):
return None return None
return "n.mid = %s" % val return "n.mid = %s" % val
def _findModel(self, args) -> str: def _findModel(self, args: Tuple[str, List[Any]]) -> str:
(val, args) = args (val, __) = args
ids = [] ids = []
val = val.lower() val = val.lower()
for m in self.col.models.all(): for m in self.col.models.all():
@ -389,9 +392,9 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
ids.append(m["id"]) ids.append(m["id"])
return "n.mid in %s" % ids2str(ids) return "n.mid in %s" % ids2str(ids)
def _findDeck(self, args) -> Optional[str]: def _findDeck(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# if searching for all decks, skip # if searching for all decks, skip
(val, args) = args (val, __) = args
if val == "*": if val == "*":
return "skip" return "skip"
# deck types # deck types
@ -422,9 +425,9 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
sids = ids2str(ids) sids = ids2str(ids)
return "c.did in %s or c.odid in %s" % (sids, sids) return "c.did in %s or c.odid in %s" % (sids, sids)
def _findTemplate(self, args) -> str: def _findTemplate(self, args: Tuple[str, List[Any]]) -> str:
# were we given an ordinal number? # were we given an ordinal number?
(val, args) = args (val, __) = args
try: try:
num = int(val) - 1 num = int(val) - 1
except: except:
@ -445,7 +448,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """
lims.append("(n.mid = %s and c.ord = %s)" % (m["id"], t["ord"])) lims.append("(n.mid = %s and c.ord = %s)" % (m["id"], t["ord"]))
return " or ".join(lims) return " or ".join(lims)
def _findField(self, field, val) -> Optional[str]: def _findField(self, field: str, val: str) -> Optional[str]:
field = field.lower() field = field.lower()
val = val.replace("*", "%") val = val.replace("*", "%")
# find models that have that field # find models that have that field
@ -481,7 +484,7 @@ where mid in %s and flds like ? escape '\\'"""
def _findDupes(self, args) -> Optional[str]: def _findDupes(self, args) -> Optional[str]:
# caller must call stripHTMLMedia on passed val # caller must call stripHTMLMedia on passed val
(val, args) = args (val, __) = args
try: try:
mid, val = val.split(",", 1) mid, val = val.split(",", 1)
except OSError: except OSError:
@ -500,9 +503,17 @@ where mid in %s and flds like ? escape '\\'"""
########################################################################## ##########################################################################
def findReplace(col, nids, src, dst, regex=False, field=None, fold=True) -> int: def findReplace(
col: _Collection,
nids: List[int],
src: str,
dst: str,
regex: bool = False,
field: Optional[str] = None,
fold: bool = True,
) -> int:
"Find and replace fields in a note." "Find and replace fields in a note."
mmap = {} mmap: Dict[str, Any] = {}
if field: if field:
for m in col.models.all(): for m in col.models.all():
for f in m["flds"]: for f in m["flds"]:
@ -516,10 +527,10 @@ def findReplace(col, nids, src, dst, regex=False, field=None, fold=True) -> int:
dst = dst.replace("\\", "\\\\") dst = dst.replace("\\", "\\\\")
if fold: if fold:
src = "(?i)" + src src = "(?i)" + src
regex = re.compile(src) compiled_re = re.compile(src)
def repl(str): def repl(s: str):
return re.sub(regex, dst, str) return compiled_re.sub(dst, s)
d = [] d = []
snids = ids2str(nids) snids = ids2str(nids)
@ -577,7 +588,9 @@ def fieldNamesForNotes(col, nids) -> List:
# Find duplicates # Find duplicates
########################################################################## ##########################################################################
# returns array of ("dupestr", [nids]) # returns array of ("dupestr", [nids])
def findDupes(col, fieldName, search="") -> List[Tuple[Any, List]]: def findDupes(
col: _Collection, fieldName: str, search: str = ""
) -> List[Tuple[Any, List]]:
# limit search to notes with applicable field name # limit search to notes with applicable field name
if search: if search:
search = "(" + search + ") " search = "(" + search + ") "