pass in a progress callback instead of http_data_did_[send|receive]

If a request is happening on a background thread, the GUI code
receiving requests on that thread can lead to a crash

Add-on downloading still to do.
This commit is contained in:
Damien Elmes 2020-01-18 17:49:59 +10:00
parent 0d1a25eb5b
commit 3287e8c057
4 changed files with 32 additions and 82 deletions

View file

@ -201,54 +201,6 @@ class _FieldFilterFilter:
field_filter = _FieldFilterFilter()
class _HttpDataDidReceiveHook:
_hooks: List[Callable[[int], None]] = []
def append(self, cb: Callable[[int], None]) -> None:
"""(bytes: int)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[int], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, bytes: int) -> None:
for hook in self._hooks:
try:
hook(bytes)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
http_data_did_receive = _HttpDataDidReceiveHook()
class _HttpDataDidSendHook:
_hooks: List[Callable[[int], None]] = []
def append(self, cb: Callable[[int], None]) -> None:
"""(bytes: int)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[int], None]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, bytes: int) -> None:
for hook in self._hooks:
try:
hook(bytes)
except:
# if the hook fails, remove it
self._hooks.remove(hook)
raise
http_data_did_send = _HttpDataDidSendHook()
class _MediaFilesDidExportHook:
_hooks: List[Callable[[int], None]] = []

View file

@ -2,21 +2,16 @@
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
"""
Wrapper for requests that adds hooks for tracking upload/download progress.
The hooks http_data_did_send and http_data_did_receive will be called for each
chunk or partial read, on the thread that is running the request.
Wrapper for requests that adds a callback for tracking upload/download progress.
"""
import io
import os
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional
import requests
from requests import Response
from anki import hooks
HTTP_BUF_SIZE = 64 * 1024
@ -24,12 +19,19 @@ class AnkiRequestsClient:
verify = True
timeout = 60
# args are (upload_bytes_in_chunk, download_bytes_in_chunk)
progress_hook: Optional[Callable[[int, int], None]] = None
def __init__(self) -> None:
def __init__(
self, progress_hook: Optional[Callable[[int, int], None]] = None
) -> None:
self.progress_hook = progress_hook
self.session = requests.Session()
def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response:
data = _MonitoringFile(data) # pytype: disable=wrong-arg-types
data = _MonitoringFile(
data, hook=self.progress_hook
) # pytype: disable=wrong-arg-types
headers["User-Agent"] = self._agentName()
return self.session.post(
url,
@ -53,7 +55,8 @@ class AnkiRequestsClient:
buf = io.BytesIO()
for chunk in resp.iter_content(chunk_size=HTTP_BUF_SIZE):
hooks.http_data_did_receive(len(chunk))
if self.progress_hook:
self.progress_hook(0, len(chunk))
buf.write(chunk)
return buf.getvalue()
@ -73,7 +76,12 @@ if os.environ.get("ANKI_NOVERIFYSSL"):
class _MonitoringFile(io.BufferedReader):
def __init__(self, raw: io.RawIOBase, hook: Optional[Callable[[int, int], None]]):
io.BufferedReader.__init__(self, raw)
self.hook = hook
def read(self, size=-1) -> bytes:
data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
hooks.http_data_did_send(len(data))
if self.hook:
self.hook(len(data), 0)
return data

View file

@ -50,8 +50,6 @@ hooks = [
),
Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"),
Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"),
Hook(name="http_data_did_send", args=["bytes: int"]),
Hook(name="http_data_did_receive", args=["bytes: int"]),
Hook(
name="tag_added", args=["tag: str"], legacy_hook="newTag", legacy_no_args=True,
),

View file

@ -46,6 +46,7 @@ class SyncManager(QObject):
hostNum=self.pm.profile.get("hostNum"),
)
t._event.connect(self.onEvent)
t.progress_event.connect(self.on_progress)
self.label = _("Connecting...")
prog = self.mw.progress.start(immediate=True, label=self.label)
self.sentBytes = self.recvBytes = 0
@ -93,6 +94,12 @@ automatically."""
)
)
def on_progress(self, upload: int, download: int) -> None:
# posted events not guaranteed to arrive in order; don't go backwards
self.sentBytes = max(self.sentBytes, upload)
self.recvBytes = max(self.recvBytes, download)
self._updateLabel()
def onEvent(self, evt, *args):
pu = self.mw.progress.update
if evt == "badAuth":
@ -165,13 +172,6 @@ sync again to correct the issue."""
"Your AnkiWeb collection does not contain any cards. Please sync again and choose 'Upload' instead."
)
)
elif evt == "send":
# posted events not guaranteed to arrive in order
self.sentBytes = max(self.sentBytes, int(args[0]))
self._updateLabel()
elif evt == "recv":
self.recvBytes = max(self.recvBytes, int(args[0]))
self._updateLabel()
def _rewriteError(self, err):
if "Errno 61" in err:
@ -356,6 +356,7 @@ Check Database, then sync again."""
class SyncThread(QThread):
_event = pyqtSignal(str, str)
progress_event = pyqtSignal(int, int)
def __init__(self, path, hkey, auth=None, media=True, hostNum=None):
QThread.__init__(self)
@ -390,26 +391,19 @@ class SyncThread(QThread):
def syncMsg(msg):
self.fireEvent("syncMsg", msg)
def sendEvent(bytes):
def http_progress(upload: int, download: int) -> None:
if not self._abort:
self.sentTotal += bytes
self.fireEvent("send", str(self.sentTotal))
self.sentTotal += upload
self.recvTotal += download
self.progress_event.emit(self.sentTotal, self.recvTotal) # type: ignore
elif self._abort == 1:
self._abort = 2
raise Exception("sync cancelled")
def recvEvent(bytes):
if not self._abort:
self.recvTotal += bytes
self.fireEvent("recv", str(self.recvTotal))
elif self._abort == 1:
self._abort = 2
raise Exception("sync cancelled")
self.server.client.progress_hook = http_progress
hooks.sync_stage_did_change.append(syncEvent)
hooks.sync_progress_did_change.append(syncMsg)
hooks.http_data_did_send.append(sendEvent)
hooks.http_data_did_receive.append(recvEvent)
# run sync and catch any errors
try:
self._sync()
@ -421,8 +415,6 @@ class SyncThread(QThread):
self.col.close(save=False)
hooks.sync_stage_did_change.remove(syncEvent)
hooks.sync_progress_did_change.remove(syncMsg)
hooks.http_data_did_send.remove(sendEvent)
hooks.http_data_did_receive.remove(recvEvent)
def _abortingSync(self):
try: