rework progress handling

- client now polls status instead of backend pushing it
- supports multiple threads
- update throttling happens in one place
This commit is contained in:
Damien Elmes 2020-05-29 19:59:50 +10:00
parent b254b1f722
commit ee6d7f82e7
14 changed files with 211 additions and 231 deletions

View file

@ -68,6 +68,9 @@ message DeckConfigID {
///////////////////////////////////////////////////////////
service BackendService {
rpc LatestProgress (Empty) returns (Progress);
rpc SetWantsAbort (Empty) returns (Empty);
// card rendering
rpc ExtractAVTags (ExtractAVTagsIn) returns (ExtractAVTagsOut);
@ -448,14 +451,6 @@ message BackendError {
}
}
message Progress {
oneof value {
MediaSyncProgress media_sync = 1;
string media_check = 2;
FullSyncProgress full_sync = 3;
}
}
message NetworkError {
enum NetworkErrorKind {
OTHER = 0;
@ -482,6 +477,18 @@ message SyncError {
SyncErrorKind kind = 1;
}
// Progress
///////////////////////////////////////////////////////////
message Progress {
oneof value {
Empty none = 1;
MediaSyncProgress media_sync = 2;
string media_check = 3;
FullSyncProgress full_sync = 4;
}
}
message MediaSyncProgress {
string checked = 1;
string added = 2;
@ -498,6 +505,7 @@ message MediaSyncUploadProgress {
uint32 deletions = 2;
}
// Messages
///////////////////////////////////////////////////////////

View file

@ -26,7 +26,7 @@ from anki.lang import _
from anki.media import MediaManager, media_paths_from_col_path
from anki.models import ModelManager
from anki.notes import Note
from anki.rsbackend import TR, DBError, FormatTimeSpanContext, RustBackend, pb
from anki.rsbackend import TR, DBError, FormatTimeSpanContext, Progress, RustBackend, pb
from anki.sched import Scheduler as V1Scheduler
from anki.schedv2 import Scheduler as V2Scheduler
from anki.tags import TagManager
@ -91,6 +91,12 @@ class Collection:
) -> str:
return self.backend.format_timespan(seconds=seconds, context=context)
# Progress
##########################################################################
def latest_progress(self) -> Progress:
return Progress.from_proto(self.backend.latest_progress())
# Scheduler
##########################################################################

View file

@ -28,33 +28,6 @@ from anki.notes import Note
# @@AUTOGEN@@
class _BgThreadProgressCallbackFilter:
"""Warning: this is called on a background thread."""
_hooks: List[Callable[[bool, "anki.rsbackend.Progress"], bool]] = []
def append(self, cb: Callable[[bool, "anki.rsbackend.Progress"], bool]) -> None:
"""(proceed: bool, progress: anki.rsbackend.Progress)"""
self._hooks.append(cb)
def remove(self, cb: Callable[[bool, "anki.rsbackend.Progress"], bool]) -> None:
if cb in self._hooks:
self._hooks.remove(cb)
def __call__(self, proceed: bool, progress: anki.rsbackend.Progress) -> bool:
for filter in self._hooks:
try:
proceed = filter(proceed, progress)
except:
# if the hook fails, remove it
self._hooks.remove(filter)
raise
return proceed
bg_thread_progress_callback = _BgThreadProgressCallbackFilter()
class _CardDidLeechHook:
_hooks: List[Callable[[Card], None]] = []

View file

@ -152,31 +152,28 @@ FormatTimeSpanContext = pb.FormatTimespanIn.Context
class ProgressKind(enum.Enum):
MediaSync = 0
MediaCheck = 1
NoProgress = 0
MediaSync = 1
MediaCheck = 2
FullSync = 3
@dataclass
class Progress:
kind: ProgressKind
val: Union[MediaSyncProgress, str]
val: Union[MediaSyncProgress, pb.FullSyncProgress, str]
def proto_progress_to_native(progress: pb.Progress) -> Progress:
kind = progress.WhichOneof("value")
@staticmethod
def from_proto(proto: pb.Progress) -> Progress:
kind = proto.WhichOneof("value")
if kind == "media_sync":
return Progress(kind=ProgressKind.MediaSync, val=progress.media_sync)
return Progress(kind=ProgressKind.MediaSync, val=proto.media_sync)
elif kind == "media_check":
return Progress(kind=ProgressKind.MediaCheck, val=progress.media_check)
return Progress(kind=ProgressKind.MediaCheck, val=proto.media_check)
elif kind == "full_sync":
return Progress(kind=ProgressKind.FullSync, val=proto.full_sync)
else:
assert_impossible_literal(kind)
def _on_progress(progress_bytes: bytes) -> bool:
progress = pb.Progress()
progress.ParseFromString(progress_bytes)
native_progress = proto_progress_to_native(progress)
return hooks.bg_thread_progress_callback(True, native_progress)
return Progress(kind=ProgressKind.NoProgress, val="")
class RustBackend(RustBackendGenerated):
@ -196,7 +193,6 @@ class RustBackend(RustBackendGenerated):
locale_folder_path=ftl_folder, preferred_langs=langs, server=server,
)
self._backend = ankirspy.open_backend(init_msg.SerializeToString())
self._backend.set_progress_callback(_on_progress)
def db_query(
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool

View file

@ -35,12 +35,6 @@ hooks = [
),
Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"),
Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"),
Hook(
name="bg_thread_progress_callback",
args=["proceed: bool", "progress: anki.rsbackend.Progress"],
return_type="bool",
doc="Warning: this is called on a background thread.",
),
Hook(
name="field_filter",
args=[

View file

@ -9,8 +9,7 @@ from concurrent.futures import Future
from typing import Iterable, List, Optional, Sequence, TypeVar
import aqt
from anki import hooks
from anki.rsbackend import TR, Interrupted, Progress, ProgressKind, pb
from anki.rsbackend import TR, Interrupted, ProgressKind, pb
from aqt.qt import *
from aqt.utils import askUser, restoreGeom, saveGeom, showText, tooltip, tr
@ -36,28 +35,42 @@ class MediaChecker:
def __init__(self, mw: aqt.AnkiQt) -> None:
self.mw = mw
self._progress_timer: Optional[QTimer] = None
def check(self) -> None:
self.progress_dialog = self.mw.progress.start()
hooks.bg_thread_progress_callback.append(self._on_progress)
self._set_progress_enabled(True)
self.mw.taskman.run_in_background(self._check, self._on_finished)
def _on_progress(self, proceed: bool, progress: Progress) -> bool:
if progress.kind != ProgressKind.MediaCheck:
return proceed
def _set_progress_enabled(self, enabled: bool) -> None:
if self._progress_timer:
self._progress_timer.stop()
self._progress_timer = None
if enabled:
self._progress_timer = self.mw.progress.timer(100, self._on_progress, True)
def _on_progress(self) -> None:
progress = self.mw.col.latest_progress()
if progress.kind != ProgressKind.MediaCheck:
return
assert isinstance(progress.val, str)
try:
if self.progress_dialog.wantCancel:
return False
self.mw.col.backend.set_wants_abort()
except AttributeError:
# dialog may not be active
pass
self.mw.taskman.run_on_main(lambda: self.mw.progress.update(progress.val))
return True
def _check(self) -> pb.CheckMediaOut:
"Run the check on a background thread."
return self.mw.col.media.check()
def _on_finished(self, future: Future) -> None:
hooks.bg_thread_progress_callback.remove(self._on_progress)
self._set_progress_enabled(False)
self.mw.progress.finish()
self.progress_dialog = None
@ -162,14 +175,14 @@ class MediaChecker:
def _on_empty_trash(self):
self.progress_dialog = self.mw.progress.start()
hooks.bg_thread_progress_callback.append(self._on_progress)
self._set_progress_enabled(True)
def empty_trash():
self.mw.col.backend.empty_trash()
def on_done(fut: Future):
self.mw.progress.finish()
hooks.bg_thread_progress_callback.remove(self._on_progress)
self._set_progress_enabled(False)
# check for errors
fut.result()
@ -179,14 +192,14 @@ class MediaChecker:
def _on_restore_trash(self):
self.progress_dialog = self.mw.progress.start()
hooks.bg_thread_progress_callback.append(self._on_progress)
self._set_progress_enabled(True)
def restore_trash():
self.mw.col.backend.restore_trash()
def on_done(fut: Future):
self.mw.progress.finish()
hooks.bg_thread_progress_callback.remove(self._on_progress)
self._set_progress_enabled(False)
# check for errors
fut.result()

View file

@ -6,23 +6,21 @@ from __future__ import annotations
import time
from concurrent.futures import Future
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union
import aqt
from anki import hooks
from anki.consts import SYNC_BASE
from anki.rsbackend import (
TR,
Interrupted,
MediaSyncProgress,
NetworkError,
Progress,
ProgressKind,
)
from anki.types import assert_impossible
from anki.utils import intTime
from aqt import gui_hooks
from aqt.qt import QDialog, QDialogButtonBox, QPushButton, qconnect
from aqt.qt import QDialog, QDialogButtonBox, QPushButton, QTimer, qconnect
from aqt.utils import showWarning, tr
LogEntry = Union[MediaSyncProgress, str]
@ -39,22 +37,19 @@ class MediaSyncer:
self.mw = mw
self._syncing: bool = False
self._log: List[LogEntryWithTime] = []
self._want_stop = False
hooks.bg_thread_progress_callback.append(self._on_rust_progress)
self._progress_timer: Optional[QTimer] = None
gui_hooks.media_sync_did_start_or_stop.append(self._on_start_stop)
def _on_rust_progress(self, proceed: bool, progress: Progress) -> bool:
def _on_progress(self):
progress = self.mw.col.latest_progress()
if progress.kind != ProgressKind.MediaSync:
return proceed
return
print(progress.val)
assert isinstance(progress.val, MediaSyncProgress)
self._log_and_notify(progress.val)
if self._want_stop:
return False
else:
return proceed
def start(self) -> None:
"Start media syncing in the background, if it's not already running."
if self._syncing:
@ -70,7 +65,7 @@ class MediaSyncer:
self._log_and_notify(tr(TR.SYNC_MEDIA_STARTING))
self._syncing = True
self._want_stop = False
self._progress_timer = self.mw.progress.timer(1000, self._on_progress, True)
gui_hooks.media_sync_did_start_or_stop(True)
def run() -> None:
@ -95,6 +90,9 @@ class MediaSyncer:
def _on_finished(self, future: Future) -> None:
self._syncing = False
if self._progress_timer:
self._progress_timer.stop()
self._progress_timer = None
gui_hooks.media_sync_did_start_or_stop(False)
exc = future.exception()
@ -122,7 +120,8 @@ class MediaSyncer:
if not self.is_syncing():
return
self._log_and_notify(tr(TR.SYNC_MEDIA_ABORTING))
self._want_stop = True
# fixme: latter should do the former for us in the future
self.mw.col.backend.set_wants_abort()
self.mw.col.backend.abort_sync()
def is_syncing(self) -> bool:

View file

@ -56,20 +56,45 @@ use tokio::runtime::Runtime;
mod dbproxy;
pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>;
struct ThrottlingProgressHandler {
state: Arc<Mutex<ProgressState>>,
last_update: coarsetime::Instant,
}
impl ThrottlingProgressHandler {
/// Returns true if should continue.
fn update(&mut self, progress: impl Into<Progress>) -> bool {
let now = coarsetime::Instant::now();
if now.duration_since(self.last_update).as_f64() < 0.1 {
return false;
}
self.last_update = now;
let mut guard = self.state.lock().unwrap();
guard.last_progress.replace(progress.into());
let want_abort = guard.want_abort;
guard.want_abort = false;
!want_abort
}
}
struct ProgressState {
want_abort: bool,
last_progress: Option<Progress>,
}
pub struct Backend {
col: Arc<Mutex<Option<Collection>>>,
progress_callback: Option<ProtoProgressCallback>,
i18n: I18n,
server: bool,
sync_abort: Option<AbortHandle>,
progress_state: Arc<Mutex<ProgressState>>,
}
enum Progress<'a> {
MediaSync(&'a MediaSyncProgress),
#[derive(Clone, Copy)]
enum Progress {
MediaSync(MediaSyncProgress),
MediaCheck(u32),
FullSync(&'a FullSyncProgress),
FullSync(FullSyncProgress),
}
/// Convert an Anki error to a protobuf error.
@ -206,6 +231,16 @@ impl From<pb::DeckConfigId> for DeckConfID {
}
impl BackendService for Backend {
fn latest_progress(&mut self, _input: Empty) -> BackendResult<pb::Progress> {
let progress = self.progress_state.lock().unwrap().last_progress;
Ok(progress_to_proto(progress, &self.i18n))
}
fn set_wants_abort(&mut self, _input: Empty) -> BackendResult<Empty> {
self.progress_state.lock().unwrap().want_abort = true;
Ok(().into())
}
// card rendering
fn render_existing_card(
@ -801,13 +836,13 @@ impl BackendService for Backend {
}
fn empty_trash(&mut self, _input: Empty) -> BackendResult<Empty> {
let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mut handler = self.new_progress_handler();
let progress_fn = move |progress| handler.update(Progress::MediaCheck(progress as u32));
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
let mut checker = MediaChecker::new(ctx, &mgr, progress_fn);
checker.empty_trash()
})
@ -816,14 +851,13 @@ impl BackendService for Backend {
}
fn restore_trash(&mut self, _input: Empty) -> BackendResult<Empty> {
let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mut handler = self.new_progress_handler();
let progress_fn = move |progress| handler.update(Progress::MediaCheck(progress as u32));
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
let mut checker = MediaChecker::new(ctx, &mgr, progress_fn);
checker.restore_trash()
})
@ -841,13 +875,12 @@ impl BackendService for Backend {
}
fn check_media(&mut self, _input: pb::Empty) -> Result<pb::CheckMediaOut> {
let callback =
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mut handler = self.new_progress_handler();
let progress_fn = move |progress| handler.update(Progress::MediaCheck(progress as u32));
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
col.transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback);
let mut checker = MediaChecker::new(ctx, &mgr, progress_fn);
let mut output = checker.check()?;
let report = checker.summarize_output(&mut output);
@ -1100,10 +1133,13 @@ impl Backend {
pub fn new(i18n: I18n, server: bool) -> Backend {
Backend {
col: Arc::new(Mutex::new(None)),
progress_callback: None,
i18n,
server,
sync_abort: None,
progress_state: Arc::new(Mutex::new(ProgressState {
want_abort: false,
last_progress: None,
})),
}
}
@ -1140,19 +1176,14 @@ impl Backend {
)
}
fn fire_progress_callback(&self, progress: Progress) -> bool {
if let Some(cb) = &self.progress_callback {
let bytes = progress_to_proto_bytes(progress, &self.i18n);
cb(bytes)
} else {
true
fn new_progress_handler(&self) -> ThrottlingProgressHandler {
self.progress_state.lock().unwrap().want_abort = false;
ThrottlingProgressHandler {
state: self.progress_state.clone(),
last_update: coarsetime::Instant::now(),
}
}
pub fn set_progress_callback(&mut self, progress_cb: Option<ProtoProgressCallback>) {
self.progress_callback = progress_cb;
}
fn sync_media_inner(
&mut self,
input: pb::SyncMediaIn,
@ -1163,13 +1194,12 @@ impl Backend {
let (abort_handle, abort_reg) = AbortHandle::new_pair();
self.sync_abort = Some(abort_handle);
let callback = |progress: &MediaSyncProgress| {
self.fire_progress_callback(Progress::MediaSync(progress))
};
let mut handler = self.new_progress_handler();
let progress_fn = move |progress| handler.update(progress);
let mgr = MediaManager::new(&folder, &db)?;
let mut rt = Runtime::new().unwrap();
let sync_fut = mgr.sync_media(callback, &input.endpoint, &input.hkey, log);
let sync_fut = mgr.sync_media(progress_fn, &input.endpoint, &input.hkey, log);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
let ret = match rt.block_on(abortable_sync) {
Ok(sync_result) => sync_result,
@ -1249,18 +1279,17 @@ impl Backend {
let media_db_path = col_inner.media_db.clone();
let logger = col_inner.log.clone();
// FIXME: throttle
let progress_fn = |progress: &FullSyncProgress| {
self.fire_progress_callback(Progress::FullSync(progress));
let mut handler = self.new_progress_handler();
let progress_fn = move |progress: FullSyncProgress| {
handler.update(progress);
};
let mut rt = Runtime::new().unwrap();
let result = if upload {
todo!()
// let sync_fut = col_inner.full_upload(input.into(), progress_fn);
// let abortable_sync = Abortable::new(sync_fut, abort_reg);
// rt.block_on(abortable_sync)
let sync_fut = col_inner.full_upload(input.into(), progress_fn);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
rt.block_on(abortable_sync)
} else {
let sync_fut = col_inner.full_download(input.into(), progress_fn);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
@ -1337,9 +1366,9 @@ impl From<RenderCardOutput> for pb::RenderCardOut {
}
}
fn progress_to_proto_bytes(progress: Progress, i18n: &I18n) -> Vec<u8> {
let proto = pb::Progress {
value: Some(match progress {
fn progress_to_proto(progress: Option<Progress>, i18n: &I18n) -> pb::Progress {
let progress = if let Some(progress) = progress {
match progress {
Progress::MediaSync(p) => pb::progress::Value::MediaSync(media_sync_progress(p, i18n)),
Progress::MediaCheck(n) => {
let s = i18n.trn(TR::MediaCheckChecked, tr_args!["count"=>n]);
@ -1349,15 +1378,16 @@ fn progress_to_proto_bytes(progress: Progress, i18n: &I18n) -> Vec<u8> {
transferred: p.transferred_bytes as u32,
total: p.total_bytes as u32,
}),
}),
}
} else {
pb::progress::Value::None(pb::Empty {})
};
let mut buf = vec![];
proto.encode(&mut buf).expect("encode failed");
buf
pb::Progress {
value: Some(progress),
}
}
fn media_sync_progress(p: &MediaSyncProgress, i18n: &I18n) -> pb::MediaSyncProgress {
fn media_sync_progress(p: MediaSyncProgress, i18n: &I18n) -> pb::MediaSyncProgress {
pb::MediaSyncProgress {
checked: i18n.trn(TR::SyncMediaCheckedCount, tr_args!["count"=>p.checked]),
added: i18n.trn(
@ -1480,3 +1510,15 @@ impl From<pb::SyncAuth> for SyncAuth {
}
}
}
impl From<FullSyncProgress> for Progress {
fn from(p: FullSyncProgress) -> Self {
Progress::FullSync(p)
}
}
impl From<MediaSyncProgress> for Progress {
fn from(p: MediaSyncProgress) -> Self {
Progress::MediaSync(p)
}
}

View file

@ -14,7 +14,6 @@ use crate::media::files::{
use crate::notes::Note;
use crate::text::{normalize_to_nfc, MediaRef};
use crate::{media::MediaManager, text::extract_media_refs};
use coarsetime::Instant;
use lazy_static::lazy_static;
use regex::Regex;
use std::collections::{HashMap, HashSet};
@ -52,7 +51,6 @@ where
mgr: &'b MediaManager,
progress_cb: P,
checked: usize,
progress_updated: Instant,
}
impl<P> MediaChecker<'_, '_, P>
@ -69,7 +67,6 @@ where
mgr,
progress_cb,
checked: 0,
progress_updated: Instant::now(),
}
}
@ -209,7 +206,7 @@ where
self.checked += 1;
if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
}
// if the filename is not valid unicode, skip it
@ -284,15 +281,6 @@ where
}
}
fn maybe_fire_progress_cb(&mut self) -> Result<()> {
let now = Instant::now();
if now.duration_since(self.progress_updated).as_f64() < 0.15 {
return Ok(());
}
self.progress_updated = now;
self.fire_progress_cb()
}
/// Returns the count and total size of the files in the trash folder
fn files_in_trash(&mut self) -> Result<(u64, u64)> {
let trash = trash_folder(&self.mgr.media_folder)?;
@ -304,7 +292,7 @@ where
self.checked += 1;
if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
}
if dentry.file_name() == ".DS_Store" {
@ -328,7 +316,7 @@ where
self.checked += 1;
if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
}
fs::remove_file(dentry.path())?;
@ -345,7 +333,7 @@ where
self.checked += 1;
if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
}
let orig_path = self.mgr.media_folder.join(dentry.file_name());
@ -388,7 +376,7 @@ where
for nid in nids {
self.checked += 1;
if self.checked % 10 == 0 {
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
}
let mut note = self.ctx.storage.get_note(nid)?.unwrap();
let nt = note_types

View file

@ -135,7 +135,7 @@ impl MediaManager {
log: Logger,
) -> Result<()>
where
F: Fn(&MediaSyncProgress) -> bool,
F: FnMut(MediaSyncProgress) -> bool,
{
let mut syncer = MediaSyncer::new(self, progress, endpoint, log);
syncer.sync(hkey).await

View file

@ -10,7 +10,6 @@ use crate::media::files::{
use crate::media::MediaManager;
use crate::version;
use bytes::Bytes;
use coarsetime::Instant;
use reqwest::{multipart, Client, Response};
use serde_derive::{Deserialize, Serialize};
use serde_tuple::Serialize_tuple;
@ -27,7 +26,7 @@ static SYNC_MAX_FILES: usize = 25;
static SYNC_MAX_BYTES: usize = (2.5 * 1024.0 * 1024.0) as usize;
static SYNC_SINGLE_FILE_MAX_BYTES: usize = 100 * 1024 * 1024;
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone, Copy)]
pub struct MediaSyncProgress {
pub checked: usize,
pub downloaded_files: usize,
@ -38,7 +37,7 @@ pub struct MediaSyncProgress {
pub struct MediaSyncer<'a, P>
where
P: Fn(&MediaSyncProgress) -> bool,
P: FnMut(MediaSyncProgress) -> bool,
{
mgr: &'a MediaManager,
ctx: MediaDatabaseContext<'a>,
@ -46,7 +45,6 @@ where
client: Client,
progress_cb: P,
progress: MediaSyncProgress,
progress_updated: Instant,
endpoint: &'a str,
log: Logger,
}
@ -136,7 +134,7 @@ struct FinalizeResponse {
impl<P> MediaSyncer<'_, P>
where
P: Fn(&MediaSyncProgress) -> bool,
P: FnMut(MediaSyncProgress) -> bool,
{
pub fn new<'a>(
mgr: &'a MediaManager,
@ -158,7 +156,6 @@ where
client,
progress_cb,
progress: Default::default(),
progress_updated: Instant::now(),
endpoint,
log,
}
@ -221,18 +218,11 @@ where
fn register_changes(&mut self) -> Result<()> {
// make borrow checker happy
let progress = &mut self.progress;
let updated = &mut self.progress_updated;
let progress_cb = &self.progress_cb;
let progress_cb = &mut self.progress_cb;
let progress = |checked| {
progress.checked = checked;
let now = Instant::now();
if now.duration_since(*updated).as_secs() < 1 {
true
} else {
*updated = now;
(progress_cb)(progress)
}
(progress_cb)(*progress)
};
ChangeTracker::new(self.mgr.media_folder.as_path(), progress, &self.log)
@ -276,7 +266,7 @@ where
last_usn = batch.last().unwrap().usn;
self.progress.checked += batch.len();
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
let (to_download, to_delete, to_remove_pending) =
determine_required_changes(&mut self.ctx, &batch, &self.log)?;
@ -284,7 +274,7 @@ where
// file removal
self.mgr.remove_files(&mut self.ctx, to_delete.as_slice())?;
self.progress.downloaded_deletions += to_delete.len();
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
// file download
let mut downloaded = vec![];
@ -307,7 +297,7 @@ where
downloaded.extend(download_batch);
self.progress.downloaded_files += len;
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
}
// then update the DB
@ -339,7 +329,7 @@ where
let zip_data = zip_files(&mut self.ctx, &self.mgr.media_folder, &pending, &self.log)?;
if zip_data.is_none() {
self.progress.checked += pending.len();
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
// discard zip info and retry batch - not particularly efficient,
// but this is a corner case
continue;
@ -354,7 +344,7 @@ where
self.progress.uploaded_files += processed_files.len();
self.progress.uploaded_deletions += processed_deletions.len();
self.maybe_fire_progress_cb()?;
self.fire_progress_cb()?;
let fnames: Vec<_> = processed_files
.iter()
@ -407,23 +397,14 @@ where
}
}
fn fire_progress_cb(&self) -> Result<()> {
if (self.progress_cb)(&self.progress) {
fn fire_progress_cb(&mut self) -> Result<()> {
if (self.progress_cb)(self.progress) {
Ok(())
} else {
Err(AnkiError::Interrupted)
}
}
fn maybe_fire_progress_cb(&mut self) -> Result<()> {
let now = Instant::now();
if now.duration_since(self.progress_updated).as_f64() < 0.15 {
return Ok(());
}
self.progress_updated = now;
self.fire_progress_cb()
}
async fn fetch_record_batch(&self, last_usn: i32) -> Result<Vec<ServerMediaRecord>> {
let url = format!("{}mediaChanges", self.endpoint);
@ -828,7 +809,7 @@ mod test {
std::fs::write(media_dir.join("test.file").as_path(), "hello")?;
let progress = |progress: &MediaSyncProgress| {
let progress = |progress: MediaSyncProgress| {
println!("got progress: {:?}", progress);
true
};

View file

@ -76,7 +76,7 @@ struct SanityCheckIn {
struct Empty {}
impl HTTPSyncClient {
pub fn new<'a>(hkey: Option<String>, host_number: u32) -> HTTPSyncClient {
pub fn new(hkey: Option<String>, host_number: u32) -> HTTPSyncClient {
let client = Client::builder()
.connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(60))
@ -170,9 +170,9 @@ impl HTTPSyncClient {
local_is_newer: bool,
) -> Result<Graves> {
let input = StartIn {
local_usn: local_usn,
local_usn,
minutes_west,
local_is_newer: local_is_newer,
local_is_newer,
local_graves: None,
};
self.json_request_deserialized("start", &input).await
@ -232,9 +232,13 @@ impl HTTPSyncClient {
/// Download collection into a temporary file, returning it.
/// Caller should persist the file in the correct path after checking it.
pub(crate) async fn download<P>(&self, folder: &Path, progress_fn: P) -> Result<NamedTempFile>
pub(crate) async fn download<P>(
&self,
folder: &Path,
mut progress_fn: P,
) -> Result<NamedTempFile>
where
P: Fn(&FullSyncProgress),
P: FnMut(FullSyncProgress),
{
let mut temp_file = NamedTempFile::new_in(folder)?;
let (size, mut stream) = self.download_inner().await?;
@ -246,7 +250,7 @@ impl HTTPSyncClient {
let chunk = chunk?;
temp_file.write_all(&chunk)?;
progress.transferred_bytes += chunk.len();
progress_fn(&progress);
progress_fn(progress);
}
Ok(temp_file)
@ -261,7 +265,7 @@ impl HTTPSyncClient {
pub(crate) async fn upload<P>(&mut self, col_path: &Path, progress_fn: P) -> Result<()>
where
P: Fn(&FullSyncProgress) + Send + Sync + 'static,
P: FnMut(FullSyncProgress) + Send + Sync + 'static,
{
let file = tokio::fs::File::open(col_path).await?;
let total_bytes = file.metadata().await?.len() as usize;
@ -300,7 +304,7 @@ struct ProgressWrapper<S, P> {
impl<S, P> Stream for ProgressWrapper<S, P>
where
S: AsyncRead,
P: Fn(&FullSyncProgress),
P: FnMut(FullSyncProgress),
{
type Item = std::result::Result<Bytes, std::io::Error>;
@ -312,7 +316,7 @@ where
Ok(size) => {
buf.resize(size, 0);
this.progress.transferred_bytes += size;
(this.progress_fn)(&this.progress);
(this.progress_fn)(*this.progress);
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
Err(e) => Poll::Ready(Some(Err(e))),

View file

@ -179,7 +179,7 @@ pub struct SanityCheckDueCounts {
pub review: u32,
}
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone, Copy)]
pub struct FullSyncProgress {
pub transferred_bytes: usize,
pub total_bytes: usize,
@ -219,7 +219,7 @@ struct NormalSyncer<'a> {
impl NormalSyncer<'_> {
/// Create a new syncing instance. If host_number is unavailable, use 0.
pub fn new<'a>(col: &'a mut Collection, auth: SyncAuth) -> NormalSyncer<'a> {
pub fn new(col: &mut Collection, auth: SyncAuth) -> NormalSyncer<'_> {
NormalSyncer {
col,
remote: HTTPSyncClient::new(Some(auth.hkey), auth.host_number),
@ -423,7 +423,7 @@ impl Collection {
/// Upload collection to AnkiWeb. Caller must re-open afterwards.
pub async fn full_upload<F>(mut self, auth: SyncAuth, progress_fn: F) -> Result<()>
where
F: Fn(&FullSyncProgress) + Send + Sync + 'static,
F: FnMut(FullSyncProgress) + Send + Sync + 'static,
{
self.before_upload()?;
let col_path = self.col_path.clone();
@ -436,7 +436,7 @@ impl Collection {
/// Download collection from AnkiWeb. Caller must re-open afterwards.
pub async fn full_download<F>(self, auth: SyncAuth, progress_fn: F) -> Result<()>
where
F: Fn(&FullSyncProgress),
F: FnMut(FullSyncProgress),
{
let col_path = self.col_path.clone();
let folder = col_path.parent().unwrap();
@ -690,7 +690,7 @@ impl Collection {
let mut note: Note = entry.into();
let nt = self
.get_notetype(note.ntid)?
.ok_or(AnkiError::invalid_input("note missing notetype"))?;
.ok_or_else(|| AnkiError::invalid_input("note missing notetype"))?;
note.prepare_for_update(&nt, false)?;
self.storage.add_or_update_note(&note)?;
}

View file

@ -107,6 +107,11 @@ fn want_release_gil(method: u32) -> bool {
BackendMethod::NoteIsDuplicateOrEmpty => true,
BackendMethod::SyncLogin => true,
BackendMethod::SyncCollection => true,
BackendMethod::LatestProgress => false,
BackendMethod::SetWantsAbort => false,
BackendMethod::SyncStatus => true,
BackendMethod::FullUpload => true,
BackendMethod::FullDownload => true,
}
} else {
false
@ -129,35 +134,6 @@ impl Backend {
.map_err(|err_bytes| BackendError::py_err(err_bytes))
}
fn set_progress_callback(&mut self, callback: PyObject) {
if callback.is_none() {
self.backend.set_progress_callback(None);
} else {
let func = move |bytes: Vec<u8>| {
let gil = Python::acquire_gil();
let py = gil.python();
let out_bytes = PyBytes::new(py, &bytes);
let out_obj: PyObject = out_bytes.into();
let res: PyObject = match callback.call1(py, (out_obj,)) {
Ok(res) => res,
Err(e) => {
println!("error calling callback:");
e.print(py);
return false;
}
};
match res.extract(py) {
Ok(cont) => cont,
Err(e) => {
println!("callback did not return bool: {:?}", e);
false
}
}
};
self.backend.set_progress_callback(Some(Box::new(func)));
}
}
fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult<PyObject> {
let in_bytes = input.as_bytes();
let out_res = py.allow_threads(move || {