Merge remote-tracking branch 'danielelmes/master' into fix_tests_on_windows

# Conflicts:
#	.github/scripts/trailing-newlines.sh
This commit is contained in:
evandrocoan 2020-03-23 18:44:11 -03:00
commit b1b3e5b87c
81 changed files with 3897 additions and 1760 deletions

1
.gitattributes vendored Normal file
View file

@ -0,0 +1 @@
*.ftl eol=lf

View file

@ -6,7 +6,8 @@ set -eo pipefail
# Because `set -e` does not work inside the subshell $() # Because `set -e` does not work inside the subshell $()
rg --version > /dev/null 2>&1 rg --version > /dev/null 2>&1
files=$(rg -l '[^\n]\z' -g '!*.{png,svg,scss,json}' || true) files=$(rg -l '[^\n]\z' -g '!*.{png,svg,scss,json,sql}' || true)
if [ "$files" != "" ]; then if [ "$files" != "" ]; then
echo "the following files are missing a newline on the last line:" echo "the following files are missing a newline on the last line:"
echo $files echo $files

View file

@ -1 +1 @@
2.1.22 2.1.24

View file

@ -6,13 +6,14 @@ package backend_proto;
message Empty {} message Empty {}
message OptionalInt32 {
sint32 val = 1;
}
message BackendInit { message BackendInit {
string collection_path = 1; repeated string preferred_langs = 1;
string media_folder_path = 2; string locale_folder_path = 2;
string media_db_path = 3; bool server = 3;
repeated string preferred_langs = 4;
string locale_folder_path = 5;
string log_path = 6;
} }
message I18nBackendInit { message I18nBackendInit {
@ -27,8 +28,8 @@ message BackendInput {
TemplateRequirementsIn template_requirements = 16; TemplateRequirementsIn template_requirements = 16;
SchedTimingTodayIn sched_timing_today = 17; SchedTimingTodayIn sched_timing_today = 17;
Empty deck_tree = 18; Empty deck_tree = 18;
FindCardsIn find_cards = 19; SearchCardsIn search_cards = 19;
BrowserRowsIn browser_rows = 20; SearchNotesIn search_notes = 20;
RenderCardIn render_card = 21; RenderCardIn render_card = 21;
int64 local_minutes_west = 22; int64 local_minutes_west = 22;
string strip_av_tags = 23; string strip_av_tags = 23;
@ -44,6 +45,8 @@ message BackendInput {
CongratsLearnMsgIn congrats_learn_msg = 33; CongratsLearnMsgIn congrats_learn_msg = 33;
Empty empty_trash = 34; Empty empty_trash = 34;
Empty restore_trash = 35; Empty restore_trash = 35;
OpenCollectionIn open_collection = 36;
Empty close_collection = 37;
} }
} }
@ -63,8 +66,8 @@ message BackendOutput {
// fallible commands // fallible commands
TemplateRequirementsOut template_requirements = 16; TemplateRequirementsOut template_requirements = 16;
DeckTreeOut deck_tree = 18; DeckTreeOut deck_tree = 18;
FindCardsOut find_cards = 19; SearchCardsOut search_cards = 19;
BrowserRowsOut browser_rows = 20; SearchNotesOut search_notes = 20;
RenderCardOut render_card = 21; RenderCardOut render_card = 21;
string add_media_file = 26; string add_media_file = 26;
Empty sync_media = 27; Empty sync_media = 27;
@ -72,6 +75,8 @@ message BackendOutput {
Empty trash_media_files = 29; Empty trash_media_files = 29;
Empty empty_trash = 34; Empty empty_trash = 34;
Empty restore_trash = 35; Empty restore_trash = 35;
Empty open_collection = 36;
Empty close_collection = 37;
BackendError error = 2047; BackendError error = 2047;
} }
@ -162,10 +167,10 @@ message TemplateRequirementAny {
message SchedTimingTodayIn { message SchedTimingTodayIn {
int64 created_secs = 1; int64 created_secs = 1;
sint32 created_mins_west = 2; int64 now_secs = 2;
int64 now_secs = 3; OptionalInt32 created_mins_west = 3;
sint32 now_mins_west = 4; OptionalInt32 now_mins_west = 4;
sint32 rollover_hour = 5; OptionalInt32 rollover_hour = 5;
} }
message SchedTimingTodayOut { message SchedTimingTodayOut {
@ -188,23 +193,6 @@ message DeckTreeNode {
bool collapsed = 7; bool collapsed = 7;
} }
message FindCardsIn {
string search = 1;
}
message FindCardsOut {
repeated int64 card_ids = 1;
}
message BrowserRowsIn {
repeated int64 card_ids = 1;
}
message BrowserRowsOut {
// just sort fields for proof of concept
repeated string sort_fields = 1;
}
message RenderCardIn { message RenderCardIn {
string question_template = 1; string question_template = 1;
string answer_template = 2; string answer_template = 2;
@ -324,3 +312,58 @@ message CongratsLearnMsgIn {
float next_due = 1; float next_due = 1;
uint32 remaining = 2; uint32 remaining = 2;
} }
message OpenCollectionIn {
string collection_path = 1;
string media_folder_path = 2;
string media_db_path = 3;
string log_path = 4;
}
message SearchCardsIn {
string search = 1;
SortOrder order = 2;
}
message SearchCardsOut {
repeated int64 card_ids = 1;
}
message SortOrder {
oneof value {
Empty from_config = 1;
Empty none = 2;
string custom = 3;
BuiltinSearchOrder builtin = 4;
}
}
message SearchNotesIn {
string search = 1;
}
message SearchNotesOut {
repeated int64 note_ids = 2;
}
message BuiltinSearchOrder {
BuiltinSortKind kind = 1;
bool reverse = 2;
}
enum BuiltinSortKind {
NOTE_CREATION = 0;
NOTE_MOD = 1;
NOTE_FIELD = 2;
NOTE_TAGS = 3;
NOTE_TYPE = 4;
CARD_MOD = 5;
CARD_REPS = 6;
CARD_DUE = 7;
CARD_EASE = 8;
CARD_LAPSES = 9;
CARD_INTERVAL = 10;
CARD_DECK = 11;
CARD_TEMPLATE = 12;
}

View file

@ -124,30 +124,6 @@ insert or replace into cards values
) )
self.col.log(self) self.col.log(self)
def flushSched(self) -> None:
self._preFlush()
# bug checks
self.col.db.execute(
"""update cards set
mod=?, usn=?, type=?, queue=?, due=?, ivl=?, factor=?, reps=?,
lapses=?, left=?, odue=?, odid=?, did=? where id = ?""",
self.mod,
self.usn,
self.type,
self.queue,
self.due,
self.ivl,
self.factor,
self.reps,
self.lapses,
self.left,
self.odue,
self.odid,
self.did,
self.id,
)
self.col.log(self)
def question(self, reload: bool = False, browser: bool = False) -> str: def question(self, reload: bool = False, browser: bool = False) -> str:
return self.css() + self.render_output(reload, browser).question_text return self.css() + self.render_output(reload, browser).question_text
@ -181,6 +157,8 @@ lapses=?, left=?, odue=?, odid=?, did=? where id = ?""",
def note_type(self) -> NoteType: def note_type(self) -> NoteType:
return self.col.models.get(self.note().mid) return self.col.models.get(self.note().mid)
# legacy aliases
flushSched = flush
q = question q = question
a = answer a = answer
model = note_type model = note_type

View file

@ -15,7 +15,7 @@ import time
import traceback import traceback
import unicodedata import unicodedata
import weakref import weakref
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import anki.find import anki.find
import anki.latex # sets up hook import anki.latex # sets up hook
@ -23,7 +23,7 @@ import anki.template
from anki import hooks from anki import hooks
from anki.cards import Card from anki.cards import Card
from anki.consts import * from anki.consts import *
from anki.db import DB from anki.dbproxy import DBProxy
from anki.decks import DeckManager from anki.decks import DeckManager
from anki.errors import AnkiError from anki.errors import AnkiError
from anki.lang import _, ngettext from anki.lang import _, ngettext
@ -67,7 +67,7 @@ defaultConf = {
# this is initialized by storage.Collection # this is initialized by storage.Collection
class _Collection: class _Collection:
db: Optional[DB] db: Optional[DBProxy]
sched: Union[V1Scheduler, V2Scheduler] sched: Union[V1Scheduler, V2Scheduler]
crt: int crt: int
mod: int mod: int
@ -80,13 +80,12 @@ class _Collection:
def __init__( def __init__(
self, self,
db: DB, db: DBProxy,
backend: RustBackend, backend: RustBackend,
server: Optional["anki.storage.ServerData"] = None, server: Optional["anki.storage.ServerData"] = None,
log: bool = False,
) -> None: ) -> None:
self.backend = backend self.backend = backend
self._debugLog = log self._debugLog = not server
self.db = db self.db = db
self.path = db._path self.path = db._path
self._openLog() self._openLog()
@ -139,10 +138,6 @@ class _Collection:
self.sched = V1Scheduler(self) self.sched = V1Scheduler(self)
elif ver == 2: elif ver == 2:
self.sched = V2Scheduler(self) self.sched = V2Scheduler(self)
if not self.server:
self.conf["localOffset"] = self.sched._current_timezone_offset()
elif self.server.minutes_west is not None:
self.conf["localOffset"] = self.server.minutes_west
def changeSchedulerVer(self, ver: int) -> None: def changeSchedulerVer(self, ver: int) -> None:
if ver == self.schedVer(): if ver == self.schedVer():
@ -165,12 +160,13 @@ class _Collection:
self._loadScheduler() self._loadScheduler()
# the sync code uses this to send the local timezone to AnkiWeb
def localOffset(self) -> Optional[int]: def localOffset(self) -> Optional[int]:
"Minutes west of UTC. Only applies to V2 scheduler." "Minutes west of UTC. Only applies to V2 scheduler."
if isinstance(self.sched, V1Scheduler): if isinstance(self.sched, V1Scheduler):
return None return None
else: else:
return self.sched._current_timezone_offset() return self.backend.local_minutes_west(intTime())
# DB-related # DB-related
########################################################################## ##########################################################################
@ -220,8 +216,10 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
json.dumps(self.conf), json.dumps(self.conf),
) )
def save(self, name: Optional[str] = None, mod: Optional[int] = None) -> None: def save(
"Flush, commit DB, and take out another write lock." self, name: Optional[str] = None, mod: Optional[int] = None, trx: bool = True
) -> None:
"Flush, commit DB, and take out another write lock if trx=True."
# let the managers conditionally flush # let the managers conditionally flush
self.models.flush() self.models.flush()
self.decks.flush() self.decks.flush()
@ -230,8 +228,14 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
if self.db.mod: if self.db.mod:
self.flush(mod=mod) self.flush(mod=mod)
self.db.commit() self.db.commit()
self.lock()
self.db.mod = False self.db.mod = False
if trx:
self.db.begin()
elif not trx:
# if no changes were pending but calling code expects to be
# outside of a transaction, we need to roll back
self.db.rollback()
self._markOp(name) self._markOp(name)
self._lastSave = time.time() self._lastSave = time.time()
@ -242,39 +246,24 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
return True return True
return None return None
def lock(self) -> None:
# make sure we don't accidentally bump mod time
mod = self.db.mod
self.db.execute("update col set mod=mod")
self.db.mod = mod
def close(self, save: bool = True) -> None: def close(self, save: bool = True) -> None:
"Disconnect from DB." "Disconnect from DB."
if self.db: if self.db:
if save: if save:
self.save() self.save(trx=False)
else: else:
self.db.rollback() self.db.rollback()
if not self.server: if not self.server:
self.db.setAutocommit(True)
self.db.execute("pragma journal_mode = delete") self.db.execute("pragma journal_mode = delete")
self.db.setAutocommit(False) self.backend.close_collection()
self.db.close()
self.db = None self.db = None
self.media.close() self.media.close()
self._closeLog() self._closeLog()
def reopen(self) -> None:
"Reconnect to DB (after changing threads, etc)."
if not self.db:
self.db = DB(self.path)
self.media.connect()
self._openLog()
def rollback(self) -> None: def rollback(self) -> None:
self.db.rollback() self.db.rollback()
self.db.begin()
self.load() self.load()
self.lock()
def modSchema(self, check: bool) -> None: def modSchema(self, check: bool) -> None:
"Mark schema modified. Call this first so user can abort if necessary." "Mark schema modified. Call this first so user can abort if necessary."
@ -305,10 +294,10 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
self.modSchema(check=False) self.modSchema(check=False)
self.ls = self.scm self.ls = self.scm
# ensure db is compacted before upload # ensure db is compacted before upload
self.db.setAutocommit(True) self.save(trx=False)
self.db.execute("vacuum") self.db.execute("vacuum")
self.db.execute("analyze") self.db.execute("analyze")
self.close() self.close(save=False)
# Object creation helpers # Object creation helpers
########################################################################## ##########################################################################
@ -626,11 +615,25 @@ where c.nid = n.id and c.id in %s group by nid"""
# Finding cards # Finding cards
########################################################################## ##########################################################################
def findCards(self, query: str, order: Union[bool, str] = False) -> Any: # if order=True, use the sort order stored in the collection config
return anki.find.Finder(self).findCards(query, order) # if order=False, do no ordering
#
# if order is a string, that text is added after 'order by' in the sql statement.
# you must add ' asc' or ' desc' to the order, as Anki will replace asc with
# desc and vice versa when reverse is set in the collection config, eg
# order="c.ivl asc, c.due desc"
#
# if order is an int enum, sort using that builtin sort, eg
# col.find_cards("", order=BuiltinSortKind.CARD_DUE)
# the reverse argument only applies when a BuiltinSortKind is provided;
# otherwise the collection config defines whether reverse is set or not
def find_cards(
self, query: str, order: Union[bool, str, int] = False, reverse: bool = False,
) -> Sequence[int]:
return self.backend.search_cards(query, order, reverse)
def findNotes(self, query: str) -> Any: def find_notes(self, query: str) -> Sequence[int]:
return anki.find.Finder(self).findNotes(query) return self.backend.search_notes(query)
def findReplace( def findReplace(
self, self,
@ -646,6 +649,9 @@ where c.nid = n.id and c.id in %s group by nid"""
def findDupes(self, fieldName: str, search: str = "") -> List[Tuple[Any, list]]: def findDupes(self, fieldName: str, search: str = "") -> List[Tuple[Any, list]]:
return anki.find.findDupes(self, fieldName, search) return anki.find.findDupes(self, fieldName, search)
findCards = find_cards
findNotes = find_notes
# Stats # Stats
########################################################################## ##########################################################################
@ -793,7 +799,6 @@ select id from notes where mid = ?) limit 1"""
problems = [] problems = []
# problems that don't require a full sync # problems that don't require a full sync
syncable_problems = [] syncable_problems = []
curs = self.db.cursor()
self.save() self.save()
oldSize = os.stat(self.path)[stat.ST_SIZE] oldSize = os.stat(self.path)[stat.ST_SIZE]
if self.db.scalar("pragma integrity_check") != "ok": if self.db.scalar("pragma integrity_check") != "ok":
@ -942,16 +947,18 @@ select id from cards where odid > 0 and did in %s"""
self.updateFieldCache(self.models.nids(m)) self.updateFieldCache(self.models.nids(m))
# new cards can't have a due position > 32 bits, so wrap items over # new cards can't have a due position > 32 bits, so wrap items over
# 2 million back to 1 million # 2 million back to 1 million
curs.execute( self.db.execute(
""" """
update cards set due=1000000+due%1000000,mod=?,usn=? where due>=1000000 update cards set due=1000000+due%1000000,mod=?,usn=? where due>=1000000
and type=0""", and type=0""",
[intTime(), self.usn()], intTime(),
self.usn(),
) )
if curs.rowcount: rowcount = self.db.scalar("select changes()")
if rowcount:
problems.append( problems.append(
"Found %d new cards with a due number >= 1,000,000 - consider repositioning them in the Browse screen." "Found %d new cards with a due number >= 1,000,000 - consider repositioning them in the Browse screen."
% curs.rowcount % rowcount
) )
# new card position # new card position
self.conf["nextPos"] = ( self.conf["nextPos"] = (
@ -969,18 +976,20 @@ and type=0""",
self.usn(), self.usn(),
) )
# v2 sched had a bug that could create decimal intervals # v2 sched had a bug that could create decimal intervals
curs.execute( self.db.execute(
"update cards set ivl=round(ivl),due=round(due) where ivl!=round(ivl) or due!=round(due)" "update cards set ivl=round(ivl),due=round(due) where ivl!=round(ivl) or due!=round(due)"
) )
if curs.rowcount: rowcount = self.db.scalar("select changes()")
problems.append("Fixed %d cards with v2 scheduler bug." % curs.rowcount) if rowcount:
problems.append("Fixed %d cards with v2 scheduler bug." % rowcount)
curs.execute( self.db.execute(
"update revlog set ivl=round(ivl),lastIvl=round(lastIvl) where ivl!=round(ivl) or lastIvl!=round(lastIvl)" "update revlog set ivl=round(ivl),lastIvl=round(lastIvl) where ivl!=round(ivl) or lastIvl!=round(lastIvl)"
) )
if curs.rowcount: rowcount = self.db.scalar("select changes()")
if rowcount:
problems.append( problems.append(
"Fixed %d review history entries with v2 scheduler bug." % curs.rowcount "Fixed %d review history entries with v2 scheduler bug." % rowcount
) )
# models # models
if self.models.ensureNotEmpty(): if self.models.ensureNotEmpty():
@ -1011,11 +1020,10 @@ and type=0""",
return len(to_fix) return len(to_fix)
def optimize(self) -> None: def optimize(self) -> None:
self.db.setAutocommit(True) self.save(trx=False)
self.db.execute("vacuum") self.db.execute("vacuum")
self.db.execute("analyze") self.db.execute("analyze")
self.db.setAutocommit(False) self.db.begin()
self.lock()
# Logging # Logging
########################################################################## ##########################################################################

View file

@ -1,6 +1,14 @@
# Copyright: Ankitects Pty Ltd and contributors # Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
"""
A convenience wrapper over pysqlite.
Anki's Collection class now uses dbproxy.py instead of this class,
but this class is still used by aqt's profile manager, and a number
of add-ons rely on it.
"""
import os import os
import time import time
from sqlite3 import Cursor from sqlite3 import Cursor

92
pylib/anki/dbproxy.py Normal file
View file

@ -0,0 +1,92 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from __future__ import annotations
from typing import Any, Iterable, List, Optional, Sequence, Union
import anki
# DBValue is actually Union[str, int, float, None], but if defined
# that way, every call site needs to do a type check prior to using
# the return values.
ValueFromDB = Any
Row = Sequence[ValueFromDB]
ValueForDB = Union[str, int, float, None]
class DBProxy:
# Lifecycle
###############
def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None:
self._backend = backend
self._path = path
self.mod = False
# Transactions
###############
def begin(self) -> None:
self._backend.db_begin()
def commit(self) -> None:
self._backend.db_commit()
def rollback(self) -> None:
self._backend.db_rollback()
# Querying
################
def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False
) -> List[Row]:
# mark modified?
s = sql.strip().lower()
for stmt in "insert", "update", "delete":
if s.startswith(stmt):
self.mod = True
assert ":" not in sql
# fetch rows
return self._backend.db_query(sql, args, first_row_only)
# Query shortcuts
###################
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
return self._query(sql, *args)
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args)]
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0]
else:
return None
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0][0]
else:
return None
# execute used to return a pysqlite cursor, but now is synonymous
# with .all()
execute = all
# Updates
################
def executemany(self, sql: str, args: Iterable[Sequence[ValueForDB]]) -> None:
self.mod = True
assert ":" not in sql
if isinstance(args, list):
list_args = args
else:
list_args = list(args)
self._backend.db_execute_many(sql, list_args)

View file

@ -397,20 +397,20 @@ class AnkiCollectionPackageExporter(AnkiPackageExporter):
AnkiPackageExporter.__init__(self, col) AnkiPackageExporter.__init__(self, col)
def doExport(self, z, path): def doExport(self, z, path):
# close our deck & write it into the zip file, and reopen "Export collection. Caller must re-open afterwards."
# close our deck & write it into the zip file
self.count = self.col.cardCount() self.count = self.col.cardCount()
v2 = self.col.schedVer() != 1 v2 = self.col.schedVer() != 1
mdir = self.col.media.dir()
self.col.close() self.col.close()
if not v2: if not v2:
z.write(self.col.path, "collection.anki2") z.write(self.col.path, "collection.anki2")
else: else:
self._addDummyCollection(z) self._addDummyCollection(z)
z.write(self.col.path, "collection.anki21") z.write(self.col.path, "collection.anki21")
self.col.reopen()
# copy all media # copy all media
if not self.includeMedia: if not self.includeMedia:
return {} return {}
mdir = self.col.media.dir()
return self._exportMedia(z, os.listdir(mdir), mdir) return self._exportMedia(z, os.listdir(mdir), mdir)

View file

@ -4,500 +4,25 @@
from __future__ import annotations from __future__ import annotations
import re import re
import sre_constants from typing import TYPE_CHECKING, Optional, Set
import unicodedata
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union, cast
from anki import hooks
from anki.consts import *
from anki.hooks import * from anki.hooks import *
from anki.utils import ( from anki.utils import ids2str, intTime, joinFields, splitFields, stripHTMLMedia
fieldChecksum,
ids2str,
intTime,
joinFields,
splitFields,
stripHTMLMedia,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from anki.collection import _Collection from anki.collection import _Collection
# Find
##########################################################################
class Finder: class Finder:
def __init__(self, col: Optional[_Collection]) -> None: def __init__(self, col: Optional[_Collection]) -> None:
self.col = col.weakref() self.col = col.weakref()
self.search = dict( print("Finder() is deprecated, please use col.find_cards() or .find_notes()")
added=self._findAdded,
card=self._findTemplate,
deck=self._findDeck,
mid=self._findMid,
nid=self._findNids,
cid=self._findCids,
note=self._findModel,
prop=self._findProp,
rated=self._findRated,
tag=self._findTag,
dupe=self._findDupes,
flag=self._findFlag,
)
self.search["is"] = self._findCardState
hooks.search_terms_prepared(self.search)
def findCards(self, query: str, order: Union[bool, str] = False) -> List[Any]: def findCards(self, query, order):
"Return a list of card ids for QUERY." return self.col.find_cards(query, order)
tokens = self._tokenize(query)
preds, args = self._where(tokens)
if preds is None:
raise Exception("invalidSearch")
order, rev = self._order(order)
sql = self._query(preds, order)
try:
res = self.col.db.list(sql, *args)
except:
# invalid grouping
return []
if rev:
res.reverse()
return res
def findNotes(self, query: str) -> List[Any]: def findNotes(self, query):
tokens = self._tokenize(query) return self.col.find_notes(query)
preds, args = self._where(tokens)
if preds is None:
return []
if preds:
preds = "(" + preds + ")"
else:
preds = "1"
sql = (
"""
select distinct(n.id) from cards c, notes n where c.nid=n.id and """
+ preds
)
try:
res = self.col.db.list(sql, *args)
except:
# invalid grouping
return []
return res
# Tokenizing
######################################################################
def _tokenize(self, query: str) -> List[str]:
inQuote: Union[bool, str] = False
tokens = []
token = ""
for c in query:
# quoted text
if c in ("'", '"'):
if inQuote:
if c == inQuote:
inQuote = False
else:
token += c
elif token:
# quotes are allowed to start directly after a :
if token[-1] == ":":
inQuote = c
else:
token += c
else:
inQuote = c
# separator (space and ideographic space)
elif c in (" ", "\u3000"):
if inQuote:
token += c
elif token:
# space marks token finished
tokens.append(token)
token = ""
# nesting
elif c in ("(", ")"):
if inQuote:
token += c
else:
if c == ")" and token:
tokens.append(token)
token = ""
tokens.append(c)
# negation
elif c == "-":
if token:
token += c
elif not tokens or tokens[-1] != "-":
tokens.append("-")
# normal character
else:
token += c
# if we finished in a token, add it
if token:
tokens.append(token)
return tokens
# Query building
######################################################################
def _where(self, tokens: List[str]) -> Tuple[str, Optional[List[str]]]:
# state and query
s: Dict[str, Any] = dict(isnot=False, isor=False, join=False, q="", bad=False)
args: List[Any] = []
def add(txt, wrap=True):
# failed command?
if not txt:
# if it was to be negated then we can just ignore it
if s["isnot"]:
s["isnot"] = False
return None, None
else:
s["bad"] = True
return None, None
elif txt == "skip":
return None, None
# do we need a conjunction?
if s["join"]:
if s["isor"]:
s["q"] += " or "
s["isor"] = False
else:
s["q"] += " and "
if s["isnot"]:
s["q"] += " not "
s["isnot"] = False
if wrap:
txt = "(" + txt + ")"
s["q"] += txt
s["join"] = True
for token in tokens:
if s["bad"]:
return None, None
# special tokens
if token == "-":
s["isnot"] = True
elif token.lower() == "or":
s["isor"] = True
elif token == "(":
add(token, wrap=False)
s["join"] = False
elif token == ")":
s["q"] += ")"
# commands
elif ":" in token:
cmd, val = token.split(":", 1)
cmd = cmd.lower()
if cmd in self.search:
add(self.search[cmd]((val, args)))
else:
add(self._findField(cmd, val))
# normal text search
else:
add(self._findText(token, args))
if s["bad"]:
return None, None
return s["q"], args
def _query(self, preds: str, order: str) -> str:
# can we skip the note table?
if "n." not in preds and "n." not in order:
sql = "select c.id from cards c where "
else:
sql = "select c.id from cards c, notes n where c.nid=n.id and "
# combine with preds
if preds:
sql += "(" + preds + ")"
else:
sql += "1"
# order
if order:
sql += " " + order
return sql
# Ordering
######################################################################
def _order(self, order: Union[bool, str]) -> Tuple[str, bool]:
if not order:
return "", False
elif order is not True:
# custom order string provided
return " order by " + cast(str, order), False
# use deck default
type = self.col.conf["sortType"]
sort = None
if type.startswith("note"):
if type == "noteCrt":
sort = "n.id, c.ord"
elif type == "noteMod":
sort = "n.mod, c.ord"
elif type == "noteFld":
sort = "n.sfld collate nocase, c.ord"
elif type.startswith("card"):
if type == "cardMod":
sort = "c.mod"
elif type == "cardReps":
sort = "c.reps"
elif type == "cardDue":
sort = "c.type, c.due"
elif type == "cardEase":
sort = f"c.type == {CARD_TYPE_NEW}, c.factor"
elif type == "cardLapses":
sort = "c.lapses"
elif type == "cardIvl":
sort = "c.ivl"
if not sort:
# deck has invalid sort order; revert to noteCrt
sort = "n.id, c.ord"
return " order by " + sort, self.col.conf["sortBackwards"]
# Commands
######################################################################
def _findTag(self, args: Tuple[str, List[Any]]) -> str:
(val, list_args) = args
if val == "none":
return 'n.tags = ""'
val = val.replace("*", "%")
if not val.startswith("%"):
val = "% " + val
if not val.endswith("%") or val.endswith("\\%"):
val += " %"
list_args.append(val)
return "n.tags like ? escape '\\'"
def _findCardState(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
if val in ("review", "new", "learn"):
if val == "review":
n = 2
elif val == "new":
n = CARD_TYPE_NEW
else:
return f"queue in ({QUEUE_TYPE_LRN}, {QUEUE_TYPE_DAY_LEARN_RELEARN})"
return "type = %d" % n
elif val == "suspended":
return "c.queue = -1"
elif val == "buried":
return f"c.queue in ({QUEUE_TYPE_SIBLING_BURIED}, {QUEUE_TYPE_MANUALLY_BURIED})"
elif val == "due":
return f"""
(c.queue in ({QUEUE_TYPE_REV},{QUEUE_TYPE_DAY_LEARN_RELEARN}) and c.due <= %d) or
(c.queue = {QUEUE_TYPE_LRN} and c.due <= %d)""" % (
self.col.sched.today,
self.col.sched.dayCutoff,
)
else:
# unknown
return None
def _findFlag(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
if not val or len(val) != 1 or val not in "01234":
return None
mask = 2 ** 3 - 1
return "(c.flags & %d) == %d" % (mask, int(val))
def _findRated(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# days(:optional_ease)
(val, __) = args
r = val.split(":")
try:
days = int(r[0])
except ValueError:
return None
days = min(days, 31)
# ease
ease = ""
if len(r) > 1:
if r[1] not in ("1", "2", "3", "4"):
return None
ease = "and ease=%s" % r[1]
cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000
return "c.id in (select cid from revlog where id>%d %s)" % (cutoff, ease)
def _findAdded(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
try:
days = int(val)
except ValueError:
return None
cutoff = (self.col.sched.dayCutoff - 86400 * days) * 1000
return "c.id > %d" % cutoff
def _findProp(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# extract
(strval, __) = args
m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", strval)
if not m:
return None
prop, cmp, strval = m.groups()
prop = prop.lower() # pytype: disable=attribute-error
# is val valid?
try:
if prop == "ease":
val = float(strval)
else:
val = int(strval)
except ValueError:
return None
# is prop valid?
if prop not in ("due", "ivl", "reps", "lapses", "ease"):
return None
# query
q = []
if prop == "due":
val += self.col.sched.today
# only valid for review/daily learning
q.append(f"(c.queue in ({QUEUE_TYPE_REV},{QUEUE_TYPE_DAY_LEARN_RELEARN}))")
elif prop == "ease":
prop = "factor"
val = int(val * 1000)
q.append("(%s %s %s)" % (prop, cmp, val))
return " and ".join(q)
def _findText(self, val: str, args: List[str]) -> str:
val = val.replace("*", "%")
args.append("%" + val + "%")
args.append("%" + val + "%")
return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')"
def _findNids(self, args: Tuple[str, List[Any]]) -> Optional[str]:
(val, __) = args
if re.search("[^0-9,]", val):
return None
return "n.id in (%s)" % val
def _findCids(self, args) -> Optional[str]:
(val, __) = args
if re.search("[^0-9,]", val):
return None
return "c.id in (%s)" % val
def _findMid(self, args) -> Optional[str]:
(val, __) = args
if re.search("[^0-9]", val):
return None
return "n.mid = %s" % val
def _findModel(self, args: Tuple[str, List[Any]]) -> str:
(val, __) = args
ids = []
val = val.lower()
for m in self.col.models.all():
if unicodedata.normalize("NFC", m["name"].lower()) == val:
ids.append(m["id"])
return "n.mid in %s" % ids2str(ids)
def _findDeck(self, args: Tuple[str, List[Any]]) -> Optional[str]:
# if searching for all decks, skip
(val, __) = args
if val == "*":
return "skip"
# deck types
elif val == "filtered":
return "c.odid"
def dids(did):
if not did:
return None
return [did] + [a[1] for a in self.col.decks.children(did)]
# current deck?
ids = None
if val.lower() == "current":
ids = dids(self.col.decks.current()["id"])
elif "*" not in val:
# single deck
ids = dids(self.col.decks.id(val, create=False))
else:
# wildcard
ids = set()
val = re.escape(val).replace(r"\*", ".*")
for d in self.col.decks.all():
if re.match("(?i)" + val, unicodedata.normalize("NFC", d["name"])):
ids.update(dids(d["id"]))
if not ids:
return None
sids = ids2str(ids)
return "c.did in %s or c.odid in %s" % (sids, sids)
def _findTemplate(self, args: Tuple[str, List[Any]]) -> str:
# were we given an ordinal number?
(val, __) = args
try:
num = int(val) - 1
except:
num = None
if num is not None:
return "c.ord = %d" % num
# search for template names
lims = []
for m in self.col.models.all():
for t in m["tmpls"]:
if unicodedata.normalize("NFC", t["name"].lower()) == val.lower():
if m["type"] == MODEL_CLOZE:
# if the user has asked for a cloze card, we want
# to give all ordinals, so we just limit to the
# model instead
lims.append("(n.mid = %s)" % m["id"])
else:
lims.append("(n.mid = %s and c.ord = %s)" % (m["id"], t["ord"]))
return " or ".join(lims)
def _findField(self, field: str, val: str) -> Optional[str]:
field = field.lower()
val = val.replace("*", "%")
# find models that have that field
mods = {}
for m in self.col.models.all():
for f in m["flds"]:
if unicodedata.normalize("NFC", f["name"].lower()) == field:
mods[str(m["id"])] = (m, f["ord"])
if not mods:
# nothing has that field
return None
# gather nids
regex = re.escape(val).replace("_", ".").replace(re.escape("%"), ".*")
nids = []
for (id, mid, flds) in self.col.db.execute(
"""
select id, mid, flds from notes
where mid in %s and flds like ? escape '\\'"""
% (ids2str(list(mods.keys()))),
"%" + val + "%",
):
flds = splitFields(flds)
ord = mods[str(mid)][1]
strg = flds[ord]
try:
if re.search("(?si)^" + regex + "$", strg):
nids.append(id)
except sre_constants.error:
return None
if not nids:
return "0"
return "n.id in %s" % ids2str(nids)
def _findDupes(self, args) -> Optional[str]:
# caller must call stripHTMLMedia on passed val
(val, __) = args
try:
mid, val = val.split(",", 1)
except OSError:
return None
csum = fieldChecksum(val)
nids = []
for nid, flds in self.col.db.execute(
"select id, flds from notes where mid=? and csum=?", mid, csum
):
if stripHTMLMedia(splitFields(flds)[0]) == val:
nids.append(nid)
return "n.id in %s" % ids2str(nids)
# Find and replace # Find and replace
@ -555,11 +80,11 @@ def findReplace(
flds = joinFields(sflds) flds = joinFields(sflds)
if flds != origFlds: if flds != origFlds:
nids.append(nid) nids.append(nid)
d.append(dict(nid=nid, flds=flds, u=col.usn(), m=intTime())) d.append((flds, intTime(), col.usn(), nid))
if not d: if not d:
return 0 return 0
# replace # replace
col.db.executemany("update notes set flds=:flds,mod=:m,usn=:u where id=:nid", d) col.db.executemany("update notes set flds=?,mod=?,usn=? where id=?", d)
col.updateFieldCache(nids) col.updateFieldCache(nids)
col.genCards(nids) col.genCards(nids)
return len(d) return len(d)
@ -595,7 +120,7 @@ def findDupes(
# limit search to notes with applicable field name # limit search to notes with applicable field name
if search: if search:
search = "(" + search + ") " search = "(" + search + ") "
search += "'%s:*'" % fieldName search += '"%s:*"' % fieldName.replace('"', '"')
# go through notes # go through notes
vals: Dict[str, List[int]] = {} vals: Dict[str, List[int]] = {}
dupes = [] dupes = []

View file

@ -492,32 +492,6 @@ class _SchemaWillChangeFilter:
schema_will_change = _SchemaWillChangeFilter() schema_will_change = _SchemaWillChangeFilter()
class _SearchTermsPreparedHook:
_hooks: List[Callable[[Dict[str, Callable]], None]] = []
def append(self, cb: Callable[[Dict[str, Callable]], None]) -> None:
"""(searches: Dict[str, Callable])"""
self._hooks.append(cb)
def remove(self, cb: Callable[[Dict[str, Callable]], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, searches: Dict[str, Callable]) -> None:
for hook in self._hooks:
try:
hook(searches)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
# legacy support
runHook("search", searches)
search_terms_prepared = _SearchTermsPreparedHook()
class _SyncProgressDidChangeHook: class _SyncProgressDidChangeHook:
_hooks: List[Callable[[str], None]] = [] _hooks: List[Callable[[str], None]] = []

View file

@ -65,10 +65,7 @@ class Anki2Importer(Importer):
self._importCards() self._importCards()
self._importStaticMedia() self._importStaticMedia()
self._postImport() self._postImport()
self.dst.db.setAutocommit(True) self.dst.optimize()
self.dst.db.execute("vacuum")
self.dst.db.execute("analyze")
self.dst.db.setAutocommit(False)
# Notes # Notes
###################################################################### ######################################################################

View file

@ -287,7 +287,7 @@ content in the text file to the correct fields."""
return [intTime(), self.col.usn(), n.fieldsStr, id, n.fieldsStr] return [intTime(), self.col.usn(), n.fieldsStr, id, n.fieldsStr]
def addUpdates(self, rows: List[List[Union[int, str]]]) -> None: def addUpdates(self, rows: List[List[Union[int, str]]]) -> None:
old = self.col.db.totalChanges() changes = self.col.db.scalar("select total_changes()")
if self._tagsMapped: if self._tagsMapped:
self.col.db.executemany( self.col.db.executemany(
""" """
@ -309,7 +309,8 @@ update notes set mod = ?, usn = ?, flds = ?
where id = ? and flds != ?""", where id = ? and flds != ?""",
rows, rows,
) )
self.updateCount = self.col.db.totalChanges() - old changes2 = self.col.db.scalar("select total_changes()")
self.updateCount = changes2 - changes
def processFields( def processFields(
self, note: ForeignNote, fields: Optional[List[str]] = None self, note: ForeignNote, fields: Optional[List[str]] = None

View file

@ -145,7 +145,7 @@ current_catalog: Optional[
] = None ] = None
# the current Fluent translation instance # the current Fluent translation instance
current_i18n: Optional[anki.rsbackend.I18nBackend] current_i18n: Optional[anki.rsbackend.RustBackend]
# path to locale folder # path to locale folder
locale_folder = "" locale_folder = ""
@ -175,9 +175,9 @@ def set_lang(lang: str, locale_dir: str) -> None:
current_catalog = gettext.translation( current_catalog = gettext.translation(
"anki", gettext_dir, languages=[lang], fallback=True "anki", gettext_dir, languages=[lang], fallback=True
) )
current_i18n = anki.rsbackend.I18nBackend(
preferred_langs=[lang], ftl_folder=ftl_dir current_i18n = anki.rsbackend.RustBackend(ftl_folder=ftl_dir, langs=[lang])
)
locale_folder = locale_dir locale_folder = locale_dir

View file

@ -171,8 +171,11 @@ class MediaManager:
########################################################################## ##########################################################################
def check(self) -> MediaCheckOutput: def check(self) -> MediaCheckOutput:
"This should be called while the collection is closed." output = self.col.backend.check_media()
return self.col.backend.check_media() # files may have been renamed on disk, so an undo at this point could
# break file references
self.col.save()
return output
def render_all_latex( def render_all_latex(
self, progress_cb: Optional[Callable[[int], bool]] = None self, progress_cb: Optional[Callable[[int], bool]] = None

View file

@ -504,17 +504,9 @@ select id from notes where mid = ?)"""
for c in range(nfields): for c in range(nfields):
flds.append(newflds.get(c, "")) flds.append(newflds.get(c, ""))
flds = joinFields(flds) flds = joinFields(flds)
d.append( d.append((flds, newModel["id"], intTime(), self.col.usn(), nid,))
dict(
nid=nid,
flds=flds,
mid=newModel["id"],
m=intTime(),
u=self.col.usn(),
)
)
self.col.db.executemany( self.col.db.executemany(
"update notes set flds=:flds,mid=:mid,mod=:m,usn=:u where id = :nid", d "update notes set flds=?,mid=?,mod=?,usn=? where id = ?", d
) )
self.col.updateFieldCache(nids) self.col.updateFieldCache(nids)
@ -543,12 +535,10 @@ select id from notes where mid = ?)"""
# mapping from a regular note, so the map should be valid # mapping from a regular note, so the map should be valid
new = map[ord] new = map[ord]
if new is not None: if new is not None:
d.append(dict(cid=cid, new=new, u=self.col.usn(), m=intTime())) d.append((new, self.col.usn(), intTime(), cid))
else: else:
deleted.append(cid) deleted.append(cid)
self.col.db.executemany( self.col.db.executemany("update cards set ord=?,usn=?,mod=? where id=?", d)
"update cards set ord=:new,usn=:u,mod=:m where id=:cid", d
)
self.col.remCards(deleted) self.col.remCards(deleted)
# Schema hash # Schema hash

View file

@ -48,19 +48,19 @@ class PythonBackend:
native = self.col.sched.deckDueTree() native = self.col.sched.deckDueTree()
return native_deck_tree_to_proto(native) return native_deck_tree_to_proto(native)
def find_cards(self, input: pb.FindCardsIn) -> pb.FindCardsOut: # def find_cards(self, input: pb.FindCardsIn) -> pb.FindCardsOut:
cids = self.col.findCards(input.search) # cids = self.col.findCards(input.search)
return pb.FindCardsOut(card_ids=cids) # return pb.FindCardsOut(card_ids=cids)
#
def browser_rows(self, input: pb.BrowserRowsIn) -> pb.BrowserRowsOut: # def browser_rows(self, input: pb.BrowserRowsIn) -> pb.BrowserRowsOut:
sort_fields = [] # sort_fields = []
for cid in input.card_ids: # for cid in input.card_ids:
sort_fields.append( # sort_fields.append(
self.col.db.scalar( # self.col.db.scalar(
"select sfld from notes n,cards c where n.id=c.nid and c.id=?", cid # "select sfld from notes n,cards c where n.id=c.nid and c.id=?", cid
) # )
) # )
return pb.BrowserRowsOut(sort_fields=sort_fields) # return pb.BrowserRowsOut(sort_fields=sort_fields)
def native_deck_tree_to_proto(native): def native_deck_tree_to_proto(native):

View file

@ -5,21 +5,50 @@
import enum import enum
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, NewType, NoReturn, Optional, Tuple, Union from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NewType,
NoReturn,
Optional,
Sequence,
Tuple,
Union,
)
import ankirspy # pytype: disable=import-error import ankirspy # pytype: disable=import-error
import anki.backend_pb2 as pb import anki.backend_pb2 as pb
import anki.buildinfo import anki.buildinfo
from anki import hooks from anki import hooks
from anki.dbproxy import Row as DBRow
from anki.dbproxy import ValueForDB
from anki.fluent_pb2 import FluentString as TR from anki.fluent_pb2 import FluentString as TR
from anki.models import AllTemplateReqs from anki.models import AllTemplateReqs
from anki.sound import AVTag, SoundOrVideoTag, TTSTag from anki.sound import AVTag, SoundOrVideoTag, TTSTag
from anki.types import assert_impossible_literal from anki.types import assert_impossible_literal
from anki.utils import intTime
assert ankirspy.buildhash() == anki.buildinfo.buildhash assert ankirspy.buildhash() == anki.buildinfo.buildhash
SchedTimingToday = pb.SchedTimingTodayOut SchedTimingToday = pb.SchedTimingTodayOut
BuiltinSortKind = pb.BuiltinSortKind
try:
import orjson
except:
# add compat layer for 32 bit builds that can't use orjson
print("reverting to stock json")
import json
class orjson: # type: ignore
def dumps(obj: Any) -> bytes:
return json.dumps(obj).encode("utf8")
loads = json.loads
class Interrupted(Exception): class Interrupted(Exception):
@ -186,16 +215,19 @@ def _on_progress(progress_bytes: bytes) -> bool:
class RustBackend: class RustBackend:
def __init__( def __init__(
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str self,
ftl_folder: Optional[str] = None,
langs: Optional[List[str]] = None,
server: bool = False,
) -> None: ) -> None:
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent") # pick up global defaults if not provided
if ftl_folder is None:
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent")
if langs is None:
langs = [anki.lang.currentLang]
init_msg = pb.BackendInit( init_msg = pb.BackendInit(
collection_path=col_path, locale_folder_path=ftl_folder, preferred_langs=langs, server=server,
media_folder_path=media_folder_path,
media_db_path=media_db_path,
locale_folder_path=ftl_folder,
preferred_langs=[anki.lang.currentLang],
log_path=log_path,
) )
self._backend = ankirspy.open_backend(init_msg.SerializeToString()) self._backend = ankirspy.open_backend(init_msg.SerializeToString())
self._backend.set_progress_callback(_on_progress) self._backend.set_progress_callback(_on_progress)
@ -213,6 +245,26 @@ class RustBackend:
else: else:
return output return output
def open_collection(
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str
):
self._run_command(
pb.BackendInput(
open_collection=pb.OpenCollectionIn(
collection_path=col_path,
media_folder_path=media_folder_path,
media_db_path=media_db_path,
log_path=log_path,
)
),
release_gil=True,
)
def close_collection(self):
self._run_command(
pb.BackendInput(close_collection=pb.Empty()), release_gil=True
)
def template_requirements( def template_requirements(
self, template_fronts: List[str], field_map: Dict[str, int] self, template_fronts: List[str], field_map: Dict[str, int]
) -> AllTemplateReqs: ) -> AllTemplateReqs:
@ -228,19 +280,33 @@ class RustBackend:
def sched_timing_today( def sched_timing_today(
self, self,
created_secs: int, created_secs: int,
created_mins_west: int, created_mins_west: Optional[int],
now_secs: int, now_mins_west: Optional[int],
now_mins_west: int, rollover: Optional[int],
rollover: int,
) -> SchedTimingToday: ) -> SchedTimingToday:
if created_mins_west is not None:
crt_west = pb.OptionalInt32(val=created_mins_west)
else:
crt_west = None
if now_mins_west is not None:
now_west = pb.OptionalInt32(val=now_mins_west)
else:
now_west = None
if rollover is not None:
roll = pb.OptionalInt32(val=rollover)
else:
roll = None
return self._run_command( return self._run_command(
pb.BackendInput( pb.BackendInput(
sched_timing_today=pb.SchedTimingTodayIn( sched_timing_today=pb.SchedTimingTodayIn(
created_secs=created_secs, created_secs=created_secs,
created_mins_west=created_mins_west, now_secs=intTime(),
now_secs=now_secs, created_mins_west=crt_west,
now_mins_west=now_mins_west, now_mins_west=now_west,
rollover_hour=rollover, rollover_hour=roll,
) )
) )
).sched_timing_today ).sched_timing_today
@ -366,6 +432,54 @@ class RustBackend:
def restore_trash(self): def restore_trash(self):
self._run_command(pb.BackendInput(restore_trash=pb.Empty())) self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
def db_query(
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool
) -> List[DBRow]:
return self._db_command(
dict(kind="query", sql=sql, args=args, first_row_only=first_row_only)
)
def db_execute_many(self, sql: str, args: List[List[ValueForDB]]) -> List[DBRow]:
return self._db_command(dict(kind="executemany", sql=sql, args=args))
def db_begin(self) -> None:
return self._db_command(dict(kind="begin"))
def db_commit(self) -> None:
return self._db_command(dict(kind="commit"))
def db_rollback(self) -> None:
return self._db_command(dict(kind="rollback"))
def _db_command(self, input: Dict[str, Any]) -> Any:
return orjson.loads(self._backend.db_command(orjson.dumps(input)))
def search_cards(
self, search: str, order: Union[bool, str, int], reverse: bool = False
) -> Sequence[int]:
if isinstance(order, str):
mode = pb.SortOrder(custom=order)
elif order is True:
mode = pb.SortOrder(from_config=pb.Empty())
elif order is False:
mode = pb.SortOrder(none=pb.Empty())
else:
# sadly we can't use the protobuf type in a Union, so we
# have to accept an int and convert it
kind = BuiltinSortKind.Value(BuiltinSortKind.Name(order))
mode = pb.SortOrder(
builtin=pb.BuiltinSearchOrder(kind=kind, reverse=reverse)
)
return self._run_command(
pb.BackendInput(search_cards=pb.SearchCardsIn(search=search, order=mode))
).search_cards.card_ids
def search_notes(self, search: str) -> Sequence[int]:
return self._run_command(
pb.BackendInput(search_notes=pb.SearchNotesIn(search=search))
).search_notes.note_ids
def translate_string_in( def translate_string_in(
key: TR, **kwargs: Union[str, int, float] key: TR, **kwargs: Union[str, int, float]
@ -379,19 +493,6 @@ def translate_string_in(
return pb.TranslateStringIn(key=key, args=args) return pb.TranslateStringIn(key=key, args=args)
class I18nBackend:
def __init__(self, preferred_langs: List[str], ftl_folder: str) -> None:
init_msg = pb.I18nBackendInit(
locale_folder_path=ftl_folder, preferred_langs=preferred_langs
)
self._backend = ankirspy.open_i18n(init_msg.SerializeToString())
def translate(self, key: TR, **kwargs: Union[str, int, float]) -> str:
return self._backend.translate(
translate_string_in(key, **kwargs).SerializeToString()
)
# temporarily force logging of media handling # temporarily force logging of media handling
if "RUST_LOG" not in os.environ: if "RUST_LOG" not in os.environ:
os.environ["RUST_LOG"] = "warn,anki::media=debug" os.environ["RUST_LOG"] = "warn,anki::media=debug"

View file

@ -8,7 +8,7 @@ import random
import time import time
from heapq import * from heapq import *
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import anki import anki
from anki import hooks from anki import hooks
@ -80,7 +80,7 @@ class Scheduler(V2):
self._updateStats(card, "time", card.timeTaken()) self._updateStats(card, "time", card.timeTaken())
card.mod = intTime() card.mod = intTime()
card.usn = self.col.usn() card.usn = self.col.usn()
card.flushSched() card.flush()
def counts(self, card: Optional[Card] = None) -> Tuple[int, int, int]: def counts(self, card: Optional[Card] = None) -> Tuple[int, int, int]:
counts = [self.newCount, self.lrnCount, self.revCount] counts = [self.newCount, self.lrnCount, self.revCount]
@ -286,11 +286,13 @@ and due <= ? limit %d"""
self._lrnQueue = self.col.db.all( self._lrnQueue = self.col.db.all(
f""" f"""
select due, id from cards where select due, id from cards where
did in %s and queue = {QUEUE_TYPE_LRN} and due < :lim did in %s and queue = {QUEUE_TYPE_LRN} and due < ?
limit %d""" limit %d"""
% (self._deckLimit(), self.reportLimit), % (self._deckLimit(), self.reportLimit),
lim=self.dayCutoff, self.dayCutoff,
) )
for i in range(len(self._lrnQueue)):
self._lrnQueue[i] = (self._lrnQueue[i][0], self._lrnQueue[i][1])
# as it arrives sorted by did first, we need to sort it # as it arrives sorted by did first, we need to sort it
self._lrnQueue.sort() self._lrnQueue.sort()
return self._lrnQueue return self._lrnQueue
@ -707,7 +709,7 @@ did = ? and queue = {QUEUE_TYPE_REV} and due <= ? limit ?""",
# Dynamic deck handling # Dynamic deck handling
########################################################################## ##########################################################################
def rebuildDyn(self, did: Optional[int] = None) -> Optional[List[int]]: # type: ignore[override] def rebuildDyn(self, did: Optional[int] = None) -> Optional[Sequence[int]]: # type: ignore[override]
"Rebuild a dynamic deck." "Rebuild a dynamic deck."
did = did or self.col.decks.selected() did = did or self.col.decks.selected()
deck = self.col.decks.get(did) deck = self.col.decks.get(did)
@ -721,7 +723,7 @@ did = ? and queue = {QUEUE_TYPE_REV} and due <= ? limit ?""",
self.col.decks.select(did) self.col.decks.select(did)
return ids return ids
def _fillDyn(self, deck: Dict[str, Any]) -> List[int]: # type: ignore[override] def _fillDyn(self, deck: Dict[str, Any]) -> Sequence[int]: # type: ignore[override]
search, limit, order = deck["terms"][0] search, limit, order = deck["terms"][0]
orderlimit = self._dynOrder(order, limit) orderlimit = self._dynOrder(order, limit)
if search.strip(): if search.strip():
@ -751,7 +753,7 @@ due = odue, odue = 0, odid = 0, usn = ? where %s"""
self.col.usn(), self.col.usn(),
) )
def _moveToDyn(self, did: int, ids: List[int]) -> None: # type: ignore[override] def _moveToDyn(self, did: int, ids: Sequence[int]) -> None: # type: ignore[override]
deck = self.col.decks.get(did) deck = self.col.decks.get(did)
data = [] data = []
t = intTime() t = intTime()
@ -867,10 +869,9 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?"""
def _updateCutoff(self) -> None: def _updateCutoff(self) -> None:
oldToday = self.today oldToday = self.today
# days since col created timing = self._timing_today()
self.today = int((time.time() - self.col.crt) // 86400) self.today = timing.days_elapsed
# end of day cutoff self.dayCutoff = timing.next_day_at
self.dayCutoff = self.col.crt + (self.today + 1) * 86400
if oldToday != self.today: if oldToday != self.today:
self.col.log(self.today, self.dayCutoff) self.col.log(self.today, self.dayCutoff)
# update all daily counts, but don't save decks to prevent needless # update all daily counts, but don't save decks to prevent needless

View file

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import datetime
import itertools import itertools
import random import random
import time import time
@ -11,7 +10,7 @@ from heapq import *
from operator import itemgetter from operator import itemgetter
# from anki.collection import _Collection # from anki.collection import _Collection
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
import anki # pylint: disable=unused-import import anki # pylint: disable=unused-import
from anki import hooks from anki import hooks
@ -82,7 +81,7 @@ class Scheduler:
self._updateStats(card, "time", card.timeTaken()) self._updateStats(card, "time", card.timeTaken())
card.mod = intTime() card.mod = intTime()
card.usn = self.col.usn() card.usn = self.col.usn()
card.flushSched() card.flush()
def _answerCard(self, card: Card, ease: int) -> None: def _answerCard(self, card: Card, ease: int) -> None:
if self._previewingCard(card): if self._previewingCard(card):
@ -138,8 +137,8 @@ class Scheduler:
def dueForecast(self, days: int = 7) -> List[Any]: def dueForecast(self, days: int = 7) -> List[Any]:
"Return counts over next DAYS. Includes today." "Return counts over next DAYS. Includes today."
daysd = dict( daysd: Dict[int, int] = dict(
self.col.db.all( self.col.db.all( # type: ignore
f""" f"""
select due, count() from cards select due, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV} where did in %s and queue = {QUEUE_TYPE_REV}
@ -542,14 +541,16 @@ select count() from cards where did in %s and queue = {QUEUE_TYPE_PREVIEW}
if self._lrnQueue: if self._lrnQueue:
return True return True
cutoff = intTime() + self.col.conf["collapseTime"] cutoff = intTime() + self.col.conf["collapseTime"]
self._lrnQueue = self.col.db.all( self._lrnQueue = self.col.db.all( # type: ignore
f""" f"""
select due, id from cards where select due, id from cards where
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < :lim did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?
limit %d""" limit %d"""
% (self._deckLimit(), self.reportLimit), % (self._deckLimit(), self.reportLimit),
lim=cutoff, cutoff,
) )
for i in range(len(self._lrnQueue)):
self._lrnQueue[i] = (self._lrnQueue[i][0], self._lrnQueue[i][1])
# as it arrives sorted by did first, we need to sort it # as it arrives sorted by did first, we need to sort it
self._lrnQueue.sort() self._lrnQueue.sort()
return self._lrnQueue return self._lrnQueue
@ -1215,7 +1216,7 @@ due = (case when odue>0 then odue else due end), odue = 0, odid = 0, usn = ? whe
t = "c.due, c.ord" t = "c.due, c.ord"
return t + " limit %d" % l return t + " limit %d" % l
def _moveToDyn(self, did: int, ids: List[int], start: int = -100000) -> None: def _moveToDyn(self, did: int, ids: Sequence[int], start: int = -100000) -> None:
deck = self.col.decks.get(did) deck = self.col.decks.get(did)
data = [] data = []
u = self.col.usn() u = self.col.usn()
@ -1353,13 +1354,8 @@ where id = ?
def _updateCutoff(self) -> None: def _updateCutoff(self) -> None:
oldToday = self.today oldToday = self.today
timing = self._timing_today() timing = self._timing_today()
self.today = timing.days_elapsed
if self._new_timezone_enabled(): self.dayCutoff = timing.next_day_at
self.today = timing.days_elapsed
self.dayCutoff = timing.next_day_at
else:
self.today = self._daysSinceCreation()
self.dayCutoff = self._dayCutoff()
if oldToday != self.today: if oldToday != self.today:
self.col.log(self.today, self.dayCutoff) self.col.log(self.today, self.dayCutoff)
@ -1385,51 +1381,39 @@ where id = ?
if time.time() > self.dayCutoff: if time.time() > self.dayCutoff:
self.reset() self.reset()
def _dayCutoff(self) -> int:
rolloverTime = self.col.conf.get("rollover", 4)
if rolloverTime < 0:
rolloverTime = 24 + rolloverTime
date = datetime.datetime.today()
date = date.replace(hour=rolloverTime, minute=0, second=0, microsecond=0)
if date < datetime.datetime.today():
date = date + datetime.timedelta(days=1)
stamp = int(time.mktime(date.timetuple()))
return stamp
def _daysSinceCreation(self) -> int:
startDate = datetime.datetime.fromtimestamp(self.col.crt)
startDate = startDate.replace(
hour=self._rolloverHour(), minute=0, second=0, microsecond=0
)
return int((time.time() - time.mktime(startDate.timetuple())) // 86400)
def _rolloverHour(self) -> int: def _rolloverHour(self) -> int:
return self.col.conf.get("rollover", 4) return self.col.conf.get("rollover", 4)
# New timezone handling
##########################################################################
def _new_timezone_enabled(self) -> bool:
return self.col.conf.get("creationOffset") is not None
def _timing_today(self) -> SchedTimingToday: def _timing_today(self) -> SchedTimingToday:
roll: Optional[int] = None
if self.col.schedVer() > 1:
roll = self._rolloverHour()
return self.col.backend.sched_timing_today( return self.col.backend.sched_timing_today(
self.col.crt, self.col.crt,
self._creation_timezone_offset(), self._creation_timezone_offset(),
intTime(),
self._current_timezone_offset(), self._current_timezone_offset(),
self._rolloverHour(), roll,
) )
def _current_timezone_offset(self) -> int: def _current_timezone_offset(self) -> Optional[int]:
if self.col.server: if self.col.server:
mins = self.col.server.minutes_west
if mins is not None:
return mins
# older Anki versions stored the local offset in
# the config
return self.col.conf.get("localOffset", 0) return self.col.conf.get("localOffset", 0)
else: else:
return self.col.backend.local_minutes_west(intTime()) return None
def _creation_timezone_offset(self) -> int: def _creation_timezone_offset(self) -> Optional[int]:
return self.col.conf.get("creationOffset", 0) return self.col.conf.get("creationOffset", None)
# New timezone handling - GUI helpers
##########################################################################
def new_timezone_enabled(self) -> bool:
return self.col.conf.get("creationOffset") is not None
def set_creation_offset(self): def set_creation_offset(self):
"""Save the UTC west offset at the time of creation into the DB. """Save the UTC west offset at the time of creation into the DB.
@ -1775,21 +1759,12 @@ and (queue={QUEUE_TYPE_NEW} or (queue={QUEUE_TYPE_REV} and due<=?))""",
mod = intTime() mod = intTime()
for id in ids: for id in ids:
r = random.randint(imin, imax) r = random.randint(imin, imax)
d.append( d.append((max(1, r), r + t, self.col.usn(), mod, STARTING_FACTOR, id,))
dict(
id=id,
due=r + t,
ivl=max(1, r),
mod=mod,
usn=self.col.usn(),
fact=STARTING_FACTOR,
)
)
self.remFromDyn(ids) self.remFromDyn(ids)
self.col.db.executemany( self.col.db.executemany(
f""" f"""
update cards set type={CARD_TYPE_REV},queue={QUEUE_TYPE_REV},ivl=:ivl,due=:due,odue=0, update cards set type={CARD_TYPE_REV},queue={QUEUE_TYPE_REV},ivl=?,due=?,odue=0,
usn=:usn,mod=:mod,factor=:fact where id=:id""", usn=?,mod=?,factor=? where id=?""",
d, d,
) )
self.col.log(ids) self.col.log(ids)
@ -1866,10 +1841,8 @@ and due >= ? and queue = {QUEUE_TYPE_NEW}"""
for id, nid in self.col.db.execute( for id, nid in self.col.db.execute(
f"select id, nid from cards where type = {CARD_TYPE_NEW} and id in " + scids f"select id, nid from cards where type = {CARD_TYPE_NEW} and id in " + scids
): ):
d.append(dict(now=now, due=due[nid], usn=self.col.usn(), cid=id)) d.append((due[nid], now, self.col.usn(), id))
self.col.db.executemany( self.col.db.executemany("update cards set due=?,mod=?,usn=? where id = ?", d)
"update cards set due=:due,mod=:now,usn=:usn where id = :cid", d
)
def randomizeCards(self, did: int) -> None: def randomizeCards(self, did: int) -> None:
cids = self.col.db.list("select id from cards where did = ?", did) cids = self.col.db.list("select id from cards where did = ?", did)

View file

@ -58,7 +58,7 @@ class CardStats:
self.addLine(_("Reviews"), "%d" % c.reps) self.addLine(_("Reviews"), "%d" % c.reps)
self.addLine(_("Lapses"), "%d" % c.lapses) self.addLine(_("Lapses"), "%d" % c.lapses)
(cnt, total) = self.col.db.first( (cnt, total) = self.col.db.first(
"select count(), sum(time)/1000 from revlog where cid = :id", id=c.id "select count(), sum(time)/1000 from revlog where cid = ?", c.id
) )
if cnt: if cnt:
self.addLine(_("Average Time"), self.time(total / float(cnt))) self.addLine(_("Average Time"), self.time(total / float(cnt)))
@ -297,12 +297,12 @@ and due = ?"""
) -> Any: ) -> Any:
lim = "" lim = ""
if start is not None: if start is not None:
lim += " and due-:today >= %d" % start lim += " and due-%d >= %d" % (self.col.sched.today, start)
if end is not None: if end is not None:
lim += " and day < %d" % end lim += " and day < %d" % end
return self.col.db.all( return self.col.db.all(
f""" f"""
select (due-:today)/:chunk as day, select (due-?)/? as day,
sum(case when ivl < 21 then 1 else 0 end), -- yng sum(case when ivl < 21 then 1 else 0 end), -- yng
sum(case when ivl >= 21 then 1 else 0 end) -- mtr sum(case when ivl >= 21 then 1 else 0 end) -- mtr
from cards from cards
@ -310,8 +310,8 @@ where did in %s and queue in ({QUEUE_TYPE_REV},{QUEUE_TYPE_DAY_LEARN_RELEARN})
%s %s
group by day order by day""" group by day order by day"""
% (self._limit(), lim), % (self._limit(), lim),
today=self.col.sched.today, self.col.sched.today,
chunk=chunk, chunk,
) )
# Added, reps and time spent # Added, reps and time spent
@ -527,14 +527,13 @@ group by day order by day"""
return self.col.db.all( return self.col.db.all(
""" """
select select
(cast((id/1000.0 - :cut) / 86400.0 as int))/:chunk as day, (cast((id/1000.0 - ?) / 86400.0 as int))/? as day,
count(id) count(id)
from cards %s from cards %s
group by day order by day""" group by day order by day"""
% lim, % lim,
cut=self.col.sched.dayCutoff, self.col.sched.dayCutoff,
tf=tf, chunk,
chunk=chunk,
) )
def _done(self, num: Optional[int] = 7, chunk: int = 1) -> Any: def _done(self, num: Optional[int] = 7, chunk: int = 1) -> Any:
@ -557,24 +556,28 @@ group by day order by day"""
return self.col.db.all( return self.col.db.all(
f""" f"""
select select
(cast((id/1000.0 - :cut) / 86400.0 as int))/:chunk as day, (cast((id/1000.0 - ?) / 86400.0 as int))/? as day,
sum(case when type = {REVLOG_LRN} then 1 else 0 end), -- lrn count sum(case when type = {REVLOG_LRN} then 1 else 0 end), -- lrn count
sum(case when type = {REVLOG_REV} and lastIvl < 21 then 1 else 0 end), -- yng count sum(case when type = {REVLOG_REV} and lastIvl < 21 then 1 else 0 end), -- yng count
sum(case when type = {REVLOG_REV} and lastIvl >= 21 then 1 else 0 end), -- mtr count sum(case when type = {REVLOG_REV} and lastIvl >= 21 then 1 else 0 end), -- mtr count
sum(case when type = {REVLOG_RELRN} then 1 else 0 end), -- lapse count sum(case when type = {REVLOG_RELRN} then 1 else 0 end), -- lapse count
sum(case when type = {REVLOG_CRAM} then 1 else 0 end), -- cram count sum(case when type = {REVLOG_CRAM} then 1 else 0 end), -- cram count
sum(case when type = {REVLOG_LRN} then time/1000.0 else 0 end)/:tf, -- lrn time sum(case when type = {REVLOG_LRN} then time/1000.0 else 0 end)/?, -- lrn time
-- yng + mtr time -- yng + mtr time
sum(case when type = {REVLOG_REV} and lastIvl < 21 then time/1000.0 else 0 end)/:tf, sum(case when type = {REVLOG_REV} and lastIvl < 21 then time/1000.0 else 0 end)/?,
sum(case when type = {REVLOG_REV} and lastIvl >= 21 then time/1000.0 else 0 end)/:tf, sum(case when type = {REVLOG_REV} and lastIvl >= 21 then time/1000.0 else 0 end)/?,
sum(case when type = {REVLOG_RELRN} then time/1000.0 else 0 end)/:tf, -- lapse time sum(case when type = {REVLOG_RELRN} then time/1000.0 else 0 end)/?, -- lapse time
sum(case when type = {REVLOG_CRAM} then time/1000.0 else 0 end)/:tf -- cram time sum(case when type = {REVLOG_CRAM} then time/1000.0 else 0 end)/? -- cram time
from revlog %s from revlog %s
group by day order by day""" group by day order by day"""
% lim, % lim,
cut=self.col.sched.dayCutoff, self.col.sched.dayCutoff,
tf=tf, chunk,
chunk=chunk, tf,
tf,
tf,
tf,
tf,
) )
def _daysStudied(self) -> Any: def _daysStudied(self) -> Any:
@ -592,11 +595,11 @@ group by day order by day"""
ret = self.col.db.first( ret = self.col.db.first(
""" """
select count(), abs(min(day)) from (select select count(), abs(min(day)) from (select
(cast((id/1000 - :cut) / 86400.0 as int)+1) as day (cast((id/1000 - ?) / 86400.0 as int)+1) as day
from revlog %s from revlog %s
group by day order by day)""" group by day order by day)"""
% lim, % lim,
cut=self.col.sched.dayCutoff, self.col.sched.dayCutoff,
) )
assert ret assert ret
return ret return ret
@ -655,12 +658,12 @@ group by day order by day)"""
data = [ data = [
self.col.db.all( self.col.db.all(
f""" f"""
select ivl / :chunk as grp, count() from cards select ivl / ? as grp, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV} %s where did in %s and queue = {QUEUE_TYPE_REV} %s
group by grp group by grp
order by grp""" order by grp"""
% (self._limit(), lim), % (self._limit(), lim),
chunk=chunk, chunk,
) )
] ]
return ( return (
@ -866,14 +869,14 @@ order by thetype, ease"""
return self.col.db.all( return self.col.db.all(
f""" f"""
select select
23 - ((cast((:cut - id/1000) / 3600.0 as int)) %% 24) as hour, 23 - ((cast((? - id/1000) / 3600.0 as int)) %% 24) as hour,
sum(case when ease = 1 then 0 else 1 end) / sum(case when ease = 1 then 0 else 1 end) /
cast(count() as float) * 100, cast(count() as float) * 100,
count() count()
from revlog where type in ({REVLOG_LRN},{REVLOG_REV},{REVLOG_RELRN}) %s from revlog where type in ({REVLOG_LRN},{REVLOG_REV},{REVLOG_RELRN}) %s
group by hour having count() > 30 order by hour""" group by hour having count() > 30 order by hour"""
% lim, % lim,
cut=self.col.sched.dayCutoff - (rolloverHour * 3600), self.col.sched.dayCutoff - (rolloverHour * 3600),
) )
# Cards # Cards

View file

@ -4,12 +4,13 @@
import copy import copy
import json import json
import os import os
import re import weakref
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from anki.collection import _Collection from anki.collection import _Collection
from anki.consts import * from anki.consts import *
from anki.db import DB from anki.dbproxy import DBProxy
from anki.lang import _ from anki.lang import _
from anki.media import media_paths_from_col_path from anki.media import media_paths_from_col_path
from anki.rsbackend import RustBackend from anki.rsbackend import RustBackend
@ -20,48 +21,42 @@ from anki.stdmodels import (
addForwardOptionalReverse, addForwardOptionalReverse,
addForwardReverse, addForwardReverse,
) )
from anki.utils import intTime, isWin from anki.utils import intTime
@dataclass
class ServerData: class ServerData:
minutes_west: Optional[int] = None minutes_west: Optional[int] = None
def Collection( def Collection(
path: str, lock: bool = True, server: Optional[ServerData] = None, log: bool = False path: str,
backend: Optional[RustBackend] = None,
server: Optional[ServerData] = None,
) -> _Collection: ) -> _Collection:
"Open a new or existing collection. Path must be unicode." "Open a new or existing collection. Path must be unicode."
assert path.endswith(".anki2") assert path.endswith(".anki2")
if backend is None:
backend = RustBackend(server=server is not None)
(media_dir, media_db) = media_paths_from_col_path(path) (media_dir, media_db) = media_paths_from_col_path(path)
log_path = "" log_path = ""
if not server: if not server:
log_path = path.replace(".anki2", "2.log") log_path = path.replace(".anki2", "2.log")
backend = RustBackend(path, media_dir, media_db, log_path)
path = os.path.abspath(path) path = os.path.abspath(path)
create = not os.path.exists(path)
if create:
base = os.path.basename(path)
for c in ("/", ":", "\\"):
assert c not in base
# connect # connect
db = DB(path) backend.open_collection(path, media_dir, media_db, log_path)
db.setAutocommit(True) db = DBProxy(weakref.proxy(backend), path)
# initial setup required?
create = db.scalar("select models = '{}' from col")
if create: if create:
ver = _createDB(db) initial_db_setup(db)
else:
ver = _upgradeSchema(db)
db.execute("pragma temp_store = memory")
db.execute("pragma cache_size = 10000")
if not isWin:
db.execute("pragma journal_mode = wal")
db.setAutocommit(False)
# add db to col and do any remaining upgrades # add db to col and do any remaining upgrades
col = _Collection(db, backend=backend, server=server, log=log) col = _Collection(db, backend=backend, server=server)
if ver < SCHEMA_VERSION: if create:
_upgrade(col, ver)
elif ver > SCHEMA_VERSION:
raise Exception("This file requires a newer version of Anki.")
elif create:
# add in reverse order so basic is default # add in reverse order so basic is default
addClozeModel(col) addClozeModel(col)
addBasicTypingModel(col) addBasicTypingModel(col)
@ -69,267 +64,21 @@ def Collection(
addForwardReverse(col) addForwardReverse(col)
addBasicModel(col) addBasicModel(col)
col.save() col.save()
if lock: else:
try: db.begin()
col.lock()
except:
col.db.close()
raise
return col return col
def _upgradeSchema(db: DB) -> Any:
ver = db.scalar("select ver from col")
if ver == SCHEMA_VERSION:
return ver
# add odid to cards, edue->odue
######################################################################
if db.scalar("select ver from col") == 1:
db.execute("alter table cards rename to cards2")
_addSchema(db, setColConf=False)
db.execute(
"""
insert into cards select
id, nid, did, ord, mod, usn, type, queue, due, ivl, factor, reps, lapses,
left, edue, 0, flags, data from cards2"""
)
db.execute("drop table cards2")
db.execute("update col set ver = 2")
_updateIndices(db)
# remove did from notes
######################################################################
if db.scalar("select ver from col") == 2:
db.execute("alter table notes rename to notes2")
_addSchema(db, setColConf=False)
db.execute(
"""
insert into notes select
id, guid, mid, mod, usn, tags, flds, sfld, csum, flags, data from notes2"""
)
db.execute("drop table notes2")
db.execute("update col set ver = 3")
_updateIndices(db)
return ver
def _upgrade(col, ver) -> None:
if ver < 3:
# new deck properties
for d in col.decks.all():
d["dyn"] = DECK_STD
d["collapsed"] = False
col.decks.save(d)
if ver < 4:
col.modSchema(check=False)
clozes = []
for m in col.models.all():
if not "{{cloze:" in m["tmpls"][0]["qfmt"]:
m["type"] = MODEL_STD
col.models.save(m)
else:
clozes.append(m)
for m in clozes:
_upgradeClozeModel(col, m)
col.db.execute("update col set ver = 4")
if ver < 5:
col.db.execute("update cards set odue = 0 where queue = 2")
col.db.execute("update col set ver = 5")
if ver < 6:
col.modSchema(check=False)
import anki.models
for m in col.models.all():
m["css"] = anki.models.defaultModel["css"]
for t in m["tmpls"]:
if "css" not in t:
# ankidroid didn't bump version
continue
m["css"] += "\n" + t["css"].replace(
".card ", ".card%d " % (t["ord"] + 1)
)
del t["css"]
col.models.save(m)
col.db.execute("update col set ver = 6")
if ver < 7:
col.modSchema(check=False)
col.db.execute(
"update cards set odue = 0 where (type = 1 or queue = 2) " "and not odid"
)
col.db.execute("update col set ver = 7")
if ver < 8:
col.modSchema(check=False)
col.db.execute("update cards set due = due / 1000 where due > 4294967296")
col.db.execute("update col set ver = 8")
if ver < 9:
# adding an empty file to a zip makes python's zip code think it's a
# folder, so remove any empty files
changed = False
dir = col.media.dir()
if dir:
for f in os.listdir(col.media.dir()):
if os.path.isfile(f) and not os.path.getsize(f):
os.unlink(f)
col.media.db.execute("delete from log where fname = ?", f)
col.media.db.execute("delete from media where fname = ?", f)
changed = True
if changed:
col.media.db.commit()
col.db.execute("update col set ver = 9")
if ver < 10:
col.db.execute(
"""
update cards set left = left + left*1000 where queue = 1"""
)
col.db.execute("update col set ver = 10")
if ver < 11:
col.modSchema(check=False)
for d in col.decks.all():
if d["dyn"]:
order = d["order"]
# failed order was removed
if order >= 5:
order -= 1
d["terms"] = [[d["search"], d["limit"], order]]
del d["search"]
del d["limit"]
del d["order"]
d["resched"] = True
d["return"] = True
else:
if "extendNew" not in d:
d["extendNew"] = 10
d["extendRev"] = 50
col.decks.save(d)
for c in col.decks.allConf():
r = c["rev"]
r["ivlFct"] = r.get("ivlfct", 1)
if "ivlfct" in r:
del r["ivlfct"]
r["maxIvl"] = 36500
col.decks.save(c)
for m in col.models.all():
for t in m["tmpls"]:
t["bqfmt"] = ""
t["bafmt"] = ""
col.models.save(m)
col.db.execute("update col set ver = 11")
def _upgradeClozeModel(col, m) -> None:
m["type"] = MODEL_CLOZE
# convert first template
t = m["tmpls"][0]
for type in "qfmt", "afmt":
t[type] = re.sub("{{cloze:1:(.+?)}}", r"{{cloze:\1}}", t[type])
t["name"] = _("Cloze")
# delete non-cloze cards for the model
rem = []
for t in m["tmpls"][1:]:
if "{{cloze:" not in t["qfmt"]:
rem.append(t)
for r in rem:
col.models.remTemplate(m, r)
del m["tmpls"][1:]
col.models._updateTemplOrds(m)
col.models.save(m)
# Creating a new collection # Creating a new collection
###################################################################### ######################################################################
def _createDB(db: DB) -> int: def initial_db_setup(db: DBProxy) -> None:
db.execute("pragma page_size = 4096") db.begin()
db.execute("pragma legacy_file_format = 0") _addColVars(db, *_getColVars(db))
db.execute("vacuum")
_addSchema(db)
_updateIndices(db)
db.execute("analyze")
return SCHEMA_VERSION
def _addSchema(db: DB, setColConf: bool = True) -> None: def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]:
db.executescript(
"""
create table if not exists col (
id integer primary key,
crt integer not null,
mod integer not null,
scm integer not null,
ver integer not null,
dty integer not null,
usn integer not null,
ls integer not null,
conf text not null,
models text not null,
decks text not null,
dconf text not null,
tags text not null
);
create table if not exists notes (
id integer primary key, /* 0 */
guid text not null, /* 1 */
mid integer not null, /* 2 */
mod integer not null, /* 3 */
usn integer not null, /* 4 */
tags text not null, /* 5 */
flds text not null, /* 6 */
sfld integer not null, /* 7 */
csum integer not null, /* 8 */
flags integer not null, /* 9 */
data text not null /* 10 */
);
create table if not exists cards (
id integer primary key, /* 0 */
nid integer not null, /* 1 */
did integer not null, /* 2 */
ord integer not null, /* 3 */
mod integer not null, /* 4 */
usn integer not null, /* 5 */
type integer not null, /* 6 */
queue integer not null, /* 7 */
due integer not null, /* 8 */
ivl integer not null, /* 9 */
factor integer not null, /* 10 */
reps integer not null, /* 11 */
lapses integer not null, /* 12 */
left integer not null, /* 13 */
odue integer not null, /* 14 */
odid integer not null, /* 15 */
flags integer not null, /* 16 */
data text not null /* 17 */
);
create table if not exists revlog (
id integer primary key,
cid integer not null,
usn integer not null,
ease integer not null,
ivl integer not null,
lastIvl integer not null,
factor integer not null,
time integer not null,
type integer not null
);
create table if not exists graves (
usn integer not null,
oid integer not null,
type integer not null
);
insert or ignore into col
values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
"""
% ({"v": SCHEMA_VERSION, "s": intTime(1000)})
)
if setColConf:
_addColVars(db, *_getColVars(db))
def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
import anki.collection import anki.collection
import anki.decks import anki.decks
@ -344,7 +93,7 @@ def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
def _addColVars( def _addColVars(
db: DB, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any] db: DBProxy, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
) -> None: ) -> None:
db.execute( db.execute(
""" """
@ -353,23 +102,3 @@ update col set conf = ?, decks = ?, dconf = ?""",
json.dumps({"1": g}), json.dumps({"1": g}),
json.dumps({"1": gc}), json.dumps({"1": gc}),
) )
def _updateIndices(db: DB) -> None:
"Add indices to the DB."
db.executescript(
"""
-- syncing
create index if not exists ix_notes_usn on notes (usn);
create index if not exists ix_cards_usn on cards (usn);
create index if not exists ix_revlog_usn on revlog (usn);
-- card spacing, etc
create index if not exists ix_cards_nid on cards (nid);
-- scheduling and deck limiting
create index if not exists ix_cards_sched on cards (did, queue, due);
-- revlog by card
create index if not exists ix_revlog_cid on revlog (cid);
-- field uniqueness
create index if not exists ix_notes_csum on notes (csum);
"""
)

View file

@ -8,8 +8,7 @@ import io
import json import json
import os import os
import random import random
import sqlite3 from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import anki import anki
from anki.consts import * from anki.consts import *
@ -32,7 +31,7 @@ class UnexpectedSchemaChange(Exception):
class Syncer: class Syncer:
cursor: Optional[sqlite3.Cursor] chunkRows: Optional[List[Sequence]]
def __init__(self, col: anki.storage._Collection, server=None) -> None: def __init__(self, col: anki.storage._Collection, server=None) -> None:
self.col = col.weakref() self.col = col.weakref()
@ -247,11 +246,11 @@ class Syncer:
def prepareToChunk(self) -> None: def prepareToChunk(self) -> None:
self.tablesLeft = ["revlog", "cards", "notes"] self.tablesLeft = ["revlog", "cards", "notes"]
self.cursor = None self.chunkRows = None
def cursorForTable(self, table) -> sqlite3.Cursor: def getChunkRows(self, table) -> List[Sequence]:
lim = self.usnLim() lim = self.usnLim()
x = self.col.db.execute x = self.col.db.all
d = (self.maxUsn, lim) d = (self.maxUsn, lim)
if table == "revlog": if table == "revlog":
return x( return x(
@ -280,14 +279,15 @@ from notes where %s"""
lim = 250 lim = 250
while self.tablesLeft and lim: while self.tablesLeft and lim:
curTable = self.tablesLeft[0] curTable = self.tablesLeft[0]
if not self.cursor: if not self.chunkRows:
self.cursor = self.cursorForTable(curTable) self.chunkRows = self.getChunkRows(curTable)
rows = self.cursor.fetchmany(lim) rows = self.chunkRows[:lim]
self.chunkRows = self.chunkRows[lim:]
fetched = len(rows) fetched = len(rows)
if fetched != lim: if fetched != lim:
# table is empty # table is empty
self.tablesLeft.pop(0) self.tablesLeft.pop(0)
self.cursor = None self.chunkRows = None
# mark the objects as having been sent # mark the objects as having been sent
self.col.db.execute( self.col.db.execute(
"update %s set usn=? where usn=-1" % curTable, self.maxUsn "update %s set usn=? where usn=-1" % curTable, self.maxUsn

View file

@ -110,30 +110,25 @@ class TagManager:
else: else:
l = "tags " l = "tags "
fn = self.remFromStr fn = self.remFromStr
lim = " or ".join([l + "like :_%d" % c for c, t in enumerate(newTags)]) lim = " or ".join(l + "like ?" for x in newTags)
res = self.col.db.all( res = self.col.db.all(
"select id, tags from notes where id in %s and (%s)" % (ids2str(ids), lim), "select id, tags from notes where id in %s and (%s)" % (ids2str(ids), lim),
**dict( *["%% %s %%" % y.replace("*", "%") for x, y in enumerate(newTags)],
[
("_%d" % x, "%% %s %%" % y.replace("*", "%"))
for x, y in enumerate(newTags)
]
),
) )
# update tags # update tags
nids = [] nids = []
def fix(row): def fix(row):
nids.append(row[0]) nids.append(row[0])
return { return [
"id": row[0], fn(tags, row[1]),
"t": fn(tags, row[1]), intTime(),
"n": intTime(), self.col.usn(),
"u": self.col.usn(), row[0],
} ]
self.col.db.executemany( self.col.db.executemany(
"update notes set tags=:t,mod=:n,usn=:u where id = :id", "update notes set tags=?,mod=?,usn=? where id = ?",
[fix(row) for row in res], [fix(row) for row in res],
) )

View file

@ -22,7 +22,7 @@ from hashlib import sha1
from html.entities import name2codepoint from html.entities import name2codepoint
from typing import Iterable, Iterator, List, Optional, Union from typing import Iterable, Iterator, List, Optional, Union
from anki.db import DB from anki.dbproxy import DBProxy
_tmpdir: Optional[str] _tmpdir: Optional[str]
@ -142,7 +142,7 @@ def ids2str(ids: Iterable[Union[int, str]]) -> str:
return "(%s)" % ",".join(str(i) for i in ids) return "(%s)" % ",".join(str(i) for i in ids)
def timestampID(db: DB, table: str) -> int: def timestampID(db: DBProxy, table: str) -> int:
"Return a non-conflicting timestamp for table." "Return a non-conflicting timestamp for table."
# be careful not to create multiple objects without flushing them, or they # be careful not to create multiple objects without flushing them, or they
# may share an ID. # may share an ID.
@ -152,7 +152,7 @@ def timestampID(db: DB, table: str) -> int:
return t return t
def maxID(db: DB) -> int: def maxID(db: DBProxy) -> int:
"Return the first safe ID to use." "Return the first safe ID to use."
now = intTime(1000) now = intTime(1000)
for tbl in "cards", "notes": for tbl in "cards", "notes":

View file

@ -21,6 +21,7 @@ setuptools.setup(
"requests", "requests",
"decorator", "decorator",
"protobuf", "protobuf",
'orjson; platform_machine == "x86_64"',
'psutil; sys_platform == "win32"', 'psutil; sys_platform == "win32"',
'distro; sys_platform != "darwin" and sys_platform != "win32"', 'distro; sys_platform != "darwin" and sys_platform != "win32"',
], ],

View file

@ -1,9 +1,22 @@
import os import os
import shutil import shutil
import tempfile import tempfile
import time
from anki import Collection as aopen from anki import Collection as aopen
# Between 2-4AM, shift the time back so test assumptions hold.
lt = time.localtime()
if lt.tm_hour >= 2 and lt.tm_hour < 4:
orig_time = time.time
def adjusted_time():
return orig_time() - 60 * 60 * 2
time.time = adjusted_time
else:
orig_time = None
def assertException(exception, func): def assertException(exception, func):
found = False found = False
@ -22,7 +35,7 @@ def getEmptyCol():
os.close(fd) os.close(fd)
os.unlink(nam) os.unlink(nam)
col = aopen(nam) col = aopen(nam)
col.db.close() col.close()
getEmptyCol.master = nam getEmptyCol.master = nam
(fd, nam) = tempfile.mkstemp(suffix=".anki2") (fd, nam) = tempfile.mkstemp(suffix=".anki2")
shutil.copy(getEmptyCol.master, nam) shutil.copy(getEmptyCol.master, nam)
@ -48,3 +61,15 @@ def getUpgradeDeckPath(name="anki12.anki"):
testDir = os.path.dirname(__file__) testDir = os.path.dirname(__file__)
def errorsAfterMidnight(func):
lt = time.localtime()
if lt.tm_hour < 4:
print("test disabled around cutoff", func)
else:
func()
def isNearCutoff():
return orig_time is not None

View file

@ -6,6 +6,7 @@ import tempfile
from anki import Collection as aopen from anki import Collection as aopen
from anki.exporting import * from anki.exporting import *
from anki.importing import Anki2Importer from anki.importing import Anki2Importer
from tests.shared import errorsAfterMidnight
from tests.shared import getEmptyCol as getEmptyColOrig from tests.shared import getEmptyCol as getEmptyColOrig
@ -97,6 +98,7 @@ def test_export_ankipkg():
e.exportInto(newname) e.exportInto(newname)
@errorsAfterMidnight
def test_export_anki_due(): def test_export_anki_due():
setup1() setup1()
deck = getEmptyCol() deck = getEmptyCol()

View file

@ -2,8 +2,8 @@
import pytest import pytest
from anki.consts import * from anki.consts import *
from anki.find import Finder from anki.rsbackend import BuiltinSortKind
from tests.shared import getEmptyCol from tests.shared import getEmptyCol, isNearCutoff
class DummyCollection: class DummyCollection:
@ -11,32 +11,6 @@ class DummyCollection:
return None return None
def test_parse():
f = Finder(DummyCollection())
assert f._tokenize("hello world") == ["hello", "world"]
assert f._tokenize("hello world") == ["hello", "world"]
assert f._tokenize("one -two") == ["one", "-", "two"]
assert f._tokenize("one --two") == ["one", "-", "two"]
assert f._tokenize("one - two") == ["one", "-", "two"]
assert f._tokenize("one or -two") == ["one", "or", "-", "two"]
assert f._tokenize("'hello \"world\"'") == ['hello "world"']
assert f._tokenize('"hello world"') == ["hello world"]
assert f._tokenize("one (two or ( three or four))") == [
"one",
"(",
"two",
"or",
"(",
"three",
"or",
"four",
")",
")",
]
assert f._tokenize("embedded'string") == ["embedded'string"]
assert f._tokenize("deck:'two words'") == ["deck:two words"]
def test_findCards(): def test_findCards():
deck = getEmptyCol() deck = getEmptyCol()
f = deck.newNote() f = deck.newNote()
@ -68,6 +42,7 @@ def test_findCards():
f["Front"] = "test" f["Front"] = "test"
f["Back"] = "foo bar" f["Back"] = "foo bar"
deck.addNote(f) deck.addNote(f)
deck.save()
latestCardIds = [c.id for c in f.cards()] latestCardIds = [c.id for c in f.cards()]
# tag searches # tag searches
assert len(deck.findCards("tag:*")) == 5 assert len(deck.findCards("tag:*")) == 5
@ -117,9 +92,8 @@ def test_findCards():
assert len(deck.findCards("nid:%d" % f.id)) == 2 assert len(deck.findCards("nid:%d" % f.id)) == 2
assert len(deck.findCards("nid:%d,%d" % (f1id, f2id))) == 2 assert len(deck.findCards("nid:%d,%d" % (f1id, f2id))) == 2
# templates # templates
with pytest.raises(Exception): assert len(deck.findCards("card:foo")) == 0
deck.findCards("card:foo") assert len(deck.findCards('"card:card 1"')) == 4
assert len(deck.findCards("'card:card 1'")) == 4
assert len(deck.findCards("card:reverse")) == 1 assert len(deck.findCards("card:reverse")) == 1
assert len(deck.findCards("card:1")) == 4 assert len(deck.findCards("card:1")) == 4
assert len(deck.findCards("card:2")) == 1 assert len(deck.findCards("card:2")) == 1
@ -133,16 +107,28 @@ def test_findCards():
assert len(deck.findCards("front:*")) == 5 assert len(deck.findCards("front:*")) == 5
# ordering # ordering
deck.conf["sortType"] = "noteCrt" deck.conf["sortType"] = "noteCrt"
deck.flush()
assert deck.findCards("front:*", order=True)[-1] in latestCardIds assert deck.findCards("front:*", order=True)[-1] in latestCardIds
assert deck.findCards("", order=True)[-1] in latestCardIds assert deck.findCards("", order=True)[-1] in latestCardIds
deck.conf["sortType"] = "noteFld" deck.conf["sortType"] = "noteFld"
deck.flush()
assert deck.findCards("", order=True)[0] == catCard.id assert deck.findCards("", order=True)[0] == catCard.id
assert deck.findCards("", order=True)[-1] in latestCardIds assert deck.findCards("", order=True)[-1] in latestCardIds
deck.conf["sortType"] = "cardMod" deck.conf["sortType"] = "cardMod"
deck.flush()
assert deck.findCards("", order=True)[-1] in latestCardIds assert deck.findCards("", order=True)[-1] in latestCardIds
assert deck.findCards("", order=True)[0] == firstCardId assert deck.findCards("", order=True)[0] == firstCardId
deck.conf["sortBackwards"] = True deck.conf["sortBackwards"] = True
deck.flush()
assert deck.findCards("", order=True)[0] in latestCardIds assert deck.findCards("", order=True)[0] in latestCardIds
assert (
deck.find_cards("", order=BuiltinSortKind.CARD_DUE, reverse=False)[0]
== firstCardId
)
assert (
deck.find_cards("", order=BuiltinSortKind.CARD_DUE, reverse=True)[0]
!= firstCardId
)
# model # model
assert len(deck.findCards("note:basic")) == 5 assert len(deck.findCards("note:basic")) == 5
assert len(deck.findCards("-note:basic")) == 0 assert len(deck.findCards("-note:basic")) == 0
@ -153,8 +139,7 @@ def test_findCards():
assert len(deck.findCards("-deck:foo")) == 5 assert len(deck.findCards("-deck:foo")) == 5
assert len(deck.findCards("deck:def*")) == 5 assert len(deck.findCards("deck:def*")) == 5
assert len(deck.findCards("deck:*EFAULT")) == 5 assert len(deck.findCards("deck:*EFAULT")) == 5
with pytest.raises(Exception): assert len(deck.findCards("deck:*cefault")) == 0
deck.findCards("deck:*cefault")
# full search # full search
f = deck.newNote() f = deck.newNote()
f["Front"] = "hello<b>world</b>" f["Front"] = "hello<b>world</b>"
@ -177,6 +162,7 @@ def test_findCards():
deck.db.execute( deck.db.execute(
"update cards set did = ? where id = ?", deck.decks.id("Default::Child"), id "update cards set did = ? where id = ?", deck.decks.id("Default::Child"), id
) )
deck.save()
assert len(deck.findCards("deck:default")) == 7 assert len(deck.findCards("deck:default")) == 7
assert len(deck.findCards("deck:default::child")) == 1 assert len(deck.findCards("deck:default::child")) == 1
assert len(deck.findCards("deck:default -deck:default::*")) == 6 assert len(deck.findCards("deck:default -deck:default::*")) == 6
@ -195,33 +181,35 @@ def test_findCards():
assert len(deck.findCards("prop:ivl!=10")) > 1 assert len(deck.findCards("prop:ivl!=10")) > 1
assert len(deck.findCards("prop:due>0")) == 1 assert len(deck.findCards("prop:due>0")) == 1
# due dates should work # due dates should work
deck.sched.today = 15 assert len(deck.findCards("prop:due=29")) == 0
assert len(deck.findCards("prop:due=14")) == 0 assert len(deck.findCards("prop:due=30")) == 1
assert len(deck.findCards("prop:due=15")) == 1
assert len(deck.findCards("prop:due=16")) == 0
# including negatives
deck.sched.today = 32
assert len(deck.findCards("prop:due=-1")) == 0
assert len(deck.findCards("prop:due=-2")) == 1
# ease factors # ease factors
assert len(deck.findCards("prop:ease=2.3")) == 0 assert len(deck.findCards("prop:ease=2.3")) == 0
assert len(deck.findCards("prop:ease=2.2")) == 1 assert len(deck.findCards("prop:ease=2.2")) == 1
assert len(deck.findCards("prop:ease>2")) == 1 assert len(deck.findCards("prop:ease>2")) == 1
assert len(deck.findCards("-prop:ease>2")) > 1 assert len(deck.findCards("-prop:ease>2")) > 1
# recently failed # recently failed
assert len(deck.findCards("rated:1:1")) == 0 if not isNearCutoff():
assert len(deck.findCards("rated:1:2")) == 0 assert len(deck.findCards("rated:1:1")) == 0
c = deck.sched.getCard() assert len(deck.findCards("rated:1:2")) == 0
deck.sched.answerCard(c, 2) c = deck.sched.getCard()
assert len(deck.findCards("rated:1:1")) == 0 deck.sched.answerCard(c, 2)
assert len(deck.findCards("rated:1:2")) == 1 assert len(deck.findCards("rated:1:1")) == 0
c = deck.sched.getCard() assert len(deck.findCards("rated:1:2")) == 1
deck.sched.answerCard(c, 1) c = deck.sched.getCard()
assert len(deck.findCards("rated:1:1")) == 1 deck.sched.answerCard(c, 1)
assert len(deck.findCards("rated:1:2")) == 1 assert len(deck.findCards("rated:1:1")) == 1
assert len(deck.findCards("rated:1")) == 2 assert len(deck.findCards("rated:1:2")) == 1
assert len(deck.findCards("rated:0:2")) == 0 assert len(deck.findCards("rated:1")) == 2
assert len(deck.findCards("rated:2:2")) == 1 assert len(deck.findCards("rated:0:2")) == 0
assert len(deck.findCards("rated:2:2")) == 1
# added
assert len(deck.findCards("added:0")) == 0
deck.db.execute("update cards set id = id - 86400*1000 where id = ?", id)
assert len(deck.findCards("added:1")) == deck.cardCount() - 1
assert len(deck.findCards("added:2")) == deck.cardCount()
else:
print("some find tests disabled near cutoff")
# empty field # empty field
assert len(deck.findCards("front:")) == 0 assert len(deck.findCards("front:")) == 0
f = deck.newNote() f = deck.newNote()
@ -235,17 +223,7 @@ def test_findCards():
assert len(deck.findCards("-(tag:monkey OR tag:sheep)")) == 6 assert len(deck.findCards("-(tag:monkey OR tag:sheep)")) == 6
assert len(deck.findCards("tag:monkey or (tag:sheep sheep)")) == 2 assert len(deck.findCards("tag:monkey or (tag:sheep sheep)")) == 2
assert len(deck.findCards("tag:monkey or (tag:sheep octopus)")) == 1 assert len(deck.findCards("tag:monkey or (tag:sheep octopus)")) == 1
# invalid grouping shouldn't error
assert len(deck.findCards(")")) == 0
assert len(deck.findCards("(()")) == 0
# added
assert len(deck.findCards("added:0")) == 0
deck.db.execute("update cards set id = id - 86400*1000 where id = ?", id)
assert len(deck.findCards("added:1")) == deck.cardCount() - 1
assert len(deck.findCards("added:2")) == deck.cardCount()
# flag # flag
with pytest.raises(Exception):
deck.findCards("flag:01")
with pytest.raises(Exception): with pytest.raises(Exception):
deck.findCards("flag:12") deck.findCards("flag:12")

View file

@ -73,8 +73,6 @@ def test_deckIntegration():
with open(os.path.join(d.media.dir(), "foo.jpg"), "w") as f: with open(os.path.join(d.media.dir(), "foo.jpg"), "w") as f:
f.write("test") f.write("test")
# check media # check media
d.close()
ret = d.media.check() ret = d.media.check()
d.reopen()
assert ret.missing == ["fake2.png"] assert ret.missing == ["fake2.png"]
assert ret.unused == ["foo.jpg"] assert ret.unused == ["foo.jpg"]

View file

@ -16,17 +16,6 @@ def getEmptyCol():
return col return col
# Between 2-4AM, shift the time back so test assumptions hold.
lt = time.localtime()
if lt.tm_hour >= 2 and lt.tm_hour < 4:
orig_time = time.time
def adjusted_time():
return orig_time() - 60 * 60 * 2
time.time = adjusted_time
def test_clock(): def test_clock():
d = getEmptyCol() d = getEmptyCol()
if (d.sched.dayCutoff - intTime()) < 10 * 60: if (d.sched.dayCutoff - intTime()) < 10 * 60:

View file

@ -37,11 +37,6 @@ hooks = [
args=["exporters: List[Tuple[str, Any]]"], args=["exporters: List[Tuple[str, Any]]"],
legacy_hook="exportersList", legacy_hook="exportersList",
), ),
Hook(
name="search_terms_prepared",
args=["searches: Dict[str, Callable]"],
legacy_hook="search",
),
Hook( Hook(
name="note_type_added", name="note_type_added",
args=["notetype: Dict[str, Any]"], args=["notetype: Dict[str, Any]"],

View file

@ -17,6 +17,7 @@ import anki.lang
import aqt.buildinfo import aqt.buildinfo
from anki import version as _version from anki import version as _version
from anki.consts import HELP_SITE from anki.consts import HELP_SITE
from anki.rsbackend import RustBackend
from anki.utils import checksum, isLin, isMac from anki.utils import checksum, isLin, isMac
from aqt.qt import * from aqt.qt import *
from aqt.utils import locale_dir from aqt.utils import locale_dir
@ -162,15 +163,15 @@ dialogs = DialogManager()
# Qt requires its translator to be installed before any GUI widgets are # Qt requires its translator to be installed before any GUI widgets are
# loaded, and we need the Qt language to match the gettext language or # loaded, and we need the Qt language to match the gettext language or
# translated shortcuts will not work. # translated shortcuts will not work.
#
# The Qt translator needs to be retained to work.
# A reference to the Qt translator needs to be held to prevent it from
# being immediately deallocated.
_qtrans: Optional[QTranslator] = None _qtrans: Optional[QTranslator] = None
def setupLang( def setupLangAndBackend(
pm: ProfileManager, app: QApplication, force: Optional[str] = None pm: ProfileManager, app: QApplication, force: Optional[str] = None
) -> None: ) -> RustBackend:
global _qtrans global _qtrans
try: try:
locale.setlocale(locale.LC_ALL, "") locale.setlocale(locale.LC_ALL, "")
@ -218,6 +219,8 @@ def setupLang(
if _qtrans.load("qtbase_" + qt_lang, qt_dir): if _qtrans.load("qtbase_" + qt_lang, qt_dir):
app.installTranslator(_qtrans) app.installTranslator(_qtrans)
return anki.lang.current_i18n
# App initialisation # App initialisation
########################################################################## ##########################################################################
@ -465,8 +468,8 @@ environment points to a valid, writable folder.""",
if opts.profile: if opts.profile:
pm.openProfile(opts.profile) pm.openProfile(opts.profile)
# i18n # i18n & backend
setupLang(pm, app, opts.lang) backend = setupLangAndBackend(pm, app, opts.lang)
if isLin and pm.glMode() == "auto": if isLin and pm.glMode() == "auto":
from aqt.utils import gfxDriverIsBroken from aqt.utils import gfxDriverIsBroken
@ -483,7 +486,7 @@ environment points to a valid, writable folder.""",
# load the main window # load the main window
import aqt.main import aqt.main
mw = aqt.main.AnkiQt(app, pm, opts, args) mw = aqt.main.AnkiQt(app, pm, backend, opts, args)
if exec: if exec:
app.exec() app.exec()
else: else:

View file

@ -167,8 +167,12 @@ class AddCards(QDialog):
def addNote(self, note) -> Optional[Note]: def addNote(self, note) -> Optional[Note]:
note.model()["did"] = self.deckChooser.selectedId() note.model()["did"] = self.deckChooser.selectedId()
ret = note.dupeOrEmpty() ret = note.dupeOrEmpty()
problem = None
if ret == 1: if ret == 1:
showWarning(_("The first field is empty."), help="AddItems#AddError") problem = _("The first field is empty.")
problem = gui_hooks.add_cards_will_add_note(problem, note)
if problem is not None:
showWarning(problem, help="AddItems#AddError")
return None return None
if "{{cloze:" in note.model()["tmpls"][0]["qfmt"]: if "{{cloze:" in note.model()["tmpls"][0]["qfmt"]:
if not self.mw.col.models._availClozeOrds( if not self.mw.col.models._availClozeOrds(

View file

@ -13,7 +13,7 @@ import unicodedata
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from operator import itemgetter from operator import itemgetter
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Sequence, Union
import anki import anki
import aqt.forms import aqt.forms
@ -69,6 +69,14 @@ class FindDupesDialog:
browser: Browser browser: Browser
@dataclass
class SearchContext:
search: str
order: Union[bool, str] = True
# if set, provided card ids will be used instead of the regular search
card_ids: Optional[Sequence[int]] = None
# Data model # Data model
########################################################################## ##########################################################################
@ -82,7 +90,7 @@ class DataModel(QAbstractTableModel):
self.activeCols = self.col.conf.get( self.activeCols = self.col.conf.get(
"activeCols", ["noteFld", "template", "cardDue", "deck"] "activeCols", ["noteFld", "template", "cardDue", "deck"]
) )
self.cards: List[int] = [] self.cards: Sequence[int] = []
self.cardObjs: Dict[int, Card] = {} self.cardObjs: Dict[int, Card] = {}
def getCard(self, index: QModelIndex) -> Card: def getCard(self, index: QModelIndex) -> Card:
@ -169,23 +177,22 @@ class DataModel(QAbstractTableModel):
# Filtering # Filtering
###################################################################### ######################################################################
def search(self, txt): def search(self, txt: str) -> None:
self.beginReset() self.beginReset()
t = time.time()
# the db progress handler may cause a refresh, so we need to zero out
# old data first
self.cards = [] self.cards = []
invalid = False invalid = False
try: try:
self.cards = self.col.findCards(txt, order=True) ctx = SearchContext(search=txt)
gui_hooks.browser_will_search(ctx)
if ctx.card_ids is None:
ctx.card_ids = self.col.find_cards(txt, order=ctx.order)
gui_hooks.browser_did_search(ctx)
self.cards = ctx.card_ids
except Exception as e: except Exception as e:
if str(e) == "invalidSearch": print("search failed:", e)
self.cards = [] invalid = True
invalid = True finally:
else: self.endReset()
raise
# print "fetch cards in %dms" % ((time.time() - t)*1000)
self.endReset()
if invalid: if invalid:
showWarning(_("Invalid search - please check for typing mistakes.")) showWarning(_("Invalid search - please check for typing mistakes."))
@ -213,7 +220,7 @@ class DataModel(QAbstractTableModel):
def _reverse(self): def _reverse(self):
self.beginReset() self.beginReset()
self.cards.reverse() self.cards = list(reversed(self.cards))
self.endReset() self.endReset()
def saveSelection(self): def saveSelection(self):
@ -275,6 +282,9 @@ class DataModel(QAbstractTableModel):
def columnType(self, column): def columnType(self, column):
return self.activeCols[column] return self.activeCols[column]
def time_format(self):
return "%Y-%m-%d"
def columnData(self, index): def columnData(self, index):
row = index.row() row = index.row()
col = index.column() col = index.column()
@ -302,11 +312,11 @@ class DataModel(QAbstractTableModel):
t = "(" + t + ")" t = "(" + t + ")"
return t return t
elif type == "noteCrt": elif type == "noteCrt":
return time.strftime("%Y-%m-%d", time.localtime(c.note().id / 1000)) return time.strftime(self.time_format(), time.localtime(c.note().id / 1000))
elif type == "noteMod": elif type == "noteMod":
return time.strftime("%Y-%m-%d", time.localtime(c.note().mod)) return time.strftime(self.time_format(), time.localtime(c.note().mod))
elif type == "cardMod": elif type == "cardMod":
return time.strftime("%Y-%m-%d", time.localtime(c.mod)) return time.strftime(self.time_format(), time.localtime(c.mod))
elif type == "cardReps": elif type == "cardReps":
return str(c.reps) return str(c.reps)
elif type == "cardLapses": elif type == "cardLapses":
@ -363,7 +373,7 @@ class DataModel(QAbstractTableModel):
date = time.time() + ((c.due - self.col.sched.today) * 86400) date = time.time() + ((c.due - self.col.sched.today) * 86400)
else: else:
return "" return ""
return time.strftime("%Y-%m-%d", time.localtime(date)) return time.strftime(self.time_format(), time.localtime(date))
def isRTL(self, index): def isRTL(self, index):
col = index.column() col = index.column()
@ -388,15 +398,12 @@ class StatusDelegate(QItemDelegate):
self.model = model self.model = model
def paint(self, painter, option, index): def paint(self, painter, option, index):
self.browser.mw.progress.blockUpdates = True
try: try:
c = self.model.getCard(index) c = self.model.getCard(index)
except: except:
# in the the middle of a reset; return nothing so this row is not # in the the middle of a reset; return nothing so this row is not
# rendered until we have a chance to reset the model # rendered until we have a chance to reset the model
return return
finally:
self.browser.mw.progress.blockUpdates = True
if self.model.isRTL(index): if self.model.isRTL(index):
option.direction = Qt.RightToLeft option.direction = Qt.RightToLeft
@ -833,6 +840,7 @@ class Browser(QMainWindow):
self.form.tableView.selectionModel() self.form.tableView.selectionModel()
self.form.tableView.setItemDelegate(StatusDelegate(self, self.model)) self.form.tableView.setItemDelegate(StatusDelegate(self, self.model))
self.form.tableView.selectionModel().selectionChanged.connect(self.onRowChanged) self.form.tableView.selectionModel().selectionChanged.connect(self.onRowChanged)
self.form.tableView.setWordWrap(False)
if not theme_manager.night_mode: if not theme_manager.night_mode:
self.form.tableView.setStyleSheet( self.form.tableView.setStyleSheet(
"QTableView{ selection-background-color: rgba(150, 150, 150, 50); " "QTableView{ selection-background-color: rgba(150, 150, 150, 50); "
@ -915,31 +923,11 @@ QTableView {{ gridline-color: {grid} }}
def _onSortChanged(self, idx, ord): def _onSortChanged(self, idx, ord):
type = self.model.activeCols[idx] type = self.model.activeCols[idx]
noSort = ("question", "answer", "template", "deck", "note", "noteTags") noSort = ("question", "answer")
if type in noSort: if type in noSort:
if type == "template": showInfo(
showInfo( _("Sorting on this column is not supported. Please " "choose another.")
_( )
"""\
This column can't be sorted on, but you can search for individual card types, \
such as 'card:1'."""
)
)
elif type == "deck":
showInfo(
_(
"""\
This column can't be sorted on, but you can search for specific decks \
by clicking on one on the left."""
)
)
else:
showInfo(
_(
"Sorting on this column is not supported. Please "
"choose another."
)
)
type = self.col.conf["sortType"] type = self.col.conf["sortType"]
if self.col.conf["sortType"] != type: if self.col.conf["sortType"] != type:
self.col.conf["sortType"] = type self.col.conf["sortType"] = type
@ -947,10 +935,14 @@ by clicking on one on the left."""
if type == "noteFld": if type == "noteFld":
ord = not ord ord = not ord
self.col.conf["sortBackwards"] = ord self.col.conf["sortBackwards"] = ord
self.col.setMod()
self.col.save()
self.search() self.search()
else: else:
if self.col.conf["sortBackwards"] != ord: if self.col.conf["sortBackwards"] != ord:
self.col.conf["sortBackwards"] = ord self.col.conf["sortBackwards"] = ord
self.col.setMod()
self.col.save()
self.model.reverse() self.model.reverse()
self.setSortIndicator() self.setSortIndicator()

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import os import os
import re import re
import time import time
from concurrent.futures import Future
from typing import List, Optional from typing import List, Optional
import aqt import aqt
@ -25,7 +26,7 @@ class ExportDialog(QDialog):
): ):
QDialog.__init__(self, mw, Qt.Window) QDialog.__init__(self, mw, Qt.Window)
self.mw = mw self.mw = mw
self.col = mw.col self.col = mw.col.weakref()
self.frm = aqt.forms.exporting.Ui_ExportDialog() self.frm = aqt.forms.exporting.Ui_ExportDialog()
self.frm.setupUi(self) self.frm.setupUi(self)
self.exporter = None self.exporter = None
@ -131,7 +132,7 @@ class ExportDialog(QDialog):
break break
self.hide() self.hide()
if file: if file:
self.mw.progress.start(immediate=True) # check we can write to file
try: try:
f = open(file, "wb") f = open(file, "wb")
f.close() f.close()
@ -139,38 +140,51 @@ class ExportDialog(QDialog):
showWarning(_("Couldn't save file: %s") % str(e)) showWarning(_("Couldn't save file: %s") % str(e))
else: else:
os.unlink(file) os.unlink(file)
exportedMedia = lambda cnt: self.mw.progress.update(
label=ngettext( # progress handler
"Exported %d media file", "Exported %d media files", cnt def exported_media(cnt):
self.mw.taskman.run_on_main(
lambda: self.mw.progress.update(
label=ngettext(
"Exported %d media file", "Exported %d media files", cnt
)
% cnt
) )
% cnt
) )
hooks.media_files_did_export.append(exportedMedia)
def do_export():
self.exporter.exportInto(file) self.exporter.exportInto(file)
hooks.media_files_did_export.remove(exportedMedia)
period = 3000 def on_done(future: Future):
if self.isVerbatim:
msg = _("Collection exported.")
else:
if self.isTextNote:
msg = (
ngettext(
"%d note exported.",
"%d notes exported.",
self.exporter.count,
)
% self.exporter.count
)
else:
msg = (
ngettext(
"%d card exported.",
"%d cards exported.",
self.exporter.count,
)
% self.exporter.count
)
tooltip(msg, period=period)
finally:
self.mw.progress.finish() self.mw.progress.finish()
QDialog.accept(self) hooks.media_files_did_export.remove(exported_media)
# raises if exporter failed
future.result()
self.on_export_finished()
self.mw.progress.start(immediate=True)
hooks.media_files_did_export.append(exported_media)
self.mw.taskman.run_in_background(do_export, on_done)
def on_export_finished(self):
if self.isVerbatim:
msg = _("Collection exported.")
self.mw.reopen()
else:
if self.isTextNote:
msg = (
ngettext(
"%d note exported.", "%d notes exported.", self.exporter.count,
)
% self.exporter.count
)
else:
msg = (
ngettext(
"%d card exported.", "%d cards exported.", self.exporter.count,
)
% self.exporter.count
)
tooltip(msg, period=3000)
QDialog.reject(self)

View file

@ -49,6 +49,43 @@ class _AddCardsDidAddNoteHook:
add_cards_did_add_note = _AddCardsDidAddNoteHook() add_cards_did_add_note = _AddCardsDidAddNoteHook()
class _AddCardsWillAddNoteFilter:
"""Decides whether the note should be added to the collection or
not. It is assumed to come from the addCards window.
reason_to_already_reject is the first reason to reject that
was found, or None. If your filter wants to reject, it should
replace return the reason to reject. Otherwise return the
input."""
_hooks: List[Callable[[Optional[str], "anki.notes.Note"], Optional[str]]] = []
def append(
self, cb: Callable[[Optional[str], "anki.notes.Note"], Optional[str]]
) -> None:
"""(problem: Optional[str], note: anki.notes.Note)"""
self._hooks.append(cb)
def remove(
self, cb: Callable[[Optional[str], "anki.notes.Note"], Optional[str]]
) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, problem: Optional[str], note: anki.notes.Note) -> Optional[str]:
for filter in self._hooks:
try:
problem = filter(problem, note)
except:
# if the hook fails, remove it
self._hooks.remove(filter)
raise
return problem
add_cards_will_add_note = _AddCardsWillAddNoteFilter()
class _AddCardsWillShowHistoryMenuHook: class _AddCardsWillShowHistoryMenuHook:
_hooks: List[Callable[["aqt.addcards.AddCards", QMenu], None]] = [] _hooks: List[Callable[["aqt.addcards.AddCards", QMenu], None]] = []
@ -272,6 +309,30 @@ class _AvPlayerWillPlayHook:
av_player_will_play = _AvPlayerWillPlayHook() av_player_will_play = _AvPlayerWillPlayHook()
class _BackupDidCompleteHook:
_hooks: List[Callable[[], None]] = []
def append(self, cb: Callable[[], None]) -> None:
"""()"""
self._hooks.append(cb)
def remove(self, cb: Callable[[], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self) -> None:
for hook in self._hooks:
try:
hook()
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
backup_did_complete = _BackupDidCompleteHook()
class _BrowserDidChangeRowHook: class _BrowserDidChangeRowHook:
_hooks: List[Callable[["aqt.browser.Browser"], None]] = [] _hooks: List[Callable[["aqt.browser.Browser"], None]] = []
@ -298,6 +359,32 @@ class _BrowserDidChangeRowHook:
browser_did_change_row = _BrowserDidChangeRowHook() browser_did_change_row = _BrowserDidChangeRowHook()
class _BrowserDidSearchHook:
"""Allows you to modify the list of returned card ids from a search."""
_hooks: List[Callable[["aqt.browser.SearchContext"], None]] = []
def append(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
"""(context: aqt.browser.SearchContext)"""
self._hooks.append(cb)
def remove(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, context: aqt.browser.SearchContext) -> None:
for hook in self._hooks:
try:
hook(context)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
browser_did_search = _BrowserDidSearchHook()
class _BrowserMenusDidInitHook: class _BrowserMenusDidInitHook:
_hooks: List[Callable[["aqt.browser.Browser"], None]] = [] _hooks: List[Callable[["aqt.browser.Browser"], None]] = []
@ -423,6 +510,42 @@ class _BrowserWillBuildTreeFilter:
browser_will_build_tree = _BrowserWillBuildTreeFilter() browser_will_build_tree = _BrowserWillBuildTreeFilter()
class _BrowserWillSearchHook:
"""Allows you to modify the search text, or perform your own search.
You can modify context.search to change the text that is sent to the
searching backend.
If you set context.card_ids to a list of ids, the regular search will
not be performed, and the provided ids will be used instead.
Your add-on should check if context.card_ids is not None, and return
without making changes if it has been set.
"""
_hooks: List[Callable[["aqt.browser.SearchContext"], None]] = []
def append(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
"""(context: aqt.browser.SearchContext)"""
self._hooks.append(cb)
def remove(self, cb: Callable[["aqt.browser.SearchContext"], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, context: aqt.browser.SearchContext) -> None:
for hook in self._hooks:
try:
hook(context)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
browser_will_search = _BrowserWillSearchHook()
class _BrowserWillShowHook: class _BrowserWillShowHook:
_hooks: List[Callable[["aqt.browser.Browser"], None]] = [] _hooks: List[Callable[["aqt.browser.Browser"], None]] = []
@ -1206,6 +1329,30 @@ class _MediaSyncDidStartOrStopHook:
media_sync_did_start_or_stop = _MediaSyncDidStartOrStopHook() media_sync_did_start_or_stop = _MediaSyncDidStartOrStopHook()
class _ModelsAdvancedWillShowHook:
_hooks: List[Callable[[QDialog], None]] = []
def append(self, cb: Callable[[QDialog], None]) -> None:
"""(advanced: QDialog)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[QDialog], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, advanced: QDialog) -> None:
for hook in self._hooks:
try:
hook(advanced)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
models_advanced_will_show = _ModelsAdvancedWillShowHook()
class _OverviewDidRefreshHook: class _OverviewDidRefreshHook:
"""Allow to update the overview window. E.g. add the deck name in the """Allow to update the overview window. E.g. add the deck name in the
title.""" title."""

View file

@ -9,6 +9,7 @@ import shutil
import traceback import traceback
import unicodedata import unicodedata
import zipfile import zipfile
from concurrent.futures import Future
import anki.importing as importing import anki.importing as importing
import aqt.deckchooser import aqt.deckchooser
@ -74,6 +75,7 @@ class ChangeMap(QDialog):
self.accept() self.accept()
# called by importFile() when importing a mappable file like .csv
class ImportDialog(QDialog): class ImportDialog(QDialog):
def __init__(self, mw: AnkiQt, importer) -> None: def __init__(self, mw: AnkiQt, importer) -> None:
QDialog.__init__(self, mw, Qt.Window) QDialog.__init__(self, mw, Qt.Window)
@ -192,30 +194,35 @@ you can enter it here. Use \\t to represent tab."""
self.mw.col.decks.select(did) self.mw.col.decks.select(did)
self.mw.progress.start(immediate=True) self.mw.progress.start(immediate=True)
self.mw.checkpoint(_("Import")) self.mw.checkpoint(_("Import"))
try:
self.importer.run() def on_done(future: Future):
except UnicodeDecodeError:
showUnicodeWarning()
return
except Exception as e:
msg = tr(TR.IMPORTING_FAILED_DEBUG_INFO) + "\n"
err = repr(str(e))
if "1-character string" in err:
msg += err
elif "invalidTempFolder" in err:
msg += self.mw.errorHandler.tempFolderMsg()
else:
msg += traceback.format_exc()
showText(msg)
return
finally:
self.mw.progress.finish() self.mw.progress.finish()
txt = _("Importing complete.") + "\n"
if self.importer.log: try:
txt += "\n".join(self.importer.log) future.result()
self.close() except UnicodeDecodeError:
showText(txt) showUnicodeWarning()
self.mw.reset() return
except Exception as e:
msg = tr(TR.IMPORTING_FAILED_DEBUG_INFO) + "\n"
err = repr(str(e))
if "1-character string" in err:
msg += err
elif "invalidTempFolder" in err:
msg += self.mw.errorHandler.tempFolderMsg()
else:
msg += traceback.format_exc()
showText(msg)
return
else:
txt = _("Importing complete.") + "\n"
if self.importer.log:
txt += "\n".join(self.importer.log)
self.close()
showText(txt)
self.mw.reset()
self.mw.taskman.run_in_background(self.importer.run, on_done)
def setupMappingFrame(self): def setupMappingFrame(self):
# qt seems to have a bug with adding/removing from a grid, so we add # qt seems to have a bug with adding/removing from a grid, so we add
@ -380,45 +387,52 @@ def importFile(mw, file):
except: except:
showWarning(invalidZipMsg()) showWarning(invalidZipMsg())
return return
# we need to ask whether to import/replace # we need to ask whether to import/replace; if it's
# a colpkg file then the rest of the import process
# will happen in setupApkgImport()
if not setupApkgImport(mw, importer): if not setupApkgImport(mw, importer):
return return
# importing non-colpkg files
mw.progress.start(immediate=True) mw.progress.start(immediate=True)
try:
def on_done(future: Future):
mw.progress.finish()
try: try:
importer.run() future.result()
finally: except zipfile.BadZipfile:
mw.progress.finish() showWarning(invalidZipMsg())
except zipfile.BadZipfile: except Exception as e:
showWarning(invalidZipMsg()) err = repr(str(e))
except Exception as e: if "invalidFile" in err:
err = repr(str(e)) msg = _(
if "invalidFile" in err:
msg = _(
"""\
Invalid file. Please restore from backup."""
)
showWarning(msg)
elif "invalidTempFolder" in err:
showWarning(mw.errorHandler.tempFolderMsg())
elif "readonly" in err:
showWarning(
_(
"""\ """\
Unable to import from a read-only file.""" Invalid file. Please restore from backup."""
) )
) showWarning(msg)
elif "invalidTempFolder" in err:
showWarning(mw.errorHandler.tempFolderMsg())
elif "readonly" in err:
showWarning(
_(
"""\
Unable to import from a read-only file."""
)
)
else:
msg = tr(TR.IMPORTING_FAILED_DEBUG_INFO) + "\n"
msg += str(traceback.format_exc())
showText(msg)
else: else:
msg = tr(TR.IMPORTING_FAILED_DEBUG_INFO) + "\n" log = "\n".join(importer.log)
msg += str(traceback.format_exc()) if "\n" not in log:
showText(msg) tooltip(log)
else: else:
log = "\n".join(importer.log) showText(log)
if "\n" not in log:
tooltip(log) mw.reset()
else:
showText(log) mw.taskman.run_in_background(importer.run, on_done)
mw.reset()
def invalidZipMsg(): def invalidZipMsg():
@ -459,48 +473,57 @@ def replaceWithApkg(mw, file, backup):
mw.unloadCollection(lambda: _replaceWithApkg(mw, file, backup)) mw.unloadCollection(lambda: _replaceWithApkg(mw, file, backup))
def _replaceWithApkg(mw, file, backup): def _replaceWithApkg(mw, filename, backup):
mw.progress.start(immediate=True) mw.progress.start(immediate=True)
z = zipfile.ZipFile(file) def do_import():
z = zipfile.ZipFile(filename)
# v2 scheduler? # v2 scheduler?
colname = "collection.anki21" colname = "collection.anki21"
try: try:
z.getinfo(colname) z.getinfo(colname)
except KeyError: except KeyError:
colname = "collection.anki2" colname = "collection.anki2"
try:
with z.open(colname) as source, open(mw.pm.collectionPath(), "wb") as target: with z.open(colname) as source, open(mw.pm.collectionPath(), "wb") as target:
shutil.copyfileobj(source, target) shutil.copyfileobj(source, target)
except:
d = os.path.join(mw.pm.profileFolder(), "collection.media")
for n, (cStr, file) in enumerate(
json.loads(z.read("media").decode("utf8")).items()
):
mw.taskman.run_on_main(
lambda n=n: mw.progress.update(
ngettext("Processed %d media file", "Processed %d media files", n)
% n
)
)
size = z.getinfo(cStr).file_size
dest = os.path.join(d, unicodedata.normalize("NFC", file))
# if we have a matching file size
if os.path.exists(dest) and size == os.stat(dest).st_size:
continue
data = z.read(cStr)
open(dest, "wb").write(data)
z.close()
def on_done(future: Future):
mw.progress.finish() mw.progress.finish()
showWarning(_("The provided file is not a valid .apkg file."))
return try:
# because users don't have a backup of media, it's safer to import new future.result()
# data and rely on them running a media db check to get rid of any except Exception as e:
# unwanted media. in the future we might also want to deduplicate this print(e)
# step showWarning(_("The provided file is not a valid .apkg file."))
d = os.path.join(mw.pm.profileFolder(), "collection.media") return
for n, (cStr, file) in enumerate(
json.loads(z.read("media").decode("utf8")).items() if not mw.loadCollection():
): return
mw.progress.update( if backup:
ngettext("Processed %d media file", "Processed %d media files", n) % n mw.col.modSchema(check=False)
)
size = z.getinfo(cStr).file_size tooltip(_("Importing complete."))
dest = os.path.join(d, unicodedata.normalize("NFC", file))
# if we have a matching file size mw.taskman.run_in_background(do_import, on_done)
if os.path.exists(dest) and size == os.stat(dest).st_size:
continue
data = z.read(cStr)
open(dest, "wb").write(data)
z.close()
# reload
if not mw.loadCollection():
mw.progress.finish()
return
if backup:
mw.col.modSchema(check=False)
mw.progress.finish()

View file

@ -12,6 +12,7 @@ import signal
import time import time
import zipfile import zipfile
from argparse import Namespace from argparse import Namespace
from concurrent.futures import Future
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
@ -28,6 +29,7 @@ from anki import hooks
from anki.collection import _Collection from anki.collection import _Collection
from anki.hooks import runHook from anki.hooks import runHook
from anki.lang import _, ngettext from anki.lang import _, ngettext
from anki.rsbackend import RustBackend
from anki.sound import AVTag, SoundOrVideoTag from anki.sound import AVTag, SoundOrVideoTag
from anki.storage import Collection from anki.storage import Collection
from anki.utils import devMode, ids2str, intTime, isMac, isWin, splitFields from anki.utils import devMode, ids2str, intTime, isMac, isWin, splitFields
@ -77,10 +79,12 @@ class AnkiQt(QMainWindow):
self, self,
app: QApplication, app: QApplication,
profileManager: ProfileManagerType, profileManager: ProfileManagerType,
backend: RustBackend,
opts: Namespace, opts: Namespace,
args: List[Any], args: List[Any],
) -> None: ) -> None:
QMainWindow.__init__(self) QMainWindow.__init__(self)
self.backend = backend
self.state = "startup" self.state = "startup"
self.opts = opts self.opts = opts
self.col: Optional[_Collection] = None self.col: Optional[_Collection] = None
@ -393,7 +397,7 @@ close the profile or restart Anki."""
# at this point there should be no windows left # at this point there should be no windows left
self._checkForUnclosedWidgets() self._checkForUnclosedWidgets()
self.maybeAutoSync(True) self.maybeAutoSync()
def _checkForUnclosedWidgets(self) -> None: def _checkForUnclosedWidgets(self) -> None:
for w in self.app.topLevelWidgets(): for w in self.app.topLevelWidgets():
@ -458,18 +462,22 @@ close the profile or restart Anki."""
def _loadCollection(self) -> bool: def _loadCollection(self) -> bool:
cpath = self.pm.collectionPath() cpath = self.pm.collectionPath()
self.col = Collection(cpath, backend=self.backend)
self.col = Collection(cpath, log=True)
self.setEnabled(True) self.setEnabled(True)
self.progress.setupDB(self.col.db)
self.maybeEnableUndo() self.maybeEnableUndo()
gui_hooks.collection_did_load(self.col)
self.moveToState("deckBrowser") self.moveToState("deckBrowser")
return True return True
def reopen(self):
cpath = self.pm.collectionPath()
self.col = Collection(cpath, backend=self.backend)
def unloadCollection(self, onsuccess: Callable) -> None: def unloadCollection(self, onsuccess: Callable) -> None:
def callback(): def callback():
self.setEnabled(False) self.setEnabled(False)
self.media_syncer.show_diag_until_finished()
self._unloadCollection() self._unloadCollection()
onsuccess() onsuccess()
@ -561,6 +569,7 @@ from the profile screen."
fname = backups.pop(0) fname = backups.pop(0)
path = os.path.join(dir, fname) path = os.path.join(dir, fname)
os.unlink(path) os.unlink(path)
gui_hooks.backup_did_complete()
def maybeOptimize(self) -> None: def maybeOptimize(self) -> None:
# have two weeks passed? # have two weeks passed?
@ -594,14 +603,6 @@ from the profile screen."
self.maybe_check_for_addon_updates() self.maybe_check_for_addon_updates()
self.deckBrowser.show() self.deckBrowser.show()
def _colLoadingState(self, oldState) -> None:
"Run once, when col is loaded."
self.enableColMenuItems()
# ensure cwd is set if media dir exists
self.col.media.dir()
gui_hooks.collection_did_load(self.col)
self.moveToState("overview")
def _selectedDeck(self) -> Optional[Dict[str, Any]]: def _selectedDeck(self) -> Optional[Dict[str, Any]]:
did = self.col.decks.selected() did = self.col.decks.selected()
if not self.col.decks.nameOrNone(did): if not self.col.decks.nameOrNone(did):
@ -753,10 +754,7 @@ title="%s" %s>%s</button>""" % (
signal.signal(signal.SIGINT, self.onSigInt) signal.signal(signal.SIGINT, self.onSigInt)
def onSigInt(self, signum, frame): def onSigInt(self, signum, frame):
# interrupt any current transaction and schedule a rollback & quit # schedule a rollback & quit
if self.col:
self.col.db.interrupt()
def quit(): def quit():
self.col.db.rollback() self.col.db.rollback()
self.close() self.close()
@ -841,7 +839,7 @@ title="%s" %s>%s</button>""" % (
self.media_syncer.start() self.media_syncer.start()
# expects a current profile, but no collection loaded # expects a current profile, but no collection loaded
def maybeAutoSync(self, closing=False) -> None: def maybeAutoSync(self) -> None:
if ( if (
not self.pm.profile["syncKey"] not self.pm.profile["syncKey"]
or not self.pm.profile["autoSync"] or not self.pm.profile["autoSync"]
@ -853,10 +851,6 @@ title="%s" %s>%s</button>""" % (
# ok to sync # ok to sync
self._sync() self._sync()
# if media still syncing at this point, pop up progress diag
if closing:
self.media_syncer.show_diag_until_finished()
def maybe_auto_sync_media(self) -> None: def maybe_auto_sync_media(self) -> None:
if not self.pm.profile["autoSync"] or self.safeMode or self.restoringBackup: if not self.pm.profile["autoSync"] or self.safeMode or self.restoringBackup:
return return
@ -1262,25 +1256,29 @@ will be lost. Continue?"""
def onCheckDB(self): def onCheckDB(self):
"True if no problems" "True if no problems"
self.progress.start(immediate=True) self.progress.start()
ret, ok = self.col.fixIntegrity()
self.progress.finish()
if not ok:
showText(ret)
else:
tooltip(ret)
# if an error has directed the user to check the database, def onDone(future: Future):
# silently clean up any broken reset hooks which distract from self.progress.finish()
# the underlying issue ret, ok = future.result()
while True:
try: if not ok:
self.reset() showText(ret)
break else:
except Exception as e: tooltip(ret)
print("swallowed exception in reset hook:", e)
continue # if an error has directed the user to check the database,
return ret # silently clean up any broken reset hooks which distract from
# the underlying issue
while True:
try:
self.reset()
break
except Exception as e:
print("swallowed exception in reset hook:", e)
continue
self.taskman.run_in_background(self.col.fixIntegrity, onDone)
def on_check_media_db(self) -> None: def on_check_media_db(self) -> None:
check_media_db(self) check_media_db(self)
@ -1363,11 +1361,42 @@ will be lost. Continue?"""
sys.stderr = self._oldStderr sys.stderr = self._oldStderr
sys.stdout = self._oldStdout sys.stdout = self._oldStdout
def _debugCard(self): def _card_repr(self, card: anki.cards.Card) -> None:
return self.reviewer.card.__dict__ import pprint, copy
def _debugBrowserCard(self): if not card:
return aqt.dialogs._dialogs["Browser"][1].card.__dict__ print("no card")
return
print("Front:", card.question())
print("\n")
print("Back:", card.answer())
print("\nNote:")
note = copy.copy(card.note())
for k, v in note.items():
print(f"- {k}:", v)
print("\n")
del note.fields
del note._fmap
del note._model
pprint.pprint(note.__dict__)
print("\nCard:")
c = copy.copy(card)
c._render_output = None
pprint.pprint(c.__dict__)
def _debugCard(self) -> Optional[anki.cards.Card]:
card = self.reviewer.card
self._card_repr(card)
return card
def _debugBrowserCard(self) -> Optional[anki.cards.Card]:
card = aqt.dialogs._dialogs["Browser"][1].card
self._card_repr(card)
return card
def onDebugPrint(self, frm): def onDebugPrint(self, frm):
cursor = frm.text.textCursor() cursor = frm.text.textCursor()
@ -1528,7 +1557,6 @@ Please ensure a profile is open and Anki is not busy, then try again."""
gc.disable() gc.disable()
def doGC(self) -> None: def doGC(self) -> None:
assert not self.progress.inDB
gc.collect() gc.collect()
# Crash log # Crash log

View file

@ -40,7 +40,6 @@ class MediaChecker:
def check(self) -> None: def check(self) -> None:
self.progress_dialog = self.mw.progress.start() self.progress_dialog = self.mw.progress.start()
hooks.bg_thread_progress_callback.append(self._on_progress) hooks.bg_thread_progress_callback.append(self._on_progress)
self.mw.col.close()
self.mw.taskman.run_in_background(self._check, self._on_finished) self.mw.taskman.run_in_background(self._check, self._on_finished)
def _on_progress(self, proceed: bool, progress: Progress) -> bool: def _on_progress(self, proceed: bool, progress: Progress) -> bool:
@ -61,7 +60,6 @@ class MediaChecker:
hooks.bg_thread_progress_callback.remove(self._on_progress) hooks.bg_thread_progress_callback.remove(self._on_progress)
self.mw.progress.finish() self.mw.progress.finish()
self.progress_dialog = None self.progress_dialog = None
self.mw.col.reopen()
exc = future.exception() exc = future.exception()
if isinstance(exc, Interrupted): if isinstance(exc, Interrupted):

View file

@ -11,7 +11,14 @@ from typing import List, Union
import aqt import aqt
from anki import hooks from anki import hooks
from anki.consts import SYNC_BASE from anki.consts import SYNC_BASE
from anki.rsbackend import TR, Interrupted, MediaSyncProgress, Progress, ProgressKind from anki.rsbackend import (
TR,
Interrupted,
MediaSyncProgress,
NetworkError,
Progress,
ProgressKind,
)
from anki.types import assert_impossible from anki.types import assert_impossible
from anki.utils import intTime from anki.utils import intTime
from aqt import gui_hooks from aqt import gui_hooks
@ -100,6 +107,10 @@ class MediaSyncer:
if isinstance(exc, Interrupted): if isinstance(exc, Interrupted):
self._log_and_notify(tr(TR.SYNC_MEDIA_ABORTED)) self._log_and_notify(tr(TR.SYNC_MEDIA_ABORTED))
return return
elif isinstance(exc, NetworkError):
# avoid popups for network errors
self._log_and_notify(str(exc))
return
self._log_and_notify(tr(TR.SYNC_MEDIA_FAILED)) self._log_and_notify(tr(TR.SYNC_MEDIA_FAILED))
showWarning(str(exc)) showWarning(str(exc))

View file

@ -6,7 +6,7 @@ from operator import itemgetter
import aqt.clayout import aqt.clayout
from anki import stdmodels from anki import stdmodels
from anki.lang import _, ngettext from anki.lang import _, ngettext
from aqt import AnkiQt from aqt import AnkiQt, gui_hooks
from aqt.qt import * from aqt.qt import *
from aqt.utils import ( from aqt.utils import (
askUser, askUser,
@ -124,6 +124,7 @@ class Models(QDialog):
d.setWindowTitle(_("Options for %s") % self.model["name"]) d.setWindowTitle(_("Options for %s") % self.model["name"])
frm.buttonBox.helpRequested.connect(lambda: openHelp("latex")) frm.buttonBox.helpRequested.connect(lambda: openHelp("latex"))
restoreGeom(d, "modelopts") restoreGeom(d, "modelopts")
gui_hooks.models_advanced_will_show(d)
d.exec_() d.exec_()
saveGeom(d, "modelopts") saveGeom(d, "modelopts")
self.model["latexsvg"] = frm.latexsvg.isChecked() self.model["latexsvg"] = frm.latexsvg.isChecked()

View file

@ -62,6 +62,8 @@ class Preferences(QDialog):
lang = anki.lang.currentLang lang = anki.lang.currentLang
if lang in anki.lang.compatMap: if lang in anki.lang.compatMap:
lang = anki.lang.compatMap[lang] lang = anki.lang.compatMap[lang]
else:
lang = lang.replace("-", "_")
try: try:
return codes.index(lang) return codes.index(lang)
except: except:
@ -98,7 +100,7 @@ class Preferences(QDialog):
f.new_timezone.setVisible(False) f.new_timezone.setVisible(False)
else: else:
f.newSched.setChecked(True) f.newSched.setChecked(True)
f.new_timezone.setChecked(self.mw.col.sched._new_timezone_enabled()) f.new_timezone.setChecked(self.mw.col.sched.new_timezone_enabled())
def updateCollection(self): def updateCollection(self):
f = self.form f = self.form
@ -124,7 +126,7 @@ class Preferences(QDialog):
qc["dayLearnFirst"] = f.dayLearnFirst.isChecked() qc["dayLearnFirst"] = f.dayLearnFirst.isChecked()
self._updateDayCutoff() self._updateDayCutoff()
if self.mw.col.schedVer() != 1: if self.mw.col.schedVer() != 1:
was_enabled = self.mw.col.sched._new_timezone_enabled() was_enabled = self.mw.col.sched.new_timezone_enabled()
is_enabled = f.new_timezone.isChecked() is_enabled = f.new_timezone.isChecked()
if was_enabled != is_enabled: if was_enabled != is_enabled:
if is_enabled: if is_enabled:

View file

@ -11,10 +11,6 @@ import aqt.forms
from anki.lang import _ from anki.lang import _
from aqt.qt import * from aqt.qt import *
# fixme: if mw->subwindow opens a progress dialog with mw as the parent, mw
# gets raised on finish on compiz. perhaps we should be using the progress
# dialog as the parent?
# Progress info # Progress info
########################################################################## ##########################################################################
@ -25,47 +21,18 @@ class ProgressManager:
self.app = QApplication.instance() self.app = QApplication.instance()
self.inDB = False self.inDB = False
self.blockUpdates = False self.blockUpdates = False
self._show_timer: Optional[QTimer] = None
self._win = None self._win = None
self._levels = 0 self._levels = 0
# SQLite progress handler
##########################################################################
def setupDB(self, db):
"Install a handler in the current DB."
self.lastDbProgress = 0
self.inDB = False
db.set_progress_handler(self._dbProgress, 10000)
def _dbProgress(self):
"Called from SQLite."
# do nothing if we don't have a progress window
if not self._win:
return
# make sure we're not executing too frequently
if (time.time() - self.lastDbProgress) < 0.01:
return
self.lastDbProgress = time.time()
# and we're in the main thread
if not self.mw.inMainThread():
return
# ensure timers don't fire
self.inDB = True
# handle GUI events
if not self.blockUpdates:
self._maybeShow()
self.app.processEvents(QEventLoop.ExcludeUserInputEvents)
self.inDB = False
# Safer timers # Safer timers
########################################################################## ##########################################################################
# QTimer may fire in processEvents(). We provide a custom timer which # A custom timer which avoids firing while a progress dialog is active
# automatically defers until the DB is not busy, and avoids running # (likely due to some long-running DB operation)
# while a progress window is visible.
def timer(self, ms, func, repeat, requiresCollection=True): def timer(self, ms, func, repeat, requiresCollection=True):
def handler(): def handler():
if self.inDB or self._levels: if self._levels:
# retry in 100ms # retry in 100ms
self.timer(100, func, False, requiresCollection) self.timer(100, func, False, requiresCollection)
elif not self.mw.col and requiresCollection: elif not self.mw.col and requiresCollection:
@ -114,10 +81,17 @@ class ProgressManager:
self._firstTime = time.time() self._firstTime = time.time()
self._lastUpdate = time.time() self._lastUpdate = time.time()
self._updating = False self._updating = False
self._show_timer = QTimer(self.mw)
self._show_timer.setSingleShot(True)
self._show_timer.start(600)
self._show_timer.timeout.connect(self._on_show_timer) # type: ignore
return self._win return self._win
def update(self, label=None, value=None, process=True, maybeShow=True): def update(self, label=None, value=None, process=True, maybeShow=True):
# print self._min, self._counter, self._max, label, time.time() - self._lastTime # print self._min, self._counter, self._max, label, time.time() - self._lastTime
if not self.mw.inMainThread():
print("progress.update() called on wrong thread")
return
if self._updating: if self._updating:
return return
if maybeShow: if maybeShow:
@ -143,6 +117,9 @@ class ProgressManager:
if self._win: if self._win:
self._closeWin() self._closeWin()
self._unsetBusy() self._unsetBusy()
if self._show_timer:
self._show_timer.stop()
self._show_timer = None
def clear(self): def clear(self):
"Restore the interface after an error." "Restore the interface after an error."
@ -189,6 +166,10 @@ class ProgressManager:
"True if processing." "True if processing."
return self._levels return self._levels
def _on_show_timer(self):
self._show_timer = None
self._showWin()
class ProgressDialog(QDialog): class ProgressDialog(QDialog):
def __init__(self, parent): def __init__(self, parent):

View file

@ -393,8 +393,9 @@ class SimpleMplayerSlaveModePlayer(SimpleMplayerPlayer):
The trailing newline is automatically added.""" The trailing newline is automatically added."""
str_args = [str(x) for x in args] str_args = [str(x) for x in args]
self._process.stdin.write(" ".join(str_args).encode("utf8") + b"\n") if self._process:
self._process.stdin.flush() self._process.stdin.write(" ".join(str_args).encode("utf8") + b"\n")
self._process.stdin.flush()
def seek_relative(self, secs: int) -> None: def seek_relative(self, secs: int) -> None:
self.command("seek", secs, 0) self.command("seek", secs, 0)

View file

@ -364,7 +364,7 @@ class SyncThread(QThread):
self.syncMsg = "" self.syncMsg = ""
self.uname = "" self.uname = ""
try: try:
self.col = Collection(self.path, log=True) self.col = Collection(self.path)
except: except:
self.fireEvent("corrupt") self.fireEvent("corrupt")
return return

View file

@ -4,7 +4,7 @@
import platform import platform
import sys import sys
from typing import Dict from typing import Dict, Optional
from anki.utils import isMac from anki.utils import isMac
from aqt import QApplication, gui_hooks, isWin from aqt import QApplication, gui_hooks, isWin
@ -17,6 +17,7 @@ class ThemeManager:
_icon_cache_light: Dict[str, QIcon] = {} _icon_cache_light: Dict[str, QIcon] = {}
_icon_cache_dark: Dict[str, QIcon] = {} _icon_cache_dark: Dict[str, QIcon] = {}
_icon_size = 128 _icon_size = 128
_macos_dark_mode_cached: Optional[bool] = None
def macos_dark_mode(self) -> bool: def macos_dark_mode(self) -> bool:
if not getattr(sys, "frozen", False): if not getattr(sys, "frozen", False):
@ -25,9 +26,13 @@ class ThemeManager:
return False return False
if qtminor < 13: if qtminor < 13:
return False return False
import darkdetect # pylint: disable=import-error if self._macos_dark_mode_cached is None:
import darkdetect # pylint: disable=import-error
return darkdetect.isDark() is True # cache the value, as the interface gets messed up
# if the value changes after starting Anki
self._macos_dark_mode_cached = darkdetect.isDark() is True
return self._macos_dark_mode_cached
def get_night_mode(self) -> bool: def get_night_mode(self) -> bool:
return self.macos_dark_mode() or self._night_mode_preference return self.macos_dark_mode() or self._night_mode_preference

View file

@ -235,6 +235,26 @@ hooks = [
return True return True
""", """,
), ),
Hook(
name="browser_will_search",
args=["context: aqt.browser.SearchContext"],
doc="""Allows you to modify the search text, or perform your own search.
You can modify context.search to change the text that is sent to the
searching backend.
If you set context.card_ids to a list of ids, the regular search will
not be performed, and the provided ids will be used instead.
Your add-on should check if context.card_ids is not None, and return
without making changes if it has been set.
""",
),
Hook(
name="browser_did_search",
args=["context: aqt.browser.SearchContext"],
doc="""Allows you to modify the list of returned card ids from a search.""",
),
# States # States
################### ###################
Hook( Hook(
@ -341,6 +361,7 @@ hooks = [
), ),
# Main # Main
################### ###################
Hook(name="backup_did_complete"),
Hook(name="profile_did_open", legacy_hook="profileLoaded"), Hook(name="profile_did_open", legacy_hook="profileLoaded"),
Hook(name="profile_will_close", legacy_hook="unloadProfile"), Hook(name="profile_will_close", legacy_hook="unloadProfile"),
Hook( Hook(
@ -412,6 +433,18 @@ def emptyNewCard():
args=["note: anki.notes.Note"], args=["note: anki.notes.Note"],
legacy_hook="AddCards.noteAdded", legacy_hook="AddCards.noteAdded",
), ),
Hook(
name="add_cards_will_add_note",
args=["problem: Optional[str]", "note: anki.notes.Note"],
return_type="Optional[str]",
doc="""Decides whether the note should be added to the collection or
not. It is assumed to come from the addCards window.
reason_to_already_reject is the first reason to reject that
was found, or None. If your filter wants to reject, it should
replace return the reason to reject. Otherwise return the
input.""",
),
# Editing # Editing
################### ###################
Hook( Hook(
@ -503,6 +536,9 @@ def emptyNewCard():
args=["dialog: aqt.addons.AddonsDialog", "add_on: aqt.addons.AddonMeta"], args=["dialog: aqt.addons.AddonsDialog", "add_on: aqt.addons.AddonMeta"],
doc="""Allows doing an action when a single add-on is selected.""", doc="""Allows doing an action when a single add-on is selected.""",
), ),
# Model
###################
Hook(name="models_advanced_will_show", args=["advanced: QDialog"],),
# Other # Other
################### ###################
Hook( Hook(

View file

@ -45,6 +45,8 @@ img {
#typeans { #typeans {
width: 100%; width: 100%;
// https://anki.tenderapp.com/discussions/beta-testing/1854-using-margin-auto-causes-horizontal-scrollbar-on-typesomething
box-sizing: border-box;
} }
.typeGood { .typeGood {

View file

@ -1,6 +1,6 @@
[package] [package]
name = "anki" name = "anki"
version = "2.1.22" # automatically updated version = "2.1.24" # automatically updated
edition = "2018" edition = "2018"
authors = ["Ankitects Pty Ltd and contributors"] authors = ["Ankitects Pty Ltd and contributors"]
license = "AGPL-3.0-or-later" license = "AGPL-3.0-or-later"
@ -36,12 +36,15 @@ slog = { version = "2.5.2", features = ["max_level_trace", "release_max_level_de
slog-term = "2.5.0" slog-term = "2.5.0"
slog-async = "2.4.0" slog-async = "2.4.0"
slog-envlogger = "2.2.0" slog-envlogger = "2.2.0"
serde_repr = "0.1.5"
num_enum = "0.4.2"
unicase = "2.6.0"
[target.'cfg(target_vendor="apple")'.dependencies] [target.'cfg(target_vendor="apple")'.dependencies]
rusqlite = { version = "0.21.0", features = ["trace"] } rusqlite = { version = "0.21.0", features = ["trace", "functions", "collation"] }
[target.'cfg(not(target_vendor="apple"))'.dependencies] [target.'cfg(not(target_vendor="apple"))'.dependencies]
rusqlite = { version = "0.21.0", features = ["trace", "bundled"] } rusqlite = { version = "0.21.0", features = ["trace", "functions", "collation", "bundled"] }
[target.'cfg(linux)'.dependencies] [target.'cfg(linux)'.dependencies]
reqwest = { version = "0.10.1", features = ["json", "native-tls-vendored"] } reqwest = { version = "0.10.1", features = ["json", "native-tls-vendored"] }

View file

@ -25,7 +25,7 @@ develop: .build/vernum ftl/repo
ftl/repo: ftl/repo:
(cd ftl && ./scripts/fetch-latest-translations) (cd ftl && ./scripts/fetch-latest-translations)
ALL_SOURCE := $(shell ${FIND} src -type f) $(wildcard ftl/*.ftl) ALL_SOURCE := $(shell ${FIND} src -type f | egrep -v "i18n/autogen|i18n/ftl|_proto.rs") $(wildcard ftl/*.ftl)
# nightly currently required for ignoring files in rustfmt.toml # nightly currently required for ignoring files in rustfmt.toml
RUST_TOOLCHAIN := $(shell cat rust-toolchain) RUST_TOOLCHAIN := $(shell cat rust-toolchain)

View file

@ -31,3 +31,4 @@ sync-client-too-old =
sync-wrong-pass = AnkiWeb ID or password was incorrect; please try again. sync-wrong-pass = AnkiWeb ID or password was incorrect; please try again.
sync-resync-required = sync-resync-required =
Please sync again. If this message keeps appearing, please post on the support site. Please sync again. If this message keeps appearing, please post on the support site.
sync-must-wait-for-end = Anki is currently syncing. Please wait for the sync to complete, then try again.

View file

@ -0,0 +1,155 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::Result;
use crate::storage::StorageContext;
use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
use rusqlite::OptionalExtension;
use serde_derive::{Deserialize, Serialize};
#[derive(Deserialize)]
#[serde(tag = "kind", rename_all = "lowercase")]
pub(super) enum DBRequest {
Query {
sql: String,
args: Vec<SqlValue>,
first_row_only: bool,
},
Begin,
Commit,
Rollback,
ExecuteMany {
sql: String,
args: Vec<Vec<SqlValue>>,
},
}
#[derive(Serialize)]
#[serde(untagged)]
pub(super) enum DBResult {
Rows(Vec<Vec<SqlValue>>),
None,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub(super) enum SqlValue {
Null,
String(String),
Int(i64),
Double(f64),
Blob(Vec<u8>),
}
impl ToSql for SqlValue {
fn to_sql(&self) -> std::result::Result<ToSqlOutput<'_>, rusqlite::Error> {
let val = match self {
SqlValue::Null => ValueRef::Null,
SqlValue::String(v) => ValueRef::Text(v.as_bytes()),
SqlValue::Int(v) => ValueRef::Integer(*v),
SqlValue::Double(v) => ValueRef::Real(*v),
SqlValue::Blob(v) => ValueRef::Blob(&v),
};
Ok(ToSqlOutput::Borrowed(val))
}
}
impl FromSql for SqlValue {
fn column_result(value: ValueRef<'_>) -> std::result::Result<Self, FromSqlError> {
let val = match value {
ValueRef::Null => SqlValue::Null,
ValueRef::Integer(i) => SqlValue::Int(i),
ValueRef::Real(v) => SqlValue::Double(v),
ValueRef::Text(v) => SqlValue::String(String::from_utf8_lossy(v).to_string()),
ValueRef::Blob(v) => SqlValue::Blob(v.to_vec()),
};
Ok(val)
}
}
pub(super) fn db_command_bytes(ctx: &StorageContext, input: &[u8]) -> Result<String> {
let req: DBRequest = serde_json::from_slice(input)?;
let resp = match req {
DBRequest::Query {
sql,
args,
first_row_only,
} => {
if first_row_only {
db_query_row(ctx, &sql, &args)?
} else {
db_query(ctx, &sql, &args)?
}
}
DBRequest::Begin => {
ctx.begin_trx()?;
DBResult::None
}
DBRequest::Commit => {
ctx.commit_trx()?;
DBResult::None
}
DBRequest::Rollback => {
ctx.rollback_trx()?;
DBResult::None
}
DBRequest::ExecuteMany { sql, args } => db_execute_many(ctx, &sql, &args)?,
};
Ok(serde_json::to_string(&resp)?)
}
pub(super) fn db_query_row(ctx: &StorageContext, sql: &str, args: &[SqlValue]) -> Result<DBResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
let columns = stmt.column_count();
let row = stmt
.query_row(args, |row| {
let mut orow = Vec::with_capacity(columns);
for i in 0..columns {
let v: SqlValue = row.get(i)?;
orow.push(v);
}
Ok(orow)
})
.optional()?;
let rows = if let Some(row) = row {
vec![row]
} else {
vec![]
};
Ok(DBResult::Rows(rows))
}
pub(super) fn db_query(ctx: &StorageContext, sql: &str, args: &[SqlValue]) -> Result<DBResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
let columns = stmt.column_count();
let res: std::result::Result<Vec<Vec<_>>, rusqlite::Error> = stmt
.query_map(args, |row| {
let mut orow = Vec::with_capacity(columns);
for i in 0..columns {
let v: SqlValue = row.get(i)?;
orow.push(v);
}
Ok(orow)
})?
.collect();
Ok(DBResult::Rows(res?))
}
pub(super) fn db_execute_many(
ctx: &StorageContext,
sql: &str,
args: &[Vec<SqlValue>],
) -> Result<DBResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
for params in args {
stmt.execute(params)?;
}
Ok(DBResult::None)
}

View file

@ -1,8 +1,11 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::backend::dbproxy::db_command_bytes;
use crate::backend_proto::backend_input::Value; use crate::backend_proto::backend_input::Value;
use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn}; use crate::backend_proto::{BuiltinSortKind, Empty, RenderedTemplateReplacement, SyncMediaIn};
use crate::collection::{open_collection, Collection};
use crate::config::SortKind;
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind}; use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
use crate::i18n::{tr_args, FString, I18n}; use crate::i18n::{tr_args, FString, I18n};
use crate::latex::{extract_latex, extract_latex_expanding_clozes, ExtractedLatex}; use crate::latex::{extract_latex, extract_latex_expanding_clozes, ExtractedLatex};
@ -12,6 +15,7 @@ use crate::media::sync::MediaSyncProgress;
use crate::media::MediaManager; use crate::media::MediaManager;
use crate::sched::cutoff::{local_minutes_west_for_stamp, sched_timing_today}; use crate::sched::cutoff::{local_minutes_west_for_stamp, sched_timing_today};
use crate::sched::timespan::{answer_button_time, learning_congrats, studied_today, time_span}; use crate::sched::timespan::{answer_button_time, learning_congrats, studied_today, time_span};
use crate::search::{search_cards, search_notes, SortMode};
use crate::template::{ use crate::template::{
render_card, without_legacy_template_directives, FieldMap, FieldRequirements, ParsedTemplate, render_card, without_legacy_template_directives, FieldMap, FieldRequirements, ParsedTemplate,
RenderedNode, RenderedNode,
@ -22,18 +26,18 @@ use fluent::FluentValue;
use prost::Message; use prost::Message;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
mod dbproxy;
pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>; pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>;
pub struct Backend { pub struct Backend {
#[allow(dead_code)] col: Arc<Mutex<Option<Collection>>>,
col_path: PathBuf,
media_folder: PathBuf,
media_db: String,
progress_callback: Option<ProtoProgressCallback>, progress_callback: Option<ProtoProgressCallback>,
i18n: I18n, i18n: I18n,
log: Logger, server: bool,
} }
enum Progress<'a> { enum Progress<'a> {
@ -55,6 +59,8 @@ fn anki_error_to_proto_error(err: AnkiError, i18n: &I18n) -> pb::BackendError {
} }
AnkiError::SyncError { kind, .. } => V::SyncError(pb::SyncError { kind: kind.into() }), AnkiError::SyncError { kind, .. } => V::SyncError(pb::SyncError { kind: kind.into() }),
AnkiError::Interrupted => V::Interrupted(Empty {}), AnkiError::Interrupted => V::Interrupted(Empty {}),
AnkiError::CollectionNotOpen => V::InvalidInput(pb::Empty {}),
AnkiError::CollectionAlreadyOpen => V::InvalidInput(pb::Empty {}),
}; };
pb::BackendError { pb::BackendError {
@ -103,50 +109,27 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
Err(_) => return Err("couldn't decode init request".into()), Err(_) => return Err("couldn't decode init request".into()),
}; };
let mut path = input.collection_path.clone();
path.push_str(".log");
let log_path = match input.log_path.as_str() {
"" => None,
path => Some(path),
};
let logger =
default_logger(log_path).map_err(|e| format!("Unable to open log file: {:?}", e))?;
let i18n = I18n::new( let i18n = I18n::new(
&input.preferred_langs, &input.preferred_langs,
input.locale_folder_path, input.locale_folder_path,
log::terminal(), log::terminal(),
); );
match Backend::new( Ok(Backend::new(i18n, input.server))
&input.collection_path,
&input.media_folder_path,
&input.media_db_path,
i18n,
logger,
) {
Ok(backend) => Ok(backend),
Err(e) => Err(format!("{:?}", e)),
}
} }
impl Backend { impl Backend {
pub fn new( pub fn new(i18n: I18n, server: bool) -> Backend {
col_path: &str, Backend {
media_folder: &str, col: Arc::new(Mutex::new(None)),
media_db: &str,
i18n: I18n,
log: Logger,
) -> Result<Backend> {
Ok(Backend {
col_path: col_path.into(),
media_folder: media_folder.into(),
media_db: media_db.into(),
progress_callback: None, progress_callback: None,
i18n, i18n,
log, server,
}) }
}
pub fn i18n(&self) -> &I18n {
&self.i18n
} }
/// Decode a request, process it, and return the encoded result. /// Decode a request, process it, and return the encoded result.
@ -172,6 +155,22 @@ impl Backend {
buf buf
} }
/// If collection is open, run the provided closure while holding
/// the mutex.
/// If collection is not open, return an error.
fn with_col<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&mut Collection) -> Result<T>,
{
func(
self.col
.lock()
.unwrap()
.as_mut()
.ok_or(AnkiError::CollectionNotOpen)?,
)
}
fn run_command(&mut self, input: pb::BackendInput) -> pb::BackendOutput { fn run_command(&mut self, input: pb::BackendInput) -> pb::BackendOutput {
let oval = if let Some(ival) = input.value { let oval = if let Some(ival) = input.value {
match self.run_command_inner(ival) { match self.run_command_inner(ival) {
@ -202,8 +201,6 @@ impl Backend {
OValue::SchedTimingToday(self.sched_timing_today(input)) OValue::SchedTimingToday(self.sched_timing_today(input))
} }
Value::DeckTree(_) => todo!(), Value::DeckTree(_) => todo!(),
Value::FindCards(_) => todo!(),
Value::BrowserRows(_) => todo!(),
Value::RenderCard(input) => OValue::RenderCard(self.render_template(input)?), Value::RenderCard(input) => OValue::RenderCard(self.render_template(input)?),
Value::LocalMinutesWest(stamp) => { Value::LocalMinutesWest(stamp) => {
OValue::LocalMinutesWest(local_minutes_west_for_stamp(stamp)) OValue::LocalMinutesWest(local_minutes_west_for_stamp(stamp))
@ -241,9 +238,63 @@ impl Backend {
self.restore_trash()?; self.restore_trash()?;
OValue::RestoreTrash(Empty {}) OValue::RestoreTrash(Empty {})
} }
Value::OpenCollection(input) => {
self.open_collection(input)?;
OValue::OpenCollection(Empty {})
}
Value::CloseCollection(_) => {
self.close_collection()?;
OValue::CloseCollection(Empty {})
}
Value::SearchCards(input) => OValue::SearchCards(self.search_cards(input)?),
Value::SearchNotes(input) => OValue::SearchNotes(self.search_notes(input)?),
}) })
} }
fn open_collection(&self, input: pb::OpenCollectionIn) -> Result<()> {
let mut col = self.col.lock().unwrap();
if col.is_some() {
return Err(AnkiError::CollectionAlreadyOpen);
}
let mut path = input.collection_path.clone();
path.push_str(".log");
let log_path = match input.log_path.as_str() {
"" => None,
path => Some(path),
};
let logger = default_logger(log_path)?;
let new_col = open_collection(
input.collection_path,
input.media_folder_path,
input.media_db_path,
self.server,
self.i18n.clone(),
logger,
)?;
*col = Some(new_col);
Ok(())
}
fn close_collection(&self) -> Result<()> {
let mut col = self.col.lock().unwrap();
if col.is_none() {
return Err(AnkiError::CollectionNotOpen);
}
if !col.as_ref().unwrap().can_close() {
return Err(AnkiError::invalid_input("can't close yet"));
}
*col = None;
Ok(())
}
fn fire_progress_callback(&self, progress: Progress) -> bool { fn fire_progress_callback(&self, progress: Progress) -> bool {
if let Some(cb) = &self.progress_callback { if let Some(cb) = &self.progress_callback {
let bytes = progress_to_proto_bytes(progress, &self.i18n); let bytes = progress_to_proto_bytes(progress, &self.i18n);
@ -301,10 +352,10 @@ impl Backend {
fn sched_timing_today(&self, input: pb::SchedTimingTodayIn) -> pb::SchedTimingTodayOut { fn sched_timing_today(&self, input: pb::SchedTimingTodayIn) -> pb::SchedTimingTodayOut {
let today = sched_timing_today( let today = sched_timing_today(
input.created_secs as i64, input.created_secs as i64,
input.created_mins_west,
input.now_secs as i64, input.now_secs as i64,
input.now_mins_west, input.created_mins_west.map(|v| v.val),
input.rollover_hour as i8, input.now_mins_west.map(|v| v.val),
input.rollover_hour.map(|v| v.val as i8),
); );
pb::SchedTimingTodayOut { pb::SchedTimingTodayOut {
days_elapsed: today.days_elapsed, days_elapsed: today.days_elapsed,
@ -389,46 +440,80 @@ impl Backend {
} }
fn add_media_file(&mut self, input: pb::AddMediaFileIn) -> Result<String> { fn add_media_file(&mut self, input: pb::AddMediaFileIn) -> Result<String> {
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; self.with_col(|col| {
let mut ctx = mgr.dbctx(); let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
Ok(mgr let mut ctx = mgr.dbctx();
.add_file(&mut ctx, &input.desired_name, &input.data)? Ok(mgr
.into()) .add_file(&mut ctx, &input.desired_name, &input.data)?
.into())
})
} }
fn sync_media(&self, input: SyncMediaIn) -> Result<()> { // fixme: will block other db access
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
fn sync_media(&self, input: SyncMediaIn) -> Result<()> {
let mut guard = self.col.lock().unwrap();
let col = guard.as_mut().unwrap();
col.set_media_sync_running()?;
let folder = col.media_folder.clone();
let db = col.media_db.clone();
let log = col.log.clone();
drop(guard);
let res = self.sync_media_inner(input, folder, db, log);
self.with_col(|col| col.set_media_sync_finished())?;
res
}
fn sync_media_inner(
&self,
input: pb::SyncMediaIn,
folder: PathBuf,
db: PathBuf,
log: Logger,
) -> Result<()> {
let callback = |progress: &MediaSyncProgress| { let callback = |progress: &MediaSyncProgress| {
self.fire_progress_callback(Progress::MediaSync(progress)) self.fire_progress_callback(Progress::MediaSync(progress))
}; };
let mgr = MediaManager::new(&folder, &db)?;
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
rt.block_on(mgr.sync_media(callback, &input.endpoint, &input.hkey, self.log.clone())) rt.block_on(mgr.sync_media(callback, &input.endpoint, &input.hkey, log))
} }
fn check_media(&self) -> Result<pb::MediaCheckOut> { fn check_media(&self) -> Result<pb::MediaCheckOut> {
let callback = let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; self.with_col(|col| {
let mut checker = MediaChecker::new(&mgr, &self.col_path, callback, &self.i18n, &self.log); let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mut output = checker.check()?; col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
let mut output = checker.check()?;
let report = checker.summarize_output(&mut output); let report = checker.summarize_output(&mut output);
Ok(pb::MediaCheckOut { Ok(pb::MediaCheckOut {
unused: output.unused, unused: output.unused,
missing: output.missing, missing: output.missing,
report, report,
have_trash: output.trash_count > 0, have_trash: output.trash_count > 0,
})
})
}) })
} }
fn remove_media_files(&self, fnames: &[String]) -> Result<()> { fn remove_media_files(&self, fnames: &[String]) -> Result<()> {
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; self.with_col(|col| {
let mut ctx = mgr.dbctx(); let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
mgr.remove_files(&mut ctx, fnames) let mut ctx = mgr.dbctx();
mgr.remove_files(&mut ctx, fnames)
})
} }
fn translate_string(&self, input: pb::TranslateStringIn) -> String { fn translate_string(&self, input: pb::TranslateStringIn) -> String {
@ -466,20 +551,66 @@ impl Backend {
let callback = let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; self.with_col(|col| {
let mut checker = MediaChecker::new(&mgr, &self.col_path, callback, &self.i18n, &self.log); let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
checker.empty_trash() checker.empty_trash()
})
})
} }
fn restore_trash(&self) -> Result<()> { fn restore_trash(&self) -> Result<()> {
let callback = let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; self.with_col(|col| {
let mut checker = MediaChecker::new(&mgr, &self.col_path, callback, &self.i18n, &self.log); let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
checker.restore_trash() col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
checker.restore_trash()
})
})
}
pub fn db_command(&self, input: &[u8]) -> Result<String> {
self.with_col(|col| col.with_ctx(|ctx| db_command_bytes(&ctx.storage, input)))
}
fn search_cards(&self, input: pb::SearchCardsIn) -> Result<pb::SearchCardsOut> {
self.with_col(|col| {
col.with_ctx(|ctx| {
let order = if let Some(order) = input.order {
use pb::sort_order::Value as V;
match order.value {
Some(V::None(_)) => SortMode::NoOrder,
Some(V::Custom(s)) => SortMode::Custom(s),
Some(V::FromConfig(_)) => SortMode::FromConfig,
Some(V::Builtin(b)) => SortMode::Builtin {
kind: sort_kind_from_pb(b.kind),
reverse: b.reverse,
},
None => SortMode::FromConfig,
}
} else {
SortMode::FromConfig
};
let cids = search_cards(ctx, &input.search, order)?;
Ok(pb::SearchCardsOut { card_ids: cids })
})
})
}
fn search_notes(&self, input: pb::SearchNotesIn) -> Result<pb::SearchNotesOut> {
self.with_col(|col| {
col.with_ctx(|ctx| {
let nids = search_notes(ctx, &input.search)?;
Ok(pb::SearchNotesOut { note_ids: nids })
})
})
} }
} }
@ -552,51 +683,24 @@ fn media_sync_progress(p: &MediaSyncProgress, i18n: &I18n) -> pb::MediaSyncProgr
} }
} }
/// Standalone I18n backend fn sort_kind_from_pb(kind: i32) -> SortKind {
/// This is a hack to allow translating strings in the GUI use SortKind as SK;
/// when a collection is not open, and in the future it should match pb::BuiltinSortKind::from_i32(kind) {
/// either be shared with or merged into the backend object. Some(pbkind) => match pbkind {
/////////////////////////////////////////////////////// BuiltinSortKind::NoteCreation => SK::NoteCreation,
BuiltinSortKind::NoteMod => SK::NoteMod,
pub struct I18nBackend { BuiltinSortKind::NoteField => SK::NoteField,
i18n: I18n, BuiltinSortKind::NoteTags => SK::NoteTags,
} BuiltinSortKind::NoteType => SK::NoteType,
BuiltinSortKind::CardMod => SK::CardMod,
pub fn init_i18n_backend(init_msg: &[u8]) -> Result<I18nBackend> { BuiltinSortKind::CardReps => SK::CardReps,
let input: pb::I18nBackendInit = match pb::I18nBackendInit::decode(init_msg) { BuiltinSortKind::CardDue => SK::CardDue,
Ok(req) => req, BuiltinSortKind::CardEase => SK::CardEase,
Err(_) => return Err(AnkiError::invalid_input("couldn't decode init msg")), BuiltinSortKind::CardLapses => SK::CardLapses,
}; BuiltinSortKind::CardInterval => SK::CardInterval,
BuiltinSortKind::CardDeck => SK::CardDeck,
let log = log::terminal(); BuiltinSortKind::CardTemplate => SK::CardTemplate,
},
let i18n = I18n::new(&input.preferred_langs, input.locale_folder_path, log); _ => SortKind::NoteCreation,
Ok(I18nBackend { i18n })
}
impl I18nBackend {
pub fn translate(&self, req: &[u8]) -> String {
let req = match pb::TranslateStringIn::decode(req) {
Ok(req) => req,
Err(_e) => return "decoding error".into(),
};
self.translate_string(req)
}
fn translate_string(&self, input: pb::TranslateStringIn) -> String {
let key = match pb::FluentString::from_i32(input.key) {
Some(key) => key,
None => return "invalid key".to_string(),
};
let map = input
.args
.iter()
.map(|(k, v)| (k.as_str(), translate_arg_to_fluent_val(&v)))
.collect();
self.i18n.trn(key, map)
} }
} }

33
rslib/src/card.rs Normal file
View file

@ -0,0 +1,33 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use num_enum::TryFromPrimitive;
use serde_repr::{Deserialize_repr, Serialize_repr};
#[derive(Serialize_repr, Deserialize_repr, Debug, PartialEq, TryFromPrimitive, Clone, Copy)]
#[repr(u8)]
pub enum CardType {
New = 0,
Learn = 1,
Review = 2,
Relearn = 3,
}
#[derive(Serialize_repr, Deserialize_repr, Debug, PartialEq, TryFromPrimitive, Clone, Copy)]
#[repr(i8)]
pub enum CardQueue {
/// due is the order cards are shown in
New = 0,
/// due is a unix timestamp
Learn = 1,
/// due is days since creation date
Review = 2,
DayLearn = 3,
/// due is a unix timestamp.
/// preview cards only placed here when failed.
PreviewRepeat = 4,
/// cards are not due in these states
Suspended = -1,
UserBuried = -2,
SchedBuried = -3,
}

131
rslib/src/collection.rs Normal file
View file

@ -0,0 +1,131 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result};
use crate::i18n::I18n;
use crate::log::Logger;
use crate::storage::{SqliteStorage, StorageContext};
use std::path::PathBuf;
pub fn open_collection<P: Into<PathBuf>>(
path: P,
media_folder: P,
media_db: P,
server: bool,
i18n: I18n,
log: Logger,
) -> Result<Collection> {
let col_path = path.into();
let storage = SqliteStorage::open_or_create(&col_path)?;
let col = Collection {
storage,
col_path,
media_folder: media_folder.into(),
media_db: media_db.into(),
server,
i18n,
log,
state: CollectionState::Normal,
};
Ok(col)
}
#[derive(Debug, PartialEq)]
pub enum CollectionState {
Normal,
// in this state, the DB must not be closed
MediaSyncRunning,
}
pub struct Collection {
pub(crate) storage: SqliteStorage,
#[allow(dead_code)]
pub(crate) col_path: PathBuf,
pub(crate) media_folder: PathBuf,
pub(crate) media_db: PathBuf,
pub(crate) server: bool,
pub(crate) i18n: I18n,
pub(crate) log: Logger,
state: CollectionState,
}
pub(crate) enum CollectionOp {}
pub(crate) struct RequestContext<'a> {
pub storage: StorageContext<'a>,
pub i18n: &'a I18n,
pub log: &'a Logger,
pub should_commit: bool,
}
impl Collection {
/// Call the provided closure with a RequestContext that exists for
/// the duration of the call. The request will cache prepared sql
/// statements, so should be passed down the call tree.
///
/// This function should be used for read-only requests. To mutate
/// the database, use transact() instead.
pub(crate) fn with_ctx<F, R>(&self, func: F) -> Result<R>
where
F: FnOnce(&mut RequestContext) -> Result<R>,
{
let mut ctx = RequestContext {
storage: self.storage.context(self.server),
i18n: &self.i18n,
log: &self.log,
should_commit: true,
};
func(&mut ctx)
}
/// Execute the provided closure in a transaction, rolling back if
/// an error is returned.
pub(crate) fn transact<F, R>(&self, op: Option<CollectionOp>, func: F) -> Result<R>
where
F: FnOnce(&mut RequestContext) -> Result<R>,
{
self.with_ctx(|ctx| {
ctx.storage.begin_rust_trx()?;
let mut res = func(ctx);
if res.is_ok() && ctx.should_commit {
if let Err(e) = ctx.storage.mark_modified() {
res = Err(e);
} else if let Err(e) = ctx.storage.commit_rust_op(op) {
res = Err(e);
}
}
if res.is_err() || !ctx.should_commit {
ctx.storage.rollback_rust_trx()?;
}
res
})
}
pub(crate) fn set_media_sync_running(&mut self) -> Result<()> {
if self.state == CollectionState::Normal {
self.state = CollectionState::MediaSyncRunning;
Ok(())
} else {
Err(AnkiError::invalid_input("media sync already running"))
}
}
pub(crate) fn set_media_sync_finished(&mut self) -> Result<()> {
if self.state == CollectionState::MediaSyncRunning {
self.state = CollectionState::Normal;
Ok(())
} else {
Err(AnkiError::invalid_input("media sync not running"))
}
}
pub(crate) fn can_close(&self) -> bool {
self.state == CollectionState::Normal
}
}

63
rslib/src/config.rs Normal file
View file

@ -0,0 +1,63 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::types::ObjID;
use serde::Deserialize as DeTrait;
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
use serde_json::Value;
pub(crate) fn default_on_invalid<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: Default + DeTrait<'de>,
D: serde::de::Deserializer<'de>,
{
let v: Value = DeTrait::deserialize(deserializer)?;
Ok(T::deserialize(v).unwrap_or_default())
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Config {
#[serde(
rename = "curDeck",
deserialize_with = "deserialize_number_from_string"
)]
pub(crate) current_deck_id: ObjID,
pub(crate) rollover: Option<i8>,
pub(crate) creation_offset: Option<i32>,
pub(crate) local_offset: Option<i32>,
#[serde(rename = "sortType", deserialize_with = "default_on_invalid")]
pub(crate) browser_sort_kind: SortKind,
#[serde(rename = "sortBackwards", deserialize_with = "default_on_invalid")]
pub(crate) browser_sort_reverse: bool,
}
#[derive(Deserialize, PartialEq, Debug)]
#[serde(rename_all = "camelCase")]
pub enum SortKind {
#[serde(rename = "noteCrt")]
NoteCreation,
NoteMod,
#[serde(rename = "noteFld")]
NoteField,
#[serde(rename = "note")]
NoteType,
NoteTags,
CardMod,
CardReps,
CardDue,
CardEase,
CardLapses,
#[serde(rename = "cardIvl")]
CardInterval,
#[serde(rename = "deck")]
CardDeck,
#[serde(rename = "template")]
CardTemplate,
}
impl Default for SortKind {
fn default() -> Self {
Self::NoteCreation
}
}

31
rslib/src/decks.rs Normal file
View file

@ -0,0 +1,31 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::types::ObjID;
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
#[derive(Deserialize)]
pub struct Deck {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub(crate) id: ObjID,
pub(crate) name: String,
}
pub(crate) fn child_ids<'a>(decks: &'a [Deck], name: &str) -> impl Iterator<Item = ObjID> + 'a {
let prefix = format!("{}::", name.to_ascii_lowercase());
decks
.iter()
.filter(move |d| d.name.to_ascii_lowercase().starts_with(&prefix))
.map(|d| d.id)
}
pub(crate) fn get_deck(decks: &[Deck], id: ObjID) -> Option<&Deck> {
for d in decks {
if d.id == id {
return Some(d);
}
}
None
}

View file

@ -20,7 +20,7 @@ pub enum AnkiError {
IOError { info: String }, IOError { info: String },
#[fail(display = "DB error: {}", info)] #[fail(display = "DB error: {}", info)]
DBError { info: String }, DBError { info: String, kind: DBErrorKind },
#[fail(display = "Network error: {:?} {}", kind, info)] #[fail(display = "Network error: {:?} {}", kind, info)]
NetworkError { NetworkError {
@ -33,6 +33,12 @@ pub enum AnkiError {
#[fail(display = "The user interrupted the operation.")] #[fail(display = "The user interrupted the operation.")]
Interrupted, Interrupted,
#[fail(display = "Operation requires an open collection.")]
CollectionNotOpen,
#[fail(display = "Close the existing collection first.")]
CollectionAlreadyOpen,
} }
// error helpers // error helpers
@ -112,6 +118,7 @@ impl From<rusqlite::Error> for AnkiError {
fn from(err: rusqlite::Error) -> Self { fn from(err: rusqlite::Error) -> Self {
AnkiError::DBError { AnkiError::DBError {
info: format!("{:?}", err), info: format!("{:?}", err),
kind: DBErrorKind::Other,
} }
} }
} }
@ -120,6 +127,7 @@ impl From<rusqlite::types::FromSqlError> for AnkiError {
fn from(err: rusqlite::types::FromSqlError) -> Self { fn from(err: rusqlite::types::FromSqlError) -> Self {
AnkiError::DBError { AnkiError::DBError {
info: format!("{:?}", err), info: format!("{:?}", err),
kind: DBErrorKind::Other,
} }
} }
} }
@ -215,3 +223,11 @@ impl From<serde_json::Error> for AnkiError {
AnkiError::sync_misc(err.to_string()) AnkiError::sync_misc(err.to_string())
} }
} }
#[derive(Debug, PartialEq)]
pub enum DBErrorKind {
FileTooNew,
FileTooOld,
MissingEntity,
Other,
}

View file

@ -10,13 +10,21 @@ pub fn version() -> &'static str {
} }
pub mod backend; pub mod backend;
pub mod card;
pub mod cloze; pub mod cloze;
pub mod collection;
pub mod config;
pub mod decks;
pub mod err; pub mod err;
pub mod i18n; pub mod i18n;
pub mod latex; pub mod latex;
pub mod log; pub mod log;
pub mod media; pub mod media;
pub mod notes;
pub mod notetypes;
pub mod sched; pub mod sched;
pub mod search;
pub mod storage;
pub mod template; pub mod template;
pub mod template_filters; pub mod template_filters;
pub mod text; pub mod text;

View file

@ -1,18 +1,17 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result}; use crate::collection::RequestContext;
use crate::i18n::{tr_args, tr_strs, FString, I18n}; use crate::err::{AnkiError, DBErrorKind, Result};
use crate::i18n::{tr_args, tr_strs, FString};
use crate::latex::extract_latex_expanding_clozes; use crate::latex::extract_latex_expanding_clozes;
use crate::log::{debug, Logger}; use crate::log::debug;
use crate::media::col::{
for_every_note, get_note_types, mark_collection_modified, open_or_create_collection_db,
set_note, Note,
};
use crate::media::database::MediaDatabaseContext; use crate::media::database::MediaDatabaseContext;
use crate::media::files::{ use crate::media::files::{
data_for_file, filename_if_normalized, trash_folder, MEDIA_SYNC_FILESIZE_LIMIT, data_for_file, filename_if_normalized, normalize_nfc_filename, trash_folder,
MEDIA_SYNC_FILESIZE_LIMIT,
}; };
use crate::notes::{for_every_note, set_note, Note};
use crate::text::{normalize_to_nfc, MediaRef}; use crate::text::{normalize_to_nfc, MediaRef};
use crate::{media::MediaManager, text::extract_media_refs}; use crate::{media::MediaManager, text::extract_media_refs};
use coarsetime::Instant; use coarsetime::Instant;
@ -26,7 +25,7 @@ lazy_static! {
static ref REMOTE_FILENAME: Regex = Regex::new("(?i)^https?://").unwrap(); static ref REMOTE_FILENAME: Regex = Regex::new("(?i)^https?://").unwrap();
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Clone)]
pub struct MediaCheckOutput { pub struct MediaCheckOutput {
pub unused: Vec<String>, pub unused: Vec<String>,
pub missing: Vec<String>, pub missing: Vec<String>,
@ -45,38 +44,32 @@ struct MediaFolderCheck {
oversize: Vec<String>, oversize: Vec<String>,
} }
pub struct MediaChecker<'a, P> pub struct MediaChecker<'a, 'b, P>
where where
P: FnMut(usize) -> bool, P: FnMut(usize) -> bool,
{ {
ctx: &'a mut RequestContext<'b>,
mgr: &'a MediaManager, mgr: &'a MediaManager,
col_path: &'a Path,
progress_cb: P, progress_cb: P,
checked: usize, checked: usize,
progress_updated: Instant, progress_updated: Instant,
i18n: &'a I18n,
log: &'a Logger,
} }
impl<P> MediaChecker<'_, P> impl<P> MediaChecker<'_, '_, P>
where where
P: FnMut(usize) -> bool, P: FnMut(usize) -> bool,
{ {
pub fn new<'a>( pub(crate) fn new<'a, 'b>(
ctx: &'a mut RequestContext<'b>,
mgr: &'a MediaManager, mgr: &'a MediaManager,
col_path: &'a Path,
progress_cb: P, progress_cb: P,
i18n: &'a I18n, ) -> MediaChecker<'a, 'b, P> {
log: &'a Logger,
) -> MediaChecker<'a, P> {
MediaChecker { MediaChecker {
ctx,
mgr, mgr,
col_path,
progress_cb, progress_cb,
checked: 0, checked: 0,
progress_updated: Instant::now(), progress_updated: Instant::now(),
i18n,
log,
} }
} }
@ -100,7 +93,7 @@ where
pub fn summarize_output(&self, output: &mut MediaCheckOutput) -> String { pub fn summarize_output(&self, output: &mut MediaCheckOutput) -> String {
let mut buf = String::new(); let mut buf = String::new();
let i = &self.i18n; let i = &self.ctx.i18n;
// top summary area // top summary area
if output.trash_count > 0 { if output.trash_count > 0 {
@ -279,7 +272,7 @@ where
} }
})?; })?;
let fname = self.mgr.add_file(ctx, disk_fname, &data)?; let fname = self.mgr.add_file(ctx, disk_fname, &data)?;
debug!(self.log, "renamed"; "from"=>disk_fname, "to"=>&fname.as_ref()); debug!(self.ctx.log, "renamed"; "from"=>disk_fname, "to"=>&fname.as_ref());
assert_ne!(fname.as_ref(), disk_fname); assert_ne!(fname.as_ref(), disk_fname);
// remove the original file // remove the original file
@ -373,7 +366,7 @@ where
self.mgr self.mgr
.add_file(&mut self.mgr.dbctx(), fname.as_ref(), &data)?; .add_file(&mut self.mgr.dbctx(), fname.as_ref(), &data)?;
} else { } else {
debug!(self.log, "file disappeared while restoring trash"; "fname"=>fname.as_ref()); debug!(self.ctx.log, "file disappeared while restoring trash"; "fname"=>fname.as_ref());
} }
fs::remove_file(dentry.path())?; fs::remove_file(dentry.path())?;
} }
@ -387,14 +380,11 @@ where
&mut self, &mut self,
renamed: &HashMap<String, String>, renamed: &HashMap<String, String>,
) -> Result<HashSet<String>> { ) -> Result<HashSet<String>> {
let mut db = open_or_create_collection_db(self.col_path)?;
let trx = db.transaction()?;
let mut referenced_files = HashSet::new(); let mut referenced_files = HashSet::new();
let note_types = get_note_types(&trx)?; let note_types = self.ctx.storage.all_note_types()?;
let mut collection_modified = false; let mut collection_modified = false;
for_every_note(&trx, |note| { for_every_note(&self.ctx.storage.db, |note| {
self.checked += 1; self.checked += 1;
if self.checked % 10 == 0 { if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?; self.maybe_fire_progress_cb()?;
@ -403,10 +393,16 @@ where
.get(&note.mid) .get(&note.mid)
.ok_or_else(|| AnkiError::DBError { .ok_or_else(|| AnkiError::DBError {
info: "missing note type".to_string(), info: "missing note type".to_string(),
kind: DBErrorKind::MissingEntity,
})?; })?;
if fix_and_extract_media_refs(note, &mut referenced_files, renamed)? { if fix_and_extract_media_refs(
note,
&mut referenced_files,
renamed,
&self.mgr.media_folder,
)? {
// note was modified, needs saving // note was modified, needs saving
set_note(&trx, note, nt)?; set_note(&self.ctx.storage.db, note, nt)?;
collection_modified = true; collection_modified = true;
} }
@ -415,9 +411,8 @@ where
Ok(()) Ok(())
})?; })?;
if collection_modified { if !collection_modified {
mark_collection_modified(&trx)?; self.ctx.should_commit = false;
trx.commit()?;
} }
Ok(referenced_files) Ok(referenced_files)
@ -429,11 +424,17 @@ fn fix_and_extract_media_refs(
note: &mut Note, note: &mut Note,
seen_files: &mut HashSet<String>, seen_files: &mut HashSet<String>,
renamed: &HashMap<String, String>, renamed: &HashMap<String, String>,
media_folder: &Path,
) -> Result<bool> { ) -> Result<bool> {
let mut updated = false; let mut updated = false;
for idx in 0..note.fields().len() { for idx in 0..note.fields().len() {
let field = normalize_and_maybe_rename_files(&note.fields()[idx], renamed, seen_files); let field = normalize_and_maybe_rename_files(
&note.fields()[idx],
renamed,
seen_files,
media_folder,
);
if let Cow::Owned(field) = field { if let Cow::Owned(field) = field {
// field was modified, need to save // field was modified, need to save
note.set_field(idx, field)?; note.set_field(idx, field)?;
@ -450,6 +451,7 @@ fn normalize_and_maybe_rename_files<'a>(
field: &'a str, field: &'a str,
renamed: &HashMap<String, String>, renamed: &HashMap<String, String>,
seen_files: &mut HashSet<String>, seen_files: &mut HashSet<String>,
media_folder: &Path,
) -> Cow<'a, str> { ) -> Cow<'a, str> {
let refs = extract_media_refs(field); let refs = extract_media_refs(field);
let mut field: Cow<str> = field.into(); let mut field: Cow<str> = field.into();
@ -466,7 +468,21 @@ fn normalize_and_maybe_rename_files<'a>(
if let Some(new_name) = renamed.get(fname.as_ref()) { if let Some(new_name) = renamed.get(fname.as_ref()) {
fname = new_name.to_owned().into(); fname = new_name.to_owned().into();
} }
// if it was not in NFC or was renamed, update the field // if the filename was in NFC and was not renamed as part of the
// media check, it may have already been renamed during a previous
// sync. If that's the case and the renamed version exists on disk,
// we'll need to update the field to match it. It may be possible
// to remove this check in the future once we can be sure all media
// files stored on AnkiWeb are in normalized form.
if matches!(fname, Cow::Borrowed(_)) {
if let Cow::Owned(normname) = normalize_nfc_filename(fname.as_ref().into()) {
let path = media_folder.join(&normname);
if path.exists() {
fname = normname.into();
}
}
}
// update the field if the filename was modified
if let Cow::Owned(ref new_name) = fname { if let Cow::Owned(ref new_name) = fname {
field = rename_media_ref_in_field(field.as_ref(), &media_ref, new_name).into(); field = rename_media_ref_in_field(field.as_ref(), &media_ref, new_name).into();
} }
@ -510,41 +526,42 @@ fn extract_latex_refs(note: &Note, seen_files: &mut HashSet<String>, svg: bool)
} }
#[cfg(test)] #[cfg(test)]
mod test { pub(crate) mod test {
pub(crate) const MEDIACHECK_ANKI2: &'static [u8] =
include_bytes!("../../tests/support/mediacheck.anki2");
use crate::collection::{open_collection, Collection};
use crate::err::Result; use crate::err::Result;
use crate::i18n::I18n; use crate::i18n::I18n;
use crate::log; use crate::log;
use crate::log::Logger;
use crate::media::check::{MediaCheckOutput, MediaChecker}; use crate::media::check::{MediaCheckOutput, MediaChecker};
use crate::media::files::trash_folder; use crate::media::files::trash_folder;
use crate::media::MediaManager; use crate::media::MediaManager;
use std::path::{Path, PathBuf}; use std::path::Path;
use std::{fs, io}; use std::{fs, io};
use tempfile::{tempdir, TempDir}; use tempfile::{tempdir, TempDir};
fn common_setup() -> Result<(TempDir, MediaManager, PathBuf, Logger, I18n)> { fn common_setup() -> Result<(TempDir, MediaManager, Collection)> {
let dir = tempdir()?; let dir = tempdir()?;
let media_dir = dir.path().join("media"); let media_dir = dir.path().join("media");
fs::create_dir(&media_dir)?; fs::create_dir(&media_dir)?;
let media_db = dir.path().join("media.db"); let media_db = dir.path().join("media.db");
let col_path = dir.path().join("col.anki2"); let col_path = dir.path().join("col.anki2");
fs::write( fs::write(&col_path, MEDIACHECK_ANKI2)?;
&col_path,
&include_bytes!("../../tests/support/mediacheck.anki2")[..],
)?;
let mgr = MediaManager::new(&media_dir, media_db)?; let mgr = MediaManager::new(&media_dir, media_db.clone())?;
let log = log::terminal(); let log = log::terminal();
let i18n = I18n::new(&["zz"], "dummy", log.clone()); let i18n = I18n::new(&["zz"], "dummy", log.clone());
Ok((dir, mgr, col_path, log, i18n)) let col = open_collection(col_path, media_dir, media_db, false, i18n, log)?;
Ok((dir, mgr, col))
} }
#[test] #[test]
fn media_check() -> Result<()> { fn media_check() -> Result<()> {
let (_dir, mgr, col_path, log, i18n) = common_setup()?; let (_dir, mgr, col) = common_setup()?;
// add some test files // add some test files
fs::write(&mgr.media_folder.join("zerobytes"), "")?; fs::write(&mgr.media_folder.join("zerobytes"), "")?;
@ -555,8 +572,13 @@ mod test {
fs::write(&mgr.media_folder.join("unused.jpg"), "foo")?; fs::write(&mgr.media_folder.join("unused.jpg"), "foo")?;
let progress = |_n| true; let progress = |_n| true;
let mut checker = MediaChecker::new(&mgr, &col_path, progress, &i18n, &log);
let mut output = checker.check()?; let (output, report) = col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
let output = checker.check()?;
let summary = checker.summarize_output(&mut output.clone());
Ok((output, summary))
})?;
assert_eq!( assert_eq!(
output, output,
@ -576,7 +598,6 @@ mod test {
assert!(fs::metadata(&mgr.media_folder.join("foo[.jpg")).is_err()); assert!(fs::metadata(&mgr.media_folder.join("foo[.jpg")).is_err());
assert!(fs::metadata(&mgr.media_folder.join("foo.jpg")).is_ok()); assert!(fs::metadata(&mgr.media_folder.join("foo.jpg")).is_ok());
let report = checker.summarize_output(&mut output);
assert_eq!( assert_eq!(
report, report,
"Missing files: 1 "Missing files: 1
@ -616,14 +637,16 @@ Unused: unused.jpg
#[test] #[test]
fn trash_handling() -> Result<()> { fn trash_handling() -> Result<()> {
let (_dir, mgr, col_path, log, i18n) = common_setup()?; let (_dir, mgr, col) = common_setup()?;
let trash_folder = trash_folder(&mgr.media_folder)?; let trash_folder = trash_folder(&mgr.media_folder)?;
fs::write(trash_folder.join("test.jpg"), "test")?; fs::write(trash_folder.join("test.jpg"), "test")?;
let progress = |_n| true; let progress = |_n| true;
let mut checker = MediaChecker::new(&mgr, &col_path, progress, &i18n, &log);
checker.restore_trash()?; col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.restore_trash()
})?;
// file should have been moved to media folder // file should have been moved to media folder
assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new()); assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new());
@ -634,7 +657,10 @@ Unused: unused.jpg
// if we repeat the process, restoring should do the same thing if the contents are equal // if we repeat the process, restoring should do the same thing if the contents are equal
fs::write(trash_folder.join("test.jpg"), "test")?; fs::write(trash_folder.join("test.jpg"), "test")?;
checker.restore_trash()?; col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.restore_trash()
})?;
assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new()); assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new());
assert_eq!( assert_eq!(
files_in_dir(&mgr.media_folder), files_in_dir(&mgr.media_folder),
@ -643,7 +669,10 @@ Unused: unused.jpg
// but rename if required // but rename if required
fs::write(trash_folder.join("test.jpg"), "test2")?; fs::write(trash_folder.join("test.jpg"), "test2")?;
checker.restore_trash()?; col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.restore_trash()
})?;
assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new()); assert_eq!(files_in_dir(&trash_folder), Vec::<String>::new());
assert_eq!( assert_eq!(
files_in_dir(&mgr.media_folder), files_in_dir(&mgr.media_folder),
@ -658,13 +687,17 @@ Unused: unused.jpg
#[test] #[test]
fn unicode_normalization() -> Result<()> { fn unicode_normalization() -> Result<()> {
let (_dir, mgr, col_path, log, i18n) = common_setup()?; let (_dir, mgr, col) = common_setup()?;
fs::write(&mgr.media_folder.join("ぱぱ.jpg"), "nfd encoding")?; fs::write(&mgr.media_folder.join("ぱぱ.jpg"), "nfd encoding")?;
let progress = |_n| true; let progress = |_n| true;
let mut checker = MediaChecker::new(&mgr, &col_path, progress, &i18n, &log);
let mut output = checker.check()?; let mut output = col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress);
checker.check()
})?;
output.missing.sort(); output.missing.sort();
if cfg!(target_vendor = "apple") { if cfg!(target_vendor = "apple") {

View file

@ -84,7 +84,7 @@ pub(crate) fn normalize_filename(fname: &str) -> Cow<str> {
} }
/// See normalize_filename(). This function expects NFC-normalized input. /// See normalize_filename(). This function expects NFC-normalized input.
fn normalize_nfc_filename(mut fname: Cow<str>) -> Cow<str> { pub(crate) fn normalize_nfc_filename(mut fname: Cow<str>) -> Cow<str> {
if fname.chars().any(disallowed_char) { if fname.chars().any(disallowed_char) {
fname = fname.replace(disallowed_char, "").into() fname = fname.replace(disallowed_char, "").into()
} }

View file

@ -12,7 +12,6 @@ use std::path::{Path, PathBuf};
pub mod changetracker; pub mod changetracker;
pub mod check; pub mod check;
pub mod col;
pub mod database; pub mod database;
pub mod files; pub mod files;
pub mod sync; pub mod sync;

View file

@ -717,6 +717,17 @@ fn zip_files<'a>(
break; break;
} }
#[cfg(target_vendor = "apple")]
{
use unicode_normalization::is_nfc;
if !is_nfc(&file.fname) {
// older Anki versions stored non-normalized filenames in the DB; clean them up
debug!(log, "clean up non-nfc entry"; "fname"=>&file.fname);
invalid_entries.push(&file.fname);
continue;
}
}
let file_data = if file.sha1.is_some() { let file_data = if file.sha1.is_some() {
match data_for_file(media_folder, &file.fname) { match data_for_file(media_folder, &file.fname) {
Ok(data) => data, Ok(data) => data,

View file

@ -1,17 +1,17 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
/// Basic note reading/updating functionality for the media DB check. /// At the moment, this is just basic note reading/updating functionality for
use crate::err::{AnkiError, Result}; /// the media DB check.
use crate::err::{AnkiError, DBErrorKind, Result};
use crate::text::strip_html_preserving_image_filenames; use crate::text::strip_html_preserving_image_filenames;
use crate::time::{i64_unix_millis, i64_unix_secs}; use crate::time::i64_unix_secs;
use crate::types::{ObjID, Timestamp, Usn}; use crate::{
notetypes::NoteType,
types::{ObjID, Timestamp, Usn},
};
use rusqlite::{params, Connection, Row, NO_PARAMS}; use rusqlite::{params, Connection, Row, NO_PARAMS};
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
use std::collections::HashMap;
use std::convert::TryInto; use std::convert::TryInto;
use std::path::Path;
#[derive(Debug)] #[derive(Debug)]
pub(super) struct Note { pub(super) struct Note {
@ -40,55 +40,13 @@ impl Note {
} }
} }
fn field_checksum(text: &str) -> u32 { /// Text must be passed to strip_html_preserving_image_filenames() by
/// caller prior to passing in here.
pub(crate) fn field_checksum(text: &str) -> u32 {
let digest = sha1::Sha1::from(text).digest().bytes(); let digest = sha1::Sha1::from(text).digest().bytes();
u32::from_be_bytes(digest[..4].try_into().unwrap()) u32::from_be_bytes(digest[..4].try_into().unwrap())
} }
pub(super) fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
let db = Connection::open(path)?;
db.pragma_update(None, "locking_mode", &"exclusive")?;
db.pragma_update(None, "page_size", &4096)?;
db.pragma_update(None, "cache_size", &(-40 * 1024))?;
db.pragma_update(None, "legacy_file_format", &false)?;
db.pragma_update(None, "journal", &"wal")?;
db.set_prepared_statement_cache_capacity(5);
Ok(db)
}
#[derive(Deserialize, Debug)]
pub(super) struct NoteType {
#[serde(deserialize_with = "deserialize_number_from_string")]
id: ObjID,
#[serde(rename = "sortf")]
sort_field_idx: u16,
#[serde(rename = "latexsvg", default)]
latex_svg: bool,
}
impl NoteType {
pub fn latex_uses_svg(&self) -> bool {
self.latex_svg
}
}
pub(super) fn get_note_types(db: &Connection) -> Result<HashMap<ObjID, NoteType>> {
let mut stmt = db.prepare("select models from col")?;
let note_types = stmt
.query_and_then(NO_PARAMS, |row| -> Result<HashMap<ObjID, NoteType>> {
let v: HashMap<ObjID, NoteType> = serde_json::from_str(row.get_raw(0).as_str()?)?;
Ok(v)
})?
.next()
.ok_or_else(|| AnkiError::DBError {
info: "col table empty".to_string(),
})??;
Ok(note_types)
}
#[allow(dead_code)] #[allow(dead_code)]
fn get_note(db: &Connection, nid: ObjID) -> Result<Option<Note>> { fn get_note(db: &Connection, nid: ObjID) -> Result<Option<Note>> {
let mut stmt = db.prepare_cached("select id, mid, mod, usn, flds from notes where id=?")?; let mut stmt = db.prepare_cached("select id, mid, mod, usn, flds from notes where id=?")?;
@ -130,14 +88,20 @@ pub(super) fn set_note(db: &Connection, note: &mut Note, note_type: &NoteType) -
note.mtime_secs = i64_unix_secs(); note.mtime_secs = i64_unix_secs();
// hard-coded for now // hard-coded for now
note.usn = -1; note.usn = -1;
let csum = field_checksum(&note.fields()[0]); let field1_nohtml = strip_html_preserving_image_filenames(&note.fields()[0]);
let sort_field = strip_html_preserving_image_filenames( let csum = field_checksum(field1_nohtml.as_ref());
note.fields() let sort_field = if note_type.sort_field_idx == 0 {
.get(note_type.sort_field_idx as usize) field1_nohtml
.ok_or_else(|| AnkiError::DBError { } else {
info: "sort field out of range".to_string(), strip_html_preserving_image_filenames(
})?, note.fields()
); .get(note_type.sort_field_idx as usize)
.ok_or_else(|| AnkiError::DBError {
info: "sort field out of range".to_string(),
kind: DBErrorKind::MissingEntity,
})?,
)
};
let mut stmt = let mut stmt =
db.prepare_cached("update notes set mod=?,usn=?,flds=?,sfld=?,csum=? where id=?")?; db.prepare_cached("update notes set mod=?,usn=?,flds=?,sfld=?,csum=? where id=?")?;
@ -152,8 +116,3 @@ pub(super) fn set_note(db: &Connection, note: &mut Note, note_type: &NoteType) -
Ok(()) Ok(())
} }
pub(super) fn mark_collection_modified(db: &Connection) -> Result<()> {
db.execute("update col set mod=?", params![i64_unix_millis()])?;
Ok(())
}

39
rslib/src/notetypes.rs Normal file
View file

@ -0,0 +1,39 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::types::ObjID;
use serde_aux::field_attributes::deserialize_number_from_string;
use serde_derive::Deserialize;
#[derive(Deserialize, Debug)]
pub(crate) struct NoteType {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub id: ObjID,
pub name: String,
#[serde(rename = "sortf")]
pub sort_field_idx: u16,
#[serde(rename = "latexsvg", default)]
pub latex_svg: bool,
#[serde(rename = "tmpls")]
pub templates: Vec<CardTemplate>,
#[serde(rename = "flds")]
pub fields: Vec<NoteField>,
}
#[derive(Deserialize, Debug)]
pub(crate) struct CardTemplate {
pub name: String,
pub ord: u16,
}
#[derive(Deserialize, Debug)]
pub(crate) struct NoteField {
pub name: String,
pub ord: u16,
}
impl NoteType {
pub fn latex_uses_svg(&self) -> bool {
self.latex_svg
}
}

View file

@ -3,6 +3,7 @@
use chrono::{Date, Duration, FixedOffset, Local, TimeZone}; use chrono::{Date, Duration, FixedOffset, Local, TimeZone};
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct SchedTimingToday { pub struct SchedTimingToday {
/// The number of days that have passed since the collection was created. /// The number of days that have passed since the collection was created.
pub days_elapsed: u32, pub days_elapsed: u32,
@ -17,7 +18,7 @@ pub struct SchedTimingToday {
/// - now_secs is a timestamp of the current time /// - now_secs is a timestamp of the current time
/// - now_mins_west is the current offset west of UTC /// - now_mins_west is the current offset west of UTC
/// - rollover_hour is the hour of the day the rollover happens (eg 4 for 4am) /// - rollover_hour is the hour of the day the rollover happens (eg 4 for 4am)
pub fn sched_timing_today( pub fn sched_timing_today_v2_new(
created_secs: i64, created_secs: i64,
created_mins_west: i32, created_mins_west: i32,
now_secs: i64, now_secs: i64,
@ -90,11 +91,84 @@ pub fn local_minutes_west_for_stamp(stamp: i64) -> i32 {
Local.timestamp(stamp, 0).offset().utc_minus_local() / 60 Local.timestamp(stamp, 0).offset().utc_minus_local() / 60
} }
// Legacy code
// ----------------------------------
fn sched_timing_today_v1(crt: i64, now: i64) -> SchedTimingToday {
let days_elapsed = (now - crt) / 86_400;
let next_day_at = crt + (days_elapsed + 1) * 86_400;
SchedTimingToday {
days_elapsed: days_elapsed as u32,
next_day_at,
}
}
fn sched_timing_today_v2_legacy(
crt: i64,
rollover: i8,
now: i64,
mins_west: i32,
) -> SchedTimingToday {
let normalized_rollover = normalized_rollover_hour(rollover);
let offset = fixed_offset_from_minutes(mins_west);
let crt_at_rollover = offset
.timestamp(crt, 0)
.date()
.and_hms(normalized_rollover as u32, 0, 0)
.timestamp();
let days_elapsed = (now - crt_at_rollover) / 86_400;
let mut next_day_at = offset
.timestamp(now, 0)
.date()
.and_hms(normalized_rollover as u32, 0, 0)
.timestamp();
if next_day_at < now {
next_day_at += 86_400;
}
SchedTimingToday {
days_elapsed: days_elapsed as u32,
next_day_at,
}
}
// ----------------------------------
/// Based on provided input, get timing info from the relevant function.
pub(crate) fn sched_timing_today(
created_secs: i64,
now_secs: i64,
created_mins_west: Option<i32>,
now_mins_west: Option<i32>,
rollover_hour: Option<i8>,
) -> SchedTimingToday {
let now_west = now_mins_west.unwrap_or_else(|| local_minutes_west_for_stamp(now_secs));
match (rollover_hour, created_mins_west) {
(None, _) => {
// if rollover unset, v1 scheduler
sched_timing_today_v1(created_secs, now_secs)
}
(Some(roll), None) => {
// if creationOffset unset, v2 scheduler with legacy cutoff handling
sched_timing_today_v2_legacy(created_secs, roll, now_secs, now_west)
}
(Some(roll), Some(crt_west)) => {
// v2 scheduler, new cutoff handling
sched_timing_today_v2_new(created_secs, crt_west, now_secs, now_west, roll)
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::SchedTimingToday;
use crate::sched::cutoff::sched_timing_today_v1;
use crate::sched::cutoff::sched_timing_today_v2_legacy;
use crate::sched::cutoff::{ use crate::sched::cutoff::{
fixed_offset_from_minutes, local_minutes_west_for_stamp, normalized_rollover_hour, fixed_offset_from_minutes, local_minutes_west_for_stamp, normalized_rollover_hour,
sched_timing_today, sched_timing_today_v2_new,
}; };
use chrono::{FixedOffset, Local, TimeZone, Utc}; use chrono::{FixedOffset, Local, TimeZone, Utc};
@ -117,7 +191,7 @@ mod test {
// helper // helper
fn elap(start: i64, end: i64, start_west: i32, end_west: i32, rollhour: i8) -> u32 { fn elap(start: i64, end: i64, start_west: i32, end_west: i32, rollhour: i8) -> u32 {
let today = sched_timing_today(start, start_west, end, end_west, rollhour); let today = sched_timing_today_v2_new(start, start_west, end, end_west, rollhour);
today.days_elapsed today.days_elapsed
} }
@ -228,7 +302,7 @@ mod test {
// before the rollover, the next day should be later on the same day // before the rollover, the next day should be later on the same day
let now = Local.ymd(2019, 1, 3).and_hms(2, 0, 0); let now = Local.ymd(2019, 1, 3).and_hms(2, 0, 0);
let next_day_at = Local.ymd(2019, 1, 3).and_hms(rollhour, 0, 0); let next_day_at = Local.ymd(2019, 1, 3).and_hms(rollhour, 0, 0);
let today = sched_timing_today( let today = sched_timing_today_v2_new(
crt.timestamp(), crt.timestamp(),
crt.offset().utc_minus_local() / 60, crt.offset().utc_minus_local() / 60,
now.timestamp(), now.timestamp(),
@ -240,7 +314,7 @@ mod test {
// after the rollover, the next day should be the next day // after the rollover, the next day should be the next day
let now = Local.ymd(2019, 1, 3).and_hms(rollhour, 0, 0); let now = Local.ymd(2019, 1, 3).and_hms(rollhour, 0, 0);
let next_day_at = Local.ymd(2019, 1, 4).and_hms(rollhour, 0, 0); let next_day_at = Local.ymd(2019, 1, 4).and_hms(rollhour, 0, 0);
let today = sched_timing_today( let today = sched_timing_today_v2_new(
crt.timestamp(), crt.timestamp(),
crt.offset().utc_minus_local() / 60, crt.offset().utc_minus_local() / 60,
now.timestamp(), now.timestamp(),
@ -252,7 +326,7 @@ mod test {
// after the rollover, the next day should be the next day // after the rollover, the next day should be the next day
let now = Local.ymd(2019, 1, 3).and_hms(rollhour + 3, 0, 0); let now = Local.ymd(2019, 1, 3).and_hms(rollhour + 3, 0, 0);
let next_day_at = Local.ymd(2019, 1, 4).and_hms(rollhour, 0, 0); let next_day_at = Local.ymd(2019, 1, 4).and_hms(rollhour, 0, 0);
let today = sched_timing_today( let today = sched_timing_today_v2_new(
crt.timestamp(), crt.timestamp(),
crt.offset().utc_minus_local() / 60, crt.offset().utc_minus_local() / 60,
now.timestamp(), now.timestamp(),
@ -261,4 +335,34 @@ mod test {
); );
assert_eq!(today.next_day_at, next_day_at.timestamp()); assert_eq!(today.next_day_at, next_day_at.timestamp());
} }
#[test]
fn legacy_timing() {
let now = 1584491078;
let mins_west = -600;
assert_eq!(
sched_timing_today_v1(1575226800, now),
SchedTimingToday {
days_elapsed: 107,
next_day_at: 1584558000
}
);
assert_eq!(
sched_timing_today_v2_legacy(1533564000, 0, now, mins_west),
SchedTimingToday {
days_elapsed: 589,
next_day_at: 1584540000
}
);
assert_eq!(
sched_timing_today_v2_legacy(1524038400, 4, now, mins_west),
SchedTimingToday {
days_elapsed: 700,
next_day_at: 1584554400
}
);
}
} }

154
rslib/src/search/cards.rs Normal file
View file

@ -0,0 +1,154 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::{parser::Node, sqlwriter::node_to_sql};
use crate::card::CardType;
use crate::collection::RequestContext;
use crate::config::SortKind;
use crate::err::Result;
use crate::search::parser::parse;
use crate::types::ObjID;
use rusqlite::params;
pub(crate) enum SortMode {
NoOrder,
FromConfig,
Builtin { kind: SortKind, reverse: bool },
Custom(String),
}
pub(crate) fn search_cards<'a, 'b>(
req: &'a mut RequestContext<'b>,
search: &'a str,
order: SortMode,
) -> Result<Vec<ObjID>> {
let top_node = Node::Group(parse(search)?);
let (sql, args) = node_to_sql(req, &top_node)?;
let mut sql = format!(
"select c.id from cards c, notes n where c.nid=n.id and {}",
sql
);
match order {
SortMode::NoOrder => (),
SortMode::FromConfig => {
let conf = req.storage.all_config()?;
prepare_sort(req, &conf.browser_sort_kind)?;
sql.push_str(" order by ");
write_order(&mut sql, &conf.browser_sort_kind, conf.browser_sort_reverse)?;
}
SortMode::Builtin { kind, reverse } => {
prepare_sort(req, &kind)?;
sql.push_str(" order by ");
write_order(&mut sql, &kind, reverse)?;
}
SortMode::Custom(order_clause) => {
sql.push_str(" order by ");
sql.push_str(&order_clause);
}
}
let mut stmt = req.storage.db.prepare(&sql)?;
let ids: Vec<i64> = stmt
.query_map(&args, |row| row.get(0))?
.collect::<std::result::Result<_, _>>()?;
Ok(ids)
}
/// Add the order clause to the sql.
fn write_order(sql: &mut String, kind: &SortKind, reverse: bool) -> Result<()> {
let tmp_str;
let order = match kind {
SortKind::NoteCreation => "n.id asc, c.ord asc",
SortKind::NoteMod => "n.mod asc, c.ord asc",
SortKind::NoteField => "n.sfld collate nocase asc, c.ord asc",
SortKind::CardMod => "c.mod asc",
SortKind::CardReps => "c.reps asc",
SortKind::CardDue => "c.type asc, c.due asc",
SortKind::CardEase => {
tmp_str = format!("c.type = {} asc, c.factor asc", CardType::New as i8);
&tmp_str
}
SortKind::CardLapses => "c.lapses asc",
SortKind::CardInterval => "c.ivl asc",
SortKind::NoteTags => "n.tags asc",
SortKind::CardDeck => "(select v from sort_order where k = c.did) asc",
SortKind::NoteType => "(select v from sort_order where k = n.mid) asc",
SortKind::CardTemplate => "(select v from sort_order where k1 = n.mid and k2 = c.ord) asc",
};
if order.is_empty() {
return Ok(());
}
if reverse {
sql.push_str(
&order
.to_ascii_lowercase()
.replace(" desc", "")
.replace(" asc", " desc"),
)
} else {
sql.push_str(order);
}
Ok(())
}
// In the future these items should be moved from JSON into separate SQL tables,
// - for now we use a temporary deck to sort them.
fn prepare_sort(req: &mut RequestContext, kind: &SortKind) -> Result<()> {
use SortKind::*;
match kind {
CardDeck | NoteType => {
prepare_sort_order_table(req)?;
let mut stmt = req
.storage
.db
.prepare("insert into sort_order (k,v) values (?,?)")?;
match kind {
CardDeck => {
for (k, v) in req.storage.all_decks()? {
stmt.execute(params![k, v.name])?;
}
}
NoteType => {
for (k, v) in req.storage.all_note_types()? {
stmt.execute(params![k, v.name])?;
}
}
_ => unreachable!(),
}
}
CardTemplate => {
prepare_sort_order_table2(req)?;
let mut stmt = req
.storage
.db
.prepare("insert into sort_order (k1,k2,v) values (?,?,?)")?;
for (ntid, nt) in req.storage.all_note_types()? {
for tmpl in nt.templates {
stmt.execute(params![ntid, tmpl.ord, tmpl.name])?;
}
}
}
_ => (),
}
Ok(())
}
fn prepare_sort_order_table(req: &mut RequestContext) -> Result<()> {
req.storage
.db
.execute_batch(include_str!("sort_order.sql"))?;
Ok(())
}
fn prepare_sort_order_table2(req: &mut RequestContext) -> Result<()> {
req.storage
.db
.execute_batch(include_str!("sort_order2.sql"))?;
Ok(())
}

7
rslib/src/search/mod.rs Normal file
View file

@ -0,0 +1,7 @@
mod cards;
mod notes;
mod parser;
mod sqlwriter;
pub(crate) use cards::{search_cards, SortMode};
pub(crate) use notes::search_notes;

28
rslib/src/search/notes.rs Normal file
View file

@ -0,0 +1,28 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::{parser::Node, sqlwriter::node_to_sql};
use crate::collection::RequestContext;
use crate::err::Result;
use crate::search::parser::parse;
use crate::types::ObjID;
pub(crate) fn search_notes<'a, 'b>(
req: &'a mut RequestContext<'b>,
search: &'a str,
) -> Result<Vec<ObjID>> {
let top_node = Node::Group(parse(search)?);
let (sql, args) = node_to_sql(req, &top_node)?;
let sql = format!(
"select n.id from cards c, notes n where c.nid=n.id and {}",
sql
);
let mut stmt = req.storage.db.prepare(&sql)?;
let ids: Vec<i64> = stmt
.query_map(&args, |row| row.get(0))?
.collect::<std::result::Result<_, _>>()?;
Ok(ids)
}

528
rslib/src/search/parser.rs Normal file
View file

@ -0,0 +1,528 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result};
use crate::types::ObjID;
use nom::branch::alt;
use nom::bytes::complete::{escaped, is_not, tag, take_while1};
use nom::character::complete::{anychar, char, one_of};
use nom::character::is_digit;
use nom::combinator::{all_consuming, map, map_res};
use nom::sequence::{delimited, preceded, tuple};
use nom::{multi::many0, IResult};
use std::{borrow::Cow, num};
// fixme: need to preserve \ when used twice in string
struct ParseError {}
impl From<num::ParseIntError> for ParseError {
fn from(_: num::ParseIntError) -> Self {
ParseError {}
}
}
impl From<num::ParseFloatError> for ParseError {
fn from(_: num::ParseFloatError) -> Self {
ParseError {}
}
}
impl<I> From<nom::Err<(I, nom::error::ErrorKind)>> for ParseError {
fn from(_: nom::Err<(I, nom::error::ErrorKind)>) -> Self {
ParseError {}
}
}
type ParseResult<T> = std::result::Result<T, ParseError>;
#[derive(Debug, PartialEq)]
pub(super) enum Node<'a> {
And,
Or,
Not(Box<Node<'a>>),
Group(Vec<Node<'a>>),
Search(SearchNode<'a>),
}
#[derive(Debug, PartialEq)]
pub(super) enum SearchNode<'a> {
// text without a colon
UnqualifiedText(Cow<'a, str>),
// foo:bar, where foo doesn't match a term below
SingleField {
field: Cow<'a, str>,
text: Cow<'a, str>,
is_re: bool,
},
AddedInDays(u32),
CardTemplate(TemplateKind),
Deck(Cow<'a, str>),
NoteTypeID(ObjID),
NoteType(Cow<'a, str>),
Rated {
days: u32,
ease: Option<u8>,
},
Tag(Cow<'a, str>),
Duplicates {
note_type_id: ObjID,
text: String,
},
State(StateKind),
Flag(u8),
NoteIDs(Cow<'a, str>),
CardIDs(Cow<'a, str>),
Property {
operator: String,
kind: PropertyKind,
},
WholeCollection,
Regex(Cow<'a, str>),
NoCombining(Cow<'a, str>),
}
#[derive(Debug, PartialEq)]
pub(super) enum PropertyKind {
Due(i32),
Interval(u32),
Reps(u32),
Lapses(u32),
Ease(f32),
}
#[derive(Debug, PartialEq)]
pub(super) enum StateKind {
New,
Review,
Learning,
Due,
Buried,
Suspended,
}
#[derive(Debug, PartialEq)]
pub(super) enum TemplateKind {
Ordinal(u16),
Name(String),
}
/// Parse the input string into a list of nodes.
#[allow(dead_code)]
pub(super) fn parse(input: &str) -> Result<Vec<Node>> {
let input = input.trim();
if input.is_empty() {
return Ok(vec![Node::Search(SearchNode::WholeCollection)]);
}
let (_, nodes) = all_consuming(group_inner)(input)
.map_err(|_e| AnkiError::invalid_input("unable to parse search"))?;
Ok(nodes)
}
/// One or more nodes surrounded by brackets, eg (one OR two)
fn group(s: &str) -> IResult<&str, Node> {
map(delimited(char('('), group_inner, char(')')), |nodes| {
Node::Group(nodes)
})(s)
}
/// One or more nodes inside brackets, er 'one OR two -three'
fn group_inner(input: &str) -> IResult<&str, Vec<Node>> {
let mut remaining = input;
let mut nodes = vec![];
loop {
match node(remaining) {
Ok((rem, node)) => {
remaining = rem;
if nodes.len() % 2 == 0 {
// before adding the node, if the length is even then the node
// must not be a boolean
if matches!(node, Node::And | Node::Or) {
return Err(nom::Err::Failure(("", nom::error::ErrorKind::NoneOf)));
}
} else {
// if the length is odd, the next item must be a boolean. if it's
// not, add an implicit and
if !matches!(node, Node::And | Node::Or) {
nodes.push(Node::And);
}
}
nodes.push(node);
}
Err(e) => match e {
nom::Err::Error(_) => break,
_ => return Err(e),
},
};
}
if nodes.is_empty() {
Err(nom::Err::Error((remaining, nom::error::ErrorKind::Many1)))
} else {
// chomp any trailing whitespace
let (remaining, _) = whitespace0(remaining)?;
Ok((remaining, nodes))
}
}
fn whitespace0(s: &str) -> IResult<&str, Vec<char>> {
many0(one_of(" \u{3000}"))(s)
}
/// Optional leading space, then a (negated) group or text
fn node(s: &str) -> IResult<&str, Node> {
preceded(whitespace0, alt((negated_node, group, text)))(s)
}
fn negated_node(s: &str) -> IResult<&str, Node> {
map(preceded(char('-'), alt((group, text))), |node| {
Node::Not(Box::new(node))
})(s)
}
/// Either quoted or unquoted text
fn text(s: &str) -> IResult<&str, Node> {
alt((quoted_term, partially_quoted_term, unquoted_term))(s)
}
/// Determine if text is a qualified search, and handle escaped chars.
fn search_node_for_text(s: &str) -> ParseResult<SearchNode> {
let mut it = s.splitn(2, ':');
let (head, tail) = (
unescape_quotes(it.next().unwrap()),
it.next().map(unescape_quotes),
);
if let Some(tail) = tail {
search_node_for_text_with_argument(head, tail)
} else {
Ok(SearchNode::UnqualifiedText(head))
}
}
/// \" -> "
fn unescape_quotes(s: &str) -> Cow<str> {
if s.find(r#"\""#).is_some() {
s.replace(r#"\""#, "\"").into()
} else {
s.into()
}
}
/// Unquoted text, terminated by a space or )
fn unquoted_term(s: &str) -> IResult<&str, Node> {
map_res(
take_while1(|c| c != ' ' && c != ')' && c != '"'),
|text: &str| -> ParseResult<Node> {
Ok(if text.eq_ignore_ascii_case("or") {
Node::Or
} else if text.eq_ignore_ascii_case("and") {
Node::And
} else {
Node::Search(search_node_for_text(text)?)
})
},
)(s)
}
/// Quoted text, including the outer double quotes.
fn quoted_term(s: &str) -> IResult<&str, Node> {
map_res(quoted_term_str, |o| -> ParseResult<Node> {
Ok(Node::Search(search_node_for_text(o)?))
})(s)
}
fn quoted_term_str(s: &str) -> IResult<&str, &str> {
delimited(char('"'), quoted_term_inner, char('"'))(s)
}
/// Quoted text, terminated by a non-escaped double quote
fn quoted_term_inner(s: &str) -> IResult<&str, &str> {
escaped(is_not(r#""\"#), '\\', anychar)(s)
}
/// eg deck:"foo bar" - quotes must come after the :
fn partially_quoted_term(s: &str) -> IResult<&str, Node> {
let term = take_while1(|c| c != ' ' && c != ')' && c != ':');
let (s, (term, _, quoted_val)) = tuple((term, char(':'), quoted_term_str))(s)?;
match search_node_for_text_with_argument(term.into(), quoted_val.into()) {
Ok(search) => Ok((s, Node::Search(search))),
Err(_) => Err(nom::Err::Failure((s, nom::error::ErrorKind::NoneOf))),
}
}
/// Convert a colon-separated key/val pair into the relevant search type.
fn search_node_for_text_with_argument<'a>(
key: Cow<'a, str>,
val: Cow<'a, str>,
) -> ParseResult<SearchNode<'a>> {
Ok(match key.to_ascii_lowercase().as_str() {
"added" => SearchNode::AddedInDays(val.parse()?),
"deck" => SearchNode::Deck(val),
"note" => SearchNode::NoteType(val),
"tag" => SearchNode::Tag(val),
"mid" => SearchNode::NoteTypeID(val.parse()?),
"nid" => SearchNode::NoteIDs(check_id_list(val)?),
"cid" => SearchNode::CardIDs(check_id_list(val)?),
"card" => parse_template(val.as_ref()),
"is" => parse_state(val.as_ref())?,
"flag" => parse_flag(val.as_ref())?,
"rated" => parse_rated(val.as_ref())?,
"dupes" => parse_dupes(val.as_ref())?,
"prop" => parse_prop(val.as_ref())?,
"re" => SearchNode::Regex(val),
"nc" => SearchNode::NoCombining(val),
// anything else is a field search
_ => parse_single_field(key.as_ref(), val.as_ref()),
})
}
/// ensure a list of ids contains only numbers and commas, returning unchanged if true
/// used by nid: and cid:
fn check_id_list(s: Cow<str>) -> ParseResult<Cow<str>> {
if s.is_empty() || s.as_bytes().iter().any(|&c| !is_digit(c) && c != b',') {
Err(ParseError {})
} else {
Ok(s)
}
}
/// eg is:due
fn parse_state(s: &str) -> ParseResult<SearchNode<'static>> {
use StateKind::*;
Ok(SearchNode::State(match s {
"new" => New,
"review" => Review,
"learn" => Learning,
"due" => Due,
"buried" => Buried,
"suspended" => Suspended,
_ => return Err(ParseError {}),
}))
}
/// flag:0-4
fn parse_flag(s: &str) -> ParseResult<SearchNode<'static>> {
let n: u8 = s.parse()?;
if n > 4 {
Err(ParseError {})
} else {
Ok(SearchNode::Flag(n))
}
}
/// eg rated:3 or rated:10:2
/// second arg must be between 0-4
fn parse_rated(val: &str) -> ParseResult<SearchNode<'static>> {
let mut it = val.splitn(2, ':');
let days = it.next().unwrap().parse()?;
let ease = match it.next() {
Some(v) => {
let n: u8 = v.parse()?;
if n < 5 {
Some(n)
} else {
return Err(ParseError {});
}
}
None => None,
};
Ok(SearchNode::Rated { days, ease })
}
/// eg dupes:1231,hello
fn parse_dupes(val: &str) -> ParseResult<SearchNode<'static>> {
let mut it = val.splitn(2, ',');
let mid: ObjID = it.next().unwrap().parse()?;
let text = it.next().ok_or(ParseError {})?;
Ok(SearchNode::Duplicates {
note_type_id: mid,
text: text.into(),
})
}
/// eg prop:ivl>3, prop:ease!=2.5
fn parse_prop(val: &str) -> ParseResult<SearchNode<'static>> {
let (val, key) = alt((
tag("ivl"),
tag("due"),
tag("reps"),
tag("lapses"),
tag("ease"),
))(val)?;
let (val, operator) = alt((
tag("<="),
tag(">="),
tag("!="),
tag("="),
tag("<"),
tag(">"),
))(val)?;
let kind = if key == "ease" {
let num: f32 = val.parse()?;
PropertyKind::Ease(num)
} else if key == "due" {
let num: i32 = val.parse()?;
PropertyKind::Due(num)
} else {
let num: u32 = val.parse()?;
match key {
"ivl" => PropertyKind::Interval(num),
"reps" => PropertyKind::Reps(num),
"lapses" => PropertyKind::Lapses(num),
_ => unreachable!(),
}
};
Ok(SearchNode::Property {
operator: operator.to_string(),
kind,
})
}
fn parse_template(val: &str) -> SearchNode<'static> {
SearchNode::CardTemplate(match val.parse::<u16>() {
Ok(n) => TemplateKind::Ordinal(n.max(1) - 1),
Err(_) => TemplateKind::Name(val.into()),
})
}
fn parse_single_field(key: &str, mut val: &str) -> SearchNode<'static> {
let is_re = if val.starts_with("re:") {
val = val.trim_start_matches("re:");
true
} else {
false
};
SearchNode::SingleField {
field: key.to_string().into(),
text: val.to_string().into(),
is_re,
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parsing() -> Result<()> {
use Node::*;
use SearchNode::*;
assert_eq!(parse("")?, vec![Search(SearchNode::WholeCollection)]);
assert_eq!(parse(" ")?, vec![Search(SearchNode::WholeCollection)]);
// leading/trailing/interspersed whitespace
assert_eq!(
parse(" t t2 ")?,
vec![
Search(UnqualifiedText("t".into())),
And,
Search(UnqualifiedText("t2".into()))
]
);
// including in groups
assert_eq!(
parse("( t t2 )")?,
vec![Group(vec![
Search(UnqualifiedText("t".into())),
And,
Search(UnqualifiedText("t2".into()))
])]
);
assert_eq!(
parse(r#"hello -(world and "foo:bar baz") OR test"#)?,
vec![
Search(UnqualifiedText("hello".into())),
And,
Not(Box::new(Group(vec![
Search(UnqualifiedText("world".into())),
And,
Search(SingleField {
field: "foo".into(),
text: "bar baz".into(),
is_re: false,
})
]))),
Or,
Search(UnqualifiedText("test".into()))
]
);
assert_eq!(
parse("foo:re:bar")?,
vec![Search(SingleField {
field: "foo".into(),
text: "bar".into(),
is_re: true
})]
);
// any character should be escapable in quotes
assert_eq!(
parse(r#""re:\btest""#)?,
vec![Search(Regex(r"\btest".into()))]
);
assert_eq!(parse("added:3")?, vec![Search(AddedInDays(3))]);
assert_eq!(
parse("card:front")?,
vec![Search(CardTemplate(TemplateKind::Name("front".into())))]
);
assert_eq!(
parse("card:3")?,
vec![Search(CardTemplate(TemplateKind::Ordinal(2)))]
);
// 0 must not cause a crash due to underflow
assert_eq!(
parse("card:0")?,
vec![Search(CardTemplate(TemplateKind::Ordinal(0)))]
);
assert_eq!(parse("deck:default")?, vec![Search(Deck("default".into()))]);
assert_eq!(
parse("deck:\"default one\"")?,
vec![Search(Deck("default one".into()))]
);
assert_eq!(parse("note:basic")?, vec![Search(NoteType("basic".into()))]);
assert_eq!(parse("tag:hard")?, vec![Search(Tag("hard".into()))]);
assert_eq!(
parse("nid:1237123712,2,3")?,
vec![Search(NoteIDs("1237123712,2,3".into()))]
);
assert!(parse("nid:1237123712_2,3").is_err());
assert_eq!(parse("is:due")?, vec![Search(State(StateKind::Due))]);
assert_eq!(parse("flag:3")?, vec![Search(Flag(3))]);
assert!(parse("flag:-1").is_err());
assert!(parse("flag:5").is_err());
assert_eq!(
parse("prop:ivl>3")?,
vec![Search(Property {
operator: ">".into(),
kind: PropertyKind::Interval(3)
})]
);
assert!(parse("prop:ivl>3.3").is_err());
assert_eq!(
parse("prop:ease<=3.3")?,
vec![Search(Property {
operator: "<=".into(),
kind: PropertyKind::Ease(3.3)
})]
);
Ok(())
}
}

View file

@ -0,0 +1,2 @@
drop table if exists sort_order;
create temporary table sort_order (k int primary key, v text);

View file

@ -0,0 +1,2 @@
drop table if exists sort_order;
create temporary table sort_order (k1 int, k2 int, v text, primary key (k1, k2)) without rowid;

View file

@ -0,0 +1,583 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::parser::{Node, PropertyKind, SearchNode, StateKind, TemplateKind};
use crate::card::CardQueue;
use crate::decks::child_ids;
use crate::decks::get_deck;
use crate::err::{AnkiError, Result};
use crate::notes::field_checksum;
use crate::text::matches_wildcard;
use crate::text::without_combining;
use crate::{
collection::RequestContext, text::strip_html_preserving_image_filenames, types::ObjID,
};
use std::fmt::Write;
struct SqlWriter<'a, 'b> {
req: &'a mut RequestContext<'b>,
sql: String,
args: Vec<String>,
}
pub(super) fn node_to_sql(req: &mut RequestContext, node: &Node) -> Result<(String, Vec<String>)> {
let mut sctx = SqlWriter::new(req);
sctx.write_node_to_sql(&node)?;
Ok((sctx.sql, sctx.args))
}
impl SqlWriter<'_, '_> {
fn new<'a, 'b>(req: &'a mut RequestContext<'b>) -> SqlWriter<'a, 'b> {
let sql = String::new();
let args = vec![];
SqlWriter { req, sql, args }
}
fn write_node_to_sql(&mut self, node: &Node) -> Result<()> {
match node {
Node::And => write!(self.sql, " and ").unwrap(),
Node::Or => write!(self.sql, " or ").unwrap(),
Node::Not(node) => {
write!(self.sql, "not ").unwrap();
self.write_node_to_sql(node)?;
}
Node::Group(nodes) => {
write!(self.sql, "(").unwrap();
for node in nodes {
self.write_node_to_sql(node)?;
}
write!(self.sql, ")").unwrap();
}
Node::Search(search) => self.write_search_node_to_sql(search)?,
};
Ok(())
}
fn write_search_node_to_sql(&mut self, node: &SearchNode) -> Result<()> {
match node {
SearchNode::UnqualifiedText(text) => self.write_unqualified(text),
SearchNode::SingleField { field, text, is_re } => {
self.write_single_field(field.as_ref(), text.as_ref(), *is_re)?
}
SearchNode::AddedInDays(days) => self.write_added(*days)?,
SearchNode::CardTemplate(template) => self.write_template(template)?,
SearchNode::Deck(deck) => self.write_deck(deck.as_ref())?,
SearchNode::NoteTypeID(ntid) => {
write!(self.sql, "n.mid = {}", ntid).unwrap();
}
SearchNode::NoteType(notetype) => self.write_note_type(notetype.as_ref())?,
SearchNode::Rated { days, ease } => self.write_rated(*days, *ease)?,
SearchNode::Tag(tag) => self.write_tag(tag),
SearchNode::Duplicates { note_type_id, text } => self.write_dupes(*note_type_id, text),
SearchNode::State(state) => self.write_state(state)?,
SearchNode::Flag(flag) => {
write!(self.sql, "(c.flags & 7) == {}", flag).unwrap();
}
SearchNode::NoteIDs(nids) => {
write!(self.sql, "n.id in ({})", nids).unwrap();
}
SearchNode::CardIDs(cids) => {
write!(self.sql, "c.id in ({})", cids).unwrap();
}
SearchNode::Property { operator, kind } => self.write_prop(operator, kind)?,
SearchNode::WholeCollection => write!(self.sql, "true").unwrap(),
SearchNode::Regex(re) => self.write_regex(re.as_ref()),
SearchNode::NoCombining(text) => self.write_no_combining(text.as_ref()),
};
Ok(())
}
fn write_unqualified(&mut self, text: &str) {
// implicitly wrap in %
let text = format!("%{}%", text);
self.args.push(text);
write!(
self.sql,
"(n.sfld like ?{n} escape '\\' or n.flds like ?{n} escape '\\')",
n = self.args.len(),
)
.unwrap();
}
fn write_no_combining(&mut self, text: &str) {
let text = format!("%{}%", without_combining(text));
self.args.push(text);
write!(
self.sql,
concat!(
"(coalesce(without_combining(cast(n.sfld as text)), n.sfld) like ?{n} escape '\\' ",
"or coalesce(without_combining(n.flds), n.flds) like ?{n} escape '\\')"
),
n = self.args.len(),
)
.unwrap();
}
fn write_tag(&mut self, text: &str) {
match text {
"none" => {
write!(self.sql, "n.tags = ''").unwrap();
}
"*" | "%" => {
write!(self.sql, "true").unwrap();
}
text => {
let tag = format!("% {} %", text.replace('*', "%"));
write!(self.sql, "n.tags like ? escape '\\'").unwrap();
self.args.push(tag);
}
}
}
fn write_rated(&mut self, days: u32, ease: Option<u8>) -> Result<()> {
let today_cutoff = self.req.storage.timing_today()?.next_day_at;
let days = days.min(365) as i64;
let target_cutoff_ms = (today_cutoff - 86_400 * days) * 1_000;
write!(
self.sql,
"c.id in (select cid from revlog where id>{}",
target_cutoff_ms
)
.unwrap();
if let Some(ease) = ease {
write!(self.sql, " and ease={})", ease).unwrap();
} else {
write!(self.sql, ")").unwrap();
}
Ok(())
}
fn write_prop(&mut self, op: &str, kind: &PropertyKind) -> Result<()> {
let timing = self.req.storage.timing_today()?;
match kind {
PropertyKind::Due(days) => {
let day = days + (timing.days_elapsed as i32);
write!(
self.sql,
"(c.queue in ({rev},{daylrn}) and due {op} {day})",
rev = CardQueue::Review as u8,
daylrn = CardQueue::DayLearn as u8,
op = op,
day = day
)
}
PropertyKind::Interval(ivl) => write!(self.sql, "ivl {} {}", op, ivl),
PropertyKind::Reps(reps) => write!(self.sql, "reps {} {}", op, reps),
PropertyKind::Lapses(days) => write!(self.sql, "lapses {} {}", op, days),
PropertyKind::Ease(ease) => {
write!(self.sql, "factor {} {}", op, (ease * 1000.0) as u32)
}
}
.unwrap();
Ok(())
}
fn write_state(&mut self, state: &StateKind) -> Result<()> {
let timing = self.req.storage.timing_today()?;
match state {
StateKind::New => write!(self.sql, "c.type = {}", CardQueue::New as i8),
StateKind::Review => write!(self.sql, "c.type = {}", CardQueue::Review as i8),
StateKind::Learning => write!(
self.sql,
"c.queue in ({},{})",
CardQueue::Learn as i8,
CardQueue::DayLearn as i8
),
StateKind::Buried => write!(
self.sql,
"c.queue in ({},{})",
CardQueue::SchedBuried as i8,
CardQueue::UserBuried as i8
),
StateKind::Suspended => write!(self.sql, "c.queue = {}", CardQueue::Suspended as i8),
StateKind::Due => write!(
self.sql,
"
(c.queue in ({rev},{daylrn}) and c.due <= {today}) or
(c.queue = {lrn} and c.due <= {daycutoff})",
rev = CardQueue::Review as i8,
daylrn = CardQueue::DayLearn as i8,
today = timing.days_elapsed,
lrn = CardQueue::Learn as i8,
daycutoff = timing.next_day_at,
),
}
.unwrap();
Ok(())
}
fn write_deck(&mut self, deck: &str) -> Result<()> {
match deck {
"*" => write!(self.sql, "true").unwrap(),
"filtered" => write!(self.sql, "c.odid > 0").unwrap(),
deck => {
let all_decks: Vec<_> = self
.req
.storage
.all_decks()?
.into_iter()
.map(|(_, v)| v)
.collect();
let dids_with_children = if deck == "current" {
let config = self.req.storage.all_config()?;
let mut dids_with_children = vec![config.current_deck_id];
let current = get_deck(&all_decks, config.current_deck_id)
.ok_or_else(|| AnkiError::invalid_input("invalid current deck"))?;
for child_did in child_ids(&all_decks, &current.name) {
dids_with_children.push(child_did);
}
dids_with_children
} else {
let mut dids_with_children = vec![];
for deck in all_decks.iter().filter(|d| matches_wildcard(&d.name, deck)) {
dids_with_children.push(deck.id);
for child_id in child_ids(&all_decks, &deck.name) {
dids_with_children.push(child_id);
}
}
dids_with_children
};
self.sql.push_str("c.did in ");
ids_to_string(&mut self.sql, &dids_with_children);
}
};
Ok(())
}
fn write_template(&mut self, template: &TemplateKind) -> Result<()> {
match template {
TemplateKind::Ordinal(n) => {
write!(self.sql, "c.ord = {}", n).unwrap();
}
TemplateKind::Name(name) => {
let note_types = self.req.storage.all_note_types()?;
let mut id_ords = vec![];
for nt in note_types.values() {
for tmpl in &nt.templates {
if matches_wildcard(&tmpl.name, name) {
id_ords.push((nt.id, tmpl.ord));
}
}
}
// sort for the benefit of unit tests
id_ords.sort();
if id_ords.is_empty() {
self.sql.push_str("false");
} else {
let v: Vec<_> = id_ords
.iter()
.map(|(ntid, ord)| format!("(n.mid = {} and c.ord = {})", ntid, ord))
.collect();
write!(self.sql, "({})", v.join(" or ")).unwrap();
}
}
};
Ok(())
}
fn write_note_type(&mut self, nt_name: &str) -> Result<()> {
let mut ntids: Vec<_> = self
.req
.storage
.all_note_types()?
.values()
.filter(|nt| matches_wildcard(&nt.name, nt_name))
.map(|nt| nt.id)
.collect();
self.sql.push_str("n.mid in ");
// sort for the benefit of unit tests
ntids.sort();
ids_to_string(&mut self.sql, &ntids);
Ok(())
}
fn write_single_field(&mut self, field_name: &str, val: &str, is_re: bool) -> Result<()> {
let note_types = self.req.storage.all_note_types()?;
let mut field_map = vec![];
for nt in note_types.values() {
for field in &nt.fields {
if matches_wildcard(&field.name, field_name) {
field_map.push((nt.id, field.ord));
}
}
}
// for now, sort the map for the benefit of unit tests
field_map.sort();
if field_map.is_empty() {
write!(self.sql, "false").unwrap();
return Ok(());
}
let cmp;
if is_re {
cmp = "regexp";
self.args.push(format!("(?i){}", val));
} else {
cmp = "like";
self.args.push(val.replace('*', "%"));
}
let arg_idx = self.args.len();
let searches: Vec<_> = field_map
.iter()
.map(|(ntid, ord)| {
format!(
"(n.mid = {mid} and field_at_index(n.flds, {ord}) {cmp} ?{n})",
mid = ntid,
ord = ord,
cmp = cmp,
n = arg_idx
)
})
.collect();
write!(self.sql, "({})", searches.join(" or ")).unwrap();
Ok(())
}
fn write_dupes(&mut self, ntid: ObjID, text: &str) {
let text_nohtml = strip_html_preserving_image_filenames(text);
let csum = field_checksum(text_nohtml.as_ref());
write!(
self.sql,
"(n.mid = {} and n.csum = {} and field_at_index(n.flds, 0) = ?",
ntid, csum
)
.unwrap();
self.args.push(text.to_string());
}
fn write_added(&mut self, days: u32) -> Result<()> {
let timing = self.req.storage.timing_today()?;
let cutoff = (timing.next_day_at - (86_400 * (days as i64))) * 1_000;
write!(self.sql, "c.id > {}", cutoff).unwrap();
Ok(())
}
fn write_regex(&mut self, word: &str) {
self.sql.push_str("n.flds regexp ?");
self.args.push(format!(r"(?i){}", word));
}
}
// Write a list of IDs as '(x,y,...)' into the provided string.
fn ids_to_string<T>(buf: &mut String, ids: &[T])
where
T: std::fmt::Display,
{
buf.push('(');
if !ids.is_empty() {
for id in ids.iter().skip(1) {
write!(buf, "{},", id).unwrap();
}
write!(buf, "{}", ids[0]).unwrap();
}
buf.push(')');
}
#[cfg(test)]
mod test {
use super::ids_to_string;
use crate::{collection::open_collection, i18n::I18n, log};
use std::{fs, path::PathBuf};
use tempfile::tempdir;
#[test]
fn ids_string() {
let mut s = String::new();
ids_to_string::<u8>(&mut s, &[]);
assert_eq!(s, "()");
s.clear();
ids_to_string(&mut s, &[7]);
assert_eq!(s, "(7)");
s.clear();
ids_to_string(&mut s, &[7, 6]);
assert_eq!(s, "(6,7)");
s.clear();
ids_to_string(&mut s, &[7, 6, 5]);
assert_eq!(s, "(6,5,7)");
s.clear();
}
use super::super::parser::parse;
use super::*;
// shortcut
fn s(req: &mut RequestContext, search: &str) -> (String, Vec<String>) {
let node = Node::Group(parse(search).unwrap());
node_to_sql(req, &node).unwrap()
}
#[test]
fn sql() -> Result<()> {
// re-use the mediacheck .anki2 file for now
use crate::media::check::test::MEDIACHECK_ANKI2;
let dir = tempdir().unwrap();
let col_path = dir.path().join("col.anki2");
fs::write(&col_path, MEDIACHECK_ANKI2).unwrap();
let i18n = I18n::new(&[""], "", log::terminal());
let col = open_collection(
&col_path,
&PathBuf::new(),
&PathBuf::new(),
false,
i18n,
log::terminal(),
)
.unwrap();
col.with_ctx(|ctx| {
// unqualified search
assert_eq!(
s(ctx, "test"),
(
"((n.sfld like ?1 escape '\\' or n.flds like ?1 escape '\\'))".into(),
vec!["%test%".into()]
)
);
assert_eq!(s(ctx, "te%st").1, vec!["%te%st%".to_string()]);
// user should be able to escape sql wildcards
assert_eq!(s(ctx, r#"te\%s\_t"#).1, vec!["%te\\%s\\_t%".to_string()]);
// qualified search
assert_eq!(
s(ctx, "front:te*st"),
(
concat!(
"(((n.mid = 1581236385344 and field_at_index(n.flds, 0) like ?1) or ",
"(n.mid = 1581236385345 and field_at_index(n.flds, 0) like ?1) or ",
"(n.mid = 1581236385346 and field_at_index(n.flds, 0) like ?1) or ",
"(n.mid = 1581236385347 and field_at_index(n.flds, 0) like ?1)))"
)
.into(),
vec!["te%st".into()]
)
);
// added
let timing = ctx.storage.timing_today().unwrap();
assert_eq!(
s(ctx, "added:3").0,
format!("(c.id > {})", (timing.next_day_at - (86_400 * 3)) * 1_000)
);
// deck
assert_eq!(s(ctx, "deck:default"), ("(c.did in (1))".into(), vec![],));
assert_eq!(s(ctx, "deck:current"), ("(c.did in (1))".into(), vec![],));
assert_eq!(s(ctx, "deck:missing"), ("(c.did in ())".into(), vec![],));
assert_eq!(s(ctx, "deck:d*"), ("(c.did in (1))".into(), vec![],));
assert_eq!(s(ctx, "deck:filtered"), ("(c.odid > 0)".into(), vec![],));
// card
assert_eq!(s(ctx, "card:front"), ("(false)".into(), vec![],));
assert_eq!(
s(ctx, r#""card:card 1""#),
(
concat!(
"(((n.mid = 1581236385344 and c.ord = 0) or ",
"(n.mid = 1581236385345 and c.ord = 0) or ",
"(n.mid = 1581236385346 and c.ord = 0) or ",
"(n.mid = 1581236385347 and c.ord = 0)))"
)
.into(),
vec![],
)
);
// IDs
assert_eq!(s(ctx, "mid:3"), ("(n.mid = 3)".into(), vec![]));
assert_eq!(s(ctx, "nid:3"), ("(n.id in (3))".into(), vec![]));
assert_eq!(s(ctx, "nid:3,4"), ("(n.id in (3,4))".into(), vec![]));
assert_eq!(s(ctx, "cid:3,4"), ("(c.id in (3,4))".into(), vec![]));
// flags
assert_eq!(s(ctx, "flag:2"), ("((c.flags & 7) == 2)".into(), vec![]));
assert_eq!(s(ctx, "flag:0"), ("((c.flags & 7) == 0)".into(), vec![]));
// dupes
assert_eq!(
s(ctx, "dupes:123,test"),
(
"((n.mid = 123 and n.csum = 2840236005 and field_at_index(n.flds, 0) = ?)"
.into(),
vec!["test".into()]
)
);
// tags
assert_eq!(
s(ctx, "tag:one"),
("(n.tags like ? escape '\\')".into(), vec!["% one %".into()])
);
assert_eq!(
s(ctx, "tag:o*e"),
("(n.tags like ? escape '\\')".into(), vec!["% o%e %".into()])
);
assert_eq!(s(ctx, "tag:none"), ("(n.tags = '')".into(), vec![]));
assert_eq!(s(ctx, "tag:*"), ("(true)".into(), vec![]));
// state
assert_eq!(
s(ctx, "is:suspended").0,
format!("(c.queue = {})", CardQueue::Suspended as i8)
);
assert_eq!(
s(ctx, "is:new").0,
format!("(c.type = {})", CardQueue::New as i8)
);
// rated
assert_eq!(
s(ctx, "rated:2").0,
format!(
"(c.id in (select cid from revlog where id>{}))",
(timing.next_day_at - (86_400 * 2)) * 1_000
)
);
assert_eq!(
s(ctx, "rated:400:1").0,
format!(
"(c.id in (select cid from revlog where id>{} and ease=1))",
(timing.next_day_at - (86_400 * 365)) * 1_000
)
);
// props
assert_eq!(s(ctx, "prop:lapses=3").0, "(lapses = 3)".to_string());
assert_eq!(s(ctx, "prop:ease>=2.5").0, "(factor >= 2500)".to_string());
assert_eq!(
s(ctx, "prop:due!=-1").0,
format!(
"((c.queue in (2,3) and due != {}))",
timing.days_elapsed - 1
)
);
// note types by name
assert_eq!(&s(ctx, "note:basic").0, "(n.mid in (1581236385347))");
assert_eq!(
&s(ctx, "note:basic*").0,
"(n.mid in (1581236385345,1581236385346,1581236385347,1581236385344))"
);
// regex
assert_eq!(
s(ctx, r"re:\bone"),
("(n.flds regexp ?)".into(), vec![r"(?i)\bone".into()])
);
Ok(())
})
.unwrap();
Ok(())
}
}

3
rslib/src/storage/mod.rs Normal file
View file

@ -0,0 +1,3 @@
mod sqlite;
pub(crate) use sqlite::{SqliteStorage, StorageContext};

View file

@ -0,0 +1,88 @@
create table col
(
id integer primary key,
crt integer not null,
mod integer not null,
scm integer not null,
ver integer not null,
dty integer not null,
usn integer not null,
ls integer not null,
conf text not null,
models text not null,
decks text not null,
dconf text not null,
tags text not null
);
create table notes
(
id integer primary key,
guid text not null,
mid integer not null,
mod integer not null,
usn integer not null,
tags text not null,
flds text not null,
sfld integer not null,
csum integer not null,
flags integer not null,
data text not null
);
create table cards
(
id integer primary key,
nid integer not null,
did integer not null,
ord integer not null,
mod integer not null,
usn integer not null,
type integer not null,
queue integer not null,
due integer not null,
ivl integer not null,
factor integer not null,
reps integer not null,
lapses integer not null,
left integer not null,
odue integer not null,
odid integer not null,
flags integer not null,
data text not null
);
create table revlog
(
id integer primary key,
cid integer not null,
usn integer not null,
ease integer not null,
ivl integer not null,
lastIvl integer not null,
factor integer not null,
time integer not null,
type integer not null
);
create table graves
(
usn integer not null,
oid integer not null,
type integer not null
);
-- syncing
create index ix_notes_usn on notes (usn);
create index ix_cards_usn on cards (usn);
create index ix_revlog_usn on revlog (usn);
-- card spacing, etc
create index ix_cards_nid on cards (nid);
-- scheduling and deck limiting
create index ix_cards_sched on cards (did, queue, due);
-- revlog by card
create index ix_revlog_cid on revlog (cid);
-- field uniqueness
create index ix_notes_csum on notes (csum);
insert into col values (1,0,0,0,0,0,0,0,'{}','{}','{}','{}','{}');

335
rslib/src/storage/sqlite.rs Normal file
View file

@ -0,0 +1,335 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::collection::CollectionOp;
use crate::config::Config;
use crate::err::Result;
use crate::err::{AnkiError, DBErrorKind};
use crate::time::{i64_unix_millis, i64_unix_secs};
use crate::{
decks::Deck,
notetypes::NoteType,
sched::cutoff::{sched_timing_today, SchedTimingToday},
text::without_combining,
types::{ObjID, Usn},
};
use regex::Regex;
use rusqlite::{params, Connection, NO_PARAMS};
use std::cmp::Ordering;
use std::{
borrow::Cow,
collections::HashMap,
path::{Path, PathBuf},
};
use unicase::UniCase;
const SCHEMA_MIN_VERSION: u8 = 11;
const SCHEMA_MAX_VERSION: u8 = 11;
fn unicase_compare(s1: &str, s2: &str) -> Ordering {
UniCase::new(s1).cmp(&UniCase::new(s2))
}
// currently public for dbproxy
#[derive(Debug)]
pub struct SqliteStorage {
// currently crate-visible for dbproxy
pub(crate) db: Connection,
// fixme: stored in wrong location?
path: PathBuf,
}
fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
let mut db = Connection::open(path)?;
if std::env::var("TRACESQL").is_ok() {
db.trace(Some(trace));
}
db.busy_timeout(std::time::Duration::from_secs(0))?;
db.pragma_update(None, "locking_mode", &"exclusive")?;
db.pragma_update(None, "page_size", &4096)?;
db.pragma_update(None, "cache_size", &(-40 * 1024))?;
db.pragma_update(None, "legacy_file_format", &false)?;
db.pragma_update(None, "journal_mode", &"wal")?;
db.pragma_update(None, "temp_store", &"memory")?;
db.set_prepared_statement_cache_capacity(50);
add_field_index_function(&db)?;
add_regexp_function(&db)?;
add_without_combining_function(&db)?;
db.create_collation("unicase", unicase_compare)?;
Ok(db)
}
/// Adds sql function field_at_index(flds, index)
/// to split provided fields and return field at zero-based index.
/// If out of range, returns empty string.
fn add_field_index_function(db: &Connection) -> rusqlite::Result<()> {
db.create_scalar_function("field_at_index", 2, true, |ctx| {
let mut fields = ctx.get_raw(0).as_str()?.split('\x1f');
let idx: u16 = ctx.get(1)?;
Ok(fields.nth(idx as usize).unwrap_or("").to_string())
})
}
fn add_without_combining_function(db: &Connection) -> rusqlite::Result<()> {
db.create_scalar_function("without_combining", 1, true, |ctx| {
let text = ctx.get_raw(0).as_str()?;
Ok(match without_combining(text) {
Cow::Borrowed(_) => None,
Cow::Owned(o) => Some(o),
})
})
}
/// Adds sql function regexp(regex, string) -> is_match
/// Taken from the rusqlite docs
fn add_regexp_function(db: &Connection) -> rusqlite::Result<()> {
db.create_scalar_function("regexp", 2, true, move |ctx| {
assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
let saved_re: Option<&Regex> = ctx.get_aux(0)?;
let new_re = match saved_re {
None => {
let s = ctx.get::<String>(0)?;
match Regex::new(&s) {
Ok(r) => Some(r),
Err(err) => return Err(rusqlite::Error::UserFunctionError(Box::new(err))),
}
}
Some(_) => None,
};
let is_match = {
let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
let text = ctx
.get_raw(1)
.as_str()
.map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
re.is_match(text)
};
if let Some(re) = new_re {
ctx.set_aux(0, re);
}
Ok(is_match)
})
}
/// Fetch schema version from database.
/// Return (must_create, version)
fn schema_version(db: &Connection) -> Result<(bool, u8)> {
if !db
.prepare("select null from sqlite_master where type = 'table' and name = 'col'")?
.exists(NO_PARAMS)?
{
return Ok((true, SCHEMA_MAX_VERSION));
}
Ok((
false,
db.query_row("select ver from col", NO_PARAMS, |r| Ok(r.get(0)?))?,
))
}
fn trace(s: &str) {
println!("sql: {}", s)
}
impl SqliteStorage {
pub(crate) fn open_or_create(path: &Path) -> Result<Self> {
let db = open_or_create_collection_db(path)?;
let (create, ver) = schema_version(&db)?;
if create {
db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?;
db.execute_batch(include_str!("schema11.sql"))?;
db.execute("update col set crt=?, ver=?", params![i64_unix_secs(), ver])?;
db.prepare_cached("commit")?.execute(NO_PARAMS)?;
} else {
if ver > SCHEMA_MAX_VERSION {
return Err(AnkiError::DBError {
info: "".to_string(),
kind: DBErrorKind::FileTooNew,
});
}
if ver < SCHEMA_MIN_VERSION {
return Err(AnkiError::DBError {
info: "".to_string(),
kind: DBErrorKind::FileTooOld,
});
}
};
let storage = Self {
db,
path: path.to_owned(),
};
Ok(storage)
}
pub(crate) fn context(&self, server: bool) -> StorageContext {
StorageContext::new(&self.db, server)
}
}
pub(crate) struct StorageContext<'a> {
pub(crate) db: &'a Connection,
#[allow(dead_code)]
server: bool,
#[allow(dead_code)]
usn: Option<Usn>,
timing_today: Option<SchedTimingToday>,
}
impl StorageContext<'_> {
fn new(db: &Connection, server: bool) -> StorageContext {
StorageContext {
db,
server,
usn: None,
timing_today: None,
}
}
// Standard transaction start/stop
//////////////////////////////////////
pub(crate) fn begin_trx(&self) -> Result<()> {
self.db
.prepare_cached("begin exclusive")?
.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit_trx(&self) -> Result<()> {
if !self.db.is_autocommit() {
self.db.prepare_cached("commit")?.execute(NO_PARAMS)?;
}
Ok(())
}
pub(crate) fn rollback_trx(&self) -> Result<()> {
if !self.db.is_autocommit() {
self.db.execute("rollback", NO_PARAMS)?;
}
Ok(())
}
// Savepoints
//////////////////////////////////////////
//
// This is necessary at the moment because Anki's current architecture uses
// long-running transactions as an undo mechanism. Once a proper undo
// mechanism has been added to all existing functionality, we could
// transition these to standard commits.
pub(crate) fn begin_rust_trx(&self) -> Result<()> {
self.db
.prepare_cached("savepoint rust")?
.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit_rust_trx(&self) -> Result<()> {
self.db.prepare_cached("release rust")?.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit_rust_op(&self, _op: Option<CollectionOp>) -> Result<()> {
self.commit_rust_trx()
}
pub(crate) fn rollback_rust_trx(&self) -> Result<()> {
self.db
.prepare_cached("rollback to rust")?
.execute(NO_PARAMS)?;
Ok(())
}
//////////////////////////////////////////
pub(crate) fn mark_modified(&self) -> Result<()> {
self.db
.prepare_cached("update col set mod=?")?
.execute(params![i64_unix_millis()])?;
Ok(())
}
#[allow(dead_code)]
pub(crate) fn usn(&mut self) -> Result<Usn> {
if self.server {
if self.usn.is_none() {
self.usn = Some(
self.db
.prepare_cached("select usn from col")?
.query_row(NO_PARAMS, |row| row.get(0))?,
);
}
Ok(*self.usn.as_ref().unwrap())
} else {
Ok(-1)
}
}
pub(crate) fn all_decks(&self) -> Result<HashMap<ObjID, Deck>> {
self.db
.query_row_and_then("select decks from col", NO_PARAMS, |row| -> Result<_> {
Ok(serde_json::from_str(row.get_raw(0).as_str()?)?)
})
}
pub(crate) fn all_config(&self) -> Result<Config> {
self.db
.query_row_and_then("select conf from col", NO_PARAMS, |row| -> Result<_> {
Ok(serde_json::from_str(row.get_raw(0).as_str()?)?)
})
}
pub(crate) fn all_note_types(&self) -> Result<HashMap<ObjID, NoteType>> {
let mut stmt = self.db.prepare("select models from col")?;
let note_types = stmt
.query_and_then(NO_PARAMS, |row| -> Result<HashMap<ObjID, NoteType>> {
let v: HashMap<ObjID, NoteType> = serde_json::from_str(row.get_raw(0).as_str()?)?;
Ok(v)
})?
.next()
.ok_or_else(|| AnkiError::DBError {
info: "col table empty".to_string(),
kind: DBErrorKind::MissingEntity,
})??;
Ok(note_types)
}
#[allow(dead_code)]
pub(crate) fn timing_today(&mut self) -> Result<SchedTimingToday> {
if self.timing_today.is_none() {
let crt: i64 = self
.db
.prepare_cached("select crt from col")?
.query_row(NO_PARAMS, |row| row.get(0))?;
let conf = self.all_config()?;
let now_offset = if self.server { conf.local_offset } else { None };
self.timing_today = Some(sched_timing_today(
crt,
i64_unix_secs(),
conf.creation_offset,
now_offset,
conf.rollover,
));
}
Ok(*self.timing_today.as_ref().unwrap())
}
}

View file

@ -5,7 +5,10 @@ use lazy_static::lazy_static;
use regex::{Captures, Regex}; use regex::{Captures, Regex};
use std::borrow::Cow; use std::borrow::Cow;
use std::ptr; use std::ptr;
use unicode_normalization::{is_nfc, UnicodeNormalization}; use unicase::eq as uni_eq;
use unicode_normalization::{
char::is_combining_mark, is_nfc, is_nfkd_quick, IsNormalized, UnicodeNormalization,
};
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum AVTag { pub enum AVTag {
@ -219,11 +222,43 @@ pub(crate) fn normalize_to_nfc(s: &str) -> Cow<str> {
} }
} }
/// True if search is equal to text, folding case.
/// Supports '*' to match 0 or more characters.
pub(crate) fn matches_wildcard(text: &str, search: &str) -> bool {
if search.contains('*') {
let search = format!("^(?i){}$", regex::escape(search).replace(r"\*", ".*"));
Regex::new(&search).unwrap().is_match(text)
} else {
uni_eq(text, search)
}
}
/// Convert provided string to NFKD form and strip combining characters.
pub(crate) fn without_combining(s: &str) -> Cow<str> {
// if the string is already normalized
if matches!(is_nfkd_quick(s.chars()), IsNormalized::Yes) {
// and no combining characters found, return unchanged
if !s.chars().any(is_combining_mark) {
return s.into();
}
}
// we need to create a new string without the combining marks
s.chars()
.nfkd()
.filter(|c| !is_combining_mark(*c))
.collect::<String>()
.into()
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::matches_wildcard;
use crate::text::without_combining;
use crate::text::{ use crate::text::{
extract_av_tags, strip_av_tags, strip_html, strip_html_preserving_image_filenames, AVTag, extract_av_tags, strip_av_tags, strip_html, strip_html_preserving_image_filenames, AVTag,
}; };
use std::borrow::Cow;
#[test] #[test]
fn stripping() { fn stripping() {
@ -265,4 +300,19 @@ mod test {
] ]
); );
} }
#[test]
fn wildcard() {
assert_eq!(matches_wildcard("foo", "bar"), false);
assert_eq!(matches_wildcard("foo", "Foo"), true);
assert_eq!(matches_wildcard("foo", "F*"), true);
assert_eq!(matches_wildcard("foo", "F*oo"), true);
assert_eq!(matches_wildcard("foo", "b*"), false);
}
#[test]
fn combining() {
assert!(matches!(without_combining("test"), Cow::Borrowed(_)));
assert!(matches!(without_combining("Über"), Cow::Owned(_)));
}
} }

View file

@ -4,15 +4,35 @@
use std::time; use std::time;
pub(crate) fn i64_unix_secs() -> i64 { pub(crate) fn i64_unix_secs() -> i64 {
time::SystemTime::now() elapsed().as_secs() as i64
.duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
} }
pub(crate) fn i64_unix_millis() -> i64 { pub(crate) fn i64_unix_millis() -> i64 {
elapsed().as_millis() as i64
}
#[cfg(not(test))]
fn elapsed() -> time::Duration {
time::SystemTime::now() time::SystemTime::now()
.duration_since(time::SystemTime::UNIX_EPOCH) .duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap() .unwrap()
.as_millis() as i64 }
// when running in CI, shift the current time away from the cutoff point
// to accomodate unit tests that depend on the current time
#[cfg(test)]
fn elapsed() -> time::Duration {
use chrono::{Local, Timelike};
let now = Local::now();
let mut elap = time::SystemTime::now()
.duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap();
if now.hour() >= 2 && now.hour() < 4 {
elap -= time::Duration::from_secs(60 * 60 * 2);
}
elap
} }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ankirspy" name = "ankirspy"
version = "2.1.22" # automatically updated version = "2.1.24" # automatically updated
edition = "2018" edition = "2018"
authors = ["Ankitects Pty Ltd and contributors"] authors = ["Ankitects Pty Ltd and contributors"]

View file

@ -1,12 +1,11 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use anki::backend::{ use anki::backend::{init_backend, Backend as RustBackend};
init_backend, init_i18n_backend, Backend as RustBackend, I18nBackend as RustI18nBackend, use pyo3::exceptions::Exception;
};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyBytes; use pyo3::types::PyBytes;
use pyo3::{exceptions, wrap_pyfunction}; use pyo3::{create_exception, exceptions, wrap_pyfunction};
// Regular backend // Regular backend
////////////////////////////////// //////////////////////////////////
@ -16,6 +15,8 @@ struct Backend {
backend: RustBackend, backend: RustBackend,
} }
create_exception!(ankirspy, DBError, Exception);
#[pyfunction] #[pyfunction]
fn buildhash() -> &'static str { fn buildhash() -> &'static str {
include_str!("../../meta/buildhash").trim() include_str!("../../meta/buildhash").trim()
@ -70,29 +71,17 @@ impl Backend {
self.backend.set_progress_callback(Some(Box::new(func))); self.backend.set_progress_callback(Some(Box::new(func)));
} }
} }
}
// I18n backend fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult<PyObject> {
//////////////////////////////////
#[pyclass]
struct I18nBackend {
backend: RustI18nBackend,
}
#[pyfunction]
fn open_i18n(init_msg: &PyBytes) -> PyResult<I18nBackend> {
match init_i18n_backend(init_msg.as_bytes()) {
Ok(backend) => Ok(I18nBackend { backend }),
Err(e) => Err(exceptions::Exception::py_err(format!("{:?}", e))),
}
}
#[pymethods]
impl I18nBackend {
fn translate(&self, input: &PyBytes) -> String {
let in_bytes = input.as_bytes(); let in_bytes = input.as_bytes();
self.backend.translate(in_bytes) let out_res = py.allow_threads(move || {
self.backend
.db_command(in_bytes)
.map_err(|e| DBError::py_err(e.localized_description(&self.backend.i18n())))
});
let out_string = out_res?;
let out_obj = PyBytes::new(py, out_string.as_bytes());
Ok(out_obj.into())
} }
} }
@ -104,7 +93,6 @@ fn ankirspy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Backend>()?; m.add_class::<Backend>()?;
m.add_wrapped(wrap_pyfunction!(buildhash)).unwrap(); m.add_wrapped(wrap_pyfunction!(buildhash)).unwrap();
m.add_wrapped(wrap_pyfunction!(open_backend)).unwrap(); m.add_wrapped(wrap_pyfunction!(open_backend)).unwrap();
m.add_wrapped(wrap_pyfunction!(open_i18n)).unwrap();
Ok(()) Ok(())
} }