pass in a progress callback instead of http_data_did_[send|receive]

If a request is happening on a background thread, the GUI code
receiving requests on that thread can lead to a crash

Add-on downloading still to do.
This commit is contained in:
Damien Elmes 2020-01-18 17:49:59 +10:00
parent 0d1a25eb5b
commit 3287e8c057
4 changed files with 32 additions and 82 deletions

View file

@ -201,54 +201,6 @@ class _FieldFilterFilter:
field_filter = _FieldFilterFilter() field_filter = _FieldFilterFilter()
class _HttpDataDidReceiveHook:
_hooks: List[Callable[[int], None]] = []
def append(self, cb: Callable[[int], None]) -> None:
"""(bytes: int)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[int], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, bytes: int) -> None:
for hook in self._hooks:
try:
hook(bytes)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
http_data_did_receive = _HttpDataDidReceiveHook()
class _HttpDataDidSendHook:
_hooks: List[Callable[[int], None]] = []
def append(self, cb: Callable[[int], None]) -> None:
"""(bytes: int)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[int], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, bytes: int) -> None:
for hook in self._hooks:
try:
hook(bytes)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
http_data_did_send = _HttpDataDidSendHook()
class _MediaFilesDidExportHook: class _MediaFilesDidExportHook:
_hooks: List[Callable[[int], None]] = [] _hooks: List[Callable[[int], None]] = []

View file

@ -2,21 +2,16 @@
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
""" """
Wrapper for requests that adds hooks for tracking upload/download progress. Wrapper for requests that adds a callback for tracking upload/download progress.
The hooks http_data_did_send and http_data_did_receive will be called for each
chunk or partial read, on the thread that is running the request.
""" """
import io import io
import os import os
from typing import Any, Dict, Optional from typing import Any, Callable, Dict, Optional
import requests import requests
from requests import Response from requests import Response
from anki import hooks
HTTP_BUF_SIZE = 64 * 1024 HTTP_BUF_SIZE = 64 * 1024
@ -24,12 +19,19 @@ class AnkiRequestsClient:
verify = True verify = True
timeout = 60 timeout = 60
# args are (upload_bytes_in_chunk, download_bytes_in_chunk)
progress_hook: Optional[Callable[[int, int], None]] = None
def __init__(self) -> None: def __init__(
self, progress_hook: Optional[Callable[[int, int], None]] = None
) -> None:
self.progress_hook = progress_hook
self.session = requests.Session() self.session = requests.Session()
def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response: def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response:
data = _MonitoringFile(data) # pytype: disable=wrong-arg-types data = _MonitoringFile(
data, hook=self.progress_hook
) # pytype: disable=wrong-arg-types
headers["User-Agent"] = self._agentName() headers["User-Agent"] = self._agentName()
return self.session.post( return self.session.post(
url, url,
@ -53,7 +55,8 @@ class AnkiRequestsClient:
buf = io.BytesIO() buf = io.BytesIO()
for chunk in resp.iter_content(chunk_size=HTTP_BUF_SIZE): for chunk in resp.iter_content(chunk_size=HTTP_BUF_SIZE):
hooks.http_data_did_receive(len(chunk)) if self.progress_hook:
self.progress_hook(0, len(chunk))
buf.write(chunk) buf.write(chunk)
return buf.getvalue() return buf.getvalue()
@ -73,7 +76,12 @@ if os.environ.get("ANKI_NOVERIFYSSL"):
class _MonitoringFile(io.BufferedReader): class _MonitoringFile(io.BufferedReader):
def __init__(self, raw: io.RawIOBase, hook: Optional[Callable[[int, int], None]]):
io.BufferedReader.__init__(self, raw)
self.hook = hook
def read(self, size=-1) -> bytes: def read(self, size=-1) -> bytes:
data = io.BufferedReader.read(self, HTTP_BUF_SIZE) data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
hooks.http_data_did_send(len(data)) if self.hook:
self.hook(len(data), 0)
return data return data

View file

@ -50,8 +50,6 @@ hooks = [
), ),
Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"), Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"),
Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"), Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"),
Hook(name="http_data_did_send", args=["bytes: int"]),
Hook(name="http_data_did_receive", args=["bytes: int"]),
Hook( Hook(
name="tag_added", args=["tag: str"], legacy_hook="newTag", legacy_no_args=True, name="tag_added", args=["tag: str"], legacy_hook="newTag", legacy_no_args=True,
), ),

View file

@ -46,6 +46,7 @@ class SyncManager(QObject):
hostNum=self.pm.profile.get("hostNum"), hostNum=self.pm.profile.get("hostNum"),
) )
t._event.connect(self.onEvent) t._event.connect(self.onEvent)
t.progress_event.connect(self.on_progress)
self.label = _("Connecting...") self.label = _("Connecting...")
prog = self.mw.progress.start(immediate=True, label=self.label) prog = self.mw.progress.start(immediate=True, label=self.label)
self.sentBytes = self.recvBytes = 0 self.sentBytes = self.recvBytes = 0
@ -93,6 +94,12 @@ automatically."""
) )
) )
def on_progress(self, upload: int, download: int) -> None:
# posted events not guaranteed to arrive in order; don't go backwards
self.sentBytes = max(self.sentBytes, upload)
self.recvBytes = max(self.recvBytes, download)
self._updateLabel()
def onEvent(self, evt, *args): def onEvent(self, evt, *args):
pu = self.mw.progress.update pu = self.mw.progress.update
if evt == "badAuth": if evt == "badAuth":
@ -165,13 +172,6 @@ sync again to correct the issue."""
"Your AnkiWeb collection does not contain any cards. Please sync again and choose 'Upload' instead." "Your AnkiWeb collection does not contain any cards. Please sync again and choose 'Upload' instead."
) )
) )
elif evt == "send":
# posted events not guaranteed to arrive in order
self.sentBytes = max(self.sentBytes, int(args[0]))
self._updateLabel()
elif evt == "recv":
self.recvBytes = max(self.recvBytes, int(args[0]))
self._updateLabel()
def _rewriteError(self, err): def _rewriteError(self, err):
if "Errno 61" in err: if "Errno 61" in err:
@ -356,6 +356,7 @@ Check Database, then sync again."""
class SyncThread(QThread): class SyncThread(QThread):
_event = pyqtSignal(str, str) _event = pyqtSignal(str, str)
progress_event = pyqtSignal(int, int)
def __init__(self, path, hkey, auth=None, media=True, hostNum=None): def __init__(self, path, hkey, auth=None, media=True, hostNum=None):
QThread.__init__(self) QThread.__init__(self)
@ -390,26 +391,19 @@ class SyncThread(QThread):
def syncMsg(msg): def syncMsg(msg):
self.fireEvent("syncMsg", msg) self.fireEvent("syncMsg", msg)
def sendEvent(bytes): def http_progress(upload: int, download: int) -> None:
if not self._abort: if not self._abort:
self.sentTotal += bytes self.sentTotal += upload
self.fireEvent("send", str(self.sentTotal)) self.recvTotal += download
self.progress_event.emit(self.sentTotal, self.recvTotal) # type: ignore
elif self._abort == 1: elif self._abort == 1:
self._abort = 2 self._abort = 2
raise Exception("sync cancelled") raise Exception("sync cancelled")
def recvEvent(bytes): self.server.client.progress_hook = http_progress
if not self._abort:
self.recvTotal += bytes
self.fireEvent("recv", str(self.recvTotal))
elif self._abort == 1:
self._abort = 2
raise Exception("sync cancelled")
hooks.sync_stage_did_change.append(syncEvent) hooks.sync_stage_did_change.append(syncEvent)
hooks.sync_progress_did_change.append(syncMsg) hooks.sync_progress_did_change.append(syncMsg)
hooks.http_data_did_send.append(sendEvent)
hooks.http_data_did_receive.append(recvEvent)
# run sync and catch any errors # run sync and catch any errors
try: try:
self._sync() self._sync()
@ -421,8 +415,6 @@ class SyncThread(QThread):
self.col.close(save=False) self.col.close(save=False)
hooks.sync_stage_did_change.remove(syncEvent) hooks.sync_stage_did_change.remove(syncEvent)
hooks.sync_progress_did_change.remove(syncMsg) hooks.sync_progress_did_change.remove(syncMsg)
hooks.http_data_did_send.remove(sendEvent)
hooks.http_data_did_receive.remove(recvEvent)
def _abortingSync(self): def _abortingSync(self):
try: try: