mirror of
https://github.com/ankitects/anki.git
synced 2025-09-23 16:26:40 -04:00
pass in a progress callback instead of http_data_did_[send|receive]
If a request is happening on a background thread, the GUI code receiving requests on that thread can lead to a crash Add-on downloading still to do.
This commit is contained in:
parent
0d1a25eb5b
commit
3287e8c057
4 changed files with 32 additions and 82 deletions
|
@ -201,54 +201,6 @@ class _FieldFilterFilter:
|
|||
field_filter = _FieldFilterFilter()
|
||||
|
||||
|
||||
class _HttpDataDidReceiveHook:
|
||||
_hooks: List[Callable[[int], None]] = []
|
||||
|
||||
def append(self, cb: Callable[[int], None]) -> None:
|
||||
"""(bytes: int)"""
|
||||
self._hooks.append(cb)
|
||||
|
||||
def remove(self, cb: Callable[[int], None]) -> None:
|
||||
if cb in self._hooks:
|
||||
self._hooks.remove(cb)
|
||||
|
||||
def __call__(self, bytes: int) -> None:
|
||||
for hook in self._hooks:
|
||||
try:
|
||||
hook(bytes)
|
||||
except:
|
||||
# if the hook fails, remove it
|
||||
self._hooks.remove(hook)
|
||||
raise
|
||||
|
||||
|
||||
http_data_did_receive = _HttpDataDidReceiveHook()
|
||||
|
||||
|
||||
class _HttpDataDidSendHook:
|
||||
_hooks: List[Callable[[int], None]] = []
|
||||
|
||||
def append(self, cb: Callable[[int], None]) -> None:
|
||||
"""(bytes: int)"""
|
||||
self._hooks.append(cb)
|
||||
|
||||
def remove(self, cb: Callable[[int], None]) -> None:
|
||||
if cb in self._hooks:
|
||||
self._hooks.remove(cb)
|
||||
|
||||
def __call__(self, bytes: int) -> None:
|
||||
for hook in self._hooks:
|
||||
try:
|
||||
hook(bytes)
|
||||
except:
|
||||
# if the hook fails, remove it
|
||||
self._hooks.remove(hook)
|
||||
raise
|
||||
|
||||
|
||||
http_data_did_send = _HttpDataDidSendHook()
|
||||
|
||||
|
||||
class _MediaFilesDidExportHook:
|
||||
_hooks: List[Callable[[int], None]] = []
|
||||
|
||||
|
|
|
@ -2,21 +2,16 @@
|
|||
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||
|
||||
"""
|
||||
Wrapper for requests that adds hooks for tracking upload/download progress.
|
||||
|
||||
The hooks http_data_did_send and http_data_did_receive will be called for each
|
||||
chunk or partial read, on the thread that is running the request.
|
||||
Wrapper for requests that adds a callback for tracking upload/download progress.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
|
||||
from anki import hooks
|
||||
|
||||
HTTP_BUF_SIZE = 64 * 1024
|
||||
|
||||
|
||||
|
@ -24,12 +19,19 @@ class AnkiRequestsClient:
|
|||
|
||||
verify = True
|
||||
timeout = 60
|
||||
# args are (upload_bytes_in_chunk, download_bytes_in_chunk)
|
||||
progress_hook: Optional[Callable[[int, int], None]] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self, progress_hook: Optional[Callable[[int, int], None]] = None
|
||||
) -> None:
|
||||
self.progress_hook = progress_hook
|
||||
self.session = requests.Session()
|
||||
|
||||
def post(self, url: str, data: Any, headers: Optional[Dict[str, str]]) -> Response:
|
||||
data = _MonitoringFile(data) # pytype: disable=wrong-arg-types
|
||||
data = _MonitoringFile(
|
||||
data, hook=self.progress_hook
|
||||
) # pytype: disable=wrong-arg-types
|
||||
headers["User-Agent"] = self._agentName()
|
||||
return self.session.post(
|
||||
url,
|
||||
|
@ -53,7 +55,8 @@ class AnkiRequestsClient:
|
|||
|
||||
buf = io.BytesIO()
|
||||
for chunk in resp.iter_content(chunk_size=HTTP_BUF_SIZE):
|
||||
hooks.http_data_did_receive(len(chunk))
|
||||
if self.progress_hook:
|
||||
self.progress_hook(0, len(chunk))
|
||||
buf.write(chunk)
|
||||
return buf.getvalue()
|
||||
|
||||
|
@ -73,7 +76,12 @@ if os.environ.get("ANKI_NOVERIFYSSL"):
|
|||
|
||||
|
||||
class _MonitoringFile(io.BufferedReader):
|
||||
def __init__(self, raw: io.RawIOBase, hook: Optional[Callable[[int, int], None]]):
|
||||
io.BufferedReader.__init__(self, raw)
|
||||
self.hook = hook
|
||||
|
||||
def read(self, size=-1) -> bytes:
|
||||
data = io.BufferedReader.read(self, HTTP_BUF_SIZE)
|
||||
hooks.http_data_did_send(len(data))
|
||||
if self.hook:
|
||||
self.hook(len(data), 0)
|
||||
return data
|
||||
|
|
|
@ -50,8 +50,6 @@ hooks = [
|
|||
),
|
||||
Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"),
|
||||
Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"),
|
||||
Hook(name="http_data_did_send", args=["bytes: int"]),
|
||||
Hook(name="http_data_did_receive", args=["bytes: int"]),
|
||||
Hook(
|
||||
name="tag_added", args=["tag: str"], legacy_hook="newTag", legacy_no_args=True,
|
||||
),
|
||||
|
|
|
@ -46,6 +46,7 @@ class SyncManager(QObject):
|
|||
hostNum=self.pm.profile.get("hostNum"),
|
||||
)
|
||||
t._event.connect(self.onEvent)
|
||||
t.progress_event.connect(self.on_progress)
|
||||
self.label = _("Connecting...")
|
||||
prog = self.mw.progress.start(immediate=True, label=self.label)
|
||||
self.sentBytes = self.recvBytes = 0
|
||||
|
@ -93,6 +94,12 @@ automatically."""
|
|||
)
|
||||
)
|
||||
|
||||
def on_progress(self, upload: int, download: int) -> None:
|
||||
# posted events not guaranteed to arrive in order; don't go backwards
|
||||
self.sentBytes = max(self.sentBytes, upload)
|
||||
self.recvBytes = max(self.recvBytes, download)
|
||||
self._updateLabel()
|
||||
|
||||
def onEvent(self, evt, *args):
|
||||
pu = self.mw.progress.update
|
||||
if evt == "badAuth":
|
||||
|
@ -165,13 +172,6 @@ sync again to correct the issue."""
|
|||
"Your AnkiWeb collection does not contain any cards. Please sync again and choose 'Upload' instead."
|
||||
)
|
||||
)
|
||||
elif evt == "send":
|
||||
# posted events not guaranteed to arrive in order
|
||||
self.sentBytes = max(self.sentBytes, int(args[0]))
|
||||
self._updateLabel()
|
||||
elif evt == "recv":
|
||||
self.recvBytes = max(self.recvBytes, int(args[0]))
|
||||
self._updateLabel()
|
||||
|
||||
def _rewriteError(self, err):
|
||||
if "Errno 61" in err:
|
||||
|
@ -356,6 +356,7 @@ Check Database, then sync again."""
|
|||
class SyncThread(QThread):
|
||||
|
||||
_event = pyqtSignal(str, str)
|
||||
progress_event = pyqtSignal(int, int)
|
||||
|
||||
def __init__(self, path, hkey, auth=None, media=True, hostNum=None):
|
||||
QThread.__init__(self)
|
||||
|
@ -390,26 +391,19 @@ class SyncThread(QThread):
|
|||
def syncMsg(msg):
|
||||
self.fireEvent("syncMsg", msg)
|
||||
|
||||
def sendEvent(bytes):
|
||||
def http_progress(upload: int, download: int) -> None:
|
||||
if not self._abort:
|
||||
self.sentTotal += bytes
|
||||
self.fireEvent("send", str(self.sentTotal))
|
||||
self.sentTotal += upload
|
||||
self.recvTotal += download
|
||||
self.progress_event.emit(self.sentTotal, self.recvTotal) # type: ignore
|
||||
elif self._abort == 1:
|
||||
self._abort = 2
|
||||
raise Exception("sync cancelled")
|
||||
|
||||
def recvEvent(bytes):
|
||||
if not self._abort:
|
||||
self.recvTotal += bytes
|
||||
self.fireEvent("recv", str(self.recvTotal))
|
||||
elif self._abort == 1:
|
||||
self._abort = 2
|
||||
raise Exception("sync cancelled")
|
||||
self.server.client.progress_hook = http_progress
|
||||
|
||||
hooks.sync_stage_did_change.append(syncEvent)
|
||||
hooks.sync_progress_did_change.append(syncMsg)
|
||||
hooks.http_data_did_send.append(sendEvent)
|
||||
hooks.http_data_did_receive.append(recvEvent)
|
||||
# run sync and catch any errors
|
||||
try:
|
||||
self._sync()
|
||||
|
@ -421,8 +415,6 @@ class SyncThread(QThread):
|
|||
self.col.close(save=False)
|
||||
hooks.sync_stage_did_change.remove(syncEvent)
|
||||
hooks.sync_progress_did_change.remove(syncMsg)
|
||||
hooks.http_data_did_send.remove(sendEvent)
|
||||
hooks.http_data_did_receive.remove(recvEvent)
|
||||
|
||||
def _abortingSync(self):
|
||||
try:
|
||||
|
|
Loading…
Reference in a new issue