diff --git a/proto/backend.proto b/proto/backend.proto index f9fcbd689..a224d0e85 100644 --- a/proto/backend.proto +++ b/proto/backend.proto @@ -68,6 +68,9 @@ message DeckConfigID { /////////////////////////////////////////////////////////// service BackendService { + rpc LatestProgress (Empty) returns (Progress); + rpc SetWantsAbort (Empty) returns (Empty); + // card rendering rpc ExtractAVTags (ExtractAVTagsIn) returns (ExtractAVTagsOut); @@ -448,14 +451,6 @@ message BackendError { } } -message Progress { - oneof value { - MediaSyncProgress media_sync = 1; - string media_check = 2; - FullSyncProgress full_sync = 3; - } -} - message NetworkError { enum NetworkErrorKind { OTHER = 0; @@ -482,6 +477,18 @@ message SyncError { SyncErrorKind kind = 1; } +// Progress +/////////////////////////////////////////////////////////// + +message Progress { + oneof value { + Empty none = 1; + MediaSyncProgress media_sync = 2; + string media_check = 3; + FullSyncProgress full_sync = 4; + } +} + message MediaSyncProgress { string checked = 1; string added = 2; @@ -498,6 +505,7 @@ message MediaSyncUploadProgress { uint32 deletions = 2; } + // Messages /////////////////////////////////////////////////////////// diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 5a582d1fb..cb04bba67 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -26,7 +26,7 @@ from anki.lang import _ from anki.media import MediaManager, media_paths_from_col_path from anki.models import ModelManager from anki.notes import Note -from anki.rsbackend import TR, DBError, FormatTimeSpanContext, RustBackend, pb +from anki.rsbackend import TR, DBError, FormatTimeSpanContext, Progress, RustBackend, pb from anki.sched import Scheduler as V1Scheduler from anki.schedv2 import Scheduler as V2Scheduler from anki.tags import TagManager @@ -91,6 +91,12 @@ class Collection: ) -> str: return self.backend.format_timespan(seconds=seconds, context=context) + # Progress + ########################################################################## + + def latest_progress(self) -> Progress: + return Progress.from_proto(self.backend.latest_progress()) + # Scheduler ########################################################################## diff --git a/pylib/anki/hooks.py b/pylib/anki/hooks.py index 90c8f0c4f..31edb5348 100644 --- a/pylib/anki/hooks.py +++ b/pylib/anki/hooks.py @@ -28,33 +28,6 @@ from anki.notes import Note # @@AUTOGEN@@ -class _BgThreadProgressCallbackFilter: - """Warning: this is called on a background thread.""" - - _hooks: List[Callable[[bool, "anki.rsbackend.Progress"], bool]] = [] - - def append(self, cb: Callable[[bool, "anki.rsbackend.Progress"], bool]) -> None: - """(proceed: bool, progress: anki.rsbackend.Progress)""" - self._hooks.append(cb) - - def remove(self, cb: Callable[[bool, "anki.rsbackend.Progress"], bool]) -> None: - if cb in self._hooks: - self._hooks.remove(cb) - - def __call__(self, proceed: bool, progress: anki.rsbackend.Progress) -> bool: - for filter in self._hooks: - try: - proceed = filter(proceed, progress) - except: - # if the hook fails, remove it - self._hooks.remove(filter) - raise - return proceed - - -bg_thread_progress_callback = _BgThreadProgressCallbackFilter() - - class _CardDidLeechHook: _hooks: List[Callable[[Card], None]] = [] diff --git a/pylib/anki/rsbackend.py b/pylib/anki/rsbackend.py index 023fec20c..481b3150f 100644 --- a/pylib/anki/rsbackend.py +++ b/pylib/anki/rsbackend.py @@ -152,31 +152,28 @@ FormatTimeSpanContext = pb.FormatTimespanIn.Context class ProgressKind(enum.Enum): - MediaSync = 0 - MediaCheck = 1 + NoProgress = 0 + MediaSync = 1 + MediaCheck = 2 + FullSync = 3 @dataclass class Progress: kind: ProgressKind - val: Union[MediaSyncProgress, str] + val: Union[MediaSyncProgress, pb.FullSyncProgress, str] - -def proto_progress_to_native(progress: pb.Progress) -> Progress: - kind = progress.WhichOneof("value") - if kind == "media_sync": - return Progress(kind=ProgressKind.MediaSync, val=progress.media_sync) - elif kind == "media_check": - return Progress(kind=ProgressKind.MediaCheck, val=progress.media_check) - else: - assert_impossible_literal(kind) - - -def _on_progress(progress_bytes: bytes) -> bool: - progress = pb.Progress() - progress.ParseFromString(progress_bytes) - native_progress = proto_progress_to_native(progress) - return hooks.bg_thread_progress_callback(True, native_progress) + @staticmethod + def from_proto(proto: pb.Progress) -> Progress: + kind = proto.WhichOneof("value") + if kind == "media_sync": + return Progress(kind=ProgressKind.MediaSync, val=proto.media_sync) + elif kind == "media_check": + return Progress(kind=ProgressKind.MediaCheck, val=proto.media_check) + elif kind == "full_sync": + return Progress(kind=ProgressKind.FullSync, val=proto.full_sync) + else: + return Progress(kind=ProgressKind.NoProgress, val="") class RustBackend(RustBackendGenerated): @@ -196,7 +193,6 @@ class RustBackend(RustBackendGenerated): locale_folder_path=ftl_folder, preferred_langs=langs, server=server, ) self._backend = ankirspy.open_backend(init_msg.SerializeToString()) - self._backend.set_progress_callback(_on_progress) def db_query( self, sql: str, args: Sequence[ValueForDB], first_row_only: bool diff --git a/pylib/tools/genhooks.py b/pylib/tools/genhooks.py index 6714706aa..f66d7ab9d 100644 --- a/pylib/tools/genhooks.py +++ b/pylib/tools/genhooks.py @@ -35,12 +35,6 @@ hooks = [ ), Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"), Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"), - Hook( - name="bg_thread_progress_callback", - args=["proceed: bool", "progress: anki.rsbackend.Progress"], - return_type="bool", - doc="Warning: this is called on a background thread.", - ), Hook( name="field_filter", args=[ diff --git a/qt/aqt/mediacheck.py b/qt/aqt/mediacheck.py index 4de188e9e..643ae1ca1 100644 --- a/qt/aqt/mediacheck.py +++ b/qt/aqt/mediacheck.py @@ -9,8 +9,7 @@ from concurrent.futures import Future from typing import Iterable, List, Optional, Sequence, TypeVar import aqt -from anki import hooks -from anki.rsbackend import TR, Interrupted, Progress, ProgressKind, pb +from anki.rsbackend import TR, Interrupted, ProgressKind, pb from aqt.qt import * from aqt.utils import askUser, restoreGeom, saveGeom, showText, tooltip, tr @@ -36,28 +35,42 @@ class MediaChecker: def __init__(self, mw: aqt.AnkiQt) -> None: self.mw = mw + self._progress_timer: Optional[QTimer] = None def check(self) -> None: self.progress_dialog = self.mw.progress.start() - hooks.bg_thread_progress_callback.append(self._on_progress) + self._set_progress_enabled(True) self.mw.taskman.run_in_background(self._check, self._on_finished) - def _on_progress(self, proceed: bool, progress: Progress) -> bool: - if progress.kind != ProgressKind.MediaCheck: - return proceed + def _set_progress_enabled(self, enabled: bool) -> None: + if self._progress_timer: + self._progress_timer.stop() + self._progress_timer = None + if enabled: + self._progress_timer = self.mw.progress.timer(100, self._on_progress, True) - if self.progress_dialog.wantCancel: - return False + def _on_progress(self) -> None: + progress = self.mw.col.latest_progress() + if progress.kind != ProgressKind.MediaCheck: + return + + assert isinstance(progress.val, str) + + try: + if self.progress_dialog.wantCancel: + self.mw.col.backend.set_wants_abort() + except AttributeError: + # dialog may not be active + pass self.mw.taskman.run_on_main(lambda: self.mw.progress.update(progress.val)) - return True def _check(self) -> pb.CheckMediaOut: "Run the check on a background thread." return self.mw.col.media.check() def _on_finished(self, future: Future) -> None: - hooks.bg_thread_progress_callback.remove(self._on_progress) + self._set_progress_enabled(False) self.mw.progress.finish() self.progress_dialog = None @@ -162,14 +175,14 @@ class MediaChecker: def _on_empty_trash(self): self.progress_dialog = self.mw.progress.start() - hooks.bg_thread_progress_callback.append(self._on_progress) + self._set_progress_enabled(True) def empty_trash(): self.mw.col.backend.empty_trash() def on_done(fut: Future): self.mw.progress.finish() - hooks.bg_thread_progress_callback.remove(self._on_progress) + self._set_progress_enabled(False) # check for errors fut.result() @@ -179,14 +192,14 @@ class MediaChecker: def _on_restore_trash(self): self.progress_dialog = self.mw.progress.start() - hooks.bg_thread_progress_callback.append(self._on_progress) + self._set_progress_enabled(True) def restore_trash(): self.mw.col.backend.restore_trash() def on_done(fut: Future): self.mw.progress.finish() - hooks.bg_thread_progress_callback.remove(self._on_progress) + self._set_progress_enabled(False) # check for errors fut.result() diff --git a/qt/aqt/mediasync.py b/qt/aqt/mediasync.py index d551fb45d..6e79b3c76 100644 --- a/qt/aqt/mediasync.py +++ b/qt/aqt/mediasync.py @@ -6,23 +6,21 @@ from __future__ import annotations import time from concurrent.futures import Future from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Union import aqt -from anki import hooks from anki.consts import SYNC_BASE from anki.rsbackend import ( TR, Interrupted, MediaSyncProgress, NetworkError, - Progress, ProgressKind, ) from anki.types import assert_impossible from anki.utils import intTime from aqt import gui_hooks -from aqt.qt import QDialog, QDialogButtonBox, QPushButton, qconnect +from aqt.qt import QDialog, QDialogButtonBox, QPushButton, QTimer, qconnect from aqt.utils import showWarning, tr LogEntry = Union[MediaSyncProgress, str] @@ -39,22 +37,19 @@ class MediaSyncer: self.mw = mw self._syncing: bool = False self._log: List[LogEntryWithTime] = [] - self._want_stop = False - hooks.bg_thread_progress_callback.append(self._on_rust_progress) + self._progress_timer: Optional[QTimer] = None gui_hooks.media_sync_did_start_or_stop.append(self._on_start_stop) - def _on_rust_progress(self, proceed: bool, progress: Progress) -> bool: + def _on_progress(self): + progress = self.mw.col.latest_progress() if progress.kind != ProgressKind.MediaSync: - return proceed + return + + print(progress.val) assert isinstance(progress.val, MediaSyncProgress) self._log_and_notify(progress.val) - if self._want_stop: - return False - else: - return proceed - def start(self) -> None: "Start media syncing in the background, if it's not already running." if self._syncing: @@ -70,7 +65,7 @@ class MediaSyncer: self._log_and_notify(tr(TR.SYNC_MEDIA_STARTING)) self._syncing = True - self._want_stop = False + self._progress_timer = self.mw.progress.timer(1000, self._on_progress, True) gui_hooks.media_sync_did_start_or_stop(True) def run() -> None: @@ -95,6 +90,9 @@ class MediaSyncer: def _on_finished(self, future: Future) -> None: self._syncing = False + if self._progress_timer: + self._progress_timer.stop() + self._progress_timer = None gui_hooks.media_sync_did_start_or_stop(False) exc = future.exception() @@ -122,7 +120,8 @@ class MediaSyncer: if not self.is_syncing(): return self._log_and_notify(tr(TR.SYNC_MEDIA_ABORTING)) - self._want_stop = True + # fixme: latter should do the former for us in the future + self.mw.col.backend.set_wants_abort() self.mw.col.backend.abort_sync() def is_syncing(self) -> bool: diff --git a/rslib/src/backend/mod.rs b/rslib/src/backend/mod.rs index 02f606f82..097c413dd 100644 --- a/rslib/src/backend/mod.rs +++ b/rslib/src/backend/mod.rs @@ -56,20 +56,45 @@ use tokio::runtime::Runtime; mod dbproxy; -pub type ProtoProgressCallback = Box) -> bool + Send>; +struct ThrottlingProgressHandler { + state: Arc>, + last_update: coarsetime::Instant, +} + +impl ThrottlingProgressHandler { + /// Returns true if should continue. + fn update(&mut self, progress: impl Into) -> bool { + let now = coarsetime::Instant::now(); + if now.duration_since(self.last_update).as_f64() < 0.1 { + return false; + } + self.last_update = now; + let mut guard = self.state.lock().unwrap(); + guard.last_progress.replace(progress.into()); + let want_abort = guard.want_abort; + guard.want_abort = false; + !want_abort + } +} + +struct ProgressState { + want_abort: bool, + last_progress: Option, +} pub struct Backend { col: Arc>>, - progress_callback: Option, i18n: I18n, server: bool, sync_abort: Option, + progress_state: Arc>, } -enum Progress<'a> { - MediaSync(&'a MediaSyncProgress), +#[derive(Clone, Copy)] +enum Progress { + MediaSync(MediaSyncProgress), MediaCheck(u32), - FullSync(&'a FullSyncProgress), + FullSync(FullSyncProgress), } /// Convert an Anki error to a protobuf error. @@ -206,6 +231,16 @@ impl From for DeckConfID { } impl BackendService for Backend { + fn latest_progress(&mut self, _input: Empty) -> BackendResult { + let progress = self.progress_state.lock().unwrap().last_progress; + Ok(progress_to_proto(progress, &self.i18n)) + } + + fn set_wants_abort(&mut self, _input: Empty) -> BackendResult { + self.progress_state.lock().unwrap().want_abort = true; + Ok(().into()) + } + // card rendering fn render_existing_card( @@ -801,13 +836,13 @@ impl BackendService for Backend { } fn empty_trash(&mut self, _input: Empty) -> BackendResult { - let callback = - |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); + let mut handler = self.new_progress_handler(); + let progress_fn = move |progress| handler.update(Progress::MediaCheck(progress as u32)); self.with_col(|col| { let mgr = MediaManager::new(&col.media_folder, &col.media_db)?; col.transact(None, |ctx| { - let mut checker = MediaChecker::new(ctx, &mgr, callback); + let mut checker = MediaChecker::new(ctx, &mgr, progress_fn); checker.empty_trash() }) @@ -816,14 +851,13 @@ impl BackendService for Backend { } fn restore_trash(&mut self, _input: Empty) -> BackendResult { - let callback = - |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); - + let mut handler = self.new_progress_handler(); + let progress_fn = move |progress| handler.update(Progress::MediaCheck(progress as u32)); self.with_col(|col| { let mgr = MediaManager::new(&col.media_folder, &col.media_db)?; col.transact(None, |ctx| { - let mut checker = MediaChecker::new(ctx, &mgr, callback); + let mut checker = MediaChecker::new(ctx, &mgr, progress_fn); checker.restore_trash() }) @@ -841,13 +875,12 @@ impl BackendService for Backend { } fn check_media(&mut self, _input: pb::Empty) -> Result { - let callback = - |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); - + let mut handler = self.new_progress_handler(); + let progress_fn = move |progress| handler.update(Progress::MediaCheck(progress as u32)); self.with_col(|col| { let mgr = MediaManager::new(&col.media_folder, &col.media_db)?; col.transact(None, |ctx| { - let mut checker = MediaChecker::new(ctx, &mgr, callback); + let mut checker = MediaChecker::new(ctx, &mgr, progress_fn); let mut output = checker.check()?; let report = checker.summarize_output(&mut output); @@ -1100,10 +1133,13 @@ impl Backend { pub fn new(i18n: I18n, server: bool) -> Backend { Backend { col: Arc::new(Mutex::new(None)), - progress_callback: None, i18n, server, sync_abort: None, + progress_state: Arc::new(Mutex::new(ProgressState { + want_abort: false, + last_progress: None, + })), } } @@ -1140,19 +1176,14 @@ impl Backend { ) } - fn fire_progress_callback(&self, progress: Progress) -> bool { - if let Some(cb) = &self.progress_callback { - let bytes = progress_to_proto_bytes(progress, &self.i18n); - cb(bytes) - } else { - true + fn new_progress_handler(&self) -> ThrottlingProgressHandler { + self.progress_state.lock().unwrap().want_abort = false; + ThrottlingProgressHandler { + state: self.progress_state.clone(), + last_update: coarsetime::Instant::now(), } } - pub fn set_progress_callback(&mut self, progress_cb: Option) { - self.progress_callback = progress_cb; - } - fn sync_media_inner( &mut self, input: pb::SyncMediaIn, @@ -1163,13 +1194,12 @@ impl Backend { let (abort_handle, abort_reg) = AbortHandle::new_pair(); self.sync_abort = Some(abort_handle); - let callback = |progress: &MediaSyncProgress| { - self.fire_progress_callback(Progress::MediaSync(progress)) - }; + let mut handler = self.new_progress_handler(); + let progress_fn = move |progress| handler.update(progress); let mgr = MediaManager::new(&folder, &db)?; let mut rt = Runtime::new().unwrap(); - let sync_fut = mgr.sync_media(callback, &input.endpoint, &input.hkey, log); + let sync_fut = mgr.sync_media(progress_fn, &input.endpoint, &input.hkey, log); let abortable_sync = Abortable::new(sync_fut, abort_reg); let ret = match rt.block_on(abortable_sync) { Ok(sync_result) => sync_result, @@ -1249,18 +1279,17 @@ impl Backend { let media_db_path = col_inner.media_db.clone(); let logger = col_inner.log.clone(); - // FIXME: throttle - let progress_fn = |progress: &FullSyncProgress| { - self.fire_progress_callback(Progress::FullSync(progress)); + let mut handler = self.new_progress_handler(); + let progress_fn = move |progress: FullSyncProgress| { + handler.update(progress); }; let mut rt = Runtime::new().unwrap(); let result = if upload { - todo!() - // let sync_fut = col_inner.full_upload(input.into(), progress_fn); - // let abortable_sync = Abortable::new(sync_fut, abort_reg); - // rt.block_on(abortable_sync) + let sync_fut = col_inner.full_upload(input.into(), progress_fn); + let abortable_sync = Abortable::new(sync_fut, abort_reg); + rt.block_on(abortable_sync) } else { let sync_fut = col_inner.full_download(input.into(), progress_fn); let abortable_sync = Abortable::new(sync_fut, abort_reg); @@ -1337,9 +1366,9 @@ impl From for pb::RenderCardOut { } } -fn progress_to_proto_bytes(progress: Progress, i18n: &I18n) -> Vec { - let proto = pb::Progress { - value: Some(match progress { +fn progress_to_proto(progress: Option, i18n: &I18n) -> pb::Progress { + let progress = if let Some(progress) = progress { + match progress { Progress::MediaSync(p) => pb::progress::Value::MediaSync(media_sync_progress(p, i18n)), Progress::MediaCheck(n) => { let s = i18n.trn(TR::MediaCheckChecked, tr_args!["count"=>n]); @@ -1349,15 +1378,16 @@ fn progress_to_proto_bytes(progress: Progress, i18n: &I18n) -> Vec { transferred: p.transferred_bytes as u32, total: p.total_bytes as u32, }), - }), + } + } else { + pb::progress::Value::None(pb::Empty {}) }; - - let mut buf = vec![]; - proto.encode(&mut buf).expect("encode failed"); - buf + pb::Progress { + value: Some(progress), + } } -fn media_sync_progress(p: &MediaSyncProgress, i18n: &I18n) -> pb::MediaSyncProgress { +fn media_sync_progress(p: MediaSyncProgress, i18n: &I18n) -> pb::MediaSyncProgress { pb::MediaSyncProgress { checked: i18n.trn(TR::SyncMediaCheckedCount, tr_args!["count"=>p.checked]), added: i18n.trn( @@ -1480,3 +1510,15 @@ impl From for SyncAuth { } } } + +impl From for Progress { + fn from(p: FullSyncProgress) -> Self { + Progress::FullSync(p) + } +} + +impl From for Progress { + fn from(p: MediaSyncProgress) -> Self { + Progress::MediaSync(p) + } +} diff --git a/rslib/src/media/check.rs b/rslib/src/media/check.rs index 822b0a4f1..1c2a487e1 100644 --- a/rslib/src/media/check.rs +++ b/rslib/src/media/check.rs @@ -14,7 +14,6 @@ use crate::media::files::{ use crate::notes::Note; use crate::text::{normalize_to_nfc, MediaRef}; use crate::{media::MediaManager, text::extract_media_refs}; -use coarsetime::Instant; use lazy_static::lazy_static; use regex::Regex; use std::collections::{HashMap, HashSet}; @@ -52,7 +51,6 @@ where mgr: &'b MediaManager, progress_cb: P, checked: usize, - progress_updated: Instant, } impl

MediaChecker<'_, '_, P> @@ -69,7 +67,6 @@ where mgr, progress_cb, checked: 0, - progress_updated: Instant::now(), } } @@ -209,7 +206,7 @@ where self.checked += 1; if self.checked % 10 == 0 { - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; } // if the filename is not valid unicode, skip it @@ -284,15 +281,6 @@ where } } - fn maybe_fire_progress_cb(&mut self) -> Result<()> { - let now = Instant::now(); - if now.duration_since(self.progress_updated).as_f64() < 0.15 { - return Ok(()); - } - self.progress_updated = now; - self.fire_progress_cb() - } - /// Returns the count and total size of the files in the trash folder fn files_in_trash(&mut self) -> Result<(u64, u64)> { let trash = trash_folder(&self.mgr.media_folder)?; @@ -304,7 +292,7 @@ where self.checked += 1; if self.checked % 10 == 0 { - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; } if dentry.file_name() == ".DS_Store" { @@ -328,7 +316,7 @@ where self.checked += 1; if self.checked % 10 == 0 { - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; } fs::remove_file(dentry.path())?; @@ -345,7 +333,7 @@ where self.checked += 1; if self.checked % 10 == 0 { - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; } let orig_path = self.mgr.media_folder.join(dentry.file_name()); @@ -388,7 +376,7 @@ where for nid in nids { self.checked += 1; if self.checked % 10 == 0 { - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; } let mut note = self.ctx.storage.get_note(nid)?.unwrap(); let nt = note_types diff --git a/rslib/src/media/mod.rs b/rslib/src/media/mod.rs index a3351b95c..a6274ecaf 100644 --- a/rslib/src/media/mod.rs +++ b/rslib/src/media/mod.rs @@ -135,7 +135,7 @@ impl MediaManager { log: Logger, ) -> Result<()> where - F: Fn(&MediaSyncProgress) -> bool, + F: FnMut(MediaSyncProgress) -> bool, { let mut syncer = MediaSyncer::new(self, progress, endpoint, log); syncer.sync(hkey).await diff --git a/rslib/src/media/sync.rs b/rslib/src/media/sync.rs index 798f86edc..63a8b45ff 100644 --- a/rslib/src/media/sync.rs +++ b/rslib/src/media/sync.rs @@ -10,7 +10,6 @@ use crate::media::files::{ use crate::media::MediaManager; use crate::version; use bytes::Bytes; -use coarsetime::Instant; use reqwest::{multipart, Client, Response}; use serde_derive::{Deserialize, Serialize}; use serde_tuple::Serialize_tuple; @@ -27,7 +26,7 @@ static SYNC_MAX_FILES: usize = 25; static SYNC_MAX_BYTES: usize = (2.5 * 1024.0 * 1024.0) as usize; static SYNC_SINGLE_FILE_MAX_BYTES: usize = 100 * 1024 * 1024; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone, Copy)] pub struct MediaSyncProgress { pub checked: usize, pub downloaded_files: usize, @@ -38,7 +37,7 @@ pub struct MediaSyncProgress { pub struct MediaSyncer<'a, P> where - P: Fn(&MediaSyncProgress) -> bool, + P: FnMut(MediaSyncProgress) -> bool, { mgr: &'a MediaManager, ctx: MediaDatabaseContext<'a>, @@ -46,7 +45,6 @@ where client: Client, progress_cb: P, progress: MediaSyncProgress, - progress_updated: Instant, endpoint: &'a str, log: Logger, } @@ -136,7 +134,7 @@ struct FinalizeResponse { impl

MediaSyncer<'_, P> where - P: Fn(&MediaSyncProgress) -> bool, + P: FnMut(MediaSyncProgress) -> bool, { pub fn new<'a>( mgr: &'a MediaManager, @@ -158,7 +156,6 @@ where client, progress_cb, progress: Default::default(), - progress_updated: Instant::now(), endpoint, log, } @@ -221,18 +218,11 @@ where fn register_changes(&mut self) -> Result<()> { // make borrow checker happy let progress = &mut self.progress; - let updated = &mut self.progress_updated; - let progress_cb = &self.progress_cb; + let progress_cb = &mut self.progress_cb; let progress = |checked| { progress.checked = checked; - let now = Instant::now(); - if now.duration_since(*updated).as_secs() < 1 { - true - } else { - *updated = now; - (progress_cb)(progress) - } + (progress_cb)(*progress) }; ChangeTracker::new(self.mgr.media_folder.as_path(), progress, &self.log) @@ -276,7 +266,7 @@ where last_usn = batch.last().unwrap().usn; self.progress.checked += batch.len(); - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; let (to_download, to_delete, to_remove_pending) = determine_required_changes(&mut self.ctx, &batch, &self.log)?; @@ -284,7 +274,7 @@ where // file removal self.mgr.remove_files(&mut self.ctx, to_delete.as_slice())?; self.progress.downloaded_deletions += to_delete.len(); - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; // file download let mut downloaded = vec![]; @@ -307,7 +297,7 @@ where downloaded.extend(download_batch); self.progress.downloaded_files += len; - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; } // then update the DB @@ -339,7 +329,7 @@ where let zip_data = zip_files(&mut self.ctx, &self.mgr.media_folder, &pending, &self.log)?; if zip_data.is_none() { self.progress.checked += pending.len(); - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; // discard zip info and retry batch - not particularly efficient, // but this is a corner case continue; @@ -354,7 +344,7 @@ where self.progress.uploaded_files += processed_files.len(); self.progress.uploaded_deletions += processed_deletions.len(); - self.maybe_fire_progress_cb()?; + self.fire_progress_cb()?; let fnames: Vec<_> = processed_files .iter() @@ -407,23 +397,14 @@ where } } - fn fire_progress_cb(&self) -> Result<()> { - if (self.progress_cb)(&self.progress) { + fn fire_progress_cb(&mut self) -> Result<()> { + if (self.progress_cb)(self.progress) { Ok(()) } else { Err(AnkiError::Interrupted) } } - fn maybe_fire_progress_cb(&mut self) -> Result<()> { - let now = Instant::now(); - if now.duration_since(self.progress_updated).as_f64() < 0.15 { - return Ok(()); - } - self.progress_updated = now; - self.fire_progress_cb() - } - async fn fetch_record_batch(&self, last_usn: i32) -> Result> { let url = format!("{}mediaChanges", self.endpoint); @@ -828,7 +809,7 @@ mod test { std::fs::write(media_dir.join("test.file").as_path(), "hello")?; - let progress = |progress: &MediaSyncProgress| { + let progress = |progress: MediaSyncProgress| { println!("got progress: {:?}", progress); true }; diff --git a/rslib/src/sync/http_client.rs b/rslib/src/sync/http_client.rs index 55ca3d672..b91933c97 100644 --- a/rslib/src/sync/http_client.rs +++ b/rslib/src/sync/http_client.rs @@ -76,7 +76,7 @@ struct SanityCheckIn { struct Empty {} impl HTTPSyncClient { - pub fn new<'a>(hkey: Option, host_number: u32) -> HTTPSyncClient { + pub fn new(hkey: Option, host_number: u32) -> HTTPSyncClient { let client = Client::builder() .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(60)) @@ -170,9 +170,9 @@ impl HTTPSyncClient { local_is_newer: bool, ) -> Result { let input = StartIn { - local_usn: local_usn, + local_usn, minutes_west, - local_is_newer: local_is_newer, + local_is_newer, local_graves: None, }; self.json_request_deserialized("start", &input).await @@ -232,9 +232,13 @@ impl HTTPSyncClient { /// Download collection into a temporary file, returning it. /// Caller should persist the file in the correct path after checking it. - pub(crate) async fn download

(&self, folder: &Path, progress_fn: P) -> Result + pub(crate) async fn download

( + &self, + folder: &Path, + mut progress_fn: P, + ) -> Result where - P: Fn(&FullSyncProgress), + P: FnMut(FullSyncProgress), { let mut temp_file = NamedTempFile::new_in(folder)?; let (size, mut stream) = self.download_inner().await?; @@ -246,7 +250,7 @@ impl HTTPSyncClient { let chunk = chunk?; temp_file.write_all(&chunk)?; progress.transferred_bytes += chunk.len(); - progress_fn(&progress); + progress_fn(progress); } Ok(temp_file) @@ -261,7 +265,7 @@ impl HTTPSyncClient { pub(crate) async fn upload

(&mut self, col_path: &Path, progress_fn: P) -> Result<()> where - P: Fn(&FullSyncProgress) + Send + Sync + 'static, + P: FnMut(FullSyncProgress) + Send + Sync + 'static, { let file = tokio::fs::File::open(col_path).await?; let total_bytes = file.metadata().await?.len() as usize; @@ -300,7 +304,7 @@ struct ProgressWrapper { impl Stream for ProgressWrapper where S: AsyncRead, - P: Fn(&FullSyncProgress), + P: FnMut(FullSyncProgress), { type Item = std::result::Result; @@ -312,7 +316,7 @@ where Ok(size) => { buf.resize(size, 0); this.progress.transferred_bytes += size; - (this.progress_fn)(&this.progress); + (this.progress_fn)(*this.progress); Poll::Ready(Some(Ok(Bytes::from(buf)))) } Err(e) => Poll::Ready(Some(Err(e))), diff --git a/rslib/src/sync/mod.rs b/rslib/src/sync/mod.rs index f2006e957..3f43b73a5 100644 --- a/rslib/src/sync/mod.rs +++ b/rslib/src/sync/mod.rs @@ -179,7 +179,7 @@ pub struct SanityCheckDueCounts { pub review: u32, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone, Copy)] pub struct FullSyncProgress { pub transferred_bytes: usize, pub total_bytes: usize, @@ -219,7 +219,7 @@ struct NormalSyncer<'a> { impl NormalSyncer<'_> { /// Create a new syncing instance. If host_number is unavailable, use 0. - pub fn new<'a>(col: &'a mut Collection, auth: SyncAuth) -> NormalSyncer<'a> { + pub fn new(col: &mut Collection, auth: SyncAuth) -> NormalSyncer<'_> { NormalSyncer { col, remote: HTTPSyncClient::new(Some(auth.hkey), auth.host_number), @@ -423,7 +423,7 @@ impl Collection { /// Upload collection to AnkiWeb. Caller must re-open afterwards. pub async fn full_upload(mut self, auth: SyncAuth, progress_fn: F) -> Result<()> where - F: Fn(&FullSyncProgress) + Send + Sync + 'static, + F: FnMut(FullSyncProgress) + Send + Sync + 'static, { self.before_upload()?; let col_path = self.col_path.clone(); @@ -436,7 +436,7 @@ impl Collection { /// Download collection from AnkiWeb. Caller must re-open afterwards. pub async fn full_download(self, auth: SyncAuth, progress_fn: F) -> Result<()> where - F: Fn(&FullSyncProgress), + F: FnMut(FullSyncProgress), { let col_path = self.col_path.clone(); let folder = col_path.parent().unwrap(); @@ -690,7 +690,7 @@ impl Collection { let mut note: Note = entry.into(); let nt = self .get_notetype(note.ntid)? - .ok_or(AnkiError::invalid_input("note missing notetype"))?; + .ok_or_else(|| AnkiError::invalid_input("note missing notetype"))?; note.prepare_for_update(&nt, false)?; self.storage.add_or_update_note(¬e)?; } diff --git a/rspy/src/lib.rs b/rspy/src/lib.rs index c423b27b0..f6cb6e9a0 100644 --- a/rspy/src/lib.rs +++ b/rspy/src/lib.rs @@ -107,6 +107,11 @@ fn want_release_gil(method: u32) -> bool { BackendMethod::NoteIsDuplicateOrEmpty => true, BackendMethod::SyncLogin => true, BackendMethod::SyncCollection => true, + BackendMethod::LatestProgress => false, + BackendMethod::SetWantsAbort => false, + BackendMethod::SyncStatus => true, + BackendMethod::FullUpload => true, + BackendMethod::FullDownload => true, } } else { false @@ -129,35 +134,6 @@ impl Backend { .map_err(|err_bytes| BackendError::py_err(err_bytes)) } - fn set_progress_callback(&mut self, callback: PyObject) { - if callback.is_none() { - self.backend.set_progress_callback(None); - } else { - let func = move |bytes: Vec| { - let gil = Python::acquire_gil(); - let py = gil.python(); - let out_bytes = PyBytes::new(py, &bytes); - let out_obj: PyObject = out_bytes.into(); - let res: PyObject = match callback.call1(py, (out_obj,)) { - Ok(res) => res, - Err(e) => { - println!("error calling callback:"); - e.print(py); - return false; - } - }; - match res.extract(py) { - Ok(cont) => cont, - Err(e) => { - println!("callback did not return bool: {:?}", e); - false - } - } - }; - self.backend.set_progress_callback(Some(Box::new(func))); - } - } - fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult { let in_bytes = input.as_bytes(); let out_res = py.allow_threads(move || {