diff --git a/proto/backend.proto b/proto/backend.proto index a224d0e85..d45dd55e1 100644 --- a/proto/backend.proto +++ b/proto/backend.proto @@ -159,7 +159,7 @@ service BackendService { // sync - rpc SyncMedia (SyncMediaIn) returns (Empty); + rpc SyncMedia (SyncAuth) returns (Empty); rpc AbortSync (Empty) returns (Empty); rpc BeforeUpload (Empty) returns (Empty); rpc SyncLogin (SyncLoginIn) returns (SyncAuth); @@ -610,11 +610,6 @@ message AddMediaFileIn { bytes data = 2; } -message SyncMediaIn { - string hkey = 1; - string endpoint = 2; -} - message CheckMediaOut { repeated string unused = 1; repeated string missing = 2; diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index cb04bba67..024a658ba 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -239,11 +239,20 @@ class Collection: self.media.close() self._closeLog() + def close_for_full_sync(self) -> None: + # save and cleanup, but backend will take care of collection close + if self.db: + self.save(trx=False) + self.models._clear_cache() + self.db = None + self.media.close() + self._closeLog() + def rollback(self) -> None: self.db.rollback() self.db.begin() - def reopen(self) -> None: + def reopen(self, after_full_sync=False) -> None: assert not self.db assert self.path.endswith(".anki2") @@ -255,12 +264,13 @@ class Collection: log_path = self.path.replace(".anki2", "2.log") # connect - self.backend.open_collection( - collection_path=self.path, - media_folder_path=media_dir, - media_db_path=media_db, - log_path=log_path, - ) + if not after_full_sync: + self.backend.open_collection( + collection_path=self.path, + media_folder_path=media_dir, + media_db_path=media_db, + log_path=log_path, + ) self.db = DBProxy(weakref.proxy(self.backend)) self.db.begin() diff --git a/pylib/anki/consts.py b/pylib/anki/consts.py index 921a1a702..4a2ab6e66 100644 --- a/pylib/anki/consts.py +++ b/pylib/anki/consts.py @@ -66,13 +66,6 @@ MODEL_CLOZE = 1 STARTING_FACTOR = 2500 -# deck schema & syncing vars -SCHEMA_VERSION = 11 -SYNC_ZIP_SIZE = int(2.5 * 1024 * 1024) -SYNC_ZIP_COUNT = 25 -SYNC_BASE = "https://sync%s.ankiweb.net/" -SYNC_VER = 10 - HELP_SITE = "https://apps.ankiweb.net/docs/manual.html" # Leech actions diff --git a/pylib/anki/hooks.py b/pylib/anki/hooks.py index 31edb5348..9aac5568d 100644 --- a/pylib/anki/hooks.py +++ b/pylib/anki/hooks.py @@ -466,6 +466,8 @@ schema_will_change = _SchemaWillChangeFilter() class _SyncProgressDidChangeHook: + """Obsolete, do not use.""" + _hooks: List[Callable[[str], None]] = [] def append(self, cb: Callable[[str], None]) -> None: @@ -484,14 +486,14 @@ class _SyncProgressDidChangeHook: # if the hook fails, remove it self._hooks.remove(hook) raise - # legacy support - runHook("syncMsg", msg) sync_progress_did_change = _SyncProgressDidChangeHook() class _SyncStageDidChangeHook: + """Obsolete, do not use.""" + _hooks: List[Callable[[str], None]] = [] def append(self, cb: Callable[[str], None]) -> None: @@ -510,8 +512,6 @@ class _SyncStageDidChangeHook: # if the hook fails, remove it self._hooks.remove(hook) raise - # legacy support - runHook("sync", stage) sync_stage_did_change = _SyncStageDidChangeHook() diff --git a/pylib/anki/rsbackend.py b/pylib/anki/rsbackend.py index 481b3150f..eed465e00 100644 --- a/pylib/anki/rsbackend.py +++ b/pylib/anki/rsbackend.py @@ -52,6 +52,8 @@ TagUsnTuple = pb.TagUsnTuple NoteType = pb.NoteType DeckTreeNode = pb.DeckTreeNode StockNoteType = pb.StockNoteType +SyncAuth = pb.SyncAuth +SyncOutput = pb.SyncCollectionOut try: import orjson @@ -147,6 +149,7 @@ def proto_exception_to_native(err: pb.BackendError) -> Exception: MediaSyncProgress = pb.MediaSyncProgress +FullSyncProgress = pb.FullSyncProgress FormatTimeSpanContext = pb.FormatTimespanIn.Context @@ -254,4 +257,6 @@ def translate_string_in( # temporarily force logging of media handling if "RUST_LOG" not in os.environ: - os.environ["RUST_LOG"] = "warn,anki::media=debug,anki::dbcheck=debug" + os.environ[ + "RUST_LOG" + ] = "warn,anki::media=debug,anki::sync=debug,anki::dbcheck=debug" diff --git a/pylib/anki/sync.py b/pylib/anki/sync.py index 8d925802e..6aef01114 100644 --- a/pylib/anki/sync.py +++ b/pylib/anki/sync.py @@ -1,674 +1,11 @@ # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -from __future__ import annotations - -import gzip -import io -import json -import os -import random -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - -import anki -from anki.consts import * -from anki.db import DB -from anki.utils import checksum, ids2str, intTime, platDesc, versionWithBuild - -from . import hooks from .httpclient import HttpClient -# add-on compat -from .rsbackend import from_json_bytes, to_json_bytes - AnkiRequestsClient = HttpClient -class UnexpectedSchemaChange(Exception): - pass - - -# Incremental syncing -########################################################################## - - class Syncer: - chunkRows: Optional[List[Sequence]] - - def __init__(self, col: anki.collection.Collection, server=None) -> None: - self.col = col.weakref() - self.server = server - - # these are set later; provide dummy values for type checking - self.lnewer = False - self.maxUsn = 0 - self.tablesLeft: List[str] = [] - def sync(self) -> str: - "Returns 'noChanges', 'fullSync', 'success', etc" - self.syncMsg = "" - self.uname = "" - # if the deck has any pending changes, flush them first and bump mod - # time - self.col.save() - - # step 1: login & metadata - hooks.sync_stage_did_change("login") - meta = self.server.meta() - self.col.log("rmeta", meta) - if not meta: - return "badAuth" - # server requested abort? - self.syncMsg = meta["msg"] - if not meta["cont"]: - return "serverAbort" - else: - # don't abort, but if 'msg' is not blank, gui should show 'msg' - # after sync finishes and wait for confirmation before hiding - pass - rscm = meta["scm"] - rts = meta["ts"] - self.rmod = meta["mod"] - self.maxUsn = meta["usn"] - self.uname = meta.get("uname", "") - self.hostNum = meta.get("hostNum") - meta = self.meta() - self.col.log("lmeta", meta) - self.lmod = meta["mod"] - self.minUsn = meta["usn"] - lscm = meta["scm"] - lts = meta["ts"] - if abs(rts - lts) > 300: - self.col.log("clock off") - return "clockOff" - if self.lmod == self.rmod: - self.col.log("no changes") - return "noChanges" - elif lscm != rscm: - self.col.log("schema diff") - return "fullSync" - self.lnewer = self.lmod > self.rmod - # step 1.5: check collection is valid - if not self.col.basicCheck(): - self.col.log("basic check") - return "basicCheckFailed" - # step 2: startup and deletions - hooks.sync_stage_did_change("meta") - rrem = self.server.start( - minUsn=self.minUsn, lnewer=self.lnewer, offset=self.col.localOffset() - ) - - # apply deletions to server - lgraves = self.removed() - while lgraves: - gchunk, lgraves = self._gravesChunk(lgraves) - self.server.applyGraves(chunk=gchunk) - - # then apply server deletions here - self.remove(rrem) - - # ...and small objects - lchg = self.changes() - rchg = self.server.applyChanges(changes=lchg) - try: - self.mergeChanges(lchg, rchg) - except UnexpectedSchemaChange: - self.server.abort() - return self._forceFullSync() - # step 3: stream large tables from server - hooks.sync_stage_did_change("server") - while 1: - hooks.sync_stage_did_change("stream") - chunk = self.server.chunk() - self.col.log("server chunk", chunk) - self.applyChunk(chunk=chunk) - if chunk["done"]: - break - # step 4: stream to server - hooks.sync_stage_did_change("client") - while 1: - hooks.sync_stage_did_change("stream") - chunk = self.chunk() - self.col.log("client chunk", chunk) - self.server.applyChunk(chunk=chunk) - if chunk["done"]: - break - # step 5: sanity check - hooks.sync_stage_did_change("sanity") - c = self.sanityCheck() - ret = self.server.sanityCheck2(client=c) - if ret["status"] != "ok": - return self._forceFullSync() - # finalize - hooks.sync_stage_did_change("finalize") - mod = self.server.finish() - self.finish(mod) - return "success" - - def _forceFullSync(self) -> str: - # roll back and force full sync - self.col.rollback() - self.col.modSchema(False) - self.col.save() - return "sanityCheckFailed" - - def _gravesChunk(self, graves: Dict) -> Tuple[Dict, Optional[Dict]]: - lim = 250 - chunk: Dict[str, Any] = dict(notes=[], cards=[], decks=[]) - for cat in "notes", "cards", "decks": - if lim and graves[cat]: - chunk[cat] = graves[cat][:lim] - graves[cat] = graves[cat][lim:] - lim -= len(chunk[cat]) - - # anything remaining? - if graves["notes"] or graves["cards"] or graves["decks"]: - return chunk, graves - return chunk, None - - def meta(self) -> dict: - return dict( - mod=self.col.mod, - scm=self.col.scm, - usn=self.col._usn, - ts=intTime(), - musn=0, - msg="", - cont=True, - ) - - def changes(self) -> dict: - "Bundle up small objects." - d: Dict[str, Any] = dict( - models=self.getModels(), decks=self.getDecks(), tags=self.getTags() - ) - if self.lnewer: - d["conf"] = self.getConf() - d["crt"] = self.col.crt - return d - - def mergeChanges(self, lchg, rchg) -> None: - # then the other objects - self.mergeModels(rchg["models"]) - self.mergeDecks(rchg["decks"]) - self.mergeTags(rchg["tags"]) - if "conf" in rchg: - self.mergeConf(rchg["conf"]) - # this was left out of earlier betas - if "crt" in rchg: - self.col.crt = rchg["crt"] - self.prepareToChunk() - - def sanityCheck(self) -> Union[list, str]: - if not self.col.basicCheck(): - return "failed basic check" - for t in "cards", "notes", "revlog", "graves": - if self.col.db.scalar("select count() from %s where usn = -1" % t): - return "%s had usn = -1" % t - for g in self.col.decks.all(): - if g["usn"] == -1: - return "deck had usn = -1" - for tup in self.col.backend.all_tags(): - if tup.usn == -1: - return "tag had usn = -1" - found = False - for m in self.col.models.all(): - if m["usn"] == -1: - return "model had usn = -1" - if found: - self.col.models.save() - self.col.sched.reset() - # return summary of deck - return [ - list(self.col.sched.counts()), - self.col.db.scalar("select count() from cards"), - self.col.db.scalar("select count() from notes"), - self.col.db.scalar("select count() from revlog"), - self.col.db.scalar("select count() from graves"), - len(self.col.models.all()), - len(self.col.decks.all()), - len(self.col.decks.allConf()), - ] - - def usnLim(self) -> str: - return "usn = -1" - - def finish(self, mod: int) -> int: - self.col.ls = mod - self.col._usn = self.maxUsn + 1 - # ensure we save the mod time even if no changes made - self.col.db.mod = True - self.col.save(mod=mod) - return mod - - # Chunked syncing - ########################################################################## - - def prepareToChunk(self) -> None: - self.tablesLeft = ["revlog", "cards", "notes"] - self.chunkRows = None - - def getChunkRows(self, table) -> List[Sequence]: - lim = self.usnLim() - x = self.col.db.all - d = (self.maxUsn, lim) - if table == "revlog": - return x( - """ -select id, cid, %d, ease, ivl, lastIvl, factor, time, type -from revlog where %s""" - % d - ) - elif table == "cards": - return x( - """ -select id, nid, did, ord, mod, %d, type, queue, due, ivl, factor, reps, -lapses, left, odue, odid, flags, data from cards where %s""" - % d - ) - else: - return x( - """ -select id, guid, mid, mod, %d, tags, flds, '', '', flags, data -from notes where %s""" - % d - ) - - def chunk(self) -> dict: - buf: Dict[str, Any] = dict(done=False) - lim = 250 - while self.tablesLeft and lim: - curTable = self.tablesLeft[0] - if self.chunkRows is None: - self.chunkRows = self.getChunkRows(curTable) - rows = self.chunkRows[:lim] - self.chunkRows = self.chunkRows[lim:] - fetched = len(rows) - if fetched != lim: - # table is empty - self.tablesLeft.pop(0) - self.chunkRows = None - # mark the objects as having been sent - self.col.db.execute( - "update %s set usn=? where usn=-1" % curTable, self.maxUsn - ) - buf[curTable] = rows - lim -= fetched - if not self.tablesLeft: - buf["done"] = True - return buf - - def applyChunk(self, chunk) -> None: - if "revlog" in chunk: - self.mergeRevlog(chunk["revlog"]) - if "cards" in chunk: - self.mergeCards(chunk["cards"]) - if "notes" in chunk: - self.mergeNotes(chunk["notes"]) - - # Deletions - ########################################################################## - - def removed(self) -> dict: - cards = [] - notes = [] - decks = [] - - curs = self.col.db.execute("select oid, type from graves where usn = -1") - - for oid, type in curs: - if type == REM_CARD: - cards.append(oid) - elif type == REM_NOTE: - notes.append(oid) - else: - decks.append(oid) - - self.col.db.execute("update graves set usn=? where usn=-1", self.maxUsn) - - return dict(cards=cards, notes=notes, decks=decks) - - def remove(self, graves) -> None: - # pretend to be the server so we don't set usn = -1 - self.col.server = True - - # notes first, so we don't end up with duplicate graves - self.col._remNotes(graves["notes"]) - # then cards - self.col.remCards(graves["cards"], notes=False) - # and decks - for oid in graves["decks"]: - self.col.decks.rem(oid, childrenToo=False) - - self.col.server = False - - # Models - ########################################################################## - - def getModels(self) -> List: - mods = [m for m in self.col.models.all() if m["usn"] == -1] - for m in mods: - m["usn"] = self.maxUsn - self.col.models.update(m, preserve_usn=True) - return mods - - def mergeModels(self, rchg) -> None: - for r in rchg: - l = self.col.models.get(r["id"]) - # if missing locally or server is newer, update - if not l or r["mod"] > l["mod"]: - # This is a hack to detect when the note type has been altered - # in an import without a full sync being forced. A future - # syncing algorithm should handle this in a better way. - if l: - if len(l["flds"]) != len(r["flds"]): - raise UnexpectedSchemaChange() - if len(l["tmpls"]) != len(r["tmpls"]): - raise UnexpectedSchemaChange() - self.col.models.update(r, preserve_usn=True) - - # Decks - ########################################################################## - - def getDecks(self) -> List[list]: - decks = [g for g in self.col.decks.all() if g["usn"] == -1] - for g in decks: - g["usn"] = self.maxUsn - self.col.decks.update(g, preserve_usn=True) - dconf = [g for g in self.col.decks.all_config() if g["usn"] == -1] - for g in dconf: - g["usn"] = self.maxUsn - self.col.decks.update_config(g, preserve_usn=True) - return [decks, dconf] - - def mergeDecks(self, rchg) -> None: - for r in rchg[0]: - l = self.col.decks.get(r["id"], False) - # work around mod time being stored as string - if l and not isinstance(l["mod"], int): - l["mod"] = int(l["mod"]) - - # if missing locally or server is newer, update - if not l or r["mod"] > l["mod"]: - self.col.decks.update(r, preserve_usn=True) - for r in rchg[1]: - try: - l = self.col.decks.get_config(r["id"]) - except KeyError: - l = None - # if missing locally or server is newer, update - if not l or r["mod"] > l["mod"]: - self.col.decks.update_config(r, preserve_usn=True) - - # Tags - ########################################################################## - - def getTags(self) -> List: - return list(self.col.backend.get_changed_tags(self.maxUsn)) - - def mergeTags(self, tags) -> None: - self.col.tags.register(tags, usn=self.maxUsn) - - # Cards/notes/revlog - ########################################################################## - - def mergeRevlog(self, logs) -> None: - self.col.db.executemany( - "insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)", logs - ) - - def newerRows(self, data, table, modIdx) -> List: - ids = (r[0] for r in data) - lmods = {} - for id, mod in self.col.db.execute( - "select id, mod from %s where id in %s and %s" - % (table, ids2str(ids), self.usnLim()) - ): - lmods[id] = mod - update = [] - for r in data: - if r[0] not in lmods or lmods[r[0]] < r[modIdx]: - update.append(r) - self.col.log(table, data) - return update - - def mergeCards(self, cards) -> None: - self.col.db.executemany( - "insert or replace into cards values " - "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - self.newerRows(cards, "cards", 4), - ) - - def mergeNotes(self, notes) -> None: - rows = self.newerRows(notes, "notes", 3) - self.col.db.executemany( - "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", rows - ) - self.col.updateFieldCache([f[0] for f in rows]) - - # Col config - ########################################################################## - - def getConf(self) -> Dict[str, Any]: - return from_json_bytes(self.col.backend.get_all_config()) - - def mergeConf(self, conf: Dict[str, Any]) -> None: - self.col.backend.set_all_config(to_json_bytes(conf)) - - -# HTTP syncing tools -########################################################################## - - -class HttpSyncer: - def __init__(self, hkey=None, client=None, hostNum=None) -> None: - self.hkey = hkey - self.skey = checksum(str(random.random()))[:8] - self.client = client or HttpClient() - self.postVars: Dict[str, str] = {} - self.hostNum = hostNum - self.prefix = "sync/" - - def syncURL(self) -> str: - url = SYNC_BASE % (self.hostNum or "") - return url + self.prefix - - def assertOk(self, resp) -> None: - # not using raise_for_status() as aqt expects this error msg - if resp.status_code != 200: - raise Exception("Unknown response code: %s" % resp.status_code) - - # Posting data as a file - ###################################################################### - # We don't want to post the payload as a form var, as the percent-encoding is - # costly. We could send it as a raw post, but more HTTP clients seem to - # support file uploading, so this is the more compatible choice. - - def _buildPostData(self, fobj, comp) -> Tuple[Dict[str, str], io.BytesIO]: - BOUNDARY = b"Anki-sync-boundary" - bdry = b"--" + BOUNDARY - buf = io.BytesIO() - # post vars - self.postVars["c"] = "1" if comp else "0" - for (key, value) in list(self.postVars.items()): - buf.write(bdry + b"\r\n") - buf.write( - ( - 'Content-Disposition: form-data; name="%s"\r\n\r\n%s\r\n' - % (key, value) - ).encode("utf8") - ) - # payload as raw data or json - rawSize = 0 - if fobj: - # header - buf.write(bdry + b"\r\n") - buf.write( - b"""\ -Content-Disposition: form-data; name="data"; filename="data"\r\n\ -Content-Type: application/octet-stream\r\n\r\n""" - ) - # write file into buffer, optionally compressing - if comp: - tgt = gzip.GzipFile(mode="wb", fileobj=buf, compresslevel=comp) - else: - tgt = buf # type: ignore - while 1: - data = fobj.read(65536) - if not data: - if comp: - tgt.close() - break - rawSize += len(data) - tgt.write(data) - buf.write(b"\r\n") - buf.write(bdry + b"--\r\n") - size = buf.tell() - # connection headers - headers = { - "Content-Type": "multipart/form-data; boundary=%s" - % BOUNDARY.decode("utf8"), - "Content-Length": str(size), - } - buf.seek(0) - - if size >= 100 * 1024 * 1024 or rawSize >= 250 * 1024 * 1024: - raise Exception("Collection too large to upload to AnkiWeb.") - - return headers, buf - - def req(self, method, fobj=None, comp=6, badAuthRaises=True) -> Any: - headers, body = self._buildPostData(fobj, comp) - - r = self.client.post(self.syncURL() + method, data=body, headers=headers) - if not badAuthRaises and r.status_code == 403: - return False - self.assertOk(r) - - buf = self.client.streamContent(r) - return buf - - -# Incremental sync over HTTP -###################################################################### - - -class RemoteServer(HttpSyncer): - def __init__(self, hkey, hostNum) -> None: - HttpSyncer.__init__(self, hkey, hostNum=hostNum) - - def hostKey(self, user, pw) -> Any: - "Returns hkey or none if user/pw incorrect." - self.postVars = dict() - ret = self.req( - "hostKey", - io.BytesIO(json.dumps(dict(u=user, p=pw)).encode("utf8")), - badAuthRaises=False, - ) - if not ret: - # invalid auth - return - self.hkey = json.loads(ret.decode("utf8"))["key"] - return self.hkey - - def meta(self) -> Any: - self.postVars = dict(k=self.hkey, s=self.skey,) - ret = self.req( - "meta", - io.BytesIO( - json.dumps( - dict( - v=SYNC_VER, - cv="ankidesktop,%s,%s" % (versionWithBuild(), platDesc()), - ) - ).encode("utf8") - ), - badAuthRaises=False, - ) - if not ret: - # invalid auth - return - return json.loads(ret.decode("utf8")) - - def applyGraves(self, **kw) -> Any: - return self._run("applyGraves", kw) - - def applyChanges(self, **kw) -> Any: - return self._run("applyChanges", kw) - - def start(self, **kw) -> Any: - return self._run("start", kw) - - def chunk(self, **kw) -> Any: - return self._run("chunk", kw) - - def applyChunk(self, **kw) -> Any: - return self._run("applyChunk", kw) - - def sanityCheck2(self, **kw) -> Any: - return self._run("sanityCheck2", kw) - - def finish(self, **kw) -> Any: - return self._run("finish", kw) - - def abort(self, **kw) -> Any: - return self._run("abort", kw) - - def _run(self, cmd: str, data: Any) -> Any: - return json.loads( - self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8") - ) - - -# Full syncing -########################################################################## - - -class FullSyncer(HttpSyncer): - def __init__(self, col, hkey, client, hostNum) -> None: - HttpSyncer.__init__(self, hkey, client, hostNum=hostNum) - self.postVars = dict( - k=self.hkey, v="ankidesktop,%s,%s" % (anki.version, platDesc()), - ) - self.col = col.weakref() - - def download(self) -> Optional[str]: - hooks.sync_stage_did_change("download") - localNotEmpty = self.col.db.scalar("select 1 from cards") - self.col.close(downgrade=False) - cont = self.req("download") - tpath = self.col.path + ".tmp" - if cont == "upgradeRequired": - hooks.sync_stage_did_change("upgradeRequired") - return None - with open(tpath, "wb") as file: - file.write(cont) - # check the received file is ok - d = DB(tpath) - assert d.scalar("pragma integrity_check") == "ok" - remoteEmpty = not d.scalar("select 1 from cards") - d.close() - # accidental clobber? - if localNotEmpty and remoteEmpty: - os.unlink(tpath) - return "downloadClobber" - # overwrite existing collection - os.unlink(self.col.path) - os.rename(tpath, self.col.path) - self.col = None - return None - - def upload(self) -> bool: - "True if upload successful." - hooks.sync_stage_did_change("upload") - # make sure it's ok before we try to upload - if self.col.db.scalar("pragma integrity_check") != "ok": - return False - if not self.col.basicCheck(): - return False - # apply some adjustments, then upload - self.col.beforeUpload() - with open(self.col.path, "rb") as file: - if self.req("upload", file) != b"OK": - return False - return True + pass diff --git a/pylib/tools/genhooks.py b/pylib/tools/genhooks.py index f66d7ab9d..fb1fe14b9 100644 --- a/pylib/tools/genhooks.py +++ b/pylib/tools/genhooks.py @@ -33,8 +33,6 @@ hooks = [ args=["exporters: List[Tuple[str, Any]]"], legacy_hook="exportersList", ), - Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"), - Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"), Hook( name="field_filter", args=[ @@ -92,6 +90,12 @@ hooks = [ args=["notetype: Dict[str, Any]"], doc="Obsolete, do not use.", ), + Hook( + name="sync_stage_did_change", args=["stage: str"], doc="Obsolete, do not use.", + ), + Hook( + name="sync_progress_did_change", args=["msg: str"], doc="Obsolete, do not use.", + ), ] if __name__ == "__main__": diff --git a/qt/aqt/main.py b/qt/aqt/main.py index e4064f952..3f521aaa3 100644 --- a/qt/aqt/main.py +++ b/qt/aqt/main.py @@ -875,7 +875,8 @@ title="%s" %s>%s""" % ( if self.media_syncer.is_syncing(): self.media_syncer.show_sync_log() else: - self.unloadCollection(self._onSync) + self.temp_sync() + # self.unloadCollection(self._onSync) def _onSync(self): self._sync() @@ -910,6 +911,11 @@ title="%s" %s>%s""" % ( self.syncer.sync() self.app.setQuitOnLastWindowClosed(True) + def temp_sync(self): + from aqt.sync import sync + + sync(self) + # Tools ########################################################################## diff --git a/qt/aqt/mediasync.py b/qt/aqt/mediasync.py index 6e79b3c76..ab2dd8286 100644 --- a/qt/aqt/mediasync.py +++ b/qt/aqt/mediasync.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import List, Optional, Union import aqt -from anki.consts import SYNC_BASE from anki.rsbackend import ( TR, Interrupted, @@ -45,8 +44,6 @@ class MediaSyncer: if progress.kind != ProgressKind.MediaSync: return - print(progress.val) - assert isinstance(progress.val, MediaSyncProgress) self._log_and_notify(progress.val) @@ -55,32 +52,24 @@ class MediaSyncer: if self._syncing: return - hkey = self.mw.pm.sync_key() - if hkey is None: - return - if not self.mw.pm.media_syncing_enabled(): self._log_and_notify(tr(TR.SYNC_MEDIA_DISABLED)) return + auth = self.mw.pm.sync_auth() + if auth is None: + return + self._log_and_notify(tr(TR.SYNC_MEDIA_STARTING)) self._syncing = True self._progress_timer = self.mw.progress.timer(1000, self._on_progress, True) gui_hooks.media_sync_did_start_or_stop(True) def run() -> None: - self.mw.col.backend.sync_media(hkey=hkey, endpoint=self._endpoint()) + self.mw.col.backend.sync_media(auth) self.mw.taskman.run_in_background(run, self._on_finished) - def _endpoint(self) -> str: - shard = self.mw.pm.sync_shard() - if shard is not None: - shard_str = str(shard) - else: - shard_str = "" - return f"{SYNC_BASE % shard_str}msync/" - def _log_and_notify(self, entry: LogEntry) -> None: entry_with_time = LogEntryWithTime(time=intTime(), entry=entry) self._log.append(entry_with_time) diff --git a/qt/aqt/profiles.py b/qt/aqt/profiles.py index 4bd22e631..7eaa01311 100644 --- a/qt/aqt/profiles.py +++ b/qt/aqt/profiles.py @@ -21,6 +21,7 @@ import aqt.sound from anki import Collection from anki.db import DB from anki.lang import _, without_unicode_isolation +from anki.rsbackend import SyncAuth from anki.utils import intTime, isMac, isWin from aqt import appHelpSite from aqt.qt import * @@ -605,17 +606,23 @@ create table if not exists profiles self.profile["interrupt_audio"] = val aqt.sound.av_player.interrupt_current_audio = val - def sync_key(self) -> Optional[str]: - return self.profile.get("syncKey") - def set_sync_key(self, val: Optional[str]) -> None: self.profile["syncKey"] = val + def set_sync_username(self, val: Optional[str]) -> None: + self.profile["syncUser"] = val + + def set_host_number(self, val: Optional[int]) -> None: + self.profile["hostNum"] = val or 0 + def media_syncing_enabled(self) -> bool: return self.profile["syncMedia"] - def sync_shard(self) -> Optional[int]: - return self.profile.get("hostNum") + def sync_auth(self) -> Optional[SyncAuth]: + hkey = self.profile.get("syncKey") + if not hkey: + return None + return SyncAuth(hkey=hkey, host_number=self.profile.get("hostNum", 0)) ###################################################################### diff --git a/qt/aqt/progress.py b/qt/aqt/progress.py index 29202287c..9096c7dca 100644 --- a/qt/aqt/progress.py +++ b/qt/aqt/progress.py @@ -87,7 +87,14 @@ class ProgressManager: qconnect(self._show_timer.timeout, self._on_show_timer) return self._win - def update(self, label=None, value=None, process=True, maybeShow=True) -> None: + def update( + self, + label=None, + value=None, + process=True, + maybeShow=True, + max: Optional[int] = None, + ) -> None: # print self._min, self._counter, self._max, label, time.time() - self._lastTime if not self.mw.inMainThread(): print("progress.update() called on wrong thread") @@ -101,7 +108,9 @@ class ProgressManager: elapsed = time.time() - self._lastUpdate if label: self._win.form.label.setText(label) + self._max = max if self._max: + self._win.form.progressBar.setMaximum(max) self._counter = value or (self._counter + 1) self._win.form.progressBar.setValue(self._counter) if process and elapsed >= 0.2: @@ -170,6 +179,13 @@ class ProgressManager: self._show_timer = None self._showWin() + def want_cancel(self) -> bool: + win = self._win + if win: + return win.wantCancel + else: + return False + class ProgressDialog(QDialog): def __init__(self, parent): diff --git a/qt/aqt/sync.py b/qt/aqt/sync.py index 2c805b6a7..fe7281c9b 100644 --- a/qt/aqt/sync.py +++ b/qt/aqt/sync.py @@ -1,496 +1,253 @@ # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -import gc -import time +from __future__ import annotations -from anki import hooks -from anki.lang import _ -from anki.storage import Collection -from anki.sync import FullSyncer, RemoteServer, Syncer -from aqt.qt import * -from aqt.utils import askUserDialog, showInfo, showText, showWarning, tooltip +import enum +from typing import Callable, Tuple -# Sync manager -###################################################################### +import aqt +from anki.rsbackend import ( + TR, + FullSyncProgress, + ProgressKind, + SyncError, + SyncErrorKind, + SyncOutput, +) +from aqt.qt import ( + QDialog, + QDialogButtonBox, + QGridLayout, + QLabel, + QLineEdit, + Qt, + QTimer, + QVBoxLayout, + qconnect, +) +from aqt.utils import askUser, askUserDialog, showWarning, tr -class SyncManager(QObject): - def __init__(self, mw, pm): - QObject.__init__(self, mw) - self.mw = mw - self.pm = pm +class FullSyncChoice(enum.Enum): + CANCEL = 0 + UPLOAD = 1 + DOWNLOAD = 2 - def sync(self): - if not self.pm.profile["syncKey"]: - auth = self._getUserPass() - if not auth: + +def get_sync_status(mw: aqt.main.AnkiQt, callback: Callable[[SyncOutput], None]): + auth = mw.pm.sync_auth() + if not auth: + return + + def on_done(fut): + callback(fut.result()) + + mw.taskman.run_in_background(lambda: mw.col.backend.sync_status(auth), on_done) + + +def sync(mw: aqt.main.AnkiQt) -> None: + auth = mw.pm.sync_auth() + if not auth: + login(mw, on_success=lambda: sync(mw)) + return + + def on_done(fut): + mw.col.db.begin() + out: SyncOutput = fut.result() + mw.pm.set_host_number(out.host_number) + if out.required == out.NO_CHANGES: + # all done + return + else: + full_sync(mw, out) + + if not mw.col.basicCheck(): + showWarning("Please use Tools>Check Database") + return + + mw.col.save(trx=False) + mw.taskman.with_progress( + lambda: mw.col.backend.sync_collection(auth), + on_done, + label=tr(TR.SYNC_CHECKING), + ) + + +def full_sync(mw: aqt.main.AnkiQt, out: SyncOutput) -> None: + if out.required == out.FULL_DOWNLOAD: + confirm_full_download(mw) + elif out.required == out.FULL_UPLOAD: + full_upload(mw) + else: + choice = ask_user_to_decide_direction() + if choice == FullSyncChoice.UPLOAD: + full_upload(mw) + elif choice == FullSyncChoice.DOWNLOAD: + full_download(mw) + + +def confirm_full_download(mw: aqt.main.AnkiQt) -> None: + # confirmation step required, as some users customize their notetypes + # in an empty collection, then want to upload them + if not askUser(tr(TR.SYNC_CONFIRM_EMPTY_DOWNLOAD)): + return + else: + mw.closeAllWindows(lambda: full_download(mw)) + + +def on_full_sync_timer(mw: aqt.main.AnkiQt) -> None: + progress = mw.col.latest_progress() + if progress.kind != ProgressKind.FullSync: + return + + assert isinstance(progress, FullSyncProgress) + mw.progress.update(value=progress.val.transferred, max=progress.val.total) + + if mw.progress.want_cancel(): + mw.col.backend.abort_sync() + + +def full_download(mw: aqt.main.AnkiQt) -> None: + mw.col.close_for_full_sync() + + def on_timer(): + on_full_sync_timer(mw) + + timer = QTimer(mw) + qconnect(timer.timeout, on_timer) + timer.start(150) + + def on_done(fut): + timer.stop() + mw.col.reopen(after_full_sync=True) + mw.reset() + try: + fut.result() + except Exception as e: + showWarning(str(e)) + return + + mw.taskman.with_progress( + lambda: mw.col.backend.full_download(mw.pm.sync_auth()), + on_done, + label=tr(TR.SYNC_DOWNLOADING_FROM_ANKIWEB), + ) + + +def full_upload(mw: aqt.main.AnkiQt) -> None: + mw.col.close_for_full_sync() + + def on_timer(): + on_full_sync_timer(mw) + + timer = QTimer(mw) + qconnect(timer.timeout, on_timer) + timer.start(150) + + def on_done(fut): + timer.stop() + mw.col.reopen(after_full_sync=True) + mw.reset() + try: + fut.result() + except Exception as e: + showWarning(str(e)) + return + + mw.taskman.with_progress( + lambda: mw.col.backend.full_upload(mw.pm.sync_auth()), + on_done, + label=tr(TR.SYNC_UPLOADING_TO_ANKIWEB), + ) + + +def login( + mw: aqt.main.AnkiQt, on_success: Callable[[], None], username="", password="" +) -> None: + while True: + (username, password) = get_id_and_pass_from_user(mw, username, password) + if not username and not password: + return + if username and password: + break + + def on_done(fut): + try: + auth = fut.result() + except SyncError as e: + if e.kind() == SyncErrorKind.AUTH_FAILED: + showWarning(str(e)) + login(mw, on_success, username, password) return - self.pm.profile["syncUser"] = auth[0] - self._sync(auth) - else: - self._sync() - - def _sync(self, auth=None): - # to avoid gui widgets being garbage collected in the worker thread, - # run gc in advance - self._didFullUp = False - self._didError = False - gc.collect() - # create the thread, setup signals and start running - t = self.thread = SyncThread( - self.pm.collectionPath(), - self.pm.profile["syncKey"], - auth=auth, - hostNum=self.pm.profile.get("hostNum"), - ) - qconnect(t._event, self.onEvent) - qconnect(t.progress_event, self.on_progress) - self.label = _("Connecting...") - prog = self.mw.progress.start(immediate=True, label=self.label) - self.sentBytes = self.recvBytes = 0 - self._updateLabel() - self.thread.start() - while not self.thread.isFinished(): - if prog.wantCancel: - self.thread.flagAbort() - # make sure we don't display 'upload success' msg - self._didFullUp = False - # abort may take a while - self.mw.progress.update(_("Stopping...")) - self.mw.app.processEvents() - self.thread.wait(100) - self.mw.progress.finish() - if self.thread.syncMsg: - showText(self.thread.syncMsg) - if self.thread.uname: - self.pm.profile["syncUser"] = self.thread.uname - self.pm.profile["hostNum"] = self.thread.hostNum - - def delayedInfo(): - if self._didFullUp and not self._didError: - showInfo( - _( - """\ -Your collection was successfully uploaded to AnkiWeb. - -If you use any other devices, please sync them now, and choose \ -to download the collection you have just uploaded from this computer. \ -After doing so, future reviews and added cards will be merged \ -automatically.""" - ) - ) - - self.mw.progress.timer(1000, delayedInfo, False, requiresCollection=False) - - def _updateLabel(self): - self.mw.progress.update( - label="%s\n%s" - % ( - self.label, - _("%(a)0.1fkB up, %(b)0.1fkB down") - % dict(a=self.sentBytes / 1024, b=self.recvBytes / 1024), - ) - ) - - def on_progress(self, upload: int, download: int) -> None: - # posted events not guaranteed to arrive in order; don't go backwards - self.sentBytes = max(self.sentBytes, upload) - self.recvBytes = max(self.recvBytes, download) - self._updateLabel() - - def onEvent(self, evt, *args): - pu = self.mw.progress.update - if evt == "badAuth": - tooltip( - _("AnkiWeb ID or password was incorrect; please try again."), - parent=self.mw, - ) - # blank the key so we prompt user again - self.pm.profile["syncKey"] = None - self.pm.save() - elif evt == "corrupt": - pass - elif evt == "newKey": - self.pm.profile["syncKey"] = args[0] - self.pm.save() - elif evt == "offline": - tooltip(_("Syncing failed; internet offline.")) - elif evt == "upbad": - self._didFullUp = False - self._checkFailed() - elif evt == "sync": - m = None - t = args[0] - if t == "login": - m = _("Syncing...") - elif t == "upload": - self._didFullUp = True - m = _("Uploading to AnkiWeb...") - elif t == "download": - m = _("Downloading from AnkiWeb...") - elif t == "sanity": - m = _("Checking...") - elif t == "upgradeRequired": - showText( - _( - """\ -Please visit AnkiWeb, upgrade your deck, then try again.""" - ) - ) - if m: - self.label = m - self._updateLabel() - elif evt == "syncMsg": - self.label = args[0] - self._updateLabel() - elif evt == "error": - self._didError = True - showText(_("Syncing failed:\n%s") % self._rewriteError(args[0])) - elif evt == "clockOff": - self._clockOff() - elif evt == "checkFailed": - self._checkFailed() - elif evt == "noChanges": - pass - elif evt == "fullSync": - self._confirmFullSync() - elif evt == "downloadClobber": - showInfo( - _( - "Your AnkiWeb collection does not contain any cards. Please sync again and choose 'Upload' instead." - ) - ) - - def _rewriteError(self, err): - if "Errno 61" in err: - return _( - """\ -Couldn't connect to AnkiWeb. Please check your network connection \ -and try again.""" - ) - elif "timed out" in err or "10060" in err: - return _( - """\ -The connection to AnkiWeb timed out. Please check your network \ -connection and try again.""" - ) - elif "code: 500" in err: - return _( - """\ -AnkiWeb encountered an error. Please try again in a few minutes, and if \ -the problem persists, please file a bug report.""" - ) - elif "code: 501" in err: - return _( - """\ -Please upgrade to the latest version of Anki.""" - ) - # 502 is technically due to the server restarting, but we reuse the - # error message - elif "code: 502" in err: - return _("AnkiWeb is under maintenance. Please try again in a few minutes.") - elif "code: 503" in err: - return _( - """\ -AnkiWeb is too busy at the moment. Please try again in a few minutes.""" - ) - elif "code: 504" in err: - return _( - "504 gateway timeout error received. Please try temporarily disabling your antivirus." - ) - elif "code: 409" in err: - return _( - "Only one client can access AnkiWeb at a time. If a previous sync failed, please try again in a few minutes." - ) - elif "10061" in err or "10013" in err or "10053" in err: - return _( - "Antivirus or firewall software is preventing Anki from connecting to the internet." - ) - elif "10054" in err or "Broken pipe" in err: - return _( - "Connection timed out. Either your internet connection is experiencing problems, or you have a very large file in your media folder." - ) - elif "Unable to find the server" in err or "socket.gaierror" in err: - return _( - "Server not found. Either your connection is down, or antivirus/firewall " - "software is blocking Anki from connecting to the internet." - ) - elif "code: 407" in err: - return _("Proxy authentication required.") - elif "code: 413" in err: - return _("Your collection or a media file is too large to sync.") - elif "EOF occurred in violation of protocol" in err: - return ( - _( - "Error establishing a secure connection. This is usually caused by antivirus, firewall or VPN software, or problems with your ISP." - ) - + " (eof)" - ) - elif "certificate verify failed" in err: - return ( - _( - "Error establishing a secure connection. This is usually caused by antivirus, firewall or VPN software, or problems with your ISP." - ) - + " (invalid cert)" - ) - return err - - def _getUserPass(self): - d = QDialog(self.mw) - d.setWindowTitle("Anki") - d.setWindowModality(Qt.WindowModal) - vbox = QVBoxLayout() - l = QLabel( - _( - """\ -