diff --git a/pylib/anki/hooks.py b/pylib/anki/hooks.py index 41de717d0..6a7caaad8 100644 --- a/pylib/anki/hooks.py +++ b/pylib/anki/hooks.py @@ -201,54 +201,6 @@ class _FieldFilterFilter: field_filter = _FieldFilterFilter() -class _HttpDataDidReceiveHook: - _hooks: List[Callable[[int], None]] = [] - - def append(self, cb: Callable[[int], None]) -> None: - """(bytes: int)""" - self._hooks.append(cb) - - def remove(self, cb: Callable[[int], None]) -> None: - if cb in self._hooks: - self._hooks.remove(cb) - - def __call__(self, bytes: int) -> None: - for hook in self._hooks: - try: - hook(bytes) - except: - # if the hook fails, remove it - self._hooks.remove(hook) - raise - - -http_data_did_receive = _HttpDataDidReceiveHook() - - -class _HttpDataDidSendHook: - _hooks: List[Callable[[int], None]] = [] - - def append(self, cb: Callable[[int], None]) -> None: - """(bytes: int)""" - self._hooks.append(cb) - - def remove(self, cb: Callable[[int], None]) -> None: - if cb in self._hooks: - self._hooks.remove(cb) - - def __call__(self, bytes: int) -> None: - for hook in self._hooks: - try: - hook(bytes) - except: - # if the hook fails, remove it - self._hooks.remove(hook) - raise - - -http_data_did_send = _HttpDataDidSendHook() - - class _MediaFilesDidExportHook: _hooks: List[Callable[[int], None]] = [] diff --git a/pylib/anki/httpclient.py b/pylib/anki/httpclient.py index 95bb82d49..ace66d4df 100644 --- a/pylib/anki/httpclient.py +++ b/pylib/anki/httpclient.py @@ -2,21 +2,16 @@ # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html """ -Wrapper for requests that adds hooks for tracking upload/download progress. - -The hooks http_data_did_send and http_data_did_receive will be called for each -chunk or partial read, on the thread that is running the request. +Wrapper for requests that adds a callback for tracking upload/download progress. """ import io import os -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional import requests from requests import Response -from anki import hooks - HTTP_BUF_SIZE = 64 * 1024 @@ -24,12 +19,19 @@ class AnkiRequestsClient: verify = True timeout = 60 + # args are (upload_bytes_in_chunk, download_bytes_in_chunk) + progress_hook: Optional[Callable[[int, int], None]] = None - def __init__(self) -> None: + def __init__( + self, progress_hook: Optional[Callable[[int, int], None]] = None + ) -> None: + self.progress_hook = progress_hook self.session = requests.Session() def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response: - data = _MonitoringFile(data) # pytype: disable=wrong-arg-types + data = _MonitoringFile( + data, hook=self.progress_hook + ) # pytype: disable=wrong-arg-types headers["User-Agent"] = self._agentName() return self.session.post( url, @@ -53,7 +55,8 @@ class AnkiRequestsClient: buf = io.BytesIO() for chunk in resp.iter_content(chunk_size=HTTP_BUF_SIZE): - hooks.http_data_did_receive(len(chunk)) + if self.progress_hook: + self.progress_hook(0, len(chunk)) buf.write(chunk) return buf.getvalue() @@ -73,7 +76,12 @@ if os.environ.get("ANKI_NOVERIFYSSL"): class _MonitoringFile(io.BufferedReader): + def __init__(self, raw: io.RawIOBase, hook: Optional[Callable[[int, int], None]]): + io.BufferedReader.__init__(self, raw) + self.hook = hook + def read(self, size=-1) -> bytes: data = io.BufferedReader.read(self, HTTP_BUF_SIZE) - hooks.http_data_did_send(len(data)) + if self.hook: + self.hook(len(data), 0) return data diff --git a/pylib/tools/genhooks.py b/pylib/tools/genhooks.py index 68be0a515..0152fc85a 100644 --- a/pylib/tools/genhooks.py +++ b/pylib/tools/genhooks.py @@ -50,8 +50,6 @@ hooks = [ ), Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"), Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"), - Hook(name="http_data_did_send", args=["bytes: int"]), - Hook(name="http_data_did_receive", args=["bytes: int"]), Hook( name="tag_added", args=["tag: str"], legacy_hook="newTag", legacy_no_args=True, ), diff --git a/qt/aqt/sync.py b/qt/aqt/sync.py index 60ce4c0b4..78f7fe00b 100644 --- a/qt/aqt/sync.py +++ b/qt/aqt/sync.py @@ -46,6 +46,7 @@ class SyncManager(QObject): hostNum=self.pm.profile.get("hostNum"), ) t._event.connect(self.onEvent) + t.progress_event.connect(self.on_progress) self.label = _("Connecting...") prog = self.mw.progress.start(immediate=True, label=self.label) self.sentBytes = self.recvBytes = 0 @@ -93,6 +94,12 @@ automatically.""" ) ) + def on_progress(self, upload: int, download: int) -> None: + # posted events not guaranteed to arrive in order; don't go backwards + self.sentBytes = max(self.sentBytes, upload) + self.recvBytes = max(self.recvBytes, download) + self._updateLabel() + def onEvent(self, evt, *args): pu = self.mw.progress.update if evt == "badAuth": @@ -165,13 +172,6 @@ sync again to correct the issue.""" "Your AnkiWeb collection does not contain any cards. Please sync again and choose 'Upload' instead." ) ) - elif evt == "send": - # posted events not guaranteed to arrive in order - self.sentBytes = max(self.sentBytes, int(args[0])) - self._updateLabel() - elif evt == "recv": - self.recvBytes = max(self.recvBytes, int(args[0])) - self._updateLabel() def _rewriteError(self, err): if "Errno 61" in err: @@ -356,6 +356,7 @@ Check Database, then sync again.""" class SyncThread(QThread): _event = pyqtSignal(str, str) + progress_event = pyqtSignal(int, int) def __init__(self, path, hkey, auth=None, media=True, hostNum=None): QThread.__init__(self) @@ -390,26 +391,19 @@ class SyncThread(QThread): def syncMsg(msg): self.fireEvent("syncMsg", msg) - def sendEvent(bytes): + def http_progress(upload: int, download: int) -> None: if not self._abort: - self.sentTotal += bytes - self.fireEvent("send", str(self.sentTotal)) + self.sentTotal += upload + self.recvTotal += download + self.progress_event.emit(self.sentTotal, self.recvTotal) # type: ignore elif self._abort == 1: self._abort = 2 raise Exception("sync cancelled") - def recvEvent(bytes): - if not self._abort: - self.recvTotal += bytes - self.fireEvent("recv", str(self.recvTotal)) - elif self._abort == 1: - self._abort = 2 - raise Exception("sync cancelled") + self.server.client.progress_hook = http_progress hooks.sync_stage_did_change.append(syncEvent) hooks.sync_progress_did_change.append(syncMsg) - hooks.http_data_did_send.append(sendEvent) - hooks.http_data_did_receive.append(recvEvent) # run sync and catch any errors try: self._sync() @@ -421,8 +415,6 @@ class SyncThread(QThread): self.col.close(save=False) hooks.sync_stage_did_change.remove(syncEvent) hooks.sync_progress_did_change.remove(syncMsg) - hooks.http_data_did_send.remove(sendEvent) - hooks.http_data_did_receive.remove(recvEvent) def _abortingSync(self): try: