Merge remote-tracking branch 'danielelmes/master' into fix_tests_on_windows

# Conflicts:
#	.github/scripts/trailing-newlines.sh
This commit is contained in:
evandrocoan 2020-03-23 18:44:11 -03:00
commit b1b3e5b87c
81 changed files with 3897 additions and 1760 deletions

1
.gitattributes vendored Normal file
View file

@ -0,0 +1 @@
*.ftl eol=lf

View file

@ -6,7 +6,8 @@ set -eo pipefail
# Because `set -e` does not work inside the subshell $()
rg --version > /dev/null 2>&1
files=$(rg -l '[^\n]\z' -g '!*.{png,svg,scss,json}' || true)
files=$(rg -l '[^\n]\z' -g '!*.{png,svg,scss,json,sql}' || true)
if [ "$files" != "" ]; then
echo "the following files are missing a newline on the last line:"
echo $files

View file

@ -1 +1 @@
2.1.22
2.1.24

View file

@ -6,13 +6,14 @@ package backend_proto;
message Empty {}
message OptionalInt32 {
sint32 val = 1;
}
message BackendInit {
string collection_path = 1;
string media_folder_path = 2;
string media_db_path = 3;
repeated string preferred_langs = 4;
string locale_folder_path = 5;
string log_path = 6;
repeated string preferred_langs = 1;
string locale_folder_path = 2;
bool server = 3;
}
message I18nBackendInit {
@ -27,8 +28,8 @@ message BackendInput {
TemplateRequirementsIn template_requirements = 16;
SchedTimingTodayIn sched_timing_today = 17;
Empty deck_tree = 18;
FindCardsIn find_cards = 19;
BrowserRowsIn browser_rows = 20;
SearchCardsIn search_cards = 19;
SearchNotesIn search_notes = 20;
RenderCardIn render_card = 21;
int64 local_minutes_west = 22;
string strip_av_tags = 23;
@ -44,6 +45,8 @@ message BackendInput {
CongratsLearnMsgIn congrats_learn_msg = 33;
Empty empty_trash = 34;
Empty restore_trash = 35;
OpenCollectionIn open_collection = 36;
Empty close_collection = 37;
}
}
@ -63,8 +66,8 @@ message BackendOutput {
// fallible commands
TemplateRequirementsOut template_requirements = 16;
DeckTreeOut deck_tree = 18;
FindCardsOut find_cards = 19;
BrowserRowsOut browser_rows = 20;
SearchCardsOut search_cards = 19;
SearchNotesOut search_notes = 20;
RenderCardOut render_card = 21;
string add_media_file = 26;
Empty sync_media = 27;
@ -72,6 +75,8 @@ message BackendOutput {
Empty trash_media_files = 29;
Empty empty_trash = 34;
Empty restore_trash = 35;
Empty open_collection = 36;
Empty close_collection = 37;
BackendError error = 2047;
}
@ -162,10 +167,10 @@ message TemplateRequirementAny {
message SchedTimingTodayIn {
int64 created_secs = 1;
sint32 created_mins_west = 2;
int64 now_secs = 3;
sint32 now_mins_west = 4;
sint32 rollover_hour = 5;
int64 now_secs = 2;
OptionalInt32 created_mins_west = 3;
OptionalInt32 now_mins_west = 4;
OptionalInt32 rollover_hour = 5;
}
message SchedTimingTodayOut {
@ -188,23 +193,6 @@ message DeckTreeNode {
bool collapsed = 7;
}
message FindCardsIn {
string search = 1;
}
message FindCardsOut {
repeated int64 card_ids = 1;
}
message BrowserRowsIn {
repeated int64 card_ids = 1;
}
message BrowserRowsOut {
// just sort fields for proof of concept
repeated string sort_fields = 1;
}
message RenderCardIn {
string question_template = 1;
string answer_template = 2;
@ -324,3 +312,58 @@ message CongratsLearnMsgIn {
float next_due = 1;
uint32 remaining = 2;
}
message OpenCollectionIn {
string collection_path = 1;
string media_folder_path = 2;
string media_db_path = 3;
string log_path = 4;
}
message SearchCardsIn {
string search = 1;
SortOrder order = 2;
}
message SearchCardsOut {
repeated int64 card_ids = 1;
}
message SortOrder {
oneof value {
Empty from_config = 1;
Empty none = 2;
string custom = 3;
BuiltinSearchOrder builtin = 4;
}
}
message SearchNotesIn {
string search = 1;
}
message SearchNotesOut {
repeated int64 note_ids = 2;
}
message BuiltinSearchOrder {
BuiltinSortKind kind = 1;
bool reverse = 2;
}
enum BuiltinSortKind {
NOTE_CREATION = 0;
NOTE_MOD = 1;
NOTE_FIELD = 2;
NOTE_TAGS = 3;
NOTE_TYPE = 4;
CARD_MOD = 5;
CARD_REPS = 6;
CARD_DUE = 7;
CARD_EASE = 8;
CARD_LAPSES = 9;
CARD_INTERVAL = 10;
CARD_DECK = 11;
CARD_TEMPLATE = 12;
}

View file

@ -124,30 +124,6 @@ insert or replace into cards values
)
self.col.log(self)
def flushSched(self) -> None:
self._preFlush()
# bug checks
self.col.db.execute(
"""update cards set
mod=?, usn=?, type=?, queue=?, due=?, ivl=?, factor=?, reps=?,
lapses=?, left=?, odue=?, odid=?, did=? where id = ?""",
self.mod,
self.usn,
self.type,
self.queue,
self.due,
self.ivl,
self.factor,
self.reps,
self.lapses,
self.left,
self.odue,
self.odid,
self.did,
self.id,
)
self.col.log(self)
def question(self, reload: bool = False, browser: bool = False) -> str:
return self.css() + self.render_output(reload, browser).question_text
@ -181,6 +157,8 @@ lapses=?, left=?, odue=?, odid=?, did=? where id = ?""",
def note_type(self) -> NoteType:
return self.col.models.get(self.note().mid)
# legacy aliases
flushSched = flush
q = question
a = answer
model = note_type

View file

@ -15,7 +15,7 @@ import time
import traceback
import unicodedata
import weakref
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import anki.find
import anki.latex # sets up hook
@ -23,7 +23,7 @@ import anki.template
from anki import hooks
from anki.cards import Card
from anki.consts import *
from anki.db import DB
from anki.dbproxy import DBProxy
from anki.decks import DeckManager
from anki.errors import AnkiError
from anki.lang import _, ngettext
@ -67,7 +67,7 @@ defaultConf = {
# this is initialized by storage.Collection
class _Collection:
db: Optional[DB]
db: Optional[DBProxy]
sched: Union[V1Scheduler, V2Scheduler]
crt: int
mod: int
@ -80,13 +80,12 @@ class _Collection:
def __init__(
self,
db: DB,
db: DBProxy,
backend: RustBackend,
server: Optional["anki.storage.ServerData"] = None,
log: bool = False,
) -> None:
self.backend = backend
self._debugLog = log
self._debugLog = not server
self.db = db
self.path = db._path
self._openLog()
@ -139,10 +138,6 @@ class _Collection:
self.sched = V1Scheduler(self)
elif ver == 2:
self.sched = V2Scheduler(self)
if not self.server:
self.conf["localOffset"] = self.sched._current_timezone_offset()
elif self.server.minutes_west is not None:
self.conf["localOffset"] = self.server.minutes_west
def changeSchedulerVer(self, ver: int) -> None:
if ver == self.schedVer():
@ -165,12 +160,13 @@ class _Collection:
self._loadScheduler()
# the sync code uses this to send the local timezone to AnkiWeb
def localOffset(self) -> Optional[int]:
"Minutes west of UTC. Only applies to V2 scheduler."
if isinstance(self.sched, V1Scheduler):
return None
else:
return self.sched._current_timezone_offset()
return self.backend.local_minutes_west(intTime())
# DB-related
##########################################################################
@ -220,8 +216,10 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
json.dumps(self.conf),
)
def save(self, name: Optional[str] = None, mod: Optional[int] = None) -> None:
"Flush, commit DB, and take out another write lock."
def save(
self, name: Optional[str] = None, mod: Optional[int] = None, trx: bool = True
) -> None:
"Flush, commit DB, and take out another write lock if trx=True."
# let the managers conditionally flush
self.models.flush()
self.decks.flush()
@ -230,8 +228,14 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
if self.db.mod:
self.flush(mod=mod)
self.db.commit()
self.lock()
self.db.mod = False
if trx:
self.db.begin()
elif not trx:
# if no changes were pending but calling code expects to be
# outside of a transaction, we need to roll back
self.db.rollback()
self._markOp(name)
self._lastSave = time.time()
@ -242,39 +246,24 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
return True
return None
def lock(self) -> None:
# make sure we don't accidentally bump mod time
mod = self.db.mod
self.db.execute("update col set mod=mod")
self.db.mod = mod
def close(self, save: bool = True) -> None:
"Disconnect from DB."
if self.db:
if save:
self.save()
self.save(trx=False)
else:
self.db.rollback()
if not self.server:
self.db.setAutocommit(True)
self.db.execute("pragma journal_mode = delete")
self.db.setAutocommit(False)
self.db.close()
self.backend.close_collection()
self.db = None
self.media.close()
self._closeLog()
def reopen(self) -> None:
"Reconnect to DB (after changing threads, etc)."
if not self.db:
self.db = DB(self.path)
self.media.connect()
self._openLog()
def rollback(self) -> None:
self.db.rollback()
self.db.begin()
self.load()
self.lock()
def modSchema(self, check: bool) -> None:
"Mark schema modified. Call this first so user can abort if necessary."
@ -305,10 +294,10 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
self.modSchema(check=False)
self.ls = self.scm
# ensure db is compacted before upload
self.db.setAutocommit(True)
self.save(trx=False)
self.db.execute("vacuum")
self.db.execute("analyze")
self.close()
self.close(save=False)
# Object creation helpers
##########################################################################
@ -626,11 +615,25 @@ where c.nid = n.id and c.id in %s group by nid"""
# Finding cards
##########################################################################
def findCards(self, query: str, order: Union[bool, str] = False) -> Any:
return anki.find.Finder(self).findCards(query, order)
# if order=True, use the sort order stored in the collection config
# if order=False, do no ordering
#
# if order is a string, that text is added after 'order by' in the sql statement.
# you must add ' asc' or ' desc' to the order, as Anki will replace asc with
# desc and vice versa when reverse is set in the collection config, eg
# order="c.ivl asc, c.due desc"
#
# if order is an int enum, sort using that builtin sort, eg
# col.find_cards("", order=BuiltinSortKind.CARD_DUE)
# the reverse argument only applies when a BuiltinSortKind is provided;
# otherwise the collection config defines whether reverse is set or not
def find_cards(
self, query: str, order: Union[bool, str, int] = False, reverse: bool = False,
) -> Sequence[int]:
return self.backend.search_cards(query, order, reverse)
def findNotes(self, query: str) -> Any:
return anki.find.Finder(self).findNotes(query)
def find_notes(self, query: str) -> Sequence[int]:
return self.backend.search_notes(query)
def findReplace(
self,
@ -646,6 +649,9 @@ where c.nid = n.id and c.id in %s group by nid"""
def findDupes(self, fieldName: str, search: str = "") -> List[Tuple[Any, list]]:
return anki.find.findDupes(self, fieldName, search)
findCards = find_cards
findNotes = find_notes
# Stats
##########################################################################
@ -793,7 +799,6 @@ select id from notes where mid = ?) limit 1"""
problems = []
# problems that don't require a full sync
syncable_problems = []
curs = self.db.cursor()
self.save()
oldSize = os.stat(self.path)[stat.ST_SIZE]
if self.db.scalar("pragma integrity_check") != "ok":
@ -942,16 +947,18 @@ select id from cards where odid > 0 and did in %s"""
self.updateFieldCache(self.models.nids(m))
# new cards can't have a due position > 32 bits, so wrap items over
# 2 million back to 1 million
curs.execute(
self.db.execute(
"""
update cards set due=1000000+due%1000000,mod=?,usn=? where due>=1000000
and type=0""",
[intTime(), self.usn()],
intTime(),
self.usn(),
)
if curs.rowcount:
rowcount = self.db.scalar("select changes()")
if rowcount:
problems.append(
"Found %d new cards with a due number >= 1,000,000 - consider repositioning them in the Browse screen."
% curs.rowcount
% rowcount
)
# new card position
self.conf["nextPos"] = (
@ -969,18 +976,20 @@ and type=0""",
self.usn(),
)
# v2 sched had a bug that could create decimal intervals
curs.execute(
self.db.execute(
"update cards set ivl=round(ivl),due=round(due) where ivl!=round(ivl) or due!=round(due)"
)
if curs.rowcount:
problems.append("Fixed %d cards with v2 scheduler bug." % curs.rowcount)
rowcount = self.db.scalar("select changes()")
if rowcount:
problems.append("Fixed %d cards with v2 scheduler bug." % rowcount)
curs.execute(
self.db.execute(
"update revlog set ivl=round(ivl),lastIvl=round(lastIvl) where ivl!=round(ivl) or lastIvl!=round(lastIvl)"
)
if curs.rowcount:
rowcount = self.db.scalar("select changes()")
if rowcount:
problems.append(
"Fixed %d review history entries with v2 scheduler bug." % curs.rowcount
"Fixed %d review history entries with v2 scheduler bug." % rowcount
)
# models
if self.models.ensureNotEmpty():
@ -1011,11 +1020,10 @@ and type=0""",
return len(to_fix)
def optimize(self) -> None:
self.db.setAutocommit(True)
self.save(trx=False)
self.db.execute("vacuum")
self.db.execute("analyze")
self.db.setAutocommit(False)
self.lock()
self.db.begin()
# Logging
##########################################################################

View file

@ -1,6 +1,14 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
"""
A convenience wrapper over pysqlite.
Anki's Collection class now uses dbproxy.py instead of this class,
but this class is still used by aqt's profile manager, and a number
of add-ons rely on it.
"""
import os
import time
from sqlite3 import Cursor

92
pylib/anki/dbproxy.py Normal file
View file

@ -0,0 +1,92 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from __future__ import annotations
from typing import Any, Iterable, List, Optional, Sequence, Union
import anki
# DBValue is actually Union[str, int, float, None], but if defined
# that way, every call site needs to do a type check prior to using
# the return values.
ValueFromDB = Any
Row = Sequence[ValueFromDB]
ValueForDB = Union[str, int, float, None]
class DBProxy:
# Lifecycle
###############
def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None:
self._backend = backend
self._path = path
self.mod = False
# Transactions
###############
def begin(self) -> None:
self._backend.db_begin()
def commit(self) -> None:
self._backend.db_commit()
def rollback(self) -> None:
self._backend.db_rollback()
# Querying
################
def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False
) -> List[Row]:
# mark modified?
s = sql.strip().lower()
for stmt in "insert", "update", "delete":
if s.startswith(stmt):
self.mod = True
assert ":" not in sql
# fetch rows
return self._backend.db_query(sql, args, first_row_only)
# Query shortcuts
###################
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
return self._query(sql, *args)
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args)]
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0]
else:
return None
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0][0]
else:
return None
# execute used to return a pysqlite cursor, but now is synonymous
# with .all()
execute = all
# Updates
################
def executemany(self, sql: str, args: Iterable[Sequence[ValueForDB]]) -> None:
self.mod = True
assert ":" not in sql
if isinstance(args, list):
list_args = args
else:
list_args = list(args)
self._backend.db_execute_many(sql, list_args)

View file

@ -397,20 +397,20 @@ class AnkiCollectionPackageExporter(AnkiPackageExporter):
AnkiPackageExporter.__init__(self, col)
def doExport(self, z, path):
# close our deck & write it into the zip file, and reopen
"Export collection. Caller must re-open afterwards."
# close our deck & write it into the zip file
self.count = self.col.cardCount()
v2 = self.col.schedVer() != 1
mdir = self.col.media.dir()
self.col.close()
if not v2:
z.write(self.col.path, "collection.anki2")
else:
self._addDummyCollection(z)
z.write(self.col.path, "collection.anki21")
self.col.reopen()
# copy all media
if not self.includeMedia:
return {}
mdir = self.col.media.dir()
return self._exportMedia(z, os.listdir(mdir), mdir)

View file

@ -4,500 +4,25 @@
from __future__ import annotations
import re
import sre_constants
import unicodedata
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Set
from anki import hooks
from anki.consts import *
from anki.hooks import *
from anki.utils import (
fieldChecksum,
ids2str,
intTime,
joinFields,
splitFields,
stripHTMLMedia,
)
from anki.utils import ids2str, intTime, joinFields, splitFields, stripHTMLMedia
if TYPE_CHECKING:
from anki.collection import _Collection
# Find
##########################################################################
class Finder:
def __init__(self, col: Optional[_Collection]) -> None:
self.col = col.weakref()
self.search = dict(
added=self._findAdded,
card=self._findTemplate,
deck=self._findDeck,
mid=self._findMid,
nid=self._findNids,
cid=self._findCids,
note=self._findModel,
prop=self._findProp,
rated=self._findRated,
tag=self._findTag,
dupe=self._findDupes,
flag=self._findFlag,
)
self.search["is"] = self._findCardState
hooks.search_terms_prepared(self.search)
print("Finder() is deprecated, please use col.find_cards() or .find_notes()")
def findCards(self, query: str, order: Union[bool, str] = False) -> List[Any]:
"Return a list of card ids for QUERY."
tokens = self._tokenize(query)
preds, args = self._where(tokens)
if preds is None:
raise Exception("invalidSearch")
order, rev = self._order(order)
sql = self._query(preds, order)
try:
res = self.col.db.list(sql, *args)
except:
# invalid grouping
return []
if rev:
res.reverse()
return res
def findCards(self, query, order):
return self.col.find_cards(query, order)
def findNotes(self, query: str) -> List[Any]:
tokens = self._tokenize(query)
preds, args = self._where(tokens)
if preds is None:
return []
if preds:
preds = "(" + preds + ")"
else:
preds = "1"
sql = (
"""
select distinct(n.id) from cards c, notes n where c.nid=n.id and """
+ preds
)
try:
res = self.col.db.list(sql, *args)
except:
# invalid grouping
return []
return res
# Tokenizing
######################################################################
def _tokenize(self, query: str) -> List[str]:
inQuote: Union[bool, str] = False
tokens = []
token = ""
for c in query:
# quoted text
if c in ("'", '"'):
if inQuote:
if c == inQuote:
inQuote = False
else:
token += c
elif token:
# quotes are allowed to start directly after a :
if token[-1] == ":":
inQuote = c
else:
token += c
else:
inQuote = c
# separator (space and ideographic space)
elif c in (" ", "\u3000"):
if inQuote:
token += c
elif token:
# space marks token finished
tokens.append(token)
token = ""
# nesting
elif c in ("(", ")"):
if inQuote:
token += c
else:
if c == ")" and token:
tokens.append(token)
token = ""
tokens.append(c)
# negation
elif c == "-":
if token:
token += c
elif not tokens or tokens[-1] != "-":
tokens.append("-")
# normal character
else:
token += c
# if we finished in a token, add it
if token:
tokens.append(token)
return tokens
# Query building
######################################################################
def _where(self, tokens: List[str]) -> Tuple[str, Optional[List[str]]]:
# state and query
s: Dict[str, Any] = dict(isnot=False, isor=False, join=False, q="", bad=False)
args: List[Any] = []
def add(txt, wrap=True):
# failed command?
if not txt:
# if it was to be negated then we can just ignore it
if s["isnot"]:
s["isnot"] = False
return None, None
else:
s["bad"] = True
return None, None
elif txt == "skip":
return None, None
# do we need a conjunction?
if s["join"]:
if s["isor"]:
s["q"] += " or "
s["isor"] = False
else:
s["q"] += " and "
if s["isnot"]:
s["q"] += " not "
s["isnot"] = False
if wrap:
txt = "(" + txt + ")"
s["q"] += txt
s["join"] = True
for token in tokens:
if s["bad"]:
return None, None
# special tokens
if token == "-":
s["isnot"] = True
elif token.lower() == "or":
s["isor"] = True
elif token == "(":
add(token, wrap=False)
s["join"] = False
elif token == ")":
s["q"] += ")"
# commands
elif ":" in token:
cmd, val = token.split(":", 1)
cmd = cmd.lower()
if cmd in self.search:
add(self.search[cmd]((val, args)))
else:
add(self._findField(cmd, val))
# normal text search
else:
add(self._findText(token, args))
if s["bad"]:
return None, None
return s["q"], args
def _query(self, preds: str, order: str) -> str:
# can we skip the note table?
if "n." not in preds and "n." not in order:
sql = "select c.id from cards c where "
else:
sql = "select c.id from cards c, notes n where c.nid=n.id and "
# combine with preds
if preds:
sql += "(" + preds + ")"
else:
sql += "1"
# order
if order:
sql += " " + order
return sql
# Ordering
######################################################################
def _order(self, order: Union[bool, str]) -> Tuple[str, bool]:
if not order:
return "", False
elif order is not True:
# custom order string provided
return " order by " + cast(str, order), False
# use deck default
type = self.col.conf["sortType"]
sort = None
if type.startswith("note"):
if type == "noteCrt":
sort = "n.id, c.ord"
elif type == "noteMod":
sort = "n.mod, c.ord"
elif type == "noteFld":
sort = "n.sfld collate nocase, c.ord"
elif type.startswith("card"):
if type == "cardMod":
sort = "c.mod"
elif type == "cardReps":
sort = "c.reps"
elif type == "cardDue":
sort = "c.type, c.due"
elif type == "cardEase":
sort = f"c.type == {CARD_TYPE_NEW}, c.factor"
elif type == "cardLapses":
sort = "c.lapses"
elif type == "cardIvl":
sort = "c.ivl"
if not sort:
# deck has invalid sort order; revert to noteCrt
sort = "n.id, c.ord"
return " order by " + sort, self.col.conf["sortBackwards"]
# Commands
######################################################################
def _findTag(self, args: Tuple[str, List[Any]]) -> str:
(val, list_args) = args
if val == "none":
return 'n.tags = ""'
val = val.replace("*", "%")
if not val.startswith("%"):
val = "% " + val
if not val.endswith("%") or val.endswith("\\%"):
val += " %"
list_args.append(val)
return "n.tags like ? escape '\\'"
def _findCardState(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
if val in ("review", "new", "learn"):
if val == "review":
n = 2
elif val == "new":
n = CARD_TYPE_NEW
else:
return f"queue in ({QUEUE_TYPE_LRN}, {QUEUE_TYPE_DAY_LEARN_RELEARN})"
return "type = %d" % n
elif val == "suspended":
return "c.queue = -1"
elif val == "buried":
return f"c.queue in ({QUEUE_TYPE_SIBLING_BURIED}, {QUEUE_TYPE_MANUALLY_BURIED})"
elif val == "due":
return f"""
(c.queue in ({QUEUE_TYPE_REV},{QUEUE_TYPE_DAY_LEARN_RELEARN}) and c.due <= %d) or
(c.queue = {QUEUE_TYPE_LRN} and c.due <= %d)""" % (
self.col.sched.today,
self.col.sched.dayCutoff,
)
else:
# unknown
return None
def _findFlag(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
if not val or len(val) != 1 or val not in "01234":
return None
mask = 2 ** 3 - 1
return "(c.flags & %d) == %d" % (mask, int(val))
def _findRated(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# days(:optional_ease)
(val, __) = args
r = val.split(":")
try:
days = int(r[0])
except ValueError:
return None
days = min(days, 31)
# ease
ease = ""
if len(r) > 1:
if r[1] not in ("1", "2", "3", "4"):
return None
ease = "and ease=%s" % r[1]
cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000
return "c.id in (select cid from revlog where id>%d %s)" % (cutoff, ease)
def _findAdded(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
try:
days = int(val)
except ValueError:
return None
cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000
return "c.id > %d" % cutoff
def _findProp(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# extract
(strval, __) = args
m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", strval)
if not m:
return None
prop, cmp, strval = m.groups()
prop = prop.lower() # pytype: disable=attribute-error
# is val valid?
try:
if prop == "ease":
val = float(strval)
else:
val = int(strval)
except ValueError:
return None
# is prop valid?
if prop not in ("due", "ivl", "reps", "lapses", "ease"):
return None
# query
q = []
if prop == "due":
val += self.col.sched.today
# only valid for review/daily learning
q.append(f"(c.queue in ({QUEUE_TYPE_REV},{QUEUE_TYPE_DAY_LEARN_RELEARN}))")
elif prop == "ease":
prop = "factor"
val = int(val * 1000)
q.append("(%s %s %s)" % (prop, cmp, val))
return " and ".join(q)
def _findText(self, val: str, args: List[str]) -> str:
val = val.replace("*", "%")
args.append("%" + val + "%")
args.append("%" + val + "%")
return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')"
def _findNids(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
if re.search("[^0-9,]", val):
return None
return "n.id in (%s)" % val
def _findCids(self, args) -> Optional[str]:
(val, __) = args
if re.search("[^0-9,]", val):
return None
return "c.id in (%s)" % val
def _findMid(self, args) -> Optional[str]:
(val, __) = args
if re.search("[^0-9]", val):
return None
return "n.mid = %s" % val
def _findModel(self, args: Tuple[str, List[Any]]) -> str:
(val, __) = args
ids = []
val = val.lower()
for m in self.col.models.all():
if unicodedata.normalize("NFC", m["name"].lower()) == val:
ids.append(m["id"])
return "n.mid in %s" % ids2str(ids)
def _findDeck(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# if searching for all decks, skip
(val, __) = args
if val == "*":
return "skip"
# deck types
elif val == "filtered":
return "c.odid"
def dids(did):
if not did:
return None
return [did] + [a[1] for a in self.col.decks.children(did)]
# current deck?
ids = None
if val.lower() == "current":
ids = dids(self.col.decks.current()["id"])
elif "*" not in val:
# single deck
ids = dids(self.col.decks.id(val, create=False))
else:
# wildcard
ids = set()
val = re.escape(val).replace(r"\*", ".*")
for d in self.col.decks.all():
if re.match("(?i)" + val, unicodedata.normalize("NFC", d["name"])):
ids.update(dids(d["id"]))
if not ids:
return None
sids = ids2str(ids)
return "c.did in %s or c.odid in %s" % (sids, sids)
def _findTemplate(self, args: Tuple[str, List[Any]]) -> str:
# were we given an ordinal number?
(val, __) = args
try:
num = int(val) - 1
except:
num = None
if num is not None:
return "c.ord = %d" % num
# search for template names
lims = []
for m in self.col.models.all():
for t in m["tmpls"]:
if unicodedata.normalize("NFC", t["name"].lower()) == val.lower():
if m["type"] == MODEL_CLOZE:
# if the user has asked for a cloze card, we want
# to give all ordinals, so we just limit to the
# model instead
lims.append("(n.mid = %s)" % m["id"])
else:
lims.append("(n.mid = %s and c.ord = %s)" % (m["id"], t["ord"]))
return " or ".join(lims)
def _findField(self, field: str, val: str) -> Optional[str]:
field = field.lower()
val = val.replace("*", "%")
# find models that have that field
mods = {}
for m in self.col.models.all():
for f in m["flds"]:
if unicodedata.normalize("NFC", f["name"].lower()) == field:
mods[str(m["id"])] = (m, f["ord"])
if not mods:
# nothing has that field
return None
# gather nids
regex = re.escape(val).replace("_", ".").replace(re.escape("%"), ".*")
nids = []
for (id, mid, flds) in self.col.db.execute(
"""
select id, mid, flds from notes
where mid in %s and flds like ? escape '\\'"""
% (ids2str(list(mods.keys()))),
"%" + val + "%",
):
flds = splitFields(flds)
ord = mods[str(mid)][1]
strg = flds[ord]
try:
if re.search("(?si)^" + regex + "$", strg):
nids.append(id)
except sre_constants.error:
return None
if not nids:
return "0"
return "n.id in %s" % ids2str(nids)
def _findDupes(self, args) -> Optional[str]:
# caller must call stripHTMLMedia on passed val
(val, __) = args
try:
mid, val = val.split(",", 1)
except OSError:
return None
csum = fieldChecksum(val)
nids = []
for nid, flds in self.col.db.execute(
"select id, flds from notes where mid=? and csum=?", mid, csum
):
if stripHTMLMedia(splitFields(flds)[0]) == val:
nids.append(nid)
return "n.id in %s" % ids2str(nids)
def findNotes(self, query):
return self.col.find_notes(query)
# Find and replace
@ -555,11 +80,11 @@ def findReplace(
flds = joinFields(sflds)
if flds != origFlds:
nids.append(nid)
d.append(dict(nid=nid, flds=flds, u=col.usn(), m=intTime()))
d.append((flds, intTime(), col.usn(), nid))
if not d:
return 0
# replace
col.db.executemany("update notes set flds=:flds,mod=:m,usn=:u where id=:nid", d)
col.db.executemany("update notes set flds=?,mod=?,usn=? where id=?", d)
col.updateFieldCache(nids)
col.genCards(nids)
return len(d)
@ -595,7 +120,7 @@ def findDupes(
# limit search to notes with applicable field name
if search:
search = "(" + search + ") "
search += "'%s:*'" % fieldName
search += '"%s:*"' % fieldName.replace('"', '"')
# go through notes
vals: Dict[str, List[int]] = {}
dupes = []

View file

@ -492,32 +492,6 @@ class _SchemaWillChangeFilter:
schema_will_change = _SchemaWillChangeFilter()
class _SearchTermsPreparedHook:
_hooks: List[Callable[[Dict[str, Callable]], None]] = []
def append(self, cb: Callable[[Dict[str, Callable]], None]) -> None:
"""(searches: Dict[str, Callable])"""
self._hooks.append(cb)
def remove(self, cb: Callable[[Dict[str, Callable]], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, searches: Dict[str, Callable]) -> None:
for hook in self._hooks:
try:
hook(searches)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
# legacy support
runHook("search", searches)
search_terms_prepared = _SearchTermsPreparedHook()
class _SyncProgressDidChangeHook:
_hooks: List[Callable[[str], None]] = []

View file

@ -65,10 +65,7 @@ class Anki2Importer(Importer):
self._importCards()
self._importStaticMedia()
self._postImport()
self.dst.db.setAutocommit(True)
self.dst.db.execute("vacuum")
self.dst.db.execute("analyze")
self.dst.db.setAutocommit(False)
self.dst.optimize()
# Notes
######################################################################

View file

@ -287,7 +287,7 @@ content in the text file to the correct fields."""
return [intTime(), self.col.usn(), n.fieldsStr, id, n.fieldsStr]
def addUpdates(self, rows: List[List[Union[int, str]]]) -> None:
old = self.col.db.totalChanges()
changes = self.col.db.scalar("select total_changes()")
if self._tagsMapped:
self.col.db.executemany(
"""
@ -309,7 +309,8 @@ update notes set mod = ?, usn = ?, flds = ?
where id = ? and flds != ?""",
rows,
)
self.updateCount = self.col.db.totalChanges() - old
changes2 = self.col.db.scalar("select total_changes()")
self.updateCount = changes2 - changes
def processFields(
self, note: ForeignNote, fields: Optional[List[str]] = None

View file

@ -145,7 +145,7 @@ current_catalog: Optional[
] = None
# the current Fluent translation instance
current_i18n: Optional[anki.rsbackend.I18nBackend]
current_i18n: Optional[anki.rsbackend.RustBackend]
# path to locale folder
locale_folder = ""
@ -175,9 +175,9 @@ def set_lang(lang: str, locale_dir: str) -> None:
current_catalog = gettext.translation(
"anki", gettext_dir, languages=[lang], fallback=True
)
current_i18n = anki.rsbackend.I18nBackend(
preferred_langs=[lang], ftl_folder=ftl_dir
)
current_i18n = anki.rsbackend.RustBackend(ftl_folder=ftl_dir, langs=[lang])
locale_folder = locale_dir

View file

@ -171,8 +171,11 @@ class MediaManager:
##########################################################################
def check(self) -> MediaCheckOutput:
"This should be called while the collection is closed."
return self.col.backend.check_media()
output = self.col.backend.check_media()
# files may have been renamed on disk, so an undo at this point could
# break file references
self.col.save()
return output
def render_all_latex(
self, progress_cb: Optional[Callable[[int], bool]] = None

View file

@ -504,17 +504,9 @@ select id from notes where mid = ?)"""
for c in range(nfields):
flds.append(newflds.get(c, ""))
flds = joinFields(flds)
d.append(
dict(
nid=nid,
flds=flds,
mid=newModel["id"],
m=intTime(),
u=self.col.usn(),
)
)
d.append((flds, newModel["id"], intTime(), self.col.usn(), nid,))
self.col.db.executemany(
"update notes set flds=:flds,mid=:mid,mod=:m,usn=:u where id = :nid", d
"update notes set flds=?,mid=?,mod=?,usn=? where id = ?", d
)
self.col.updateFieldCache(nids)
@ -543,12 +535,10 @@ select id from notes where mid = ?)"""
# mapping from a regular note, so the map should be valid
new = map[ord]
if new is not None:
d.append(dict(cid=cid, new=new, u=self.col.usn(), m=intTime()))
d.append((new, self.col.usn(), intTime(), cid))
else:
deleted.append(cid)
self.col.db.executemany(
"update cards set ord=:new,usn=:u,mod=:m where id=:cid", d
)
self.col.db.executemany("update cards set ord=?,usn=?,mod=? where id=?", d)
self.col.remCards(deleted)
# Schema hash

View file

@ -48,19 +48,19 @@ class PythonBackend:
native = self.col.sched.deckDueTree()
return native_deck_tree_to_proto(native)
def find_cards(self, input: pb.FindCardsIn) -> pb.FindCardsOut:
cids = self.col.findCards(input.search)
return pb.FindCardsOut(card_ids=cids)
def browser_rows(self, input: pb.BrowserRowsIn) -> pb.BrowserRowsOut:
sort_fields = []
for cid in input.card_ids:
sort_fields.append(
self.col.db.scalar(
"select sfld from notes n,cards c where n.id=c.nid and c.id=?", cid
)
)
return pb.BrowserRowsOut(sort_fields=sort_fields)
# def find_cards(self, input: pb.FindCardsIn) -> pb.FindCardsOut:
# cids = self.col.findCards(input.search)
# return pb.FindCardsOut(card_ids=cids)
#
# def browser_rows(self, input: pb.BrowserRowsIn) -> pb.BrowserRowsOut:
# sort_fields = []
# for cid in input.card_ids:
# sort_fields.append(
# self.col.db.scalar(
# "select sfld from notes n,cards c where n.id=c.nid and c.id=?", cid
# )
# )
# return pb.BrowserRowsOut(sort_fields=sort_fields)
def native_deck_tree_to_proto(native):

View file

@ -5,21 +5,50 @@
import enum
import os
from dataclasses import dataclass
from typing import Callable, Dict, List, NewType, NoReturn, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NewType,
NoReturn,
Optional,
Sequence,
Tuple,
Union,
)
import ankirspy # pytype: disable=import-error
import anki.backend_pb2 as pb
import anki.buildinfo
from anki import hooks
from anki.dbproxy import Row as DBRow
from anki.dbproxy import ValueForDB
from anki.fluent_pb2 import FluentString as TR
from anki.models import AllTemplateReqs
from anki.sound import AVTag, SoundOrVideoTag, TTSTag
from anki.types import assert_impossible_literal
from anki.utils import intTime
assert ankirspy.buildhash() == anki.buildinfo.buildhash
SchedTimingToday = pb.SchedTimingTodayOut
BuiltinSortKind = pb.BuiltinSortKind
try:
import orjson
except:
# add compat layer for 32 bit builds that can't use orjson
print("reverting to stock json")
import json
class orjson: # type: ignore
def dumps(obj: Any) -> bytes:
return json.dumps(obj).encode("utf8")
loads = json.loads
class Interrupted(Exception):
@ -186,16 +215,19 @@ def _on_progress(progress_bytes: bytes) -> bool:
class RustBackend:
def __init__(
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str
self,
ftl_folder: Optional[str] = None,
langs: Optional[List[str]] = None,
server: bool = False,
) -> None:
# pick up global defaults if not provided
if ftl_folder is None:
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent")
if langs is None:
langs = [anki.lang.currentLang]
init_msg = pb.BackendInit(
collection_path=col_path,
media_folder_path=media_folder_path,
media_db_path=media_db_path,
locale_folder_path=ftl_folder,
preferred_langs=[anki.lang.currentLang],
log_path=log_path,
locale_folder_path=ftl_folder, preferred_langs=langs, server=server,
)
self._backend = ankirspy.open_backend(init_msg.SerializeToString())
self._backend.set_progress_callback(_on_progress)
@ -213,6 +245,26 @@ class RustBackend:
else:
return output
def open_collection(
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str
):
self._run_command(
pb.BackendInput(
open_collection=pb.OpenCollectionIn(
collection_path=col_path,
media_folder_path=media_folder_path,
media_db_path=media_db_path,
log_path=log_path,
)
),
release_gil=True,
)
def close_collection(self):
self._run_command(
pb.BackendInput(close_collection=pb.Empty()), release_gil=True
)
def template_requirements(
self, template_fronts: List[str], field_map: Dict[str, int]
) -> AllTemplateReqs:
@ -228,19 +280,33 @@ class RustBackend:
def sched_timing_today(
self,
created_secs: int,
created_mins_west: int,
now_secs: int,
now_mins_west: int,
rollover: int,
created_mins_west: Optional[int],
now_mins_west: Optional[int],
rollover: Optional[int],
) -> SchedTimingToday:
if created_mins_west is not None:
crt_west = pb.OptionalInt32(val=created_mins_west)
else:
crt_west = None
if now_mins_west is not None:
now_west = pb.OptionalInt32(val=now_mins_west)
else:
now_west = None
if rollover is not None:
roll = pb.OptionalInt32(val=rollover)
else:
roll = None
return self._run_command(
pb.BackendInput(
sched_timing_today=pb.SchedTimingTodayIn(
created_secs=created_secs,
created_mins_west=created_mins_west,
now_secs=now_secs,
now_mins_west=now_mins_west,
rollover_hour=rollover,
now_secs=intTime(),
created_mins_west=crt_west,
now_mins_west=now_west,
rollover_hour=roll,
)
)
).sched_timing_today
@ -366,6 +432,54 @@ class RustBackend:
def restore_trash(self):
self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
def db_query(
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool
) -> List[DBRow]:
return self._db_command(
dict(kind="query", sql=sql, args=args, first_row_only=first_row_only)
)
def db_execute_many(self, sql: str, args: List[List[ValueForDB]]) -> List[DBRow]:
return self._db_command(dict(kind="executemany", sql=sql, args=args))
def db_begin(self) -> None:
return self._db_command(dict(kind="begin"))
def db_commit(self) -> None:
return self._db_command(dict(kind="commit"))
def db_rollback(self) -> None:
return self._db_command(dict(kind="rollback"))
def _db_command(self, input: Dict[str, Any]) -> Any:
return orjson.loads(self._backend.db_command(orjson.dumps(input)))
def search_cards(
self, search: str, order: Union[bool, str, int], reverse: bool = False
) -> Sequence[int]:
if isinstance(order, str):
mode = pb.SortOrder(custom=order)
elif order is True:
mode = pb.SortOrder(from_config=pb.Empty())
elif order is False:
mode = pb.SortOrder(none=pb.Empty())
else:
# sadly we can't use the protobuf type in a Union, so we
# have to accept an int and convert it
kind = BuiltinSortKind.Value(BuiltinSortKind.Name(order))
mode = pb.SortOrder(
builtin=pb.BuiltinSearchOrder(kind=kind, reverse=reverse)
)
return self._run_command(
pb.BackendInput(search_cards=pb.SearchCardsIn(search=search, order=mode))
).search_cards.card_ids
def search_notes(self, search: str) -> Sequence[int]:
return self._run_command(
pb.BackendInput(search_notes=pb.SearchNotesIn(search=search))
).search_notes.note_ids
def translate_string_in(
key: TR, **kwargs: Union[str, int, float]
@ -379,19 +493,6 @@ def translate_string_in(
return pb.TranslateStringIn(key=key, args=args)
class I18nBackend:
def __init__(self, preferred_langs: List[str], ftl_folder: str) -> None:
init_msg = pb.I18nBackendInit(
locale_folder_path=ftl_folder, preferred_langs=preferred_langs
)
self._backend = ankirspy.open_i18n(init_msg.SerializeToString())
def translate(self, key: TR, **kwargs: Union[str, int, float]) -> str:
return self._backend.translate(
translate_string_in(key, **kwargs).SerializeToString()
)
# temporarily force logging of media handling
if "RUST_LOG" not in os.environ:
os.environ["RUST_LOG"] = "warn,anki::media=debug"

View file

@ -8,7 +8,7 @@ import random
import time
from heapq import *
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import anki
from anki import hooks
@ -80,7 +80,7 @@ class Scheduler(V2):
self._updateStats(card, "time", card.timeTaken())
card.mod = intTime()
card.usn = self.col.usn()
card.flushSched()
card.flush()
def counts(self, card: Optional[Card] = None) -> Tuple[int, int, int]:
counts = [self.newCount, self.lrnCount, self.revCount]
@ -286,11 +286,13 @@ and due <= ? limit %d"""
self._lrnQueue = self.col.db.all(
f"""
select due, id from cards where
did in %s and queue = {QUEUE_TYPE_LRN} and due < :lim
did in %s and queue = {QUEUE_TYPE_LRN} and due < ?
limit %d"""
% (self._deckLimit(), self.reportLimit),
lim=self.dayCutoff,
self.dayCutoff,
)
for i in range(len(self._lrnQueue)):
self._lrnQueue[i] = (self._lrnQueue[i][0], self._lrnQueue[i][1])
# as it arrives sorted by did first, we need to sort it
self._lrnQueue.sort()
return self._lrnQueue
@ -707,7 +709,7 @@ did = ? and queue = {QUEUE_TYPE_REV} and due <= ? limit ?""",
# Dynamic deck handling
##########################################################################
def rebuildDyn(self, did: Optional[int] = None) -> Optional[List[int]]: # type: ignore[override]
def rebuildDyn(self, did: Optional[int] = None) -> Optional[Sequence[int]]: # type: ignore[override]
"Rebuild a dynamic deck."
did = did or self.col.decks.selected()
deck = self.col.decks.get(did)
@ -721,7 +723,7 @@ did = ? and queue = {QUEUE_TYPE_REV} and due <= ? limit ?""",
self.col.decks.select(did)
return ids
def _fillDyn(self, deck: Dict[str, Any]) -> List[int]: # type: ignore[override]
def _fillDyn(self, deck: Dict[str, Any]) -> Sequence[int]: # type: ignore[override]
search, limit, order = deck["terms"][0]
orderlimit = self._dynOrder(order, limit)
if search.strip():
@ -751,7 +753,7 @@ due = odue, odue = 0, odid = 0, usn = ? where %s"""
self.col.usn(),
)
def _moveToDyn(self, did: int, ids: List[int]) -> None: # type: ignore[override]
def _moveToDyn(self, did: int, ids: Sequence[int]) -> None: # type: ignore[override]
deck = self.col.decks.get(did)
data = []
t = intTime()
@ -867,10 +869,9 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?"""
def _updateCutoff(self) -> None:
oldToday = self.today
# days since col created
self.today = int((time.time() - self.col.crt) // 86400)
# end of day cutoff
self.dayCutoff = self.col.crt + (self.today + 1) * 86400
timing = self._timing_today()
self.today = timing.days_elapsed
self.dayCutoff = timing.next_day_at
if oldToday != self.today:
self.col.log(self.today, self.dayCutoff)
# update all daily counts, but don't save decks to prevent needless

View file

@ -3,7 +3,6 @@
from __future__ import annotations
import datetime
import itertools
import random
import time
@ -11,7 +10,7 @@ from heapq import *
from operator import itemgetter
# from anki.collection import _Collection
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
import anki # pylint: disable=unused-import
from anki import hooks
@ -82,7 +81,7 @@ class Scheduler:
self._updateStats(card, "time", card.timeTaken())
card.mod = intTime()
card.usn = self.col.usn()
card.flushSched()
card.flush()
def _answerCard(self, card: Card, ease: int) -> None:
if self._previewingCard(card):
@ -138,8 +137,8 @@ class Scheduler:
def dueForecast(self, days: int = 7) -> List[Any]:
"Return counts over next DAYS. Includes today."
daysd = dict(
self.col.db.all(
daysd: Dict[int, int] = dict(
self.col.db.all( # type: ignore
f"""
select due, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV}
@ -542,14 +541,16 @@ select count() from cards where did in %s and queue = {QUEUE_TYPE_PREVIEW}
if self._lrnQueue:
return True
cutoff = intTime() + self.col.conf["collapseTime"]
self._lrnQueue = self.col.db.all(
self._lrnQueue = self.col.db.all( # type: ignore
f"""
select due, id from cards where
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < :lim
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?
limit %d"""
% (self._deckLimit(), self.reportLimit),
lim=cutoff,
cutoff,
)
for i in range(len(self._lrnQueue)):
self._lrnQueue[i] = (self._lrnQueue[i][0], self._lrnQueue[i][1])
# as it arrives sorted by did first, we need to sort it
self._lrnQueue.sort()
return self._lrnQueue
@ -1215,7 +1216,7 @@ due = (case when odue>0 then odue else due end), odue = 0, odid = 0, usn = ? whe
t = "c.due, c.ord"
return t + " limit %d" % l
def _moveToDyn(self, did: int, ids: List[int], start: int = -100000) -> None:
def _moveToDyn(self, did: int, ids: Sequence[int], start: int = -100000) -> None:
deck = self.col.decks.get(did)
data = []
u = self.col.usn()
@ -1353,13 +1354,8 @@ where id = ?
def _updateCutoff(self) -> None:
oldToday = self.today
timing = self._timing_today()
if self._new_timezone_enabled():
self.today = timing.days_elapsed
self.dayCutoff = timing.next_day_at
else:
self.today = self._daysSinceCreation()
self.dayCutoff = self._dayCutoff()
if oldToday != self.today:
self.col.log(self.today, self.dayCutoff)
@ -1385,51 +1381,39 @@ where id = ?
if time.time() > self.dayCutoff:
self.reset()
def _dayCutoff(self) -> int:
rolloverTime = self.col.conf.get("rollover", 4)
if rolloverTime < 0:
rolloverTime = 24 + rolloverTime
date = datetime.datetime.today()
date = date.replace(hour=rolloverTime, minute=0, second=0, microsecond=0)
if date < datetime.datetime.today():
date = date + datetime.timedelta(days=1)
stamp = int(time.mktime(date.timetuple()))
return stamp
def _daysSinceCreation(self) -> int:
startDate = datetime.datetime.fromtimestamp(self.col.crt)
startDate = startDate.replace(
hour=self._rolloverHour(), minute=0, second=0, microsecond=0
)
return int((time.time() - time.mktime(startDate.timetuple())) // 86400)
def _rolloverHour(self) -> int:
return self.col.conf.get("rollover", 4)
# New timezone handling
##########################################################################
def _new_timezone_enabled(self) -> bool:
return self.col.conf.get("creationOffset") is not None
def _timing_today(self) -> SchedTimingToday:
roll: Optional[int] = None
if self.col.schedVer() > 1:
roll = self._rolloverHour()
return self.col.backend.sched_timing_today(
self.col.crt,
self._creation_timezone_offset(),
intTime(),
self._current_timezone_offset(),
self._rolloverHour(),
roll,
)
def _current_timezone_offset(self) -> int:
def _current_timezone_offset(self) -> Optional[int]:
if self.col.server:
mins = self.col.server.minutes_west
if mins is not None:
return mins
# older Anki versions stored the local offset in
# the config
return self.col.conf.get("localOffset", 0)
else:
return self.col.backend.local_minutes_west(intTime())
return None
def _creation_timezone_offset(self) -> int:
return self.col.conf.get("creationOffset", 0)
def _creation_timezone_offset(self) -> Optional[int]:
return self.col.conf.get("creationOffset", None)
# New timezone handling - GUI helpers
##########################################################################
def new_timezone_enabled(self) -> bool:
return self.col.conf.get("creationOffset") is not None
def set_creation_offset(self):
"""Save the UTC west offset at the time of creation into the DB.
@ -1775,21 +1759,12 @@ and (queue={QUEUE_TYPE_NEW} or (queue={QUEUE_TYPE_REV} and due<=?))""",
mod = intTime()
for id in ids:
r = random.randint(imin, imax)
d.append(
dict(
id=id,
due=r + t,
ivl=max(1, r),
mod=mod,
usn=self.col.usn(),
fact=STARTING_FACTOR,
)
)
d.append((max(1, r), r + t, self.col.usn(), mod, STARTING_FACTOR, id,))
self.remFromDyn(ids)
self.col.db.executemany(
f"""
update cards set type={CARD_TYPE_REV},queue={QUEUE_TYPE_REV},ivl=:ivl,due=:due,odue=0,
usn=:usn,mod=:mod,factor=:fact where id=:id""",
update cards set type={CARD_TYPE_REV},queue={QUEUE_TYPE_REV},ivl=?,due=?,odue=0,
usn=?,mod=?,factor=? where id=?""",
d,
)
self.col.log(ids)
@ -1866,10 +1841,8 @@ and due >= ? and queue = {QUEUE_TYPE_NEW}"""
for id, nid in self.col.db.execute(
f"select id, nid from cards where type = {CARD_TYPE_NEW} and id in " + scids
):
d.append(dict(now=now, due=due[nid], usn=self.col.usn(), cid=id))
self.col.db.executemany(
"update cards set due=:due,mod=:now,usn=:usn where id = :cid", d
)
d.append((due[nid], now, self.col.usn(), id))
self.col.db.executemany("update cards set due=?,mod=?,usn=? where id = ?", d)
def randomizeCards(self, did: int) -> None:
cids = self.col.db.list("select id from cards where did = ?", did)

View file

@ -58,7 +58,7 @@ class CardStats:
self.addLine(_("Reviews"), "%d" % c.reps)
self.addLine(_("Lapses"), "%d" % c.lapses)
(cnt, total) = self.col.db.first(
"select count(), sum(time)/1000 from revlog where cid = :id", id=c.id
"select count(), sum(time)/1000 from revlog where cid = ?", c.id
)
if cnt:
self.addLine(_("Average Time"), self.time(total / float(cnt)))
@ -297,12 +297,12 @@ and due = ?"""
) -> Any:
lim = ""
if start is not None:
lim += " and due-:today >= %d" % start
lim += " and due-%d >= %d" % (self.col.sched.today, start)
if end is not None:
lim += " and day < %d" % end
return self.col.db.all(
f"""
select (due-:today)/:chunk as day,
select (due-?)/? as day,
sum(case when ivl < 21 then 1 else 0 end), -- yng
sum(case when ivl >= 21 then 1 else 0 end) -- mtr
from cards
@ -310,8 +310,8 @@ where did in %s and queue in ({QUEUE_TYPE_REV},{QUEUE_TYPE_DAY_LEARN_RELEARN})
%s
group by day order by day"""
% (self._limit(), lim),
today=self.col.sched.today,
chunk=chunk,
self.col.sched.today,
chunk,
)
# Added, reps and time spent
@ -527,14 +527,13 @@ group by day order by day"""
return self.col.db.all(
"""
select
(cast((id/1000.0 - :cut) / 86400.0 as int))/:chunk as day,
(cast((id/1000.0 - ?) / 86400.0 as int))/? as day,
count(id)
from cards %s
group by day order by day"""
% lim,
cut=self.col.sched.dayCutoff,
tf=tf,
chunk=chunk,
self.col.sched.dayCutoff,
chunk,
)
def _done(self, num: Optional[int] = 7, chunk: int = 1) -> Any:
@ -557,24 +556,28 @@ group by day order by day"""
return self.col.db.all(
f"""
select
(cast((id/1000.0 - :cut) / 86400.0 as int))/:chunk as day,
(cast((id/1000.0 - ?) / 86400.0 as int))/? as day,
sum(case when type = {REVLOG_LRN} then 1 else 0 end), -- lrn count
sum(case when type = {REVLOG_REV} and lastIvl < 21 then 1 else 0 end), -- yng count
sum(case when type = {REVLOG_REV} and lastIvl >= 21 then 1 else 0 end), -- mtr count
sum(case when type = {REVLOG_RELRN} then 1 else 0 end), -- lapse count
sum(case when type = {REVLOG_CRAM} then 1 else 0 end), -- cram count
sum(case when type = {REVLOG_LRN} then time/1000.0 else 0 end)/:tf, -- lrn time
sum(case when type = {REVLOG_LRN} then time/1000.0 else 0 end)/?, -- lrn time
-- yng + mtr time
sum(case when type = {REVLOG_REV} and lastIvl < 21 then time/1000.0 else 0 end)/:tf,
sum(case when type = {REVLOG_REV} and lastIvl >= 21 then time/1000.0 else 0 end)/:tf,
sum(case when type = {REVLOG_RELRN} then time/1000.0 else 0 end)/:tf, -- lapse time
sum(case when type = {REVLOG_CRAM} then time/1000.0 else 0 end)/:tf -- cram time
sum(case when type = {REVLOG_REV} and lastIvl < 21 then time/1000.0 else 0 end)/?,
sum(case when type = {REVLOG_REV} and lastIvl >= 21 then time/1000.0 else 0 end)/?,
sum(case when type = {REVLOG_RELRN} then time/1000.0 else 0 end)/?, -- lapse time
sum(case when type = {REVLOG_CRAM} then time/1000.0 else 0 end)/? -- cram time
from revlog %s
group by day order by day"""
% lim,
cut=self.col.sched.dayCutoff,
tf=tf,
chunk=chunk,
self.col.sched.dayCutoff,
chunk,
tf,
tf,
tf,
tf,
tf,
)
def _daysStudied(self) -> Any:
@ -592,11 +595,11 @@ group by day order by day"""
ret = self.col.db.first(
"""
select count(), abs(min(day)) from (select
(cast((id/1000 - :cut) / 86400.0 as int)+1) as day
(cast((id/1000 - ?) / 86400.0 as int)+1) as day
from revlog %s
group by day order by day)"""
% lim,
cut=self.col.sched.dayCutoff,
self.col.sched.dayCutoff,
)
assert ret
return ret
@ -655,12 +658,12 @@ group by day order by day)"""
data = [
self.col.db.all(
f"""
select ivl / :chunk as grp, count() from cards
select ivl / ? as grp, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV} %s
group by grp
order by grp"""
% (self._limit(), lim),
chunk=chunk,
chunk,
)
]
return (
@ -866,14 +869,14 @@ order by thetype, ease"""
return self.col.db.all(
f"""
select
23 - ((cast((:cut - id/1000) / 3600.0 as int)) %% 24) as hour,
23 - ((cast((? - id/1000) / 3600.0 as int)) %% 24) as hour,
sum(case when ease = 1 then 0 else 1 end) /
cast(count() as float) * 100,
count()
from revlog where type in ({REVLOG_LRN},{REVLOG_REV},{REVLOG_RELRN}) %s
group by hour having count() > 30 order by hour"""
% lim,
cut=self.col.sched.dayCutoff - (rolloverHour * 3600),
self.col.sched.dayCutoff - (rolloverHour * 3600),
)
# Cards

View file

@ -4,12 +4,13 @@
import copy
import json
import os
import re
import weakref
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from anki.collection import _Collection
from anki.consts import *
from anki.db import DB
from anki.dbproxy import DBProxy
from anki.lang import _
from anki.media import media_paths_from_col_path
from anki.rsbackend import RustBackend
@ -20,48 +21,42 @@ from anki.stdmodels import (
addForwardOptionalReverse,
addForwardReverse,
)
from anki.utils import intTime, isWin
from anki.utils import intTime
@dataclass
class ServerData:
minutes_west: Optional[int] = None
def Collection(
path: str, lock: bool = True, server: Optional[ServerData] = None, log: bool = False
path: str,
backend: Optional[RustBackend] = None,
server: Optional[ServerData] = None,
) -> _Collection:
"Open a new or existing collection. Path must be unicode."
assert path.endswith(".anki2")
if backend is None:
backend = RustBackend(server=server is not None)
(media_dir, media_db) = media_paths_from_col_path(path)
log_path = ""
if not server:
log_path = path.replace(".anki2", "2.log")
backend = RustBackend(path, media_dir, media_db, log_path)
path = os.path.abspath(path)
create = not os.path.exists(path)
if create:
base = os.path.basename(path)
for c in ("/", ":", "\\"):
assert c not in base
# connect
db = DB(path)
db.setAutocommit(True)
backend.open_collection(path, media_dir, media_db, log_path)
db = DBProxy(weakref.proxy(backend), path)
# initial setup required?
create = db.scalar("select models = '{}' from col")
if create:
ver = _createDB(db)
else:
ver = _upgradeSchema(db)
db.execute("pragma temp_store = memory")
db.execute("pragma cache_size = 10000")
if not isWin:
db.execute("pragma journal_mode = wal")
db.setAutocommit(False)
initial_db_setup(db)
# add db to col and do any remaining upgrades
col = _Collection(db, backend=backend, server=server, log=log)
if ver < SCHEMA_VERSION:
_upgrade(col, ver)
elif ver > SCHEMA_VERSION:
raise Exception("This file requires a newer version of Anki.")
elif create:
col = _Collection(db, backend=backend, server=server)
if create:
# add in reverse order so basic is default
addClozeModel(col)
addBasicTypingModel(col)
@ -69,267 +64,21 @@ def Collection(
addForwardReverse(col)
addBasicModel(col)
col.save()
if lock:
try:
col.lock()
except:
col.db.close()
raise
else:
db.begin()
return col
def _upgradeSchema(db: DB) -> Any:
ver = db.scalar("select ver from col")
if ver == SCHEMA_VERSION:
return ver
# add odid to cards, edue->odue
######################################################################
if db.scalar("select ver from col") == 1:
db.execute("alter table cards rename to cards2")
_addSchema(db, setColConf=False)
db.execute(
"""
insert into cards select
id, nid, did, ord, mod, usn, type, queue, due, ivl, factor, reps, lapses,
left, edue, 0, flags, data from cards2"""
)
db.execute("drop table cards2")
db.execute("update col set ver = 2")
_updateIndices(db)
# remove did from notes
######################################################################
if db.scalar("select ver from col") == 2:
db.execute("alter table notes rename to notes2")
_addSchema(db, setColConf=False)
db.execute(
"""
insert into notes select
id, guid, mid, mod, usn, tags, flds, sfld, csum, flags, data from notes2"""
)
db.execute("drop table notes2")
db.execute("update col set ver = 3")
_updateIndices(db)
return ver
def _upgrade(col, ver) -> None:
if ver < 3:
# new deck properties
for d in col.decks.all():
d["dyn"] = DECK_STD
d["collapsed"] = False
col.decks.save(d)
if ver < 4:
col.modSchema(check=False)
clozes = []
for m in col.models.all():
if not "{{cloze:" in m["tmpls"][0]["qfmt"]:
m["type"] = MODEL_STD
col.models.save(m)
else:
clozes.append(m)
for m in clozes:
_upgradeClozeModel(col, m)
col.db.execute("update col set ver = 4")
if ver < 5:
col.db.execute("update cards set odue = 0 where queue = 2")
col.db.execute("update col set ver = 5")
if ver < 6:
col.modSchema(check=False)
import anki.models
for m in col.models.all():
m["css"] = anki.models.defaultModel["css"]
for t in m["tmpls"]:
if "css" not in t:
# ankidroid didn't bump version
continue
m["css"] += "\n" + t["css"].replace(
".card ", ".card%d " % (t["ord"] + 1)
)
del t["css"]
col.models.save(m)
col.db.execute("update col set ver = 6")
if ver < 7:
col.modSchema(check=False)
col.db.execute(
"update cards set odue = 0 where (type = 1 or queue = 2) " "and not odid"
)
col.db.execute("update col set ver = 7")
if ver < 8:
col.modSchema(check=False)
col.db.execute("update cards set due = due / 1000 where due > 4294967296")
col.db.execute("update col set ver = 8")
if ver < 9:
# adding an empty file to a zip makes python's zip code think it's a
# folder, so remove any empty files
changed = False
dir = col.media.dir()
if dir:
for f in os.listdir(col.media.dir()):
if os.path.isfile(f) and not os.path.getsize(f):
os.unlink(f)
col.media.db.execute("delete from log where fname = ?", f)
col.media.db.execute("delete from media where fname = ?", f)
changed = True
if changed:
col.media.db.commit()
col.db.execute("update col set ver = 9")
if ver < 10:
col.db.execute(
"""
update cards set left = left + left*1000 where queue = 1"""
)
col.db.execute("update col set ver = 10")
if ver < 11:
col.modSchema(check=False)
for d in col.decks.all():
if d["dyn"]:
order = d["order"]
# failed order was removed
if order >= 5:
order -= 1
d["terms"] = [[d["search"], d["limit"], order]]
del d["search"]
del d["limit"]
del d["order"]
d["resched"] = True
d["return"] = True
else:
if "extendNew" not in d:
d["extendNew"] = 10
d["extendRev"] = 50
col.decks.save(d)
for c in col.decks.allConf():
r = c["rev"]
r["ivlFct"] = r.get("ivlfct", 1)
if "ivlfct" in r:
del r["ivlfct"]
r["maxIvl"] = 36500
col.decks.save(c)
for m in col.models.all():
for t in m["tmpls"]:
t["bqfmt"] = ""
t["bafmt"] = ""
col.models.save(m)
col.db.execute("update col set ver = 11")
def _upgradeClozeModel(col, m) -> None:
m["type"] = MODEL_CLOZE
# convert first template
t = m["tmpls"][0]
for type in "qfmt", "afmt":
t[type] = re.sub("{{cloze:1:(.+?)}}", r"{{cloze:\1}}", t[type])
t["name"] = _("Cloze")
# delete non-cloze cards for the model
rem = []
for t in m["tmpls"][1:]:
if "{{cloze:" not in t["qfmt"]:
rem.append(t)
for r in rem:
col.models.remTemplate(m, r)
del m["tmpls"][1:]
col.models._updateTemplOrds(m)
col.models.save(m)
# Creating a new collection
######################################################################
def _createDB(db: DB) -> int:
db.execute("pragma page_size = 4096")
db.execute("pragma legacy_file_format = 0")
db.execute("vacuum")
_addSchema(db)
_updateIndices(db)
db.execute("analyze")
return SCHEMA_VERSION
def _addSchema(db: DB, setColConf: bool = True) -> None:
db.executescript(
"""
create table if not exists col (
id integer primary key,
crt integer not null,
mod integer not null,
scm integer not null,
ver integer not null,
dty integer not null,
usn integer not null,
ls integer not null,
conf text not null,
models text not null,
decks text not null,
dconf text not null,
tags text not null
);
create table if not exists notes (
id integer primary key, /* 0 */
guid text not null, /* 1 */
mid integer not null, /* 2 */
mod integer not null, /* 3 */
usn integer not null, /* 4 */
tags text not null, /* 5 */
flds text not null, /* 6 */
sfld integer not null, /* 7 */
csum integer not null, /* 8 */
flags integer not null, /* 9 */
data text not null /* 10 */
);
create table if not exists cards (
id integer primary key, /* 0 */
nid integer not null, /* 1 */
did integer not null, /* 2 */
ord integer not null, /* 3 */
mod integer not null, /* 4 */
usn integer not null, /* 5 */
type integer not null, /* 6 */
queue integer not null, /* 7 */
due integer not null, /* 8 */
ivl integer not null, /* 9 */
factor integer not null, /* 10 */
reps integer not null, /* 11 */
lapses integer not null, /* 12 */
left integer not null, /* 13 */
odue integer not null, /* 14 */
odid integer not null, /* 15 */
flags integer not null, /* 16 */
data text not null /* 17 */
);
create table if not exists revlog (
id integer primary key,
cid integer not null,
usn integer not null,
ease integer not null,
ivl integer not null,
lastIvl integer not null,
factor integer not null,
time integer not null,
type integer not null
);
create table if not exists graves (
usn integer not null,
oid integer not null,
type integer not null
);
insert or ignore into col
values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
"""
% ({"v": SCHEMA_VERSION, "s": intTime(1000)})
)
if setColConf:
def initial_db_setup(db: DBProxy) -> None:
db.begin()
_addColVars(db, *_getColVars(db))
def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]:
import anki.collection
import anki.decks
@ -344,7 +93,7 @@ def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
def _addColVars(
db: DB, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
db: DBProxy, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
) -> None:
db.execute(
"""
@ -353,23 +102,3 @@ update col set conf = ?, decks = ?, dconf = ?""",
json.dumps({"1": g}),
json.dumps({"1": gc}),
)
def _updateIndices(db: DB) -> None:
"Add indices to the DB."
db.executescript(
"""
-- syncing
create index if not exists ix_notes_usn on notes (usn);
create index if not exists ix_cards_usn on cards (usn);
create index if not exists ix_revlog_usn on revlog (usn);
-- card spacing, etc
create index if not exists ix_cards_nid on cards (nid);
-- scheduling and deck limiting
create index if not exists ix_cards_sched on cards (did, queue, due);
-- revlog by card
create index if not exists ix_revlog_cid on revlog (cid);
-- field uniqueness
create index if not exists ix_notes_csum on notes (csum);
"""
)

View file

@ -8,8 +8,7 @@ import io
import json
import os
import random
import sqlite3
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import anki
from anki.consts import *
@ -32,7 +31,7 @@ class UnexpectedSchemaChange(Exception):
class Syncer:
cursor: Optional[sqlite3.Cursor]
chunkRows: Optional[List[Sequence]]
def __init__(self, col: anki.storage._Collection, server=None) -> None:
self.col = col.weakref()
@ -247,11 +246,11 @@ class Syncer:
def prepareToChunk(self) -> None:
self.tablesLeft = ["revlog", "cards", "notes"]
self.cursor = None
self.chunkRows = None
def cursorForTable(self, table) -> sqlite3.Cursor:
def getChunkRows(self, table) -> List[Sequence]:
lim = self.usnLim()
x = self.col.db.execute
x = self.col.db.all
d = (self.maxUsn, lim)
if table == "revlog":
return x(
@ -280,14 +279,15 @@ from notes where %s"""
lim = 250
while self.tablesLeft and lim:
curTable = self.tablesLeft[0]
if not self.cursor:
self.cursor = self.cursorForTable(curTable)
rows = self.cursor.fetchmany(lim)
if not self.chunkRows:
self.chunkRows = self.getChunkRows(curTable)
rows = self.chunkRows[:lim]
self.chunkRows = self.chunkRows[lim:]
fetched = len(rows)
if fetched != lim:
# table is empty
self.tablesLeft.pop(0)
self.cursor = None
self.chunkRows = None
# mark the objects as having been sent
self.col.db.execute(
"update %s set usn=? where usn=-1" % curTable, self.maxUsn

View file

@ -110,30 +110,25 @@ class TagManager:
else:
l = "tags "
fn = self.remFromStr
lim = " or ".join([l + "like :_%d" % c for c, t in enumerate(newTags)])
lim = " or ".join(l + "like ?" for x in newTags)
res = self.col.db.all(
"select id, tags from notes where id in %s and (%s)" % (ids2str(ids), lim),
**dict(
[
("_%d" % x, "%% %s %%" % y.replace("*", "%"))
for x, y in enumerate(newTags)
]
),
*["%% %s %%" % y.replace("*", "%") for x, y in enumerate(newTags)],
)
# update tags
nids = []
def fix(row):
nids.append(row[0])
return {
"id": row[0],
"t": fn(tags, row[1]),
"n": intTime(),
"u": self.col.usn(),
}
return [
fn(tags, row[1]),
intTime(),
self.col.usn(),
row[0],
]
self.col.db.executemany(
"update notes set tags=:t,mod=:n,usn=:u where id = :id",
"update notes set tags=?,mod=?,usn=? where id = ?",
[fix(row) for row in res],
)

View file

@ -22,7 +22,7 @@ from hashlib import sha1
from html.entities import name2codepoint
from typing import Iterable, Iterator, List, Optional, Union
from anki.db import DB
from anki.dbproxy import DBProxy
_tmpdir: Optional[str]
@ -142,7 +142,7 @@ def ids2str(ids: Iterable[Union[int, str]]) -> str:
return "(%s)" % ",".join(str(i) for i in ids)
def timestampID(db: DB, table: str) -> int:
def timestampID(db: DBProxy, table: str) -> int:
"Return a non-conflicting timestamp for table."
# be careful not to create multiple objects without flushing them, or they
# may share an ID.
@ -152,7 +152,7 @@ def timestampID(db: DB, table: str) -> int:
return t
def maxID(db: DB) -> int:
def maxID(db: DBProxy) -> int:
"Return the first safe ID to use."
now = intTime(1000)
for tbl in "cards", "notes":

View file

@ -21,6 +21,7 @@ setuptools.setup(
"requests",
"decorator",
"protobuf",
'orjson; platform_machine == "x86_64"',
'psutil; sys_platform == "win32"',
'distro; sys_platform != "darwin" and sys_platform != "win32"',
],

View file

@ -1,9 +1,22 @@
import os
import shutil
import tempfile
import time
from anki import Collection as aopen
# Between 2-4AM, shift the time back so test assumptions hold.
lt = time.localtime()
if lt.tm_hour >= 2 and lt.tm_hour < 4:
orig_time = time.time
def adjusted_time():
return orig_time() - 60 * 60 * 2
time.time = adjusted_time
else:
orig_time = None
def assertException(exception, func):
found = False
@ -22,7 +35,7 @@ def getEmptyCol():
os.close(fd)
os.unlink(nam)
col = aopen(nam)
col.db.close()
col.close()
getEmptyCol.master = nam
(fd, nam) = tempfile.mkstemp(suffix=".anki2")
shutil.copy(getEmptyCol.master, nam)
@ -48,3 +61,15 @@ def getUpgradeDeckPath(name="anki12.anki"):
testDir = os.path.dirname(__file__)
def errorsAfterMidnight(func):
lt = time.localtime()
if lt.tm_hour < 4:
print("test disabled around cutoff", func)
else:
func()
def isNearCutoff():
return orig_time is not None

View file

@ -6,6 +6,7 @@ import tempfile
from anki import Collection as aopen
from anki.exporting import *
from anki.importing import Anki2Importer
from tests.shared import errorsAfterMidnight
from tests.shared import getEmptyCol as getEmptyColOrig
@ -97,6 +98,7 @@ def test_export_ankipkg():
e.exportInto(newname)
@errorsAfterMidnight
def test_export_anki_due():
setup1()
deck = getEmptyCol()

View file

@ -2,8 +2,8 @@
import pytest
from anki.consts import *
from anki.find import Finder
from tests.shared import getEmptyCol
from anki.rsbackend import BuiltinSortKind
from tests.shared import getEmptyCol, isNearCutoff
class DummyCollection:
@ -11,32 +11,6 @@ class DummyCollection:
return None
def test_parse():
f = Finder(DummyCollection())
assert f._tokenize("hello world") == ["hello", "world"]
assert f._tokenize("hello world") == ["hello", "world"]
assert f._tokenize("one -two") == ["one", "-", "two"]
assert f._tokenize("one --two") == ["one", "-", "two"]
assert f._tokenize("one - two") == ["one", "-", "two"]
assert f._tokenize("one or -two") == ["one", "or", "-", "two"]
assert f._tokenize("'hello \"world\"'") == ['hello "world"']
assert f._tokenize('"hello world"') == ["hello world"]
assert f._tokenize("one (two or ( three or four))") == [
"one",
"(",
"two",
"or",
"(",
"three",
"or",
"four",
")",
")",
]
assert f._tokenize("embedded'string") == ["embedded'string"]
assert f._tokenize("deck:'two words'") == ["deck:two words"]
def test_findCards():
deck = getEmptyCol()
f = deck.newNote()
@ -68,6 +42,7 @@ def test_findCards():
f["Front"] = "test"
f["Back"] = "foo bar"
deck.addNote(f)
deck.save()
latestCardIds = [c.id for c in f.cards()]
# tag searches
assert len(deck.findCards("tag:*")) == 5
@ -117,9 +92,8 @@ def test_findCards():
assert len(deck.findCards("nid:%d" % f.id)) == 2
assert len(deck.findCards("nid:%d,%d" % (f1id, f2id))) == 2
# templates
with pytest.raises(Exception):
deck.findCards("card:foo")
assert len(deck.findCards("'card:card 1'")) == 4
assert len(deck.findCards("card:foo")) == 0
assert len(deck.findCards('"card:card 1"')) == 4
assert len(deck.findCards("card:reverse")) == 1
assert len(deck.findCards("card:1")) == 4
assert len(deck.findCards("card:2")) == 1
@ -133,16 +107,28 @@ def test_findCards():
assert len(deck.findCards("front:*")) == 5
# ordering
deck.conf["sortType"] = "noteCrt"
deck.flush()
assert deck.findCards("front:*", order=True)[-1] in latestCardIds
assert deck.findCards("", order=True)[-1] in latestCardIds
deck.conf["sortType"] = "noteFld"
deck.flush()
assert deck.findCards("", order=True)[0] == catCard.id
assert deck.findCards("", order=True)[-1] in latestCardIds
deck.conf["sortType"] = "cardMod"
deck.flush()
assert deck.findCards("", order=True)[-1] in latestCardIds
assert deck.findCards("", order=True)[0] == firstCardId
deck.conf["sortBackwards"] = True
deck.flush()
assert deck.findCards("", order=True)[0] in latestCardIds
assert (
deck.find_cards("", order=BuiltinSortKind.CARD_DUE, reverse=False)[0]
== firstCardId
)
assert (
deck.find_cards("", order=BuiltinSortKind.CARD_DUE, reverse=True)[0]
!= firstCardId
)
# model
assert len(deck.findCards("note:basic")) == 5
assert len(deck.findCards("-note:basic")) == 0
@ -153,8 +139,7 @@ def test_findCards():
assert len(deck.findCards("-deck:foo")) == 5
assert len(deck.findCards("deck:def*")) == 5
assert len(deck.findCards("deck:*EFAULT")) == 5
with pytest.raises(Exception):
deck.findCards("deck:*cefault")
assert len(deck.findCards("deck:*cefault")) == 0
# full search
f = deck.newNote()
f["Front"] = "hello<b>world</b>"
@ -177,6 +162,7 @@ def test_findCards():
deck.db.execute(
"update cards set did = ? where id = ?", deck.decks.id("Default::Child"), id
)
deck.save()
assert len(deck.findCards("deck:default")) == 7
assert len(deck.findCards("deck:default::child")) == 1
assert len(deck.findCards("deck:default -deck:default::*")) == 6
@ -195,20 +181,15 @@ def test_findCards():
assert len(deck.findCards("prop:ivl!=10")) > 1
assert len(deck.findCards("prop:due>0")) == 1
# due dates should work
deck.sched.today = 15
assert len(deck.findCards("prop:due=14")) == 0
assert len(deck.findCards("prop:due=15")) == 1
assert len(deck.findCards("prop:due=16")) == 0
# including negatives
deck.sched.today = 32
assert len(deck.findCards("prop:due=-1")) == 0
assert len(deck.findCards("prop:due=-2")) == 1
assert len(deck.findCards("prop:due=29")) == 0
assert len(deck.findCards("prop:due=30")) == 1
# ease factors
assert len(deck.findCards("prop:ease=2.3")) == 0
assert len(deck.findCards("prop:ease=2.2")) == 1
assert len(deck.findCards("prop:ease>2")) == 1
assert len(deck.findCards("-prop:ease>2")) > 1
# recently failed
if not isNearCutoff():
assert len(deck.findCards("rated:1:1")) == 0
assert len(deck.findCards("rated:1:2")) == 0
c = deck.sched.getCard()
@ -222,6 +203,13 @@ def test_findCards():
assert len(deck.findCards("rated:1")) == 2
assert len(deck.findCards("rated:0:2")) == 0
assert len(deck.findCards("rated:2:2")) == 1
# added
assert len(deck.findCards("added:0")) == 0
deck.db.execute("update cards set id = id - 86400*1000 where id = ?", id)
assert len(deck.findCards("added:1")) == deck.cardCount() - 1
assert len(deck.findCards("added:2")) == deck.cardCount()
else:
print("some find tests disabled near cutoff")
# empty field
assert len(deck.findCards("front:")) == 0
f = deck.newNote()
@ -235,17 +223,7 @@ def test_findCards():
assert len(deck.findCards("-(tag:monkey OR tag:sheep)")) == 6
assert len(deck.findCards("tag:monkey or (tag:sheep sheep)")) == 2
assert len(deck.findCards("tag:monkey or (tag:sheep octopus)")) == 1
# invalid grouping shouldn't error
assert len(deck.findCards(")")) == 0
assert len(deck.findCards("(()")) == 0
# added
assert len(deck.findCards("added:0")) == 0
deck.db.execute("update cards set id = id - 86400*1000 where id = ?", id)
assert len(deck.findCards("added:1")) == deck.cardCount() - 1
assert len(deck.findCards("added:2")) == deck.cardCount()
# flag
with pytest.raises(Exception):
deck.findCards("flag:01")
with pytest.raises(Exception):
deck.findCards("flag:12")

View file

@ -73,8 +73,6 @@ def test_deckIntegration():
with open(os.path.join(d.media.dir(), "foo.jpg"), "w") as f:
f.write("test")
# check media
d.close()
ret = d.media.check()
d.reopen()
assert ret.missing == ["fake2.png"]
assert ret.unused == ["foo.jpg"]

View file

@ -16,17 +16,6 @@ def getEmptyCol():
return col
# Between 2-4AM, shift the time back so test assumptions hold.
lt = time.localtime()
if lt.tm_hour >= 2 and lt.tm_hour < 4:
orig_time = time.time
def adjusted_time():
return orig_time() - 60 * 60 * 2
time.time = adjusted_time
def test_clock():
d = getEmptyCol()
if (d.sched.dayCutoff - intTime()) < 10 * 60:

View file

@ -37,11 +37,6 @@ hooks = [
args=["exporters: List[Tuple[str, Any]]"],
legacy_hook="exportersList",
),
Hook(
name="search_terms_prepared",
args=["searches: Dict[str, Callable]"],
legacy_hook="search",
),
Hook(
name="note_type_added",
args=["notetype: Dict[str, Any]"],

View file

@ -17,6 +17,7 @@ import anki.lang
import aqt.buildinfo
from anki import version as _version
from anki.consts import HELP_SITE
from anki.rsbackend import RustBackend
from anki.utils import checksum, isLin, isMac
from aqt.qt import *
from aqt.utils import locale_dir
@ -162,15 +163,15 @@ dialogs = DialogManager()
# Qt requires its translator to be installed before any GUI widgets are
# loaded, and we need the Qt language to match the gettext language or
# translated shortcuts will not work.
#
# The Qt translator needs to be retained to work.
# A reference to the Qt translator needs to be held to prevent it from
# being immediately deallocated.
_qtrans: Optional[QTranslator] = None
def setupLang(
def setupLangAndBackend(
pm: ProfileManager, app: QApplication, force: Optional[str] = None
) -> None:
) -> RustBackend:
global _qtrans
try:
locale.setlocale(locale.LC_ALL, "")
@ -218,6 +219,8 @@ def setupLang(
if _qtrans.load("qtbase_" + qt_lang, qt_dir):
app.installTranslator(_qtrans)
return anki.lang.current_i18n
# App initialisation
##########################################################################
@ -465,8 +468,8 @@ environment points to a valid, writable folder.""",
if opts.profile:
pm.openProfile(opts.profile)
# i18n
setupLang(pm, app, opts.lang)
# i18n & backend
backend = setupLangAndBackend(pm, app, opts.lang)
if isLin and pm.glMode() == "auto":
from aqt.utils import gfxDriverIsBroken
@ -483,7 +486,7 @@ environment points to a valid, writable folder.""",
# load the main window
import aqt.main
mw = aqt.main.AnkiQt(app, pm, opts, args)
mw = aqt.main.AnkiQt(app, pm, backend, opts, args)
if exec:
app.exec()
else:

View file

@ -167,8 +167,12 @@ class AddCards(QDialog):
def addNote(self, note) -> Optional[Note]:
note.model()["did"] = self.deckChooser.selectedId()
ret = note.dupeOrEmpty()
problem = None
if ret == 1:
showWarning(_("The first field is empty."), help="AddItems#AddError")
problem = _("The first field is empty.")
problem = gui_hooks.add_cards_will_add_note(problem, note)
if problem is not None:
showWarning(problem, help="AddItems#AddError")
return None
if "{{cloze:" in note.model()["tmpls"][0]["qfmt"]:
if not self.mw.col.models._availClozeOrds(

View file

@ -13,7 +13,7 @@ import unicodedata
from dataclasses import dataclass
from enum import Enum
from operator import itemgetter
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union
import anki
import aqt.forms
@ -69,6 +69,14 @@ class FindDupesDialog:
browser: Browser
@dataclass
class SearchContext:
search: str
order: Union[bool, str] = True
# if set, provided card ids will be used instead of the regular search
card_ids: Optional[Sequence[int]] = None
# Data model
##########################################################################
@ -82,7 +90,7 @@ class DataModel(QAbstractTableModel):
self.activeCols = self.col.conf.get(
"activeCols", ["noteFld", "template", "cardDue", "deck"]
)
self.cards: List[int] = []
self.cards: Sequence[int] = []
self.cardObjs: Dict[int, Card] = {}
def getCard(self, index: QModelIndex) -> Card:
@ -169,22 +177,21 @@ class DataModel(QAbstractTableModel):
# Filtering
######################################################################
def search(self, txt):
def search(self, txt: str) -> None:
self.beginReset()
t = time.time()
# the db progress handler may cause a refresh, so we need to zero out
# old data first
self.cards = []
invalid = False
try:
self.cards = self.col.findCards(txt, order=True)
ctx = SearchContext(search=txt)
gui_hooks.browser_will_search(ctx)
if ctx.card_ids is None:
ctx.card_ids = self.col.find_cards(txt, order=ctx.order)
gui_hooks.browser_did_search(ctx)
self.cards = ctx.card_ids
except Exception as e:
if str(e) == "invalidSearch":
self.cards = []
print("search failed:", e)
invalid = True
else:
raise
# print "fetch cards in %dms" % ((time.time() - t)*1000)
finally:
self.endReset()
if invalid:
@ -213,7 +220,7 @@ class DataModel(QAbstractTableModel):
def _reverse(self):
self.beginReset()
self.cards.reverse()
self.cards = list(reversed(self.cards))
self.endReset()
def saveSelection(self):
@ -275,6 +282,9 @@ class DataModel(QAbstractTableModel):
def columnType(self, column):
return self.activeCols[column]
def time_format(self):
return "%Y-%m-%d"
def columnData(self, index):
row = index.row()
col = index.column()
@ -302,11 +312,11 @@ class DataModel(QAbstractTableModel):
t = "(" + t + ")"
return t
elif type == "noteCrt":
return time.strftime("%Y-%m-%d", time.localtime(c.note().id / 1000))
return time.strftime(self.time_format(), time.localtime(c.note().id / 1000))
elif type == "noteMod":
return time.strftime("%Y-%m-%d", time.localtime(c.note().mod))
return time.strftime(self.time_format(), time.localtime(c.note().mod))
elif type == "cardMod":
return time.strftime("%Y-%m-%d", time.localtime(c.mod))
return time.strftime(self.time_format(), time.localtime(c.mod))
elif type == "cardReps":
return str(c.reps)
elif type == "cardLapses":
@ -363,7 +373,7 @@ class DataModel(QAbstractTableModel):
date = time.time() + ((c.due - self.col.sched.today) * 86400)
else:
return ""
return time.strftime("%Y-%m-%d", time.localtime(date))
return time.strftime(self.time_format(), time.localtime(date))
def isRTL(self, index):
col = index.column()
@ -388,15 +398,12 @@ class StatusDelegate(QItemDelegate):
self.model = model
def paint(self, painter, option, index):
self.browser.mw.progress.blockUpdates = True
try:
c = self.model.getCard(index)
except:
# in the the middle of a reset; return nothing so this row is not
# rendered until we have a chance to reset the model
return
finally:
self.browser.mw.progress.blockUpdates = True
if self.model.isRTL(index):
option.direction = Qt.RightToLeft
@ -833,6 +840,7 @@ class Browser(QMainWindow):
self.form.tableView.selectionModel()
self.form.tableView.setItemDelegate(StatusDelegate(self, self.model))
self.form.tableView.selectionModel().selectionChanged.connect(self.onRowChanged)
self.form.tableView.setWordWrap(False)
if not theme_manager.night_mode:
self.form.tableView.setStyleSheet(
"QTableView{ selection-background-color: rgba(150, 150, 150, 50); "
@ -915,30 +923,10 @@ QTableView {{ gridline-color: {grid} }}
def _onSortChanged(self, idx, ord):
type = self.model.activeCols[idx]
noSort = ("question", "answer", "template", "deck", "note", "noteTags")
noSort = ("question", "answer")
if type in noSort:
if type == "template":
showInfo(
_(
"""\
This column can't be sorted on, but you can search for individual card types, \
such as 'card:1'."""
)
)
elif type == "deck":
showInfo(
_(
"""\
This column can't be sorted on, but you can search for specific decks \
by clicking on one on the left."""
)
)
else:
showInfo(
_(
"Sorting on this column is not supported. Please "
"choose another."
)
_("Sorting on this column is not supported. Please " "choose another.")
)
type = self.col.conf["sortType"]
if self.col.conf["sortType"] != type:
@ -947,10 +935,14 @@ by clicking on one on the left."""
if type == "noteFld":
ord = not ord
self.col.conf["sortBackwards"] = ord
self.col.setMod()
self.col.save()
self.search()
else:
if self.col.conf["sortBackwards"] != ord:
self.col.conf["sortBackwards"] = ord
self.col.setMod()
self.col.save()
self.model.reverse()
self.setSortIndicator()

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import os
import re
import time
from concurrent.futures import Future
from typing import List, Optional
import aqt
@ -25,7 +26,7 @@ class ExportDialog(QDialog):
):
QDialog.__init__(self, mw, Qt.Window)
self.mw = mw
self.col = mw.col
self.col = mw.col.weakref()
self.frm = aqt.forms.exporting.Ui_ExportDialog()
self.frm.setupUi(self)
self.exporter = None
@ -131,7 +132,7 @@ class ExportDialog(QDialog):
break
self.hide()
if file:
self.mw.progress.start(immediate=True)
# check we can write to file
try:
f = open(file, "wb")
f.close()
@ -139,38 +140,51 @@ class ExportDialog(QDialog):
showWarning(_("Couldn't save file: %s") % str(e))
else:
os.unlink(file)
exportedMedia = lambda cnt: self.mw.progress.update(
# progress handler
def exported_media(cnt):
self.mw.taskman.run_on_main(
lambda: self.mw.progress.update(
label=ngettext(
"Exported %d media file", "Exported %d media files", cnt
)
% cnt
)
hooks.media_files_did_export.append(exportedMedia)
)
def do_export():
self.exporter.exportInto(file)
hooks.media_files_did_export.remove(exportedMedia)
period = 3000
def on_done(future: Future):
self.mw.progress.finish()
hooks.media_files_did_export.remove(exported_media)
# raises if exporter failed
future.result()
self.on_export_finished()
self.mw.progress.start(immediate=True)
hooks.media_files_did_export.append(exported_media)
self.mw.taskman.run_in_background(do_export, on_done)
def on_export_finished(self):
if self.isVerbatim:
msg = _("Collection exported.")
self.mw.reopen()
else:
if self.isTextNote:
msg = (
ngettext(
"%d note exported.",
"%d notes exported.",
self.exporter.count,
"%d note exported.", "%d notes exported.", self.exporter.count,
)
% self.exporter.count
)
else:
msg = (
ngettext(
"%d card exported.",
"%d cards exported.",
self.exporter.count,
"%d card exported.", "%d cards exported.", self.exporter.count,
)
% self.exporter.count
)
tooltip(msg, period=period)
finally:
self.mw.progress.finish()
QDialog.accept(self)
tooltip(msg, period=3000)
QDialog.reject(self)

View file

@ -49,6 +49,43 @@ class _AddCardsDidAddNoteHook:
add_cards_did_add_note = _AddCardsDidAddNoteHook()
class _AddCardsWillAddNoteFilter:
"""Decides whether the note should be added to the collection or
not. It is assumed to come from the addCards window.
reason_to_already_reject is the first reason to reject that
was found, or None. If your filter wants to reject, it should
replace return the reason to reject. Otherwise return the
input."""
_hooks: List[Callable[[Optional[str], "anki.notes.Note"], Optional[str]]] = []
def append(
self, cb: Callable[[Optional[str], "anki.notes.Note"], Optional[str]]
) -> None:
"""(problem: Optional[str], note: anki.notes.Note)"""
self._hooks.append(cb)
def remove(
self, cb: Callable[[Optional[str], "anki.notes.Note"], Optional[str]]
) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, problem: Optional[str], note: anki.notes.Note) -> Optional[str]:
for filter in self._hooks:
try:
problem = filter(problem, note)
except:
# if the hook fails, remove it
self._hooks.remove(filter)
raise
return problem
add_cards_will_add_note = _AddCardsWillAddNoteFilter()
class _AddCardsWillShowHistoryMenuHook:
_hooks: List[Callable[["aqt.addcards.AddCards", QMenu], None]] = []
@ -272,6 +309,30 @@ class _AvPlayerWillPlayHook:
av_player_will_play = _AvPlayerWillPlayHook()
class _BackupDidCompleteHook:
_hooks: List[Callable[[], None]] = []
def append(self, cb: Callable[[], None]) -> None:
"""()"""
self._hooks.append(cb)
def remove(self, cb: Callable[[], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self) -> None:
for hook in self._hooks:
try:
hook()
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
backup_did_complete = _BackupDidCompleteHook()
class _BrowserDidChangeRowHook:
_hooks: List[Callable[["aqt.browser.Browser"], None]] = []
@ -298,6 +359,32 @@ class _BrowserDidChangeRowHook:
browser_did_change_row = _BrowserDidChangeRowHook()
class _BrowserDidSearchHook:
"""Allows you to modify the list of returned card ids from a search."""
_hooks: List[Callable[["aqt.browser.SearchContext"], None]] = []
def append(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
"""(context: aqt.browser.SearchContext)"""
self._hooks.append(cb)
def remove(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, context: aqt.browser.SearchContext) -> None:
for hook in self._hooks:
try:
hook(context)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
browser_did_search = _BrowserDidSearchHook()
class _BrowserMenusDidInitHook:
_hooks: List[Callable[["aqt.browser.Browser"], None]] = []
@ -423,6 +510,42 @@ class _BrowserWillBuildTreeFilter:
browser_will_build_tree = _BrowserWillBuildTreeFilter()
class _BrowserWillSearchHook:
"""Allows you to modify the search text, or perform your own search.
You can modify context.search to change the text that is sent to the
searching backend.
If you set context.card_ids to a list of ids, the regular search will
not be performed, and the provided ids will be used instead.
Your add-on should check if context.card_ids is not None, and return
without making changes if it has been set.
"""
_hooks: List[Callable[["aqt.browser.SearchContext"], None]] = []
def append(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
"""(context: aqt.browser.SearchContext)"""
self._hooks.append(cb)
def remove(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, context: aqt.browser.SearchContext) -> None:
for hook in self._hooks:
try:
hook(context)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
browser_will_search = _BrowserWillSearchHook()
class _BrowserWillShowHook:
_hooks: List[Callable[["aqt.browser.Browser"], None]] = []
@ -1206,6 +1329,30 @@ class _MediaSyncDidStartOrStopHook:
media_sync_did_start_or_stop = _MediaSyncDidStartOrStopHook()
class _ModelsAdvancedWillShowHook:
_hooks: List[Callable[[QDialog], None]] = []
def append(self, cb: Callable[[QDialog], None]) -> None:
"""(advanced: QDialog)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[QDialog], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, advanced: QDialog) -> None:
for hook in self._hooks:
try:
hook(advanced)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
models_advanced_will_show = _ModelsAdvancedWillShowHook()
class _OverviewDidRefreshHook:
"""Allow to update the overview window. E.g. add the deck name in the
title."""

View file

@ -9,6 +9,7 @@ import shutil
import traceback
import unicodedata
import zipfile
from concurrent.futures import Future
import anki.importing as importing
import aqt.deckchooser
@ -74,6 +75,7 @@ class ChangeMap(QDialog):
self.accept()
# called by importFile() when importing a mappable file like .csv
class ImportDialog(QDialog):
def __init__(self, mw: AnkiQt, importer) -> None:
QDialog.__init__(self, mw, Qt.Window)
@ -192,8 +194,12 @@ you can enter it here. Use \\t to represent tab."""
self.mw.col.decks.select(did)
self.mw.progress.start(immediate=True)
self.mw.checkpoint(_("Import"))
def on_done(future: Future):
self.mw.progress.finish()
try:
self.importer.run()
future.result()
except UnicodeDecodeError:
showUnicodeWarning()
return
@ -208,8 +214,7 @@ you can enter it here. Use \\t to represent tab."""
msg += traceback.format_exc()
showText(msg)
return
finally:
self.mw.progress.finish()
else:
txt = _("Importing complete.") + "\n"
if self.importer.log:
txt += "\n".join(self.importer.log)
@ -217,6 +222,8 @@ you can enter it here. Use \\t to represent tab."""
showText(txt)
self.mw.reset()
self.mw.taskman.run_in_background(self.importer.run, on_done)
def setupMappingFrame(self):
# qt seems to have a bug with adding/removing from a grid, so we add
# to a separate object and add/remove that instead
@ -380,15 +387,19 @@ def importFile(mw, file):
except:
showWarning(invalidZipMsg())
return
# we need to ask whether to import/replace
# we need to ask whether to import/replace; if it's
# a colpkg file then the rest of the import process
# will happen in setupApkgImport()
if not setupApkgImport(mw, importer):
return
# importing non-colpkg files
mw.progress.start(immediate=True)
try:
try:
importer.run()
finally:
def on_done(future: Future):
mw.progress.finish()
try:
future.result()
except zipfile.BadZipfile:
showWarning(invalidZipMsg())
except Exception as e:
@ -396,7 +407,7 @@ def importFile(mw, file):
if "invalidFile" in err:
msg = _(
"""\
Invalid file. Please restore from backup."""
Invalid file. Please restore from backup."""
)
showWarning(msg)
elif "invalidTempFolder" in err:
@ -405,7 +416,7 @@ Invalid file. Please restore from backup."""
showWarning(
_(
"""\
Unable to import from a read-only file."""
Unable to import from a read-only file."""
)
)
else:
@ -418,8 +429,11 @@ Unable to import from a read-only file."""
tooltip(log)
else:
showText(log)
mw.reset()
mw.taskman.run_in_background(importer.run, on_done)
def invalidZipMsg():
return _(
@ -459,10 +473,11 @@ def replaceWithApkg(mw, file, backup):
mw.unloadCollection(lambda: _replaceWithApkg(mw, file, backup))
def _replaceWithApkg(mw, file, backup):
def _replaceWithApkg(mw, filename, backup):
mw.progress.start(immediate=True)
z = zipfile.ZipFile(file)
def do_import():
z = zipfile.ZipFile(filename)
# v2 scheduler?
colname = "collection.anki21"
@ -471,23 +486,18 @@ def _replaceWithApkg(mw, file, backup):
except KeyError:
colname = "collection.anki2"
try:
with z.open(colname) as source, open(mw.pm.collectionPath(), "wb") as target:
shutil.copyfileobj(source, target)
except:
mw.progress.finish()
showWarning(_("The provided file is not a valid .apkg file."))
return
# because users don't have a backup of media, it's safer to import new
# data and rely on them running a media db check to get rid of any
# unwanted media. in the future we might also want to deduplicate this
# step
d = os.path.join(mw.pm.profileFolder(), "collection.media")
for n, (cStr, file) in enumerate(
json.loads(z.read("media").decode("utf8")).items()
):
mw.progress.update(
ngettext("Processed %d media file", "Processed %d media files", n) % n
mw.taskman.run_on_main(
lambda n=n: mw.progress.update(
ngettext("Processed %d media file", "Processed %d media files", n)
% n
)
)
size = z.getinfo(cStr).file_size
dest = os.path.join(d, unicodedata.normalize("NFC", file))
@ -496,11 +506,24 @@ def _replaceWithApkg(mw, file, backup):
continue
data = z.read(cStr)
open(dest, "wb").write(data)
z.close()
# reload
if not mw.loadCollection():
def on_done(future: Future):
mw.progress.finish()
try:
future.result()
except Exception as e:
print(e)
showWarning(_("The provided file is not a valid .apkg file."))
return
if not mw.loadCollection():
return
if backup:
mw.col.modSchema(check=False)
mw.progress.finish()
tooltip(_("Importing complete."))
mw.taskman.run_in_background(do_import, on_done)

View file

@ -12,6 +12,7 @@ import signal
import time
import zipfile
from argparse import Namespace
from concurrent.futures import Future
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
@ -28,6 +29,7 @@ from anki import hooks
from anki.collection import _Collection
from anki.hooks import runHook
from anki.lang import _, ngettext
from anki.rsbackend import RustBackend
from anki.sound import AVTag, SoundOrVideoTag
from anki.storage import Collection
from anki.utils import devMode, ids2str, intTime, isMac, isWin, splitFields
@ -77,10 +79,12 @@ class AnkiQt(QMainWindow):
self,
app: QApplication,
profileManager: ProfileManagerType,
backend: RustBackend,
opts: Namespace,
args: List[Any],
) -> None:
QMainWindow.__init__(self)
self.backend = backend
self.state = "startup"
self.opts = opts
self.col: Optional[_Collection] = None
@ -393,7 +397,7 @@ close the profile or restart Anki."""
# at this point there should be no windows left
self._checkForUnclosedWidgets()
self.maybeAutoSync(True)
self.maybeAutoSync()
def _checkForUnclosedWidgets(self) -> None:
for w in self.app.topLevelWidgets():
@ -458,18 +462,22 @@ close the profile or restart Anki."""
def _loadCollection(self) -> bool:
cpath = self.pm.collectionPath()
self.col = Collection(cpath, log=True)
self.col = Collection(cpath, backend=self.backend)
self.setEnabled(True)
self.progress.setupDB(self.col.db)
self.maybeEnableUndo()
gui_hooks.collection_did_load(self.col)
self.moveToState("deckBrowser")
return True
def reopen(self):
cpath = self.pm.collectionPath()
self.col = Collection(cpath, backend=self.backend)
def unloadCollection(self, onsuccess: Callable) -> None:
def callback():
self.setEnabled(False)
self.media_syncer.show_diag_until_finished()
self._unloadCollection()
onsuccess()
@ -561,6 +569,7 @@ from the profile screen."
fname = backups.pop(0)
path = os.path.join(dir, fname)
os.unlink(path)
gui_hooks.backup_did_complete()
def maybeOptimize(self) -> None:
# have two weeks passed?
@ -594,14 +603,6 @@ from the profile screen."
self.maybe_check_for_addon_updates()
self.deckBrowser.show()
def _colLoadingState(self, oldState) -> None:
"Run once, when col is loaded."
self.enableColMenuItems()
# ensure cwd is set if media dir exists
self.col.media.dir()
gui_hooks.collection_did_load(self.col)
self.moveToState("overview")
def _selectedDeck(self) -> Optional[Dict[str, Any]]:
did = self.col.decks.selected()
if not self.col.decks.nameOrNone(did):
@ -753,10 +754,7 @@ title="%s" %s>%s</button>""" % (
signal.signal(signal.SIGINT, self.onSigInt)
def onSigInt(self, signum, frame):
# interrupt any current transaction and schedule a rollback & quit
if self.col:
self.col.db.interrupt()
# schedule a rollback & quit
def quit():
self.col.db.rollback()
self.close()
@ -841,7 +839,7 @@ title="%s" %s>%s</button>""" % (
self.media_syncer.start()
# expects a current profile, but no collection loaded
def maybeAutoSync(self, closing=False) -> None:
def maybeAutoSync(self) -> None:
if (
not self.pm.profile["syncKey"]
or not self.pm.profile["autoSync"]
@ -853,10 +851,6 @@ title="%s" %s>%s</button>""" % (
# ok to sync
self._sync()
# if media still syncing at this point, pop up progress diag
if closing:
self.media_syncer.show_diag_until_finished()
def maybe_auto_sync_media(self) -> None:
if not self.pm.profile["autoSync"] or self.safeMode or self.restoringBackup:
return
@ -1262,9 +1256,12 @@ will be lost. Continue?"""
def onCheckDB(self):
"True if no problems"
self.progress.start(immediate=True)
ret, ok = self.col.fixIntegrity()
self.progress.start()
def onDone(future: Future):
self.progress.finish()
ret, ok = future.result()
if not ok:
showText(ret)
else:
@ -1280,7 +1277,8 @@ will be lost. Continue?"""
except Exception as e:
print("swallowed exception in reset hook:", e)
continue
return ret
self.taskman.run_in_background(self.col.fixIntegrity, onDone)
def on_check_media_db(self) -> None:
check_media_db(self)
@ -1363,11 +1361,42 @@ will be lost. Continue?"""
sys.stderr = self._oldStderr
sys.stdout = self._oldStdout
def _debugCard(self):
return self.reviewer.card.__dict__
def _card_repr(self, card: anki.cards.Card) -> None:
import pprint, copy
def _debugBrowserCard(self):
return aqt.dialogs._dialogs["Browser"][1].card.__dict__
if not card:
print("no card")
return
print("Front:", card.question())
print("\n")
print("Back:", card.answer())
print("\nNote:")
note = copy.copy(card.note())
for k, v in note.items():
print(f"- {k}:", v)
print("\n")
del note.fields
del note._fmap
del note._model
pprint.pprint(note.__dict__)
print("\nCard:")
c = copy.copy(card)
c._render_output = None
pprint.pprint(c.__dict__)
def _debugCard(self) -> Optional[anki.cards.Card]:
card = self.reviewer.card
self._card_repr(card)
return card
def _debugBrowserCard(self) -> Optional[anki.cards.Card]:
card = aqt.dialogs._dialogs["Browser"][1].card
self._card_repr(card)
return card
def onDebugPrint(self, frm):
cursor = frm.text.textCursor()
@ -1528,7 +1557,6 @@ Please ensure a profile is open and Anki is not busy, then try again."""
gc.disable()
def doGC(self) -> None:
assert not self.progress.inDB
gc.collect()
# Crash log

View file

@ -40,7 +40,6 @@ class MediaChecker:
def check(self) -> None:
self.progress_dialog = self.mw.progress.start()
hooks.bg_thread_progress_callback.append(self._on_progress)
self.mw.col.close()
self.mw.taskman.run_in_background(self._check, self._on_finished)
def _on_progress(self, proceed: bool, progress: Progress) -> bool:
@ -61,7 +60,6 @@ class MediaChecker:
hooks.bg_thread_progress_callback.remove(self._on_progress)
self.mw.progress.finish()
self.progress_dialog = None
self.mw.col.reopen()
exc = future.exception()
if isinstance(exc, Interrupted):

View file

@ -11,7 +11,14 @@ from typing import List, Union
import aqt
from anki import hooks
from anki.consts import SYNC_BASE
from anki.rsbackend import TR, Interrupted, MediaSyncProgress, Progress, ProgressKind
from anki.rsbackend import (
TR,
Interrupted,
MediaSyncProgress,
NetworkError,
Progress,
ProgressKind,
)
from anki.types import assert_impossible
from anki.utils import intTime
from aqt import gui_hooks
@ -100,6 +107,10 @@ class MediaSyncer:
if isinstance(exc, Interrupted):
self._log_and_notify(tr(TR.SYNC_MEDIA_ABORTED))
return
elif isinstance(exc, NetworkError):
# avoid popups for network errors
self._log_and_notify(str(exc))
return
self._log_and_notify(tr(TR.SYNC_MEDIA_FAILED))
showWarning(str(exc))

View file

@ -6,7 +6,7 @@ from operator import itemgetter
import aqt.clayout
from anki import stdmodels
from anki.lang import _, ngettext
from aqt import AnkiQt
from aqt import AnkiQt, gui_hooks
from aqt.qt import *
from aqt.utils import (
askUser,
@ -124,6 +124,7 @@ class Models(QDialog):
d.setWindowTitle(_("Options for %s") % self.model["name"])
frm.buttonBox.helpRequested.connect(lambda: openHelp("latex"))
restoreGeom(d, "modelopts")
gui_hooks.models_advanced_will_show(d)
d.exec_()
saveGeom(d, "modelopts")
self.model["latexsvg"] = frm.latexsvg.isChecked()

View file

@ -62,6 +62,8 @@ class Preferences(QDialog):
lang = anki.lang.currentLang
if lang in anki.lang.compatMap:
lang = anki.lang.compatMap[lang]
else:
lang = lang.replace("-", "_")
try:
return codes.index(lang)
except:
@ -98,7 +100,7 @@ class Preferences(QDialog):
f.new_timezone.setVisible(False)
else:
f.newSched.setChecked(True)
f.new_timezone.setChecked(self.mw.col.sched._new_timezone_enabled())
f.new_timezone.setChecked(self.mw.col.sched.new_timezone_enabled())
def updateCollection(self):
f = self.form
@ -124,7 +126,7 @@ class Preferences(QDialog):
qc["dayLearnFirst"] = f.dayLearnFirst.isChecked()
self._updateDayCutoff()
if self.mw.col.schedVer() != 1:
was_enabled = self.mw.col.sched._new_timezone_enabled()
was_enabled = self.mw.col.sched.new_timezone_enabled()
is_enabled = f.new_timezone.isChecked()
if was_enabled != is_enabled:
if is_enabled:

View file

@ -11,10 +11,6 @@ import aqt.forms
from anki.lang import _
from aqt.qt import *
# fixme: if mw->subwindow opens a progress dialog with mw as the parent, mw
# gets raised on finish on compiz. perhaps we should be using the progress
# dialog as the parent?
# Progress info
##########################################################################
@ -25,47 +21,18 @@ class ProgressManager:
self.app = QApplication.instance()
self.inDB = False
self.blockUpdates = False
self._show_timer: Optional[QTimer] = None
self._win = None
self._levels = 0
# SQLite progress handler
##########################################################################
def setupDB(self, db):
"Install a handler in the current DB."
self.lastDbProgress = 0
self.inDB = False
db.set_progress_handler(self._dbProgress, 10000)
def _dbProgress(self):
"Called from SQLite."
# do nothing if we don't have a progress window
if not self._win:
return
# make sure we're not executing too frequently
if (time.time() - self.lastDbProgress) < 0.01:
return
self.lastDbProgress = time.time()
# and we're in the main thread
if not self.mw.inMainThread():
return
# ensure timers don't fire
self.inDB = True
# handle GUI events
if not self.blockUpdates:
self._maybeShow()
self.app.processEvents(QEventLoop.ExcludeUserInputEvents)
self.inDB = False
# Safer timers
##########################################################################
# QTimer may fire in processEvents(). We provide a custom timer which
# automatically defers until the DB is not busy, and avoids running
# while a progress window is visible.
# A custom timer which avoids firing while a progress dialog is active
# (likely due to some long-running DB operation)
def timer(self, ms, func, repeat, requiresCollection=True):
def handler():
if self.inDB or self._levels:
if self._levels:
# retry in 100ms
self.timer(100, func, False, requiresCollection)
elif not self.mw.col and requiresCollection:
@ -114,10 +81,17 @@ class ProgressManager:
self._firstTime = time.time()
self._lastUpdate = time.time()
self._updating = False
self._show_timer = QTimer(self.mw)
self._show_timer.setSingleShot(True)
self._show_timer.start(600)
self._show_timer.timeout.connect(self._on_show_timer) # type: ignore
return self._win
def update(self, label=None, value=None, process=True, maybeShow=True):
# print self._min, self._counter, self._max, label, time.time() - self._lastTime
if not self.mw.inMainThread():
print("progress.update() called on wrong thread")
return
if self._updating:
return
if maybeShow:
@ -143,6 +117,9 @@ class ProgressManager:
if self._win:
self._closeWin()
self._unsetBusy()
if self._show_timer:
self._show_timer.stop()
self._show_timer = None
def clear(self):
"Restore the interface after an error."
@ -189,6 +166,10 @@ class ProgressManager:
"True if processing."
return self._levels
def _on_show_timer(self):
self._show_timer = None
self._showWin()
class ProgressDialog(QDialog):
def __init__(self, parent):

View file

@ -393,6 +393,7 @@ class SimpleMplayerSlaveModePlayer(SimpleMplayerPlayer):
The trailing newline is automatically added."""
str_args = [str(x) for x in args]
if self._process:
self._process.stdin.write(" ".join(str_args).encode("utf8") + b"\n")
self._process.stdin.flush()

View file

@ -364,7 +364,7 @@ class SyncThread(QThread):
self.syncMsg = ""
self.uname = ""
try:
self.col = Collection(self.path, log=True)
self.col = Collection(self.path)
except:
self.fireEvent("corrupt")
return

View file

@ -4,7 +4,7 @@
import platform
import sys
from typing import Dict
from typing import Dict, Optional
from anki.utils import isMac
from aqt import QApplication, gui_hooks, isWin
@ -17,6 +17,7 @@ class ThemeManager:
_icon_cache_light: Dict[str, QIcon] = {}
_icon_cache_dark: Dict[str, QIcon] = {}
_icon_size = 128
_macos_dark_mode_cached: Optional[bool] = None
def macos_dark_mode(self) -> bool:
if not getattr(sys, "frozen", False):
@ -25,9 +26,13 @@ class ThemeManager:
return False
if qtminor < 13:
return False
if self._macos_dark_mode_cached is None:
import darkdetect # pylint: disable=import-error
return darkdetect.isDark() is True
# cache the value, as the interface gets messed up
# if the value changes after starting Anki
self._macos_dark_mode_cached = darkdetect.isDark() is True
return self._macos_dark_mode_cached
def get_night_mode(self) -> bool:
return self.macos_dark_mode() or self._night_mode_preference

View file

@ -235,6 +235,26 @@ hooks = [
return True
""",
),
Hook(
name="browser_will_search",
args=["context: aqt.browser.SearchContext"],
doc="""Allows you to modify the search text, or perform your own search.
You can modify context.search to change the text that is sent to the
searching backend.
If you set context.card_ids to a list of ids, the regular search will
not be performed, and the provided ids will be used instead.
Your add-on should check if context.card_ids is not None, and return
without making changes if it has been set.
""",
),
Hook(
name="browser_did_search",
args=["context: aqt.browser.SearchContext"],
doc="""Allows you to modify the list of returned card ids from a search.""",
),
# States
###################
Hook(
@ -341,6 +361,7 @@ hooks = [
),
# Main
###################
Hook(name="backup_did_complete"),
Hook(name="profile_did_open", legacy_hook="profileLoaded"),
Hook(name="profile_will_close", legacy_hook="unloadProfile"),
Hook(
@ -412,6 +433,18 @@ def emptyNewCard():
args=["note: anki.notes.Note"],
legacy_hook="AddCards.noteAdded",
),
Hook(
name="add_cards_will_add_note",
args=["problem: Optional[str]", "note: anki.notes.Note"],
return_type="Optional[str]",
doc="""Decides whether the note should be added to the collection or
not. It is assumed to come from the addCards window.
reason_to_already_reject is the first reason to reject that
was found, or None. If your filter wants to reject, it should
replace return the reason to reject. Otherwise return the
input.""",
),
# Editing
###################
Hook(
@ -503,6 +536,9 @@ def emptyNewCard():
args=["dialog: aqt.addons.AddonsDialog", "add_on: aqt.addons.AddonMeta"],
doc="""Allows doing an action when a single add-on is selected.""",
),
# Model
###################
Hook(name="models_advanced_will_show", args=["advanced: QDialog"],),
# Other
###################
Hook(

View file

@ -45,6 +45,8 @@ img {
#typeans {
width: 100%;
// https://anki.tenderapp.com/discussions/beta-testing/1854-using-margin-auto-causes-horizontal-scrollbar-on-typesomething
box-sizing: border-box;
}
.typeGood {

View file

@ -1,6 +1,6 @@
[package]
name = "anki"
version = "2.1.22" # automatically updated
version = "2.1.24" # automatically updated
edition = "2018"
authors = ["Ankitects Pty Ltd and contributors"]
license = "AGPL-3.0-or-later"
@ -36,12 +36,15 @@ slog = { version = "2.5.2", features = ["max_level_trace", "release_max_level_de
slog-term = "2.5.0"
slog-async = "2.4.0"
slog-envlogger = "2.2.0"
serde_repr = "0.1.5"
num_enum = "0.4.2"
unicase = "2.6.0"
[target.'cfg(target_vendor="apple")'.dependencies]
rusqlite = { version = "0.21.0", features = ["trace"] }
rusqlite = { version = "0.21.0", features = ["trace", "functions", "collation"] }
[target.'cfg(not(target_vendor="apple"))'.dependencies]
rusqlite = { version = "0.21.0", features = ["trace", "bundled"] }
rusqlite = { version = "0.21.0", features = ["trace", "functions", "collation", "bundled"] }
[target.'cfg(linux)'.dependencies]
reqwest = { version = "0.10.1", features = ["json", "native-tls-vendored"] }

View file

@ -25,7 +25,7 @@ develop: .build/vernum ftl/repo
ftl/repo:
(cd ftl && ./scripts/fetch-latest-translations)
ALL_SOURCE := $(shell ${FIND} src -type f) $(wildcard ftl/*.ftl)
ALL_SOURCE := $(shell ${FIND} src -type f | egrep -v "i18n/autogen|i18n/ftl|_proto.rs") $(wildcard ftl/*.ftl)
# nightly currently required for ignoring files in rustfmt.toml
RUST_TOOLCHAIN := $(shell cat rust-toolchain)

View file

@ -31,3 +31,4 @@ sync-client-too-old =
sync-wrong-pass = AnkiWeb ID or password was incorrect; please try again.
sync-resync-required =
Please sync again. If this message keeps appearing, please post on the support site.
sync-must-wait-for-end = Anki is currently syncing. Please wait for the sync to complete, then try again.

View file

@ -0,0 +1,155 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::Result;
use crate::storage::StorageContext;
use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
use rusqlite::OptionalExtension;
use serde_derive::{Deserialize, Serialize};
#[derive(Deserialize)]
#[serde(tag = "kind", rename_all = "lowercase")]
pub(super) enum DBRequest {
Query {
sql: String,
args: Vec<SqlValue>,
first_row_only: bool,
},
Begin,
Commit,
Rollback,
ExecuteMany {
sql: String,
args: Vec<Vec<SqlValue>>,
},
}
#[derive(Serialize)]
#[serde(untagged)]
pub(super) enum DBResult {
Rows(Vec<Vec<SqlValue>>),
None,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub(super) enum SqlValue {
Null,
String(String),
Int(i64),
Double(f64),
Blob(Vec<u8>),
}
impl ToSql for SqlValue {
fn to_sql(&self) -> std::result::Result<ToSqlOutput<'_>, rusqlite::Error> {
let val = match self {
SqlValue::Null => ValueRef::Null,
SqlValue::String(v) => ValueRef::Text(v.as_bytes()),
SqlValue::Int(v) => ValueRef::Integer(*v),
SqlValue::Double(v) => ValueRef::Real(*v),
SqlValue::Blob(v) => ValueRef::Blob(&v),
};
Ok(ToSqlOutput::Borrowed(val))
}
}
impl FromSql for SqlValue {
fn column_result(value: ValueRef<'_>) -> std::result::Result<Self, FromSqlError> {
let val = match value {
ValueRef::Null => SqlValue::Null,
ValueRef::Integer(i) => SqlValue::Int(i),
ValueRef::Real(v) => SqlValue::Double(v),
ValueRef::Text(v) => SqlValue::String(String::from_utf8_lossy(v).to_string()),
ValueRef::Blob(v) => SqlValue::Blob(v.to_vec()),
};
Ok(val)
}
}
pub(super) fn db_command_bytes(ctx: &StorageContext, input: &[u8]) -> Result<String> {
let req: DBRequest = serde_json::from_slice(input)?;
let resp = match req {
DBRequest::Query {
sql,
args,
first_row_only,
} => {
if first_row_only {
db_query_row(ctx, &sql, &args)?
} else {
db_query(ctx, &sql, &args)?
}
}
DBRequest::Begin => {
ctx.begin_trx()?;
DBResult::None
}
DBRequest::Commit => {
ctx.commit_trx()?;
DBResult::None
}
DBRequest::Rollback => {
ctx.rollback_trx()?;
DBResult::None
}
DBRequest::ExecuteMany { sql, args } => db_execute_many(ctx, &sql, &args)?,
};
Ok(serde_json::to_string(&resp)?)
}
pub(super) fn db_query_row(ctx: &StorageContext, sql: &str, args: &[SqlValue]) -> Result<DBResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
let columns = stmt.column_count();
let row = stmt
.query_row(args, |row| {
let mut orow = Vec::with_capacity(columns);
for i in 0..columns {
let v: SqlValue = row.get(i)?;
orow.push(v);
}
Ok(orow)
})
.optional()?;
let rows = if let Some(row) = row {
vec![row]
} else {
vec![]
};
Ok(DBResult::Rows(rows))
}
pub(super) fn db_query(ctx: &StorageContext, sql: &str, args: &[SqlValue]) -> Result<DBResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
let columns = stmt.column_count();
let res: std::result::Result<Vec<Vec<_>>, rusqlite::Error> = stmt
.query_map(args, |row| {
let mut orow = Vec::with_capacity(columns);
for i in 0..columns {
let v: SqlValue = row.get(i)?;
orow.push(v);
}
Ok(orow)
})?
.collect();
Ok(DBResult::Rows(res?))
}
pub(super) fn db_execute_many(
ctx: &StorageContext,
sql: &str,
args: &[Vec<SqlValue>],
) -> Result<DBResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
for params in args {
stmt.execute(params)?;
}
Ok(DBResult::None)
}

View file

@ -1,8 +1,11 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::backend::dbproxy::db_command_bytes;
use crate::backend_proto::backend_input::Value;
use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn};
use crate::backend_proto::{BuiltinSortKind, Empty, RenderedTemplateReplacement, SyncMediaIn};
use crate::collection::{open_collection, Collection};
use crate::config::SortKind;
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
use crate::i18n::{tr_args, FString, I18n};
use crate::latex::{extract_latex, extract_latex_expanding_clozes, ExtractedLatex};
@ -12,6 +15,7 @@ use crate::media::sync::MediaSyncProgress;
use crate::media::MediaManager;
use crate::sched::cutoff::{local_minutes_west_for_stamp, sched_timing_today};
use crate::sched::timespan::{answer_button_time, learning_congrats, studied_today, time_span};
use crate::search::{search_cards, search_notes, SortMode};
use crate::template::{
render_card, without_legacy_template_directives, FieldMap, FieldRequirements, ParsedTemplate,
RenderedNode,
@ -22,18 +26,18 @@ use fluent::FluentValue;
use prost::Message;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tokio::runtime::Runtime;
mod dbproxy;
pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>;
pub struct Backend {
#[allow(dead_code)]
col_path: PathBuf,
media_folder: PathBuf,
media_db: String,
col: Arc<Mutex<Option<Collection>>>,
progress_callback: Option<ProtoProgressCallback>,
i18n: I18n,
log: Logger,
server: bool,
}
enum Progress<'a> {
@ -55,6 +59,8 @@ fn anki_error_to_proto_error(err: AnkiError, i18n: &I18n) -> pb::BackendError {
}
AnkiError::SyncError { kind, .. } => V::SyncError(pb::SyncError { kind: kind.into() }),
AnkiError::Interrupted => V::Interrupted(Empty {}),
AnkiError::CollectionNotOpen => V::InvalidInput(pb::Empty {}),
AnkiError::CollectionAlreadyOpen => V::InvalidInput(pb::Empty {}),
};
pb::BackendError {
@ -103,50 +109,27 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
Err(_) => return Err("couldn't decode init request".into()),
};
let mut path = input.collection_path.clone();
path.push_str(".log");
let log_path = match input.log_path.as_str() {
"" => None,
path => Some(path),
};
let logger =
default_logger(log_path).map_err(|e| format!("Unable to open log file: {:?}", e))?;
let i18n = I18n::new(
&input.preferred_langs,
input.locale_folder_path,
log::terminal(),
);
match Backend::new(
&input.collection_path,
&input.media_folder_path,
&input.media_db_path,
i18n,
logger,
) {
Ok(backend) => Ok(backend),
Err(e) => Err(format!("{:?}", e)),
}
Ok(Backend::new(i18n, input.server))
}
impl Backend {
pub fn new(
col_path: &str,
media_folder: &str,
media_db: &str,
i18n: I18n,
log: Logger,
) -> Result<Backend> {
Ok(Backend {
col_path: col_path.into(),
media_folder: media_folder.into(),
media_db: media_db.into(),
pub fn new(i18n: I18n, server: bool) -> Backend {
Backend {
col: Arc::new(Mutex::new(None)),
progress_callback: None,
i18n,
log,
})
server,
}
}
pub fn i18n(&self) -> &I18n {
&self.i18n
}
/// Decode a request, process it, and return the encoded result.
@ -172,6 +155,22 @@ impl Backend {
buf
}
/// If collection is open, run the provided closure while holding
/// the mutex.
/// If collection is not open, return an error.
fn with_col<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&mut Collection) -> Result<T>,
{
func(
self.col
.lock()
.unwrap()
.as_mut()
.ok_or(AnkiError::CollectionNotOpen)?,
)
}
fn run_command(&mut self, input: pb::BackendInput) -> pb::BackendOutput {
let oval = if let Some(ival) = input.value {
match self.run_command_inner(ival) {
@ -202,8 +201,6 @@ impl Backend {
OValue::SchedTimingToday(self.sched_timing_today(input))
}
Value::DeckTree(_) => todo!(),
Value::FindCards(_) => todo!(),
Value::BrowserRows(_) => todo!(),
Value::RenderCard(input) => OValue::RenderCard(self.render_template(input)?),
Value::LocalMinutesWest(stamp) => {
OValue::LocalMinutesWest(local_minutes_west_for_stamp(stamp))
@ -241,9 +238,63 @@ impl Backend {
self.restore_trash()?;
OValue::RestoreTrash(Empty {})
}
Value::OpenCollection(input) => {
self.open_collection(input)?;
OValue::OpenCollection(Empty {})
}
Value::CloseCollection(_) => {
self.close_collection()?;
OValue::CloseCollection(Empty {})
}
Value::SearchCards(input) => OValue::SearchCards(self.search_cards(input)?),
Value::SearchNotes(input) => OValue::SearchNotes(self.search_notes(input)?),
})
}
fn open_collection(&self, input: pb::OpenCollectionIn) -> Result<()> {
let mut col = self.col.lock().unwrap();
if col.is_some() {
return Err(AnkiError::CollectionAlreadyOpen);
}
let mut path = input.collection_path.clone();
path.push_str(".log");
let log_path = match input.log_path.as_str() {
"" => None,
path => Some(path),
};
let logger = default_logger(log_path)?;
let new_col = open_collection(
input.collection_path,
input.media_folder_path,
input.media_db_path,
self.server,
self.i18n.clone(),
logger,
)?;
*col = Some(new_col);
Ok(())
}
fn close_collection(&self) -> Result<()> {
let mut col = self.col.lock().unwrap();
if col.is_none() {
return Err(AnkiError::CollectionNotOpen);
}
if !col.as_ref().unwrap().can_close() {
return Err(AnkiError::invalid_input("can't close yet"));
}
*col = None;
Ok(())
}
fn fire_progress_callback(&self, progress: Progress) -> bool {
if let Some(cb) = &self.progress_callback {
let bytes = progress_to_proto_bytes(progress, &self.i18n);
@ -301,10 +352,10 @@ impl Backend {
fn sched_timing_today(&self, input: pb::SchedTimingTodayIn) -> pb::SchedTimingTodayOut {
let today = sched_timing_today(
input.created_secs as i64,
input.created_mins_west,
input.now_secs as i64,
input.now_mins_west,
input.rollover_hour as i8,
input.created_mins_west.map(|v| v.val),
input.now_mins_west.map(|v| v.val),
input.rollover_hour.map(|v| v.val as i8),
);
pb::SchedTimingTodayOut {
days_elapsed: today.days_elapsed,
@ -389,30 +440,60 @@ impl Backend {
}
fn add_media_file(&mut self, input: pb::AddMediaFileIn) -> Result<String> {
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mut ctx = mgr.dbctx();
Ok(mgr
.add_file(&mut ctx, &input.desired_name, &input.data)?
.into())
})
}
fn sync_media(&self, input: SyncMediaIn) -> Result<()> {
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
// fixme: will block other db access
fn sync_media(&self, input: SyncMediaIn) -> Result<()> {
let mut guard = self.col.lock().unwrap();
let col = guard.as_mut().unwrap();
col.set_media_sync_running()?;
let folder = col.media_folder.clone();
let db = col.media_db.clone();
let log = col.log.clone();
drop(guard);
let res = self.sync_media_inner(input, folder, db, log);
self.with_col(|col| col.set_media_sync_finished())?;
res
}
fn sync_media_inner(
&self,
input: pb::SyncMediaIn,
folder: PathBuf,
db: PathBuf,
log: Logger,
) -> Result<()> {
let callback = |progress: &MediaSyncProgress| {
self.fire_progress_callback(Progress::MediaSync(progress))
};
let mgr = MediaManager::new(&folder, &db)?;
let mut rt = Runtime::new().unwrap();
rt.block_on(mgr.sync_media(callback, &input.endpoint, &input.hkey, self.log.clone()))
rt.block_on(mgr.sync_media(callback, &input.endpoint, &input.hkey, log))
}
fn check_media(&self) -> Result<pb::MediaCheckOut> {
let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
let mut checker = MediaChecker::new(&mgr, &self.col_path, callback, &self.i18n, &self.log);
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
let mut output = checker.check()?;
let report = checker.summarize_output(&mut output);
@ -423,12 +504,16 @@ impl Backend {
report,
have_trash: output.trash_count > 0,
})
})
})
}
fn remove_media_files(&self, fnames: &[String]) -> Result<()> {
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mut ctx = mgr.dbctx();
mgr.remove_files(&mut ctx, fnames)
})
}
fn translate_string(&self, input: pb::TranslateStringIn) -> String {
@ -466,20 +551,66 @@ impl Backend {
let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
let mut checker = MediaChecker::new(&mgr, &self.col_path, callback, &self.i18n, &self.log);
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
checker.empty_trash()
})
})
}
fn restore_trash(&self) -> Result<()> {
let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
let mut checker = MediaChecker::new(&mgr, &self.col_path, callback, &self.i18n, &self.log);
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
checker.restore_trash()
})
})
}
pub fn db_command(&self, input: &[u8]) -> Result<String> {
self.with_col(|col| col.with_ctx(|ctx| db_command_bytes(&ctx.storage, input)))
}
fn search_cards(&self, input: pb::SearchCardsIn) -> Result<pb::SearchCardsOut> {
self.with_col(|col| {
col.with_ctx(|ctx| {
let order = if let Some(order) = input.order {
use pb::sort_order::Value as V;
match order.value {
Some(V::None(_)) => SortMode::NoOrder,
Some(V::Custom(s)) => SortMode::Custom(s),
Some(V::FromConfig(_)) => SortMode::FromConfig,
Some(V::Builtin(b)) => SortMode::Builtin {
kind: sort_kind_from_pb(b.kind),
reverse: b.reverse,
},
None => SortMode::FromConfig,
}
} else {
SortMode::FromConfig
};
let cids = search_cards(ctx, &input.search, order)?;
Ok(pb::SearchCardsOut { card_ids: cids })
})
})
}
fn search_notes(&self, input: pb::SearchNotesIn) -> Result<pb::SearchNotesOut> {
self.with_col(|col| {
col.with_ctx(|ctx| {
let nids = search_notes(ctx, &input.search)?;
Ok(pb::SearchNotesOut { note_ids: nids })
})
})
}
}
@ -552,51 +683,24 @@ fn media_sync_progress(p: &MediaSyncProgress, i18n: &I18n) -> pb::MediaSyncProgr
}
}
/// Standalone I18n backend
/// This is a hack to allow translating strings in the GUI
/// when a collection is not open, and in the future it should
/// either be shared with or merged into the backend object.
///////////////////////////////////////////////////////
pub struct I18nBackend {
i18n: I18n,
}
pub fn init_i18n_backend(init_msg: &[u8]) -> Result<I18nBackend> {
let input: pb::I18nBackendInit = match pb::I18nBackendInit::decode(init_msg) {
Ok(req) => req,
Err(_) => return Err(AnkiError::invalid_input("couldn't decode init msg")),
};
let log = log::terminal();
let i18n = I18n::new(&input.preferred_langs, input.locale_folder_path, log);
Ok(I18nBackend { i18n })
}
impl I18nBackend {
pub fn translate(&self, req: &[u8]) -> String {
let req = match pb::TranslateStringIn::decode(req) {
Ok(req) => req,
Err(_e) => return "decoding error".into(),
};
self.translate_string(req)
}
fn translate_string(&self, input: pb::TranslateStringIn) -> String {
let key = match pb::FluentString::from_i32(input.key) {
Some(key) => key,
None => return "invalid key".to_string(),
};
let map = input
.args
.iter()
.map(|(k, v)| (k.as_str(), translate_arg_to_fluent_val(&v)))
.collect();
self.i18n.trn(key, map)
fn sort_kind_from_pb(kind: i32) -> SortKind {
use SortKind as SK;
match pb::BuiltinSortKind::from_i32(kind) {
Some(pbkind) => match pbkind {
BuiltinSortKind::NoteCreation => SK::NoteCreation,
BuiltinSortKind::NoteMod => SK::NoteMod,
BuiltinSortKind::NoteField => SK::NoteField,
BuiltinSortKind::NoteTags => SK::NoteTags,
BuiltinSortKind::NoteType => SK::NoteType,
BuiltinSortKind::CardMod => SK::CardMod,
BuiltinSortKind::CardReps => SK::CardReps,
BuiltinSortKind::CardDue => SK::CardDue,
BuiltinSortKind::CardEase => SK::CardEase,
BuiltinSortKind::CardLapses => SK::CardLapses,
BuiltinSortKind::CardInterval => SK::CardInterval,
BuiltinSortKind::CardDeck => SK::CardDeck,
BuiltinSortKind::CardTemplate => SK::CardTemplate,
},
_ => SortKind::NoteCreation,
}
}

33
rslib/src/card.rs Normal file
View file

@ -0,0 +1,33 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use num_enum::TryFromPrimitive;
use serde_repr::{Deserialize_repr, Serialize_repr};
#[derive(Serialize_repr, Deserialize_repr, Debug, PartialEq, TryFromPrimitive, Clone, Copy)]
#[repr(u8)]
pub enum CardType {
New = 0,
Learn = 1,
Review = 2,
Relearn = 3,
}
#[derive(Serialize_repr, Deserialize_repr, Debug, PartialEq, TryFromPrimitive, Clone, Copy)]
#[repr(i8)]
pub enum CardQueue {
/// due is the order cards are shown in
New = 0,
/// due is a unix timestamp
Learn = 1,
/// due is days since creation date
Review = 2,
DayLearn = 3,
/// due is a unix timestamp.
/// preview cards only placed here when failed.
PreviewRepeat = 4,
/// cards are not due in these states
Suspended = -1,
UserBuried = -2,
SchedBuried = -3,
}

131
rslib/src/collection.rs Normal file
View file

@ -0,0 +1,131 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result};
use crate::i18n::I18n;
use crate::log::Logger;
use crate::storage::{SqliteStorage, StorageContext};
use std::path::PathBuf;
pub fn open_collection<P: Into<PathBuf>>(
path: P,
media_folder: P,
media_db: P,
server: bool,
i18n: I18n,
log: Logger,
) -> Result<Collection> {
let col_path = path.into();
let storage = SqliteStorage::open_or_create(&col_path)?;
let col = Collection {
storage,
col_path,
media_folder: media_folder.into(),
media_db: media_db.into(),
server,
i18n,
log,
state: CollectionState::Normal,
};
Ok(col)
}
#[derive(Debug, PartialEq)]
pub enum CollectionState {
Normal,
// in this state, the DB must not be closed
MediaSyncRunning,
}
pub struct Collection {
pub(crate) storage: SqliteStorage,
#[allow(dead_code)]
pub(crate) col_path: PathBuf,
pub(crate) media_folder: PathBuf,
pub(crate) media_db: PathBuf,
pub(crate) server: bool,
pub(crate) i18n: I18n,
pub(crate) log: Logger,
state: CollectionState,
}
pub(crate) enum CollectionOp {}
pub(crate) struct RequestContext<'a> {
pub storage: StorageContext<'a>,
pub i18n: &'a I18n,
pub log: &'a Logger,
pub should_commit: bool,
}
impl Collection {
/// Call the provided closure with a RequestContext that exists for
/// the duration of the call. The request will cache prepared sql
/// statements, so should be passed down the call tree.
///
/// This function should be used for read-only requests. To mutate
/// the database, use transact() instead.
pub(crate) fn with_ctx<F, R>(&self, func: F) -> Result<R>
where
F: FnOnce(&mut RequestContext) -> Result<R>,
{
let mut ctx = RequestContext {
storage: self.storage.context(self.server),
i18n: &self.i18n,
log: &self.log,
should_commit: true,
};
func(&mut ctx)
}
/// Execute the provided closure in a transaction, rolling back if
/// an error is returned.
pub(crate) fn transact<F, R>(&self, op: Option<CollectionOp>, func: F) -> Result<R>
where
F: FnOnce(&mut RequestContext) -> Result<R>,
{
self.with_ctx(|ctx| {
ctx.storage.begin_rust_trx()?;
let mut res = func(ctx);
if res.is_ok() && ctx.should_commit {
if let Err(e) = ctx.storage.mark_modified() {
res = Err(e);
} else if let Err(e) = ctx.storage.commit_rust_op(op) {
res = Err(e);
}
}
if res.is_err() || !ctx.should_commit {
ctx.storage.rollback_rust_trx()?;
}
res
})
}
pub(crate) fn set_media_sync_running(&mut self) -> Result<()> {
if self.state == CollectionState::Normal {
self.state = CollectionState::MediaSyncRunning;
Ok(())
} else {
Err(AnkiError::invalid_input("media sync already running"))
}
}
pub(crate) fn set_media_sync_finished(&mut self) -> Result<()> {
if self.state == CollectionState::MediaSyncRunning {
self.state = CollectionState::Normal;
Ok(())
} else {
Err(AnkiError::invalid_input("media sync not running"))
}
}
pub(crate) fn can_close(&self) -> bool {
self.state == CollectionState::Normal
}
}

63
rslib/src/config.rs Normal file
View file

@ -0,0 +1,63 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::types::ObjID;
use serde::Deserialize as DeTrait;
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
use serde_json::Value;
pub(crate) fn default_on_invalid<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: Default + DeTrait<'de>,
D: serde::de::Deserializer<'de>,
{
let v: Value = DeTrait::deserialize(deserializer)?;
Ok(T::deserialize(v).unwrap_or_default())
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Config {
#[serde(
rename = "curDeck",
deserialize_with = "deserialize_number_from_string"
)]
pub(crate) current_deck_id: ObjID,
pub(crate) rollover: Option<i8>,
pub(crate) creation_offset: Option<i32>,
pub(crate) local_offset: Option<i32>,
#[serde(rename = "sortType", deserialize_with = "default_on_invalid")]
pub(crate) browser_sort_kind: SortKind,
#[serde(rename = "sortBackwards", deserialize_with = "default_on_invalid")]
pub(crate) browser_sort_reverse: bool,
}
#[derive(Deserialize, PartialEq, Debug)]
#[serde(rename_all = "camelCase")]
pub enum SortKind {
#[serde(rename = "noteCrt")]
NoteCreation,
NoteMod,
#[serde(rename = "noteFld")]
NoteField,
#[serde(rename = "note")]
NoteType,
NoteTags,
CardMod,
CardReps,
CardDue,
CardEase,
CardLapses,
#[serde(rename = "cardIvl")]
CardInterval,
#[serde(rename = "deck")]
CardDeck,
#[serde(rename = "template")]
CardTemplate,
}
impl Default for SortKind {
fn default() -> Self {
Self::NoteCreation
}
}

31
rslib/src/decks.rs Normal file
View file

@ -0,0 +1,31 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::types::ObjID;
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
#[derive(Deserialize)]
pub struct Deck {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub(crate) id: ObjID,
pub(crate) name: String,
}
pub(crate) fn child_ids<'a>(decks: &'a [Deck], name: &str) -> impl Iterator<Item = ObjID> + 'a {
let prefix = format!("{}::", name.to_ascii_lowercase());
decks
.iter()
.filter(move |d| d.name.to_ascii_lowercase().starts_with(&prefix))
.map(|d| d.id)
}
pub(crate) fn get_deck(decks: &[Deck], id: ObjID) -> Option<&Deck> {
for d in decks {
if d.id == id {
return Some(d);
}
}
None
}

View file

@ -20,7 +20,7 @@ pub enum AnkiError {
IOError { info: String },
#[fail(display = "DB error: {}", info)]
DBError { info: String },
DBError { info: String, kind: DBErrorKind },
#[fail(display = "Network error: {:?} {}", kind, info)]
NetworkError {
@ -33,6 +33,12 @@ pub enum AnkiError {
#[fail(display = "The user interrupted the operation.")]
Interrupted,
#[fail(display = "Operation requires an open collection.")]
CollectionNotOpen,
#[fail(display = "Close the existing collection first.")]
CollectionAlreadyOpen,
}
// error helpers
@ -112,6 +118,7 @@ impl From<rusqlite::Error> for AnkiError {
fn from(err: rusqlite::Error) -> Self {
AnkiError::DBError {
info: format!("{:?}", err),
kind: DBErrorKind::Other,
}
}
}
@ -120,6 +127,7 @@ impl From<rusqlite::types::FromSqlError> for AnkiError {
fn from(err: rusqlite::types::FromSqlError) -> Self {
AnkiError::DBError {
info: format!("{:?}", err),
kind: DBErrorKind::Other,
}
}
}
@ -215,3 +223,11 @@ impl From<serde_json::Error> for AnkiError {
AnkiError::sync_misc(err.to_string())
}
}
#[derive(Debug, PartialEq)]
pub enum DBErrorKind {
FileTooNew,
FileTooOld,
MissingEntity,
Other,
}

View file

@ -10,13 +10,21 @@ pub fn version() -> &'static str {
}
pub mod backend;
pub mod card;
pub mod cloze;
pub mod collection;
pub mod config;
pub mod decks;
pub mod err;
pub mod i18n;
pub mod latex;
pub mod log;
pub mod media;
pub mod notes;
pub mod notetypes;
pub mod sched;
pub mod search;
pub mod storage;
pub mod template;
pub mod template_filters;
pub mod text;

View file

@ -1,18 +1,17 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result};
use crate::i18n::{tr_args, tr_strs, FString, I18n};
use crate::collection::RequestContext;
use crate::err::{AnkiError, DBErrorKind, Result};
use crate::i18n::{tr_args, tr_strs, FString};
use crate::latex::extract_latex_expanding_clozes;
use crate::log::{debug, Logger};
use crate::media::col::{
for_every_note, get_note_types, mark_collection_modified, open_or_create_collection_db,
set_note, Note,
};
use crate::log::debug;
use crate::media::database::MediaDatabaseContext;
use crate::media::files::{
data_for_file, filename_if_normalized, trash_folder, MEDIA_SYNC_FILESIZE_LIMIT,
data_for_file, filename_if_normalized, normalize_nfc_filename, trash_folder,
MEDIA_SYNC_FILESIZE_LIMIT,
};
use crate::notes::{for_every_note, set_note, Note};
use crate::text::{normalize_to_nfc, MediaRef};
use crate::{media::MediaManager, text::extract_media_refs};
use coarsetime::Instant;
@ -26,7 +25,7 @@ lazy_static! {
static ref REMOTE_FILENAME: Regex = Regex::new("(?i)^https?://").unwrap();
}
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub struct MediaCheckOutput {
pub unused: Vec<String>,
pub missing: Vec<String>,
@ -45,38 +44,32 @@ struct MediaFolderCheck {
oversize: Vec<String>,
}
pub struct MediaChecker<'a, P>
pub struct MediaChecker<'a, 'b, P>
where
P: FnMut(usize) -> bool,
{
ctx: &'a mut RequestContext<'b>,
mgr: &'a MediaManager,
col_path: &'a Path,
progress_cb: P,
checked: usize,
progress_updated: Instant,
i18n: &'a I18n,
log: &'a Logger,
}
impl<P> MediaChecker<'_, P>
impl<P> MediaChecker<'_, '_, P>
where
P: FnMut(usize) -> bool,
{
pub fn new<'a>(
pub(crate) fn new<'a, 'b>(
ctx: &'a mut RequestContext<'b>,
mgr: &'a MediaManager,
col_path: &'a Path,
progress_cb: P,
i18n: &'a I18n,
log: &'a Logger,
) -> MediaChecker<'a, P> {
) -> MediaChecker<'a, 'b, P> {
MediaChecker {
ctx,
mgr,
col_path,
progress_cb,
checked: 0,
progress_updated: Instant::now(),
i18n,
log,
}
}
@ -100,7 +93,7 @@ where
pub fn summarize_output(&self, output: &mut MediaCheckOutput) -> String {
let mut buf = String::new();
let i = &self.i18n;
let i = &self.ctx.i18n;
// top summary area
if output.trash_count > 0 {
@ -279,7 +272,7 @@ where
}
})?;
let fname = self.mgr.add_file(ctx, disk_fname, &data)?;
debug!(self.log, "renamed"; "from"=>disk_fname, "to"=>&fname.as_ref());
debug!(self.ctx.log, "renamed"; "from"=>disk_fname, "to"=>&fname.as_ref());
assert_ne!(fname.as_ref(), disk_fname);
// remove the original file
@ -373,7 +366,7 @@ where
self.mgr
.add_file(&mut self.mgr.dbctx(), fname.as_ref(), &data)?;
} else {
debug!(self.log, "file disappeared while restoring trash"; "fname"=>fname.as_ref());
debug!(self.ctx.log, "file disappeared while restoring trash"; "fname"=>fname.as_ref());
}
fs::remove_file(dentry.path())?;
}
@ -387,14 +380,11 @@ where
&mut self,
renamed: &HashMap<String, String>,
) -> Result<HashSet<String>> {
let mut db = open_or_create_collection_db(self.col_path)?;
let trx = db.transaction()?;
let mut referenced_files = HashSet::new();
let note_types = get_note_types(&trx)?;
let note_types = self.ctx.storage.all_note_types()?;
let mut collection_modified = false;
for_every_note(&trx, |note| {
for_every_note(&self.ctx.storage.db, |note| {
self.checked += 1;
if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?;
@ -403,10 +393,16 @@ where
.get(&note.mid)
.ok_or_else(|| AnkiError::DBError {
info: "missing note type".to_string(),
kind: DBErrorKind::MissingEntity,
})?;
if fix_and_extract_media_refs(note, &mut referenced_files, renamed)? {
if fix_and_extract_media_refs(
note,
&mut referenced_files,
renamed,
&self.mgr.media_folder,
)? {
// note was modified, needs saving
set_note(&trx, note, nt)?;
set_note(&self.ctx.storage.db, note, nt)?;
collection_modified = true;
}
@ -415,9 +411,8 @@ where
Ok(())
})?;
if collection_modified {
mark_collection_modified(&trx)?;
trx.commit()?;
if !collection_modified {
self.ctx.should_commit = false;
}
Ok(referenced_files)
@ -429,11 +424,17 @@ fn fix_and_extract_media_refs(
note: &mut Note,
seen_files: &mut HashSet<String>,
renamed: &HashMap<String, String>,
media_folder: &Path,
) -> Result<bool> {
let mut updated = false;
for idx in 0..note.fields().len() {
let field = normalize_and_maybe_rename_files(&note.fields()[idx], renamed, seen_files);
let field = normalize_and_maybe_rename_files(
&note.fields()[idx],
renamed,
seen_files,
media_folder,
);
if let Cow::Owned(field) = field {
// field was modified, need to save
note.set_field(idx, field)?;
@ -450,6 +451,7 @@ fn normalize_and_maybe_rename_files<'a>(
field: &'a str,
renamed: &HashMap<String, String>,
seen_files: &mut HashSet<String>,
media_folder: &Path,
) -> Cow<'a, str> {
let refs = extract_media_refs(field);
let mut field: Cow<str> = field.into();
@ -466,7 +468,21 @@ fn normalize_and_maybe_rename_files<'a>(
if let Some(new_name) = renamed.get(fname.as_ref()) {
fname = new_name.to_owned().into();
}
// if it was not in NFC or was renamed, update the field
// if the filename was in NFC and was not renamed as part of the
// media check, it may have already been renamed during a previous
// sync. If that's the case and the renamed version exists on disk,
// we'll need to update the field to match it. It may be possible
// to remove this check in the future once we can be sure all media
// files stored on AnkiWeb are in normalized form.
if matches!(fname, Cow::Borrowed(_)) {
if let Cow::Owned(normname) = normalize_nfc_filename(fname.as_ref().into()) {
let path = media_folder.join(&normname);
if path.exists() {
fname = normname.into();
}
}
}
// update the field if the filename was modified
if let Cow::Owned(ref new_name) = fname {
field = rename_media_ref_in_field(field.as_ref(), &media_ref, new_name).into();
}
@ -510,41 +526,42 @@ fn extract_latex_refs(note: &Note, seen_files: &mut HashSet<String>, svg: bool)
}
#[cfg(test)]
mod test {
pub(crate) mod test {
pub(crate) const MEDIACHECK_ANKI2: &'static [u8] =
include_bytes!("../../tests/support/mediacheck.anki2");
use crate::collection::{open_collection, Collection};
use crate::err::Result;
use crate::i18n::I18n;
use crate::log;
use crate::log::Logger;
use crate::media::check::{MediaCheckOutput, MediaChecker};
use crate::media::files::trash_folder;
use crate::media::MediaManager;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::{fs, io};
use tempfile::{tempdir, TempDir};
fn common_setup() -> Result<(TempDir, MediaManager, PathBuf, Logger, I18n)> {
fn common_setup() -> Result<(TempDir, MediaManager, Collection)> {
let dir = tempdir()?;
let media_dir = dir.path().join("media");
fs::create_dir(&media_dir)?;
let media_db = dir.path().join("media.db");
let col_path = dir.path().join("col.anki2");
fs::write(
&col_path,
&include_bytes!("../../tests/support/mediacheck.anki2")[..],
)?;
fs::write(&col_path, MEDIACHECK_ANKI2)?;
let mgr = MediaManager::new(&media_dir, media_db)?;
let mgr = MediaManager::new(&media_dir, media_db.clone())?;
let log = log::terminal();
let i18n = I18n::new(&["zz"], "dummy", log.clone());
Ok((dir, mgr, col_path, log, i18n))
let col = open_collection(col_path, media_dir, media_db, false, i18n, log)?;
Ok((dir, mgr, col))
}
#[test]
fn media_check() -> Result<()> {
let (_dir, mgr, col_path, log, i18n) = common_setup()?;
let (_dir, mgr, col) = common_setup()?;
// add some test files
fs::write(&mgr.media_folder.join("zerobytes"), "")?;
@ -555,8 +572,13 @@ mod test {
fs::write(&mgr.media_folder.join("unused.jpg"), "foo")?;
let progress = |_n| true;
let mut checker = MediaChecker::new(&mgr, &col_path, progress, &i18n, &log);
let mut output = checker.check()?;
let (output, report) = col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
let output = checker.check()?;
let summary = checker.summarize_output(&mut output.clone());
Ok((output, summary))
})?;
assert_eq!(
output,
@ -576,7 +598,6 @@ mod test {
assert!(fs::metadata(&mgr.media_folder.join("foo[.jpg")).is_err());
assert!(fs::metadata(&mgr.media_folder.join("foo.jpg")).is_ok());
let report = checker.summarize_output(&mut output);
assert_eq!(
report,
"Missing files: 1
@ -616,14 +637,16 @@ Unused: unused.jpg
#[test]
fn trash_handling() -> Result<()> {
let (_dir, mgr, col_path, log, i18n) = common_setup()?;
let (_dir, mgr, col) = common_setup()?;
let trash_folder = trash_folder(&mgr.media_folder)?;
fs::write(trash_folder.join("test.jpg"), "test")?;
let progress = |_n| true;
let mut checker = MediaChecker::new(&mgr, &col_path, progress, &i18n, &log);
checker.restore_trash()?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.restore_trash()
})?;
// file should have been moved to media folder
assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new());
@ -634,7 +657,10 @@ Unused: unused.jpg
// if we repeat the process, restoring should do the same thing if the contents are equal
fs::write(trash_folder.join("test.jpg"), "test")?;
checker.restore_trash()?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.restore_trash()
})?;
assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new());
assert_eq!(
files_in_dir(&mgr.media_folder),
@ -643,7 +669,10 @@ Unused: unused.jpg
// but rename if required
fs::write(trash_folder.join("test.jpg"), "test2")?;
checker.restore_trash()?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.restore_trash()
})?;
assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new());
assert_eq!(
files_in_dir(&mgr.media_folder),
@ -658,13 +687,17 @@ Unused: unused.jpg
#[test]
fn unicode_normalization() -> Result<()> {
let (_dir, mgr, col_path, log, i18n) = common_setup()?;
let (_dir, mgr, col) = common_setup()?;
fs::write(&mgr.media_folder.join("ぱぱ.jpg"), "nfd encoding")?;
let progress = |_n| true;
let mut checker = MediaChecker::new(&mgr, &col_path, progress, &i18n, &log);
let mut output = checker.check()?;
let mut output = col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.check()
})?;
output.missing.sort();
if cfg!(target_vendor = "apple") {

View file

@ -84,7 +84,7 @@ pub(crate) fn normalize_filename(fname: &str) -> Cow<str> {
}
/// See normalize_filename(). This function expects NFC-normalized input.
fn normalize_nfc_filename(mut fname: Cow<str>) -> Cow<str> {
pub(crate) fn normalize_nfc_filename(mut fname: Cow<str>) -> Cow<str> {
if fname.chars().any(disallowed_char) {
fname = fname.replace(disallowed_char, "").into()
}

View file

@ -12,7 +12,6 @@ use std::path::{Path, PathBuf};
pub mod changetracker;
pub mod check;
pub mod col;
pub mod database;
pub mod files;
pub mod sync;

View file

@ -717,6 +717,17 @@ fn zip_files<'a>(
break;
}
#[cfg(target_vendor = "apple")]
{
use unicode_normalization::is_nfc;
if !is_nfc(&file.fname) {
// older Anki versions stored non-normalized filenames in the DB; clean them up
debug!(log, "clean up non-nfc entry"; "fname"=>&file.fname);
invalid_entries.push(&file.fname);
continue;
}
}
let file_data = if file.sha1.is_some() {
match data_for_file(media_folder, &file.fname) {
Ok(data) => data,

View file

@ -1,17 +1,17 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
/// Basic note reading/updating functionality for the media DB check.
use crate::err::{AnkiError, Result};
/// At the moment, this is just basic note reading/updating functionality for
/// the media DB check.
use crate::err::{AnkiError, DBErrorKind, Result};
use crate::text::strip_html_preserving_image_filenames;
use crate::time::{i64_unix_millis, i64_unix_secs};
use crate::types::{ObjID, Timestamp, Usn};
use crate::time::i64_unix_secs;
use crate::{
notetypes::NoteType,
types::{ObjID, Timestamp, Usn},
};
use rusqlite::{params, Connection, Row, NO_PARAMS};
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
use std::collections::HashMap;
use std::convert::TryInto;
use std::path::Path;
#[derive(Debug)]
pub(super) struct Note {
@ -40,55 +40,13 @@ impl Note {
}
}
fn field_checksum(text: &str) -> u32 {
/// Text must be passed to strip_html_preserving_image_filenames() by
/// caller prior to passing in here.
pub(crate) fn field_checksum(text: &str) -> u32 {
let digest = sha1::Sha1::from(text).digest().bytes();
u32::from_be_bytes(digest[..4].try_into().unwrap())
}
pub(super) fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
let db = Connection::open(path)?;
db.pragma_update(None, "locking_mode", &"exclusive")?;
db.pragma_update(None, "page_size", &4096)?;
db.pragma_update(None, "cache_size", &(-40 * 1024))?;
db.pragma_update(None, "legacy_file_format", &false)?;
db.pragma_update(None, "journal", &"wal")?;
db.set_prepared_statement_cache_capacity(5);
Ok(db)
}
#[derive(Deserialize, Debug)]
pub(super) struct NoteType {
#[serde(deserialize_with = "deserialize_number_from_string")]
id: ObjID,
#[serde(rename = "sortf")]
sort_field_idx: u16,
#[serde(rename = "latexsvg", default)]
latex_svg: bool,
}
impl NoteType {
pub fn latex_uses_svg(&self) -> bool {
self.latex_svg
}
}
pub(super) fn get_note_types(db: &Connection) -> Result<HashMap<ObjID, NoteType>> {
let mut stmt = db.prepare("select models from col")?;
let note_types = stmt
.query_and_then(NO_PARAMS, |row| -> Result<HashMap<ObjID, NoteType>> {
let v: HashMap<ObjID, NoteType> = serde_json::from_str(row.get_raw(0).as_str()?)?;
Ok(v)
})?
.next()
.ok_or_else(|| AnkiError::DBError {
info: "col table empty".to_string(),
})??;
Ok(note_types)
}
#[allow(dead_code)]
fn get_note(db: &Connection, nid: ObjID) -> Result<Option<Note>> {
let mut stmt = db.prepare_cached("select id, mid, mod, usn, flds from notes where id=?")?;
@ -130,14 +88,20 @@ pub(super) fn set_note(db: &Connection, note: &mut Note, note_type: &NoteType) -
note.mtime_secs = i64_unix_secs();
// hard-coded for now
note.usn = -1;
let csum = field_checksum(&note.fields()[0]);
let sort_field = strip_html_preserving_image_filenames(
let field1_nohtml = strip_html_preserving_image_filenames(&note.fields()[0]);
let csum = field_checksum(field1_nohtml.as_ref());
let sort_field = if note_type.sort_field_idx == 0 {
field1_nohtml
} else {
strip_html_preserving_image_filenames(
note.fields()
.get(note_type.sort_field_idx as usize)
.ok_or_else(|| AnkiError::DBError {
info: "sort field out of range".to_string(),
kind: DBErrorKind::MissingEntity,
})?,
);
)
};
let mut stmt =
db.prepare_cached("update notes set mod=?,usn=?,flds=?,sfld=?,csum=? where id=?")?;
@ -152,8 +116,3 @@ pub(super) fn set_note(db: &Connection, note: &mut Note, note_type: &NoteType) -
Ok(())
}
pub(super) fn mark_collection_modified(db: &Connection) -> Result<()> {
db.execute("update col set mod=?", params![i64_unix_millis()])?;
Ok(())
}

39
rslib/src/notetypes.rs Normal file
View file

@ -0,0 +1,39 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::types::ObjID;
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
#[derive(Deserialize, Debug)]
pub(crate) struct NoteType {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub id: ObjID,
pub name: String,
#[serde(rename = "sortf")]
pub sort_field_idx: u16,
#[serde(rename = "latexsvg", default)]
pub latex_svg: bool,
#[serde(rename = "tmpls")]
pub templates: Vec<CardTemplate>,
#[serde(rename = "flds")]
pub fields: Vec<NoteField>,
}
#[derive(Deserialize, Debug)]
pub(crate) struct CardTemplate {
pub name: String,
pub ord: u16,
}
#[derive(Deserialize, Debug)]
pub(crate) struct NoteField {
pub name: String,
pub ord: u16,
}
impl NoteType {
pub fn latex_uses_svg(&self) -> bool {
self.latex_svg
}
}

View file

@ -3,6 +3,7 @@
use chrono::{Date, Duration, FixedOffset, Local, TimeZone};
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct SchedTimingToday {
/// The number of days that have passed since the collection was created.
pub days_elapsed: u32,
@ -17,7 +18,7 @@ pub struct SchedTimingToday {
/// - now_secs is a timestamp of the current time
/// - now_mins_west is the current offset west of UTC
/// - rollover_hour is the hour of the day the rollover happens (eg 4 for 4am)
pub fn sched_timing_today(
pub fn sched_timing_today_v2_new(
created_secs: i64,
created_mins_west: i32,
now_secs: i64,
@ -90,11 +91,84 @@ pub fn local_minutes_west_for_stamp(stamp: i64) -> i32 {
Local.timestamp(stamp, 0).offset().utc_minus_local() / 60
}
// Legacy code
// ----------------------------------
fn sched_timing_today_v1(crt: i64, now: i64) -> SchedTimingToday {
let days_elapsed = (now - crt) / 86_400;
let next_day_at = crt + (days_elapsed + 1) * 86_400;
SchedTimingToday {
days_elapsed: days_elapsed as u32,
next_day_at,
}
}
fn sched_timing_today_v2_legacy(
crt: i64,
rollover: i8,
now: i64,
mins_west: i32,
) -> SchedTimingToday {
let normalized_rollover = normalized_rollover_hour(rollover);
let offset = fixed_offset_from_minutes(mins_west);
let crt_at_rollover = offset
.timestamp(crt, 0)
.date()
.and_hms(normalized_rollover as u32, 0, 0)
.timestamp();
let days_elapsed = (now - crt_at_rollover) / 86_400;
let mut next_day_at = offset
.timestamp(now, 0)
.date()
.and_hms(normalized_rollover as u32, 0, 0)
.timestamp();
if next_day_at < now {
next_day_at += 86_400;
}
SchedTimingToday {
days_elapsed: days_elapsed as u32,
next_day_at,
}
}
// ----------------------------------
/// Based on provided input, get timing info from the relevant function.
pub(crate) fn sched_timing_today(
created_secs: i64,
now_secs: i64,
created_mins_west: Option<i32>,
now_mins_west: Option<i32>,
rollover_hour: Option<i8>,
) -> SchedTimingToday {
let now_west = now_mins_west.unwrap_or_else(|| local_minutes_west_for_stamp(now_secs));
match (rollover_hour, created_mins_west) {
(None, _) => {
// if rollover unset, v1 scheduler
sched_timing_today_v1(created_secs, now_secs)
}
(Some(roll), None) => {
// if creationOffset unset, v2 scheduler with legacy cutoff handling
sched_timing_today_v2_legacy(created_secs, roll, now_secs, now_west)
}
(Some(roll), Some(crt_west)) => {
// v2 scheduler, new cutoff handling
sched_timing_today_v2_new(created_secs, crt_west, now_secs, now_west, roll)
}
}
}
#[cfg(test)]
mod test {
use super::SchedTimingToday;
use crate::sched::cutoff::sched_timing_today_v1;
use crate::sched::cutoff::sched_timing_today_v2_legacy;
use crate::sched::cutoff::{
fixed_offset_from_minutes, local_minutes_west_for_stamp, normalized_rollover_hour,
sched_timing_today,
sched_timing_today_v2_new,
};
use chrono::{FixedOffset, Local, TimeZone, Utc};
@ -117,7 +191,7 @@ mod test {
// helper
fn elap(start: i64, end: i64, start_west: i32, end_west: i32, rollhour: i8) -> u32 {
let today = sched_timing_today(start, start_west, end, end_west, rollhour);
let today = sched_timing_today_v2_new(start, start_west, end, end_west, rollhour);
today.days_elapsed
}
@ -228,7 +302,7 @@ mod test {
// before the rollover, the next day should be later on the same day
let now = Local.ymd(2019, 1, 3).and_hms(2, 0, 0);
let next_day_at = Local.ymd(2019, 1, 3).and_hms(rollhour, 0, 0);
let today = sched_timing_today(
let today = sched_timing_today_v2_new(
crt.timestamp(),
crt.offset().utc_minus_local() / 60,
now.timestamp(),
@ -240,7 +314,7 @@ mod test {
// after the rollover, the next day should be the next day
let now = Local.ymd(2019, 1, 3).and_hms(rollhour, 0, 0);
let next_day_at = Local.ymd(2019, 1, 4).and_hms(rollhour, 0, 0);
let today = sched_timing_today(
let today = sched_timing_today_v2_new(
crt.timestamp(),
crt.offset().utc_minus_local() / 60,
now.timestamp(),
@ -252,7 +326,7 @@ mod test {
// after the rollover, the next day should be the next day
let now = Local.ymd(2019, 1, 3).and_hms(rollhour + 3, 0, 0);
let next_day_at = Local.ymd(2019, 1, 4).and_hms(rollhour, 0, 0);
let today = sched_timing_today(
let today = sched_timing_today_v2_new(
crt.timestamp(),
crt.offset().utc_minus_local() / 60,
now.timestamp(),
@ -261,4 +335,34 @@ mod test {
);
assert_eq!(today.next_day_at, next_day_at.timestamp());
}
#[test]
fn legacy_timing() {
let now = 1584491078;
let mins_west = -600;
assert_eq!(
sched_timing_today_v1(1575226800, now),
SchedTimingToday {
days_elapsed: 107,
next_day_at: 1584558000
}
);
assert_eq!(
sched_timing_today_v2_legacy(1533564000, 0, now, mins_west),
SchedTimingToday {
days_elapsed: 589,
next_day_at: 1584540000
}
);
assert_eq!(
sched_timing_today_v2_legacy(1524038400, 4, now, mins_west),
SchedTimingToday {
days_elapsed: 700,
next_day_at: 1584554400
}
);
}
}

154
rslib/src/search/cards.rs Normal file
View file

@ -0,0 +1,154 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::{parser::Node, sqlwriter::node_to_sql};
use crate::card::CardType;
use crate::collection::RequestContext;
use crate::config::SortKind;
use crate::err::Result;
use crate::search::parser::parse;
use crate::types::ObjID;
use rusqlite::params;
pub(crate) enum SortMode {
NoOrder,
FromConfig,
Builtin { kind: SortKind, reverse: bool },
Custom(String),
}
pub(crate) fn search_cards<'a, 'b>(
req: &'a mut RequestContext<'b>,
search: &'a str,
order: SortMode,
) -> Result<Vec<ObjID>> {
let top_node = Node::Group(parse(search)?);
let (sql, args) = node_to_sql(req, &top_node)?;
let mut sql = format!(
"select c.id from cards c, notes n where c.nid=n.id and {}",
sql
);
match order {
SortMode::NoOrder => (),
SortMode::FromConfig => {
let conf = req.storage.all_config()?;
prepare_sort(req, &conf.browser_sort_kind)?;
sql.push_str(" order by ");
write_order(&mut sql, &conf.browser_sort_kind, conf.browser_sort_reverse)?;
}
SortMode::Builtin { kind, reverse } => {
prepare_sort(req, &kind)?;
sql.push_str(" order by ");
write_order(&mut sql, &kind, reverse)?;
}
SortMode::Custom(order_clause) => {
sql.push_str(" order by ");
sql.push_str(&order_clause);
}
}
let mut stmt = req.storage.db.prepare(&sql)?;
let ids: Vec<i64> = stmt
.query_map(&args, |row| row.get(0))?
.collect::<std::result::Result<_, _>>()?;
Ok(ids)
}
/// Add the order clause to the sql.
fn write_order(sql: &mut String, kind: &SortKind, reverse: bool) -> Result<()> {
let tmp_str;
let order = match kind {
SortKind::NoteCreation => "n.id asc, c.ord asc",
SortKind::NoteMod => "n.mod asc, c.ord asc",
SortKind::NoteField => "n.sfld collate nocase asc, c.ord asc",
SortKind::CardMod => "c.mod asc",
SortKind::CardReps => "c.reps asc",
SortKind::CardDue => "c.type asc, c.due asc",
SortKind::CardEase => {
tmp_str = format!("c.type = {} asc, c.factor asc", CardType::New as i8);
&tmp_str
}
SortKind::CardLapses => "c.lapses asc",
SortKind::CardInterval => "c.ivl asc",
SortKind::NoteTags => "n.tags asc",
SortKind::CardDeck => "(select v from sort_order where k = c.did) asc",
SortKind::NoteType => "(select v from sort_order where k = n.mid) asc",
SortKind::CardTemplate => "(select v from sort_order where k1 = n.mid and k2 = c.ord) asc",
};
if order.is_empty() {
return Ok(());
}
if reverse {
sql.push_str(
&order
.to_ascii_lowercase()
.replace(" desc", "")
.replace(" asc", " desc"),
)
} else {
sql.push_str(order);
}
Ok(())
}
// In the future these items should be moved from JSON into separate SQL tables,
// - for now we use a temporary deck to sort them.
fn prepare_sort(req: &mut RequestContext, kind: &SortKind) -> Result<()> {
use SortKind::*;
match kind {
CardDeck | NoteType => {
prepare_sort_order_table(req)?;
let mut stmt = req
.storage
.db
.prepare("insert into sort_order (k,v) values (?,?)")?;
match kind {
CardDeck => {
for (k, v) in req.storage.all_decks()? {
stmt.execute(params![k, v.name])?;
}
}
NoteType => {
for (k, v) in req.storage.all_note_types()? {
stmt.execute(params![k, v.name])?;
}
}
_ => unreachable!(),
}
}
CardTemplate => {
prepare_sort_order_table2(req)?;
let mut stmt = req
.storage
.db
.prepare("insert into sort_order (k1,k2,v) values (?,?,?)")?;
for (ntid, nt) in req.storage.all_note_types()? {
for tmpl in nt.templates {
stmt.execute(params![ntid, tmpl.ord, tmpl.name])?;
}
}
}
_ => (),
}
Ok(())
}
fn prepare_sort_order_table(req: &mut RequestContext) -> Result<()> {
req.storage
.db
.execute_batch(include_str!("sort_order.sql"))?;
Ok(())
}
fn prepare_sort_order_table2(req: &mut RequestContext) -> Result<()> {
req.storage
.db
.execute_batch(include_str!("sort_order2.sql"))?;
Ok(())
}

7
rslib/src/search/mod.rs Normal file
View file

@ -0,0 +1,7 @@
mod cards;
mod notes;
mod parser;
mod sqlwriter;
pub(crate) use cards::{search_cards, SortMode};
pub(crate) use notes::search_notes;

28
rslib/src/search/notes.rs Normal file
View file

@ -0,0 +1,28 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::{parser::Node, sqlwriter::node_to_sql};
use crate::collection::RequestContext;
use crate::err::Result;
use crate::search::parser::parse;
use crate::types::ObjID;
pub(crate) fn search_notes<'a, 'b>(
req: &'a mut RequestContext<'b>,
search: &'a str,
) -> Result<Vec<ObjID>> {
let top_node = Node::Group(parse(search)?);
let (sql, args) = node_to_sql(req, &top_node)?;
let sql = format!(
"select n.id from cards c, notes n where c.nid=n.id and {}",
sql
);
let mut stmt = req.storage.db.prepare(&sql)?;
let ids: Vec<i64> = stmt
.query_map(&args, |row| row.get(0))?
.collect::<std::result::Result<_, _>>()?;
Ok(ids)
}

528
rslib/src/search/parser.rs Normal file
View file

@ -0,0 +1,528 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result};
use crate::types::ObjID;
use nom::branch::alt;
use nom::bytes::complete::{escaped, is_not, tag, take_while1};
use nom::character::complete::{anychar, char, one_of};
use nom::character::is_digit;
use nom::combinator::{all_consuming, map, map_res};
use nom::sequence::{delimited, preceded, tuple};
use nom::{multi::many0, IResult};
use std::{borrow::Cow, num};
// fixme: need to preserve \ when used twice in string
struct ParseError {}
impl From<num::ParseIntError> for ParseError {
fn from(_: num::ParseIntError) -> Self {
ParseError {}
}
}
impl From<num::ParseFloatError> for ParseError {
fn from(_: num::ParseFloatError) -> Self {
ParseError {}
}
}
impl<I> From<nom::Err<(I, nom::error::ErrorKind)>> for ParseError {
fn from(_: nom::Err<(I, nom::error::ErrorKind)>) -> Self {
ParseError {}
}
}
type ParseResult<T> = std::result::Result<T, ParseError>;
#[derive(Debug, PartialEq)]
pub(super) enum Node<'a> {
And,
Or,
Not(Box<Node<'a>>),
Group(Vec<Node<'a>>),
Search(SearchNode<'a>),
}
#[derive(Debug, PartialEq)]
pub(super) enum SearchNode<'a> {
// text without a colon
UnqualifiedText(Cow<'a, str>),
// foo:bar, where foo doesn't match a term below
SingleField {
field: Cow<'a, str>,
text: Cow<'a, str>,
is_re: bool,
},
AddedInDays(u32),
CardTemplate(TemplateKind),
Deck(Cow<'a, str>),
NoteTypeID(ObjID),
NoteType(Cow<'a, str>),
Rated {
days: u32,
ease: Option<u8>,
},
Tag(Cow<'a, str>),
Duplicates {
note_type_id: ObjID,
text: String,
},
State(StateKind),
Flag(u8),
NoteIDs(Cow<'a, str>),
CardIDs(Cow<'a, str>),
Property {
operator: String,
kind: PropertyKind,
},
WholeCollection,
Regex(Cow<'a, str>),
NoCombining(Cow<'a, str>),
}
#[derive(Debug, PartialEq)]
pub(super) enum PropertyKind {
Due(i32),
Interval(u32),
Reps(u32),
Lapses(u32),
Ease(f32),
}
#[derive(Debug, PartialEq)]
pub(super) enum StateKind {
New,
Review,
Learning,
Due,
Buried,
Suspended,
}
#[derive(Debug, PartialEq)]
pub(super) enum TemplateKind {
Ordinal(u16),
Name(String),
}
/// Parse the input string into a list of nodes.
#[allow(dead_code)]
pub(super) fn parse(input: &str) -> Result<Vec<Node>> {
let input = input.trim();
if input.is_empty() {
return Ok(vec![Node::Search(SearchNode::WholeCollection)]);
}
let (_, nodes) = all_consuming(group_inner)(input)
.map_err(|_e| AnkiError::invalid_input("unable to parse search"))?;
Ok(nodes)
}
/// One or more nodes surrounded by brackets, eg (one OR two)
fn group(s: &str) -> IResult<&str, Node> {
map(delimited(char('('), group_inner, char(')')), |nodes| {
Node::Group(nodes)
})(s)
}
/// One or more nodes inside brackets, er 'one OR two -three'
fn group_inner(input: &str) -> IResult<&str, Vec<Node>> {
let mut remaining = input;
let mut nodes = vec![];
loop {
match node(remaining) {
Ok((rem, node)) => {
remaining = rem;
if nodes.len() % 2 == 0 {
// before adding the node, if the length is even then the node
// must not be a boolean
if matches!(node, Node::And | Node::Or) {
return Err(nom::Err::Failure(("", nom::error::ErrorKind::NoneOf)));
}
} else {
// if the length is odd, the next item must be a boolean. if it's
// not, add an implicit and
if !matches!(node, Node::And | Node::Or) {
nodes.push(Node::And);
}
}
nodes.push(node);
}
Err(e) => match e {
nom::Err::Error(_) => break,
_ => return Err(e),
},
};
}
if nodes.is_empty() {
Err(nom::Err::Error((remaining, nom::error::ErrorKind::Many1)))
} else {
// chomp any trailing whitespace
let (remaining, _) = whitespace0(remaining)?;
Ok((remaining, nodes))
}
}
fn whitespace0(s: &str) -> IResult<&str, Vec<char>> {
many0(one_of(" \u{3000}"))(s)
}
/// Optional leading space, then a (negated) group or text
fn node(s: &str) -> IResult<&str, Node> {
preceded(whitespace0, alt((negated_node, group, text)))(s)
}
fn negated_node(s: &str) -> IResult<&str, Node> {
map(preceded(char('-'), alt((group, text))), |node| {
Node::Not(Box::new(node))
})(s)
}
/// Either quoted or unquoted text
fn text(s: &str) -> IResult<&str, Node> {
alt((quoted_term, partially_quoted_term, unquoted_term))(s)
}
/// Determine if text is a qualified search, and handle escaped chars.
fn search_node_for_text(s: &str) -> ParseResult<SearchNode> {
let mut it = s.splitn(2, ':');
let (head, tail) = (
unescape_quotes(it.next().unwrap()),
it.next().map(unescape_quotes),
);
if let Some(tail) = tail {
search_node_for_text_with_argument(head, tail)
} else {
Ok(SearchNode::UnqualifiedText(head))
}
}
/// \" -> "
fn unescape_quotes(s: &str) -> Cow<str> {
if s.find(r#"\""#).is_some() {
s.replace(r#"\""#, "\"").into()
} else {
s.into()
}
}
/// Unquoted text, terminated by a space or )
fn unquoted_term(s: &str) -> IResult<&str, Node> {
map_res(
take_while1(|c| c != ' ' && c != ')' && c != '"'),
|text: &str| -> ParseResult<Node> {
Ok(if text.eq_ignore_ascii_case("or") {
Node::Or
} else if text.eq_ignore_ascii_case("and") {
Node::And
} else {
Node::Search(search_node_for_text(text)?)
})
},
)(s)
}
/// Quoted text, including the outer double quotes.
fn quoted_term(s: &str) -> IResult<&str, Node> {
map_res(quoted_term_str, |o| -> ParseResult<Node> {
Ok(Node::Search(search_node_for_text(o)?))
})(s)
}
fn quoted_term_str(s: &str) -> IResult<&str, &str> {
delimited(char('"'), quoted_term_inner, char('"'))(s)
}
/// Quoted text, terminated by a non-escaped double quote
fn quoted_term_inner(s: &str) -> IResult<&str, &str> {
escaped(is_not(r#""\"#), '\\', anychar)(s)
}
/// eg deck:"foo bar" - quotes must come after the :
fn partially_quoted_term(s: &str) -> IResult<&str, Node> {
let term = take_while1(|c| c != ' ' && c != ')' && c != ':');
let (s, (term, _, quoted_val)) = tuple((term, char(':'), quoted_term_str))(s)?;
match search_node_for_text_with_argument(term.into(), quoted_val.into()) {
Ok(search) => Ok((s, Node::Search(search))),
Err(_) => Err(nom::Err::Failure((s, nom::error::ErrorKind::NoneOf))),
}
}
/// Convert a colon-separated key/val pair into the relevant search type.
fn search_node_for_text_with_argument<'a>(
key: Cow<'a, str>,
val: Cow<'a, str>,
) -> ParseResult<SearchNode<'a>> {
Ok(match key.to_ascii_lowercase().as_str() {
"added" => SearchNode::AddedInDays(val.parse()?),
"deck" => SearchNode::Deck(val),
"note" => SearchNode::NoteType(val),
"tag" => SearchNode::Tag(val),
"mid" => SearchNode::NoteTypeID(val.parse()?),
"nid" => SearchNode::NoteIDs(check_id_list(val)?),
"cid" => SearchNode::CardIDs(check_id_list(val)?),
"card" => parse_template(val.as_ref()),
"is" => parse_state(val.as_ref())?,
"flag" => parse_flag(val.as_ref())?,
"rated" => parse_rated(val.as_ref())?,
"dupes" => parse_dupes(val.as_ref())?,
"prop" => parse_prop(val.as_ref())?,
"re" => SearchNode::Regex(val),
"nc" => SearchNode::NoCombining(val),
// anything else is a field search
_ => parse_single_field(key.as_ref(), val.as_ref()),
})
}
/// ensure a list of ids contains only numbers and commas, returning unchanged if true
/// used by nid: and cid:
fn check_id_list(s: Cow<str>) -> ParseResult<Cow<str>> {
if s.is_empty() || s.as_bytes().iter().any(|&c| !is_digit(c) && c != b',') {
Err(ParseError {})
} else {
Ok(s)
}
}
/// eg is:due
fn parse_state(s: &str) -> ParseResult<SearchNode<'static>> {
use StateKind::*;
Ok(SearchNode::State(match s {
"new" => New,
"review" => Review,
"learn" => Learning,
"due" => Due,
"buried" => Buried,
"suspended" => Suspended,
_ => return Err(ParseError {}),
}))
}
/// flag:0-4
fn parse_flag(s: &str) -> ParseResult<SearchNode<'static>> {
let n: u8 = s.parse()?;
if n > 4 {
Err(ParseError {})
} else {
Ok(SearchNode::Flag(n))
}
}
/// eg rated:3 or rated:10:2
/// second arg must be between 0-4
fn parse_rated(val: &str) -> ParseResult<SearchNode<'static>> {
let mut it = val.splitn(2, ':');
let days = it.next().unwrap().parse()?;
let ease = match it.next() {
Some(v) => {
let n: u8 = v.parse()?;
if n < 5 {
Some(n)
} else {
return Err(ParseError {});
}
}
None => None,
};
Ok(SearchNode::Rated { days, ease })
}
/// eg dupes:1231,hello
fn parse_dupes(val: &str) -> ParseResult<SearchNode<'static>> {
let mut it = val.splitn(2, ',');
let mid: ObjID = it.next().unwrap().parse()?;
let text = it.next().ok_or(ParseError {})?;
Ok(SearchNode::Duplicates {
note_type_id: mid,
text: text.into(),
})
}
/// eg prop:ivl>3, prop:ease!=2.5
fn parse_prop(val: &str) -> ParseResult<SearchNode<'static>> {
let (val, key) = alt((
tag("ivl"),
tag("due"),
tag("reps"),
tag("lapses"),
tag("ease"),
))(val)?;
let (val, operator) = alt((
tag("<="),
tag(">="),
tag("!="),
tag("="),
tag("<"),
tag(">"),
))(val)?;
let kind = if key == "ease" {
let num: f32 = val.parse()?;
PropertyKind::Ease(num)
} else if key == "due" {
let num: i32 = val.parse()?;
PropertyKind::Due(num)
} else {
let num: u32 = val.parse()?;
match key {
"ivl" => PropertyKind::Interval(num),
"reps" => PropertyKind::Reps(num),
"lapses" => PropertyKind::Lapses(num),
_ => unreachable!(),
}
};
Ok(SearchNode::Property {
operator: operator.to_string(),
kind,
})
}
fn parse_template(val: &str) -> SearchNode<'static> {
SearchNode::CardTemplate(match val.parse::<u16>() {
Ok(n) => TemplateKind::Ordinal(n.max(1) - 1),
Err(_) => TemplateKind::Name(val.into()),
})
}
fn parse_single_field(key: &str, mut val: &str) -> SearchNode<'static> {
let is_re = if val.starts_with("re:") {
val = val.trim_start_matches("re:");
true
} else {
false
};
SearchNode::SingleField {
field: key.to_string().into(),
text: val.to_string().into(),
is_re,
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parsing() -> Result<()> {
use Node::*;
use SearchNode::*;
assert_eq!(parse("")?, vec![Search(SearchNode::WholeCollection)]);
assert_eq!(parse(" ")?, vec![Search(SearchNode::WholeCollection)]);
// leading/trailing/interspersed whitespace
assert_eq!(
parse(" t t2 ")?,
vec![
Search(UnqualifiedText("t".into())),
And,
Search(UnqualifiedText("t2".into()))
]
);
// including in groups
assert_eq!(
parse("( t t2 )")?,
vec![Group(vec![
Search(UnqualifiedText("t".into())),
And,
Search(UnqualifiedText("t2".into()))
])]
);
assert_eq!(
parse(r#"hello -(world and "foo:bar baz") OR test"#)?,
vec![
Search(UnqualifiedText("hello".into())),
And,
Not(Box::new(Group(vec![
Search(UnqualifiedText("world".into())),
And,
Search(SingleField {
field: "foo".into(),
text: "bar baz".into(),
is_re: false,
})
]))),
Or,
Search(UnqualifiedText("test".into()))
]
);
assert_eq!(
parse("foo:re:bar")?,
vec![Search(SingleField {
field: "foo".into(),
text: "bar".into(),
is_re: true
})]
);
// any character should be escapable in quotes
assert_eq!(
parse(r#""re:\btest""#)?,
vec![Search(Regex(r"\btest".into()))]
);
assert_eq!(parse("added:3")?, vec![Search(AddedInDays(3))]);
assert_eq!(
parse("card:front")?,
vec![Search(CardTemplate(TemplateKind::Name("front".into())))]
);
assert_eq!(
parse("card:3")?,
vec![Search(CardTemplate(TemplateKind::Ordinal(2)))]
);
// 0 must not cause a crash due to underflow
assert_eq!(
parse("card:0")?,
vec![Search(CardTemplate(TemplateKind::Ordinal(0)))]
);
assert_eq!(parse("deck:default")?, vec![Search(Deck("default".into()))]);
assert_eq!(
parse("deck:\"default one\"")?,
vec![Search(Deck("default one".into()))]
);
assert_eq!(parse("note:basic")?, vec![Search(NoteType("basic".into()))]);
assert_eq!(parse("tag:hard")?, vec![Search(Tag("hard".into()))]);
assert_eq!(
parse("nid:1237123712,2,3")?,
vec![Search(NoteIDs("1237123712,2,3".into()))]
);
assert!(parse("nid:1237123712_2,3").is_err());
assert_eq!(parse("is:due")?, vec![Search(State(StateKind::Due))]);
assert_eq!(parse("flag:3")?, vec![Search(Flag(3))]);
assert!(parse("flag:-1").is_err());
assert!(parse("flag:5").is_err());
assert_eq!(
parse("prop:ivl>3")?,
vec![Search(Property {
operator: ">".into(),
kind: PropertyKind::Interval(3)
})]
);
assert!(parse("prop:ivl>3.3").is_err());
assert_eq!(
parse("prop:ease<=3.3")?,
vec![Search(Property {
operator: "<=".into(),
kind: PropertyKind::Ease(3.3)
})]
);
Ok(())
}
}

View file

@ -0,0 +1,2 @@
drop table if exists sort_order;
create temporary table sort_order (k int primary key, v text);

View file

@ -0,0 +1,2 @@
drop table if exists sort_order;
create temporary table sort_order (k1 int, k2 int, v text, primary key (k1, k2)) without rowid;

View file

@ -0,0 +1,583 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::parser::{Node, PropertyKind, SearchNode, StateKind, TemplateKind};
use crate::card::CardQueue;
use crate::decks::child_ids;
use crate::decks::get_deck;
use crate::err::{AnkiError, Result};
use crate::notes::field_checksum;
use crate::text::matches_wildcard;
use crate::text::without_combining;
use crate::{
collection::RequestContext, text::strip_html_preserving_image_filenames, types::ObjID,
};
use std::fmt::Write;
struct SqlWriter<'a, 'b> {
req: &'a mut RequestContext<'b>,
sql: String,
args: Vec<String>,
}
pub(super) fn node_to_sql(req: &mut RequestContext, node: &Node) -> Result<(String, Vec<String>)> {
let mut sctx = SqlWriter::new(req);
sctx.write_node_to_sql(&node)?;
Ok((sctx.sql, sctx.args))
}
impl SqlWriter<'_, '_> {
fn new<'a, 'b>(req: &'a mut RequestContext<'b>) -> SqlWriter<'a, 'b> {
let sql = String::new();
let args = vec![];
SqlWriter { req, sql, args }
}
fn write_node_to_sql(&mut self, node: &Node) -> Result<()> {
match node {
Node::And => write!(self.sql, " and ").unwrap(),
Node::Or => write!(self.sql, " or ").unwrap(),
Node::Not(node) => {
write!(self.sql, "not ").unwrap();
self.write_node_to_sql(node)?;
}
Node::Group(nodes) => {
write!(self.sql, "(").unwrap();
for node in nodes {
self.write_node_to_sql(node)?;
}
write!(self.sql, ")").unwrap();
}
Node::Search(search) => self.write_search_node_to_sql(search)?,
};
Ok(())
}
fn write_search_node_to_sql(&mut self, node: &SearchNode) -> Result<()> {
match node {
SearchNode::UnqualifiedText(text) => self.write_unqualified(text),
SearchNode::SingleField { field, text, is_re } => {
self.write_single_field(field.as_ref(), text.as_ref(), *is_re)?
}
SearchNode::AddedInDays(days) => self.write_added(*days)?,
SearchNode::CardTemplate(template) => self.write_template(template)?,
SearchNode::Deck(deck) => self.write_deck(deck.as_ref())?,
SearchNode::NoteTypeID(ntid) => {
write!(self.sql, "n.mid = {}", ntid).unwrap();
}
SearchNode::NoteType(notetype) => self.write_note_type(notetype.as_ref())?,
SearchNode::Rated { days, ease } => self.write_rated(*days, *ease)?,
SearchNode::Tag(tag) => self.write_tag(tag),
SearchNode::Duplicates { note_type_id, text } => self.write_dupes(*note_type_id, text),
SearchNode::State(state) => self.write_state(state)?,
SearchNode::Flag(flag) => {
write!(self.sql, "(c.flags & 7) == {}", flag).unwrap();
}
SearchNode::NoteIDs(nids) => {
write!(self.sql, "n.id in ({})", nids).unwrap();
}
SearchNode::CardIDs(cids) => {
write!(self.sql, "c.id in ({})", cids).unwrap();
}
SearchNode::Property { operator, kind } => self.write_prop(operator, kind)?,
SearchNode::WholeCollection => write!(self.sql, "true").unwrap(),
SearchNode::Regex(re) => self.write_regex(re.as_ref()),
SearchNode::NoCombining(text) => self.write_no_combining(text.as_ref()),
};
Ok(())
}
fn write_unqualified(&mut self, text: &str) {
// implicitly wrap in %
let text = format!("%{}%", text);
self.args.push(text);
write!(
self.sql,
"(n.sfld like ?{n} escape '\\' or n.flds like ?{n} escape '\\')",
n = self.args.len(),
)
.unwrap();
}
fn write_no_combining(&mut self, text: &str) {
let text = format!("%{}%", without_combining(text));
self.args.push(text);
write!(
self.sql,
concat!(
"(coalesce(without_combining(cast(n.sfld as text)), n.sfld) like ?{n} escape '\\' ",
"or coalesce(without_combining(n.flds), n.flds) like ?{n} escape '\\')"
),
n = self.args.len(),
)
.unwrap();
}
fn write_tag(&mut self, text: &str) {
match text {
"none" => {
write!(self.sql, "n.tags = ''").unwrap();
}
"*" | "%" => {
write!(self.sql, "true").unwrap();
}
text => {
let tag = format!("% {} %", text.replace('*', "%"));
write!(self.sql, "n.tags like ? escape '\\'").unwrap();
self.args.push(tag);
}
}
}
fn write_rated(&mut self, days: u32, ease: Option<u8>) -> Result<()> {
let today_cutoff = self.req.storage.timing_today()?.next_day_at;
let days = days.min(365) as i64;
let target_cutoff_ms = (today_cutoff - 86_400 * days) * 1_000;
write!(
self.sql,
"c.id in (select cid from revlog where id>{}",
target_cutoff_ms
)
.unwrap();
if let Some(ease) = ease {
write!(self.sql, " and ease={})", ease).unwrap();
} else {
write!(self.sql, ")").unwrap();
}
Ok(())
}
fn write_prop(&mut self, op: &str, kind: &PropertyKind) -> Result<()> {
let timing = self.req.storage.timing_today()?;
match kind {
PropertyKind::Due(days) => {
let day = days + (timing.days_elapsed as i32);
write!(
self.sql,
"(c.queue in ({rev},{daylrn}) and due {op} {day})",
rev = CardQueue::Review as u8,
daylrn = CardQueue::DayLearn as u8,
op = op,
day = day
)
}
PropertyKind::Interval(ivl) => write!(self.sql, "ivl {} {}", op, ivl),
PropertyKind::Reps(reps) => write!(self.sql, "reps {} {}", op, reps),
PropertyKind::Lapses(days) => write!(self.sql, "lapses {} {}", op, days),
PropertyKind::Ease(ease) => {
write!(self.sql, "factor {} {}", op, (ease * 1000.0) as u32)
}
}
.unwrap();
Ok(())
}
fn write_state(&mut self, state: &StateKind) -> Result<()> {
let timing = self.req.storage.timing_today()?;
match state {
StateKind::New => write!(self.sql, "c.type = {}", CardQueue::New as i8),
StateKind::Review => write!(self.sql, "c.type = {}", CardQueue::Review as i8),
StateKind::Learning => write!(
self.sql,
"c.queue in ({},{})",
CardQueue::Learn as i8,
CardQueue::DayLearn as i8
),
StateKind::Buried => write!(
self.sql,
"c.queue in ({},{})",
CardQueue::SchedBuried as i8,
CardQueue::UserBuried as i8
),
StateKind::Suspended => write!(self.sql, "c.queue = {}", CardQueue::Suspended as i8),
StateKind::Due => write!(
self.sql,
"
(c.queue in ({rev},{daylrn}) and c.due <= {today}) or
(c.queue = {lrn} and c.due <= {daycutoff})",
rev = CardQueue::Review as i8,
daylrn = CardQueue::DayLearn as i8,
today = timing.days_elapsed,
lrn = CardQueue::Learn as i8,
daycutoff = timing.next_day_at,
),
}
.unwrap();
Ok(())
}
fn write_deck(&mut self, deck: &str) -> Result<()> {
match deck {
"*" => write!(self.sql, "true").unwrap(),
"filtered" => write!(self.sql, "c.odid > 0").unwrap(),
deck => {
let all_decks: Vec<_> = self
.req
.storage
.all_decks()?
.into_iter()
.map(|(_, v)| v)
.collect();
let dids_with_children = if deck == "current" {
let config = self.req.storage.all_config()?;
let mut dids_with_children = vec![config.current_deck_id];
let current = get_deck(&all_decks, config.current_deck_id)
.ok_or_else(|| AnkiError::invalid_input("invalid current deck"))?;
for child_did in child_ids(&all_decks, &current.name) {
dids_with_children.push(child_did);
}
dids_with_children
} else {
let mut dids_with_children = vec![];
for deck in all_decks.iter().filter(|d| matches_wildcard(&d.name, deck)) {
dids_with_children.push(deck.id);
for child_id in child_ids(&all_decks, &deck.name) {
dids_with_children.push(child_id);
}
}
dids_with_children
};
self.sql.push_str("c.did in ");
ids_to_string(&mut self.sql, &dids_with_children);
}
};
Ok(())
}
fn write_template(&mut self, template: &TemplateKind) -> Result<()> {
match template {
TemplateKind::Ordinal(n) => {
write!(self.sql, "c.ord = {}", n).unwrap();
}
TemplateKind::Name(name) => {
let note_types = self.req.storage.all_note_types()?;
let mut id_ords = vec![];
for nt in note_types.values() {
for tmpl in &nt.templates {
if matches_wildcard(&tmpl.name, name) {
id_ords.push((nt.id, tmpl.ord));
}
}
}
// sort for the benefit of unit tests
id_ords.sort();
if id_ords.is_empty() {
self.sql.push_str("false");
} else {
let v: Vec<_> = id_ords
.iter()
.map(|(ntid, ord)| format!("(n.mid = {} and c.ord = {})", ntid, ord))
.collect();
write!(self.sql, "({})", v.join(" or ")).unwrap();
}
}
};
Ok(())
}
fn write_note_type(&mut self, nt_name: &str) -> Result<()> {
let mut ntids: Vec<_> = self
.req
.storage
.all_note_types()?
.values()
.filter(|nt| matches_wildcard(&nt.name, nt_name))
.map(|nt| nt.id)
.collect();
self.sql.push_str("n.mid in ");
// sort for the benefit of unit tests
ntids.sort();
ids_to_string(&mut self.sql, &ntids);
Ok(())
}
fn write_single_field(&mut self, field_name: &str, val: &str, is_re: bool) -> Result<()> {
let note_types = self.req.storage.all_note_types()?;
let mut field_map = vec![];
for nt in note_types.values() {
for field in &nt.fields {
if matches_wildcard(&field.name, field_name) {
field_map.push((nt.id, field.ord));
}
}
}
// for now, sort the map for the benefit of unit tests
field_map.sort();
if field_map.is_empty() {
write!(self.sql, "false").unwrap();
return Ok(());
}
let cmp;
if is_re {
cmp = "regexp";
self.args.push(format!("(?i){}", val));
} else {
cmp = "like";
self.args.push(val.replace('*', "%"));
}
let arg_idx = self.args.len();
let searches: Vec<_> = field_map
.iter()
.map(|(ntid, ord)| {
format!(
"(n.mid = {mid} and field_at_index(n.flds, {ord}) {cmp} ?{n})",
mid = ntid,
ord = ord,
cmp = cmp,
n = arg_idx
)
})
.collect();
write!(self.sql, "({})", searches.join(" or ")).unwrap();
Ok(())
}
fn write_dupes(&mut self, ntid: ObjID, text: &str) {
let text_nohtml = strip_html_preserving_image_filenames(text);
let csum = field_checksum(text_nohtml.as_ref());
write!(
self.sql,
"(n.mid = {} and n.csum = {} and field_at_index(n.flds, 0) = ?",
ntid, csum
)
.unwrap();
self.args.push(text.to_string());
}
fn write_added(&mut self, days: u32) -> Result<()> {
let timing = self.req.storage.timing_today()?;
let cutoff = (timing.next_day_at - (86_400 * (days as i64))) * 1_000;
write!(self.sql, "c.id > {}", cutoff).unwrap();
Ok(())
}
fn write_regex(&mut self, word: &str) {
self.sql.push_str("n.flds regexp ?");
self.args.push(format!(r"(?i){}", word));
}
}
// Write a list of IDs as '(x,y,...)' into the provided string.
fn ids_to_string<T>(buf: &mut String, ids: &[T])
where
T: std::fmt::Display,
{
buf.push('(');
if !ids.is_empty() {
for id in ids.iter().skip(1) {
write!(buf, "{},", id).unwrap();
}
write!(buf, "{}", ids[0]).unwrap();
}
buf.push(')');
}
#[cfg(test)]
mod test {
use super::ids_to_string;
use crate::{collection::open_collection, i18n::I18n, log};
use std::{fs, path::PathBuf};
use tempfile::tempdir;
#[test]
fn ids_string() {
let mut s = String::new();
ids_to_string::<u8>(&mut s, &[]);
assert_eq!(s, "()");
s.clear();
ids_to_string(&mut s, &[7]);
assert_eq!(s, "(7)");
s.clear();
ids_to_string(&mut s, &[7, 6]);
assert_eq!(s, "(6,7)");
s.clear();
ids_to_string(&mut s, &[7, 6, 5]);
assert_eq!(s, "(6,5,7)");
s.clear();
}
use super::super::parser::parse;
use super::*;
// shortcut
fn s(req: &mut RequestContext, search: &str) -> (String, Vec<String>) {
let node = Node::Group(parse(search).unwrap());
node_to_sql(req, &node).unwrap()
}
#[test]
fn sql() -> Result<()> {
// re-use the mediacheck .anki2 file for now
use crate::media::check::test::MEDIACHECK_ANKI2;
let dir = tempdir().unwrap();
let col_path = dir.path().join("col.anki2");
fs::write(&col_path, MEDIACHECK_ANKI2).unwrap();
let i18n = I18n::new(&[""], "", log::terminal());
let col = open_collection(
&col_path,
&PathBuf::new(),
&PathBuf::new(),
false,
i18n,
log::terminal(),
)
.unwrap();
col.with_ctx(|ctx| {
// unqualified search
assert_eq!(
s(ctx, "test"),
(
"((n.sfld like ?1 escape '\\' or n.flds like ?1 escape '\\'))".into(),
vec!["%test%".into()]
)
);
assert_eq!(s(ctx, "te%st").1, vec!["%te%st%".to_string()]);
// user should be able to escape sql wildcards
assert_eq!(s(ctx, r#"te\%s\_t"#).1, vec!["%te\\%s\\_t%".to_string()]);
// qualified search
assert_eq!(
s(ctx, "front:te*st"),
(
concat!(
"(((n.mid = 1581236385344 and field_at_index(n.flds, 0) like ?1) or ",
"(n.mid = 1581236385345 and field_at_index(n.flds, 0) like ?1) or ",
"(n.mid = 1581236385346 and field_at_index(n.flds, 0) like ?1) or ",
"(n.mid = 1581236385347 and field_at_index(n.flds, 0) like ?1)))"
)
.into(),
vec!["te%st".into()]
)
);
// added
let timing = ctx.storage.timing_today().unwrap();
assert_eq!(
s(ctx, "added:3").0,
format!("(c.id > {})", (timing.next_day_at - (86_400 * 3)) * 1_000)
);
// deck
assert_eq!(s(ctx, "deck:default"), ("(c.did in (1))".into(), vec![],));
assert_eq!(s(ctx, "deck:current"), ("(c.did in (1))".into(), vec![],));
assert_eq!(s(ctx, "deck:missing"), ("(c.did in ())".into(), vec![],));
assert_eq!(s(ctx, "deck:d*"), ("(c.did in (1))".into(), vec![],));
assert_eq!(s(ctx, "deck:filtered"), ("(c.odid > 0)".into(), vec![],));
// card
assert_eq!(s(ctx, "card:front"), ("(false)".into(), vec![],));
assert_eq!(
s(ctx, r#""card:card 1""#),
(
concat!(
"(((n.mid = 1581236385344 and c.ord = 0) or ",
"(n.mid = 1581236385345 and c.ord = 0) or ",
"(n.mid = 1581236385346 and c.ord = 0) or ",
"(n.mid = 1581236385347 and c.ord = 0)))"
)
.into(),
vec![],
)
);
// IDs
assert_eq!(s(ctx, "mid:3"), ("(n.mid = 3)".into(), vec![]));
assert_eq!(s(ctx, "nid:3"), ("(n.id in (3))".into(), vec![]));
assert_eq!(s(ctx, "nid:3,4"), ("(n.id in (3,4))".into(), vec![]));
assert_eq!(s(ctx, "cid:3,4"), ("(c.id in (3,4))".into(), vec![]));
// flags
assert_eq!(s(ctx, "flag:2"), ("((c.flags & 7) == 2)".into(), vec![]));
assert_eq!(s(ctx, "flag:0"), ("((c.flags & 7) == 0)".into(), vec![]));
// dupes
assert_eq!(
s(ctx, "dupes:123,test"),
(
"((n.mid = 123 and n.csum = 2840236005 and field_at_index(n.flds, 0) = ?)"
.into(),
vec!["test".into()]
)
);
// tags
assert_eq!(
s(ctx, "tag:one"),
("(n.tags like ? escape '\\')".into(), vec!["% one %".into()])
);
assert_eq!(
s(ctx, "tag:o*e"),
("(n.tags like ? escape '\\')".into(), vec!["% o%e %".into()])
);
assert_eq!(s(ctx, "tag:none"), ("(n.tags = '')".into(), vec![]));
assert_eq!(s(ctx, "tag:*"), ("(true)".into(), vec![]));
// state
assert_eq!(
s(ctx, "is:suspended").0,
format!("(c.queue = {})", CardQueue::Suspended as i8)
);
assert_eq!(
s(ctx, "is:new").0,
format!("(c.type = {})", CardQueue::New as i8)
);
// rated
assert_eq!(
s(ctx, "rated:2").0,
format!(
"(c.id in (select cid from revlog where id>{}))",
(timing.next_day_at - (86_400 * 2)) * 1_000
)
);
assert_eq!(
s(ctx, "rated:400:1").0,
format!(
"(c.id in (select cid from revlog where id>{} and ease=1))",
(timing.next_day_at - (86_400 * 365)) * 1_000
)
);
// props
assert_eq!(s(ctx, "prop:lapses=3").0, "(lapses = 3)".to_string());
assert_eq!(s(ctx, "prop:ease>=2.5").0, "(factor >= 2500)".to_string());
assert_eq!(
s(ctx, "prop:due!=-1").0,
format!(
"((c.queue in (2,3) and due != {}))",
timing.days_elapsed - 1
)
);
// note types by name
assert_eq!(&s(ctx, "note:basic").0, "(n.mid in (1581236385347))");
assert_eq!(
&s(ctx, "note:basic*").0,
"(n.mid in (1581236385345,1581236385346,1581236385347,1581236385344))"
);
// regex
assert_eq!(
s(ctx, r"re:\bone"),
("(n.flds regexp ?)".into(), vec![r"(?i)\bone".into()])
);
Ok(())
})
.unwrap();
Ok(())
}
}

3
rslib/src/storage/mod.rs Normal file
View file

@ -0,0 +1,3 @@
mod sqlite;
pub(crate) use sqlite::{SqliteStorage, StorageContext};

View file

@ -0,0 +1,88 @@
create table col
(
id integer primary key,
crt integer not null,
mod integer not null,
scm integer not null,
ver integer not null,
dty integer not null,
usn integer not null,
ls integer not null,
conf text not null,
models text not null,
decks text not null,
dconf text not null,
tags text not null
);
create table notes
(
id integer primary key,
guid text not null,
mid integer not null,
mod integer not null,
usn integer not null,
tags text not null,
flds text not null,
sfld integer not null,
csum integer not null,
flags integer not null,
data text not null
);
create table cards
(
id integer primary key,
nid integer not null,
did integer not null,
ord integer not null,
mod integer not null,
usn integer not null,
type integer not null,
queue integer not null,
due integer not null,
ivl integer not null,
factor integer not null,
reps integer not null,
lapses integer not null,
left integer not null,
odue integer not null,
odid integer not null,
flags integer not null,
data text not null
);
create table revlog
(
id integer primary key,
cid integer not null,
usn integer not null,
ease integer not null,
ivl integer not null,
lastIvl integer not null,
factor integer not null,
time integer not null,
type integer not null
);
create table graves
(
usn integer not null,
oid integer not null,
type integer not null
);
-- syncing
create index ix_notes_usn on notes (usn);
create index ix_cards_usn on cards (usn);
create index ix_revlog_usn on revlog (usn);
-- card spacing, etc
create index ix_cards_nid on cards (nid);
-- scheduling and deck limiting
create index ix_cards_sched on cards (did, queue, due);
-- revlog by card
create index ix_revlog_cid on revlog (cid);
-- field uniqueness
create index ix_notes_csum on notes (csum);
insert into col values (1,0,0,0,0,0,0,0,'{}','{}','{}','{}','{}');

335
rslib/src/storage/sqlite.rs Normal file
View file

@ -0,0 +1,335 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::collection::CollectionOp;
use crate::config::Config;
use crate::err::Result;
use crate::err::{AnkiError, DBErrorKind};
use crate::time::{i64_unix_millis, i64_unix_secs};
use crate::{
decks::Deck,
notetypes::NoteType,
sched::cutoff::{sched_timing_today, SchedTimingToday},
text::without_combining,
types::{ObjID, Usn},
};
use regex::Regex;
use rusqlite::{params, Connection, NO_PARAMS};
use std::cmp::Ordering;
use std::{
borrow::Cow,
collections::HashMap,
path::{Path, PathBuf},
};
use unicase::UniCase;
const SCHEMA_MIN_VERSION: u8 = 11;
const SCHEMA_MAX_VERSION: u8 = 11;
fn unicase_compare(s1: &str, s2: &str) -> Ordering {
UniCase::new(s1).cmp(&UniCase::new(s2))
}
// currently public for dbproxy
#[derive(Debug)]
pub struct SqliteStorage {
// currently crate-visible for dbproxy
pub(crate) db: Connection,
// fixme: stored in wrong location?
path: PathBuf,
}
fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
let mut db = Connection::open(path)?;
if std::env::var("TRACESQL").is_ok() {
db.trace(Some(trace));
}
db.busy_timeout(std::time::Duration::from_secs(0))?;
db.pragma_update(None, "locking_mode", &"exclusive")?;
db.pragma_update(None, "page_size", &4096)?;
db.pragma_update(None, "cache_size", &(-40 * 1024))?;
db.pragma_update(None, "legacy_file_format", &false)?;
db.pragma_update(None, "journal_mode", &"wal")?;
db.pragma_update(None, "temp_store", &"memory")?;
db.set_prepared_statement_cache_capacity(50);
add_field_index_function(&db)?;
add_regexp_function(&db)?;
add_without_combining_function(&db)?;
db.create_collation("unicase", unicase_compare)?;
Ok(db)
}
/// Adds sql function field_at_index(flds, index)
/// to split provided fields and return field at zero-based index.
/// If out of range, returns empty string.
fn add_field_index_function(db: &Connection) -> rusqlite::Result<()> {
db.create_scalar_function("field_at_index", 2, true, |ctx| {
let mut fields = ctx.get_raw(0).as_str()?.split('\x1f');
let idx: u16 = ctx.get(1)?;
Ok(fields.nth(idx as usize).unwrap_or("").to_string())
})
}
fn add_without_combining_function(db: &Connection) -> rusqlite::Result<()> {
db.create_scalar_function("without_combining", 1, true, |ctx| {
let text = ctx.get_raw(0).as_str()?;
Ok(match without_combining(text) {
Cow::Borrowed(_) => None,
Cow::Owned(o) => Some(o),
})
})
}
/// Adds sql function regexp(regex, string) -> is_match
/// Taken from the rusqlite docs
fn add_regexp_function(db: &Connection) -> rusqlite::Result<()> {
db.create_scalar_function("regexp", 2, true, move |ctx| {
assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
let saved_re: Option<&Regex> = ctx.get_aux(0)?;
let new_re = match saved_re {
None => {
let s = ctx.get::<String>(0)?;
match Regex::new(&s) {
Ok(r) => Some(r),
Err(err) => return Err(rusqlite::Error::UserFunctionError(Box::new(err))),
}
}
Some(_) => None,
};
let is_match = {
let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
let text = ctx
.get_raw(1)
.as_str()
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
re.is_match(text)
};
if let Some(re) = new_re {
ctx.set_aux(0, re);
}
Ok(is_match)
})
}
/// Fetch schema version from database.
/// Return (must_create, version)
fn schema_version(db: &Connection) -> Result<(bool, u8)> {
if !db
.prepare("select null from sqlite_master where type = 'table' and name = 'col'")?
.exists(NO_PARAMS)?
{
return Ok((true, SCHEMA_MAX_VERSION));
}
Ok((
false,
db.query_row("select ver from col", NO_PARAMS, |r| Ok(r.get(0)?))?,
))
}
fn trace(s: &str) {
println!("sql: {}", s)
}
impl SqliteStorage {
pub(crate) fn open_or_create(path: &Path) -> Result<Self> {
let db = open_or_create_collection_db(path)?;
let (create, ver) = schema_version(&db)?;
if create {
db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?;
db.execute_batch(include_str!("schema11.sql"))?;
db.execute("update col set crt=?, ver=?", params![i64_unix_secs(), ver])?;
db.prepare_cached("commit")?.execute(NO_PARAMS)?;
} else {
if ver > SCHEMA_MAX_VERSION {
return Err(AnkiError::DBError {
info: "".to_string(),
kind: DBErrorKind::FileTooNew,
});
}
if ver < SCHEMA_MIN_VERSION {
return Err(AnkiError::DBError {
info: "".to_string(),
kind: DBErrorKind::FileTooOld,
});
}
};
let storage = Self {
db,
path: path.to_owned(),
};
Ok(storage)
}
pub(crate) fn context(&self, server: bool) -> StorageContext {
StorageContext::new(&self.db, server)
}
}
pub(crate) struct StorageContext<'a> {
pub(crate) db: &'a Connection,
#[allow(dead_code)]
server: bool,
#[allow(dead_code)]
usn: Option<Usn>,
timing_today: Option<SchedTimingToday>,
}
impl StorageContext<'_> {
fn new(db: &Connection, server: bool) -> StorageContext {
StorageContext {
db,
server,
usn: None,
timing_today: None,
}
}
// Standard transaction start/stop
//////////////////////////////////////
pub(crate) fn begin_trx(&self) -> Result<()> {
self.db
.prepare_cached("begin exclusive")?
.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit_trx(&self) -> Result<()> {
if !self.db.is_autocommit() {
self.db.prepare_cached("commit")?.execute(NO_PARAMS)?;
}
Ok(())
}
pub(crate) fn rollback_trx(&self) -> Result<()> {
if !self.db.is_autocommit() {
self.db.execute("rollback", NO_PARAMS)?;
}
Ok(())
}
// Savepoints
//////////////////////////////////////////
//
// This is necessary at the moment because Anki's current architecture uses
// long-running transactions as an undo mechanism. Once a proper undo
// mechanism has been added to all existing functionality, we could
// transition these to standard commits.
pub(crate) fn begin_rust_trx(&self) -> Result<()> {
self.db
.prepare_cached("savepoint rust")?
.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit_rust_trx(&self) -> Result<()> {
self.db.prepare_cached("release rust")?.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit_rust_op(&self, _op: Option<CollectionOp>) -> Result<()> {
self.commit_rust_trx()
}
pub(crate) fn rollback_rust_trx(&self) -> Result<()> {
self.db
.prepare_cached("rollback to rust")?
.execute(NO_PARAMS)?;
Ok(())
}
//////////////////////////////////////////
pub(crate) fn mark_modified(&self) -> Result<()> {
self.db
.prepare_cached("update col set mod=?")?
.execute(params![i64_unix_millis()])?;
Ok(())
}
#[allow(dead_code)]
pub(crate) fn usn(&mut self) -> Result<Usn> {
if self.server {
if self.usn.is_none() {
self.usn = Some(
self.db
.prepare_cached("select usn from col")?
.query_row(NO_PARAMS, |row| row.get(0))?,
);
}
Ok(*self.usn.as_ref().unwrap())
} else {
Ok(-1)
}
}
pub(crate) fn all_decks(&self) -> Result<HashMap<ObjID, Deck>> {
self.db
.query_row_and_then("select decks from col", NO_PARAMS, |row| -> Result<_> {
Ok(serde_json::from_str(row.get_raw(0).as_str()?)?)
})
}
pub(crate) fn all_config(&self) -> Result<Config> {
self.db
.query_row_and_then("select conf from col", NO_PARAMS, |row| -> Result<_> {
Ok(serde_json::from_str(row.get_raw(0).as_str()?)?)
})
}
pub(crate) fn all_note_types(&self) -> Result<HashMap<ObjID, NoteType>> {
let mut stmt = self.db.prepare("select models from col")?;
let note_types = stmt
.query_and_then(NO_PARAMS, |row| -> Result<HashMap<ObjID, NoteType>> {
let v: HashMap<ObjID, NoteType> = serde_json::from_str(row.get_raw(0).as_str()?)?;
Ok(v)
})?
.next()
.ok_or_else(|| AnkiError::DBError {
info: "col table empty".to_string(),
kind: DBErrorKind::MissingEntity,
})??;
Ok(note_types)
}
#[allow(dead_code)]
pub(crate) fn timing_today(&mut self) -> Result<SchedTimingToday> {
if self.timing_today.is_none() {
let crt: i64 = self
.db
.prepare_cached("select crt from col")?
.query_row(NO_PARAMS, |row| row.get(0))?;
let conf = self.all_config()?;
let now_offset = if self.server { conf.local_offset } else { None };
self.timing_today = Some(sched_timing_today(
crt,
i64_unix_secs(),
conf.creation_offset,
now_offset,
conf.rollover,
));
}
Ok(*self.timing_today.as_ref().unwrap())
}
}

View file

@ -5,7 +5,10 @@ use lazy_static::lazy_static;
use regex::{Captures, Regex};
use std::borrow::Cow;
use std::ptr;
use unicode_normalization::{is_nfc, UnicodeNormalization};
use unicase::eq as uni_eq;
use unicode_normalization::{
char::is_combining_mark, is_nfc, is_nfkd_quick, IsNormalized, UnicodeNormalization,
};
#[derive(Debug, PartialEq)]
pub enum AVTag {
@ -219,11 +222,43 @@ pub(crate) fn normalize_to_nfc(s: &str) -> Cow<str> {
}
}
/// True if search is equal to text, folding case.
/// Supports '*' to match 0 or more characters.
pub(crate) fn matches_wildcard(text: &str, search: &str) -> bool {
if search.contains('*') {
let search = format!("^(?i){}$", regex::escape(search).replace(r"\*", ".*"));
Regex::new(&search).unwrap().is_match(text)
} else {
uni_eq(text, search)
}
}
/// Convert provided string to NFKD form and strip combining characters.
pub(crate) fn without_combining(s: &str) -> Cow<str> {
// if the string is already normalized
if matches!(is_nfkd_quick(s.chars()), IsNormalized::Yes) {
// and no combining characters found, return unchanged
if !s.chars().any(is_combining_mark) {
return s.into();
}
}
// we need to create a new string without the combining marks
s.chars()
.nfkd()
.filter(|c| !is_combining_mark(*c))
.collect::<String>()
.into()
}
#[cfg(test)]
mod test {
use super::matches_wildcard;
use crate::text::without_combining;
use crate::text::{
extract_av_tags, strip_av_tags, strip_html, strip_html_preserving_image_filenames, AVTag,
};
use std::borrow::Cow;
#[test]
fn stripping() {
@ -265,4 +300,19 @@ mod test {
]
);
}
#[test]
fn wildcard() {
assert_eq!(matches_wildcard("foo", "bar"), false);
assert_eq!(matches_wildcard("foo", "Foo"), true);
assert_eq!(matches_wildcard("foo", "F*"), true);
assert_eq!(matches_wildcard("foo", "F*oo"), true);
assert_eq!(matches_wildcard("foo", "b*"), false);
}
#[test]
fn combining() {
assert!(matches!(without_combining("test"), Cow::Borrowed(_)));
assert!(matches!(without_combining("Über"), Cow::Owned(_)));
}
}

View file

@ -4,15 +4,35 @@
use std::time;
pub(crate) fn i64_unix_secs() -> i64 {
time::SystemTime::now()
.duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
elapsed().as_secs() as i64
}
pub(crate) fn i64_unix_millis() -> i64 {
elapsed().as_millis() as i64
}
#[cfg(not(test))]
fn elapsed() -> time::Duration {
time::SystemTime::now()
.duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
// when running in CI, shift the current time away from the cutoff point
// to accomodate unit tests that depend on the current time
#[cfg(test)]
fn elapsed() -> time::Duration {
use chrono::{Local, Timelike};
let now = Local::now();
let mut elap = time::SystemTime::now()
.duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap();
if now.hour() >= 2 && now.hour() < 4 {
elap -= time::Duration::from_secs(60 * 60 * 2);
}
elap
}

View file

@ -1,6 +1,6 @@
[package]
name = "ankirspy"
version = "2.1.22" # automatically updated
version = "2.1.24" # automatically updated
edition = "2018"
authors = ["Ankitects Pty Ltd and contributors"]

View file

@ -1,12 +1,11 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use anki::backend::{
init_backend, init_i18n_backend, Backend as RustBackend, I18nBackend as RustI18nBackend,
};
use anki::backend::{init_backend, Backend as RustBackend};
use pyo3::exceptions::Exception;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::{exceptions, wrap_pyfunction};
use pyo3::{create_exception, exceptions, wrap_pyfunction};
// Regular backend
//////////////////////////////////
@ -16,6 +15,8 @@ struct Backend {
backend: RustBackend,
}
create_exception!(ankirspy, DBError, Exception);
#[pyfunction]
fn buildhash() -> &'static str {
include_str!("../../meta/buildhash").trim()
@ -70,29 +71,17 @@ impl Backend {
self.backend.set_progress_callback(Some(Box::new(func)));
}
}
}
// I18n backend
//////////////////////////////////
#[pyclass]
struct I18nBackend {
backend: RustI18nBackend,
}
#[pyfunction]
fn open_i18n(init_msg: &PyBytes) -> PyResult<I18nBackend> {
match init_i18n_backend(init_msg.as_bytes()) {
Ok(backend) => Ok(I18nBackend { backend }),
Err(e) => Err(exceptions::Exception::py_err(format!("{:?}", e))),
}
}
#[pymethods]
impl I18nBackend {
fn translate(&self, input: &PyBytes) -> String {
fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult<PyObject> {
let in_bytes = input.as_bytes();
self.backend.translate(in_bytes)
let out_res = py.allow_threads(move || {
self.backend
.db_command(in_bytes)
.map_err(|e| DBError::py_err(e.localized_description(&self.backend.i18n())))
});
let out_string = out_res?;
let out_obj = PyBytes::new(py, out_string.as_bytes());
Ok(out_obj.into())
}
}
@ -104,7 +93,6 @@ fn ankirspy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Backend>()?;
m.add_wrapped(wrap_pyfunction!(buildhash)).unwrap();
m.add_wrapped(wrap_pyfunction!(open_backend)).unwrap();
m.add_wrapped(wrap_pyfunction!(open_i18n)).unwrap();
Ok(())
}