simplify code by accumulating stats at the source

This commit is contained in:
Damien Elmes 2020-02-06 18:16:39 +10:00
parent 4289f7a02a
commit 5fe1bfc5b4
6 changed files with 75 additions and 139 deletions

View file

@ -104,12 +104,11 @@ message SyncError {
}
message MediaSyncProgress {
oneof value {
uint32 downloaded_changes = 1;
uint32 downloaded_files = 2;
MediaSyncUploadProgress uploaded = 3;
uint32 removed_files = 4;
}
uint32 downloaded_meta = 1;
uint32 downloaded_files = 2;
uint32 downloaded_deletions = 3;
uint32 uploaded_files = 4;
uint32 uploaded_deletions = 5;
}
message MediaSyncUploadProgress {

View file

@ -121,33 +121,7 @@ class TemplateReplacement:
TemplateReplacementList = List[Union[str, TemplateReplacement]]
@dataclass
class MediaSyncDownloadedChanges:
changes: int
@dataclass
class MediaSyncDownloadedFiles:
files: int
@dataclass
class MediaSyncUploaded:
files: int
deletions: int
@dataclass
class MediaSyncRemovedFiles:
files: int
MediaSyncProgress = Union[
MediaSyncDownloadedChanges,
MediaSyncDownloadedFiles,
MediaSyncUploaded,
MediaSyncRemovedFiles,
]
MediaSyncProgress = pb.MediaSyncProgress
class ProgressKind(enum.Enum):
@ -181,31 +155,9 @@ def proto_replacement_list_to_native(
def proto_progress_to_native(progress: pb.Progress) -> Progress:
kind = progress.WhichOneof("value")
if kind == "media_sync":
ikind = progress.media_sync.WhichOneof("value")
pkind = ProgressKind.MediaSyncProgress
if ikind == "downloaded_changes":
return Progress(
kind=pkind,
val=MediaSyncDownloadedChanges(progress.media_sync.downloaded_changes),
)
elif ikind == "downloaded_files":
return Progress(
kind=pkind,
val=MediaSyncDownloadedFiles(progress.media_sync.downloaded_files),
)
elif ikind == "uploaded":
up = progress.media_sync.uploaded
return Progress(
kind=pkind,
val=MediaSyncUploaded(files=up.files, deletions=up.deletions),
)
elif ikind == "removed_files":
return Progress(
kind=pkind, val=MediaSyncRemovedFiles(progress.media_sync.removed_files)
)
else:
assert_impossible_literal(ikind)
assert_impossible_literal(kind)
return Progress(kind=ProgressKind.MediaSyncProgress, val=progress.media_sync)
else:
assert_impossible_literal(kind)
class RustBackend:

View file

@ -5,9 +5,8 @@ from __future__ import annotations
import time
from concurrent.futures import Future
from copy import copy
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Union
import aqt
from anki import hooks
@ -16,11 +15,7 @@ from anki.media import media_paths_from_col_path
from anki.rsbackend import (
DBError,
Interrupted,
MediaSyncDownloadedChanges,
MediaSyncDownloadedFiles,
MediaSyncProgress,
MediaSyncRemovedFiles,
MediaSyncUploaded,
NetworkError,
NetworkErrorKind,
Progress,
@ -34,17 +29,7 @@ from aqt import gui_hooks
from aqt.qt import QDialog, QDialogButtonBox, QPushButton
from aqt.utils import showWarning
@dataclass
class MediaSyncState:
downloaded_changes: int = 0
downloaded_files: int = 0
uploaded_files: int = 0
uploaded_removals: int = 0
removed_files: int = 0
LogEntry = Union[MediaSyncState, str]
LogEntry = Union[MediaSyncProgress, str]
@dataclass
@ -56,7 +41,7 @@ class LogEntryWithTime:
class MediaSyncer:
def __init__(self, mw: aqt.main.AnkiQt):
self.mw = mw
self._sync_state: Optional[MediaSyncState] = None
self._syncing: bool = False
self._log: List[LogEntryWithTime] = []
self._want_stop = False
hooks.rust_progress_callback.append(self._on_rust_progress)
@ -66,28 +51,16 @@ class MediaSyncer:
if progress.kind != ProgressKind.MediaSyncProgress:
return proceed
self._update_state(progress.val)
self._log_and_notify(copy(self._sync_state))
self._log_and_notify(progress.val)
if self._want_stop:
return False
else:
return proceed
def _update_state(self, progress: MediaSyncProgress) -> None:
if isinstance(progress, MediaSyncDownloadedChanges):
self._sync_state.downloaded_changes += progress.changes
elif isinstance(progress, MediaSyncDownloadedFiles):
self._sync_state.downloaded_files += progress.files
elif isinstance(progress, MediaSyncUploaded):
self._sync_state.uploaded_files += progress.files
self._sync_state.uploaded_removals += progress.deletions
elif isinstance(progress, MediaSyncRemovedFiles):
self._sync_state.removed_files += progress.files
def start(self) -> None:
"Start media syncing in the background, if it's not already running."
if self._sync_state is not None:
if self._syncing:
return
hkey = self.mw.pm.sync_key()
@ -99,7 +72,7 @@ class MediaSyncer:
return
self._log_and_notify(_("Media sync starting..."))
self._sync_state = MediaSyncState()
self._syncing = True
self._want_stop = False
gui_hooks.media_sync_did_start_or_stop(True)
@ -128,7 +101,7 @@ class MediaSyncer:
)
def _on_finished(self, future: Future) -> None:
self._sync_state = None
self._syncing = False
gui_hooks.media_sync_did_start_or_stop(False)
exc = future.exception()
@ -191,7 +164,7 @@ class MediaSyncer:
self._want_stop = True
def is_syncing(self) -> bool:
return self._sync_state is not None
return self._syncing
def _on_start_stop(self, running: bool):
self.mw.toolbar.set_sync_active(running) # type: ignore
@ -267,21 +240,21 @@ class MediaSyncDialog(QDialog):
def _entry_to_text(self, entry: LogEntryWithTime):
if isinstance(entry.entry, str):
txt = entry.entry
elif isinstance(entry.entry, MediaSyncState):
elif isinstance(entry.entry, MediaSyncProgress):
txt = self._logentry_to_text(entry.entry)
else:
assert_impossible(entry.entry)
return self._time_and_text(entry.time, txt)
def _logentry_to_text(self, e: MediaSyncState) -> str:
def _logentry_to_text(self, e: MediaSyncProgress) -> str:
return _(
"Added: %(a_up)s ↑, %(a_dwn)s ↓, Removed: %(r_up)s ↑, %(r_dwn)s ↓, Checked: %(chk)s"
"Added: %(a_up)s↑, %(a_dwn)s↓, Removed: %(r_up)s↑, %(r_dwn)s↓, Checked: %(chk)s"
) % dict(
a_up=e.uploaded_files,
a_dwn=e.downloaded_files,
r_up=e.uploaded_removals,
r_dwn=e.removed_files,
chk=e.downloaded_changes,
r_up=e.uploaded_deletions,
r_dwn=e.downloaded_deletions,
chk=e.downloaded_meta,
)
def _on_log_entry(self, entry: LogEntryWithTime):

View file

@ -6,7 +6,7 @@ use crate::backend_proto::backend_input::Value;
use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn};
use crate::cloze::expand_clozes_to_reveal_latex;
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
use crate::media::sync::Progress as MediaSyncProgress;
use crate::media::sync::MediaSyncProgress;
use crate::media::MediaManager;
use crate::sched::{local_minutes_west_for_stamp, sched_timing_today};
use crate::template::{
@ -29,8 +29,8 @@ pub struct Backend {
progress_callback: Option<ProtoProgressCallback>,
}
enum Progress {
MediaSync(MediaSyncProgress),
enum Progress<'a> {
MediaSync(&'a MediaSyncProgress),
}
/// Convert an Anki error to a protobuf error.
@ -320,7 +320,7 @@ impl Backend {
fn sync_media(&self, input: SyncMediaIn) -> Result<()> {
let mgr = MediaManager::new(&input.media_folder, &input.media_db)?;
let callback = |progress: MediaSyncProgress| {
let callback = |progress: &MediaSyncProgress| {
self.fire_progress_callback(Progress::MediaSync(progress))
};
@ -360,20 +360,13 @@ fn rendered_node_to_proto(node: RenderedNode) -> pt::rendered_template_node::Val
fn progress_to_proto_bytes(progress: Progress) -> Vec<u8> {
let proto = pt::Progress {
value: Some(match progress {
Progress::MediaSync(progress) => {
use pt::media_sync_progress::Value as V;
use MediaSyncProgress as P;
let val = match progress {
P::DownloadedChanges(n) => V::DownloadedChanges(n as u32),
P::DownloadedFiles(n) => V::DownloadedFiles(n as u32),
P::Uploaded { files, deletions } => V::Uploaded(pt::MediaSyncUploadProgress {
files: files as u32,
deletions: deletions as u32,
}),
P::RemovedFiles(n) => V::RemovedFiles(n as u32),
};
pt::progress::Value::MediaSync(pt::MediaSyncProgress { value: Some(val) })
}
Progress::MediaSync(p) => pt::progress::Value::MediaSync(pt::MediaSyncProgress {
downloaded_meta: p.downloaded_meta as u32,
downloaded_files: p.downloaded_files as u32,
downloaded_deletions: p.downloaded_deletions as u32,
uploaded_files: p.uploaded_files as u32,
uploaded_deletions: p.uploaded_deletions as u32,
}),
}),
};

View file

@ -7,7 +7,7 @@ use crate::media::files::{
add_data_to_folder_uniquely, mtime_as_i64, sha1_of_data, sha1_of_file,
MEDIA_SYNC_FILESIZE_LIMIT, NONSYNCABLE_FILENAME,
};
use crate::media::sync::{MediaSyncer, Progress};
use crate::media::sync::{MediaSyncProgress, MediaSyncer};
use rusqlite::Connection;
use std::borrow::Cow;
use std::collections::HashMap;
@ -98,7 +98,7 @@ impl MediaManager {
/// Sync media.
pub async fn sync_media<F>(&self, progress: F, endpoint: &str, hkey: &str) -> Result<()>
where
F: Fn(Progress) -> bool,
F: Fn(&MediaSyncProgress) -> bool,
{
let mut syncer = MediaSyncer::new(self, progress, endpoint);
syncer.sync(hkey).await

View file

@ -24,24 +24,27 @@ static SYNC_MAX_FILES: usize = 25;
static SYNC_MAX_BYTES: usize = (2.5 * 1024.0 * 1024.0) as usize;
static SYNC_SINGLE_FILE_MAX_BYTES: usize = 100 * 1024 * 1024;
/// The counts are not cumulative - the progress hook should accumulate them.
#[derive(Debug)]
pub enum Progress {
DownloadedChanges(usize),
DownloadedFiles(usize),
Uploaded { files: usize, deletions: usize },
RemovedFiles(usize),
#[derive(Debug, Default)]
pub struct MediaSyncProgress {
pub downloaded_meta: usize,
pub downloaded_files: usize,
pub downloaded_deletions: usize,
pub uploaded_files: usize,
pub uploaded_deletions: usize,
}
pub struct MediaSyncer<'a, P>
where
P: Fn(Progress) -> bool,
P: Fn(&MediaSyncProgress) -> bool,
{
mgr: &'a MediaManager,
ctx: MediaDatabaseContext<'a>,
skey: Option<String>,
client: Client,
progress_cb: P,
progress: MediaSyncProgress,
progress_updated: u64,
endpoint: &'a str,
}
@ -130,7 +133,7 @@ struct FinalizeResponse {
impl<P> MediaSyncer<'_, P>
where
P: Fn(Progress) -> bool,
P: Fn(&MediaSyncProgress) -> bool,
{
pub fn new<'a>(mgr: &'a MediaManager, progress_cb: P, endpoint: &'a str) -> MediaSyncer<'a, P> {
let client = Client::builder()
@ -145,6 +148,8 @@ where
skey: None,
client,
progress_cb,
progress: Default::default(),
progress_updated: 0,
endpoint,
}
}
@ -223,14 +228,16 @@ where
}
last_usn = batch.last().unwrap().usn;
self.progress(Progress::DownloadedChanges(batch.len()))?;
self.progress.downloaded_meta += batch.len();
self.notify_progress()?;
let (to_download, to_delete, to_remove_pending) =
determine_required_changes(&mut self.ctx, &batch)?;
// file removal
remove_files(self.mgr.media_folder.as_path(), to_delete.as_slice())?;
self.progress(Progress::RemovedFiles(to_delete.len()))?;
self.progress.downloaded_deletions += to_delete.len();
self.notify_progress()?;
// file download
let mut downloaded = vec![];
@ -248,7 +255,9 @@ where
let len = download_batch.len();
dl_fnames = &dl_fnames[len..];
downloaded.extend(download_batch);
self.progress(Progress::DownloadedFiles(len))?;
self.progress.downloaded_files += len;
self.notify_progress()?;
}
// then update the DB
@ -284,10 +293,9 @@ where
.take(reply.processed)
.partition(|e| e.sha1.is_some());
self.progress(Progress::Uploaded {
files: processed_files.len(),
deletions: processed_deletions.len(),
})?;
self.progress.uploaded_files += processed_files.len();
self.progress.uploaded_deletions += processed_deletions.len();
self.notify_progress()?;
let fnames: Vec<_> = processed_files
.iter()
@ -338,8 +346,17 @@ where
}
}
fn progress(&self, progress: Progress) -> Result<()> {
if (self.progress_cb)(progress) {
fn notify_progress(&mut self) -> Result<()> {
let now = time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_secs();
if now - self.progress_updated < 1 {
return Ok(());
}
if (self.progress_cb)(&self.progress) {
self.progress_updated = now;
Ok(())
} else {
Err(AnkiError::Interrupted)
@ -685,7 +702,9 @@ fn media_check_required() -> AnkiError {
#[cfg(test)]
mod test {
use crate::err::Result;
use crate::media::sync::{determine_required_change, LocalState, RequiredChange};
use crate::media::sync::{
determine_required_change, LocalState, MediaSyncProgress, RequiredChange,
};
use crate::media::MediaManager;
use tempfile::tempdir;
use tokio::runtime::Runtime;
@ -698,7 +717,7 @@ mod test {
std::fs::write(media_dir.join("test.file").as_path(), "hello")?;
let progress = |progress| {
let progress = |progress: &MediaSyncProgress| {
println!("got progress: {:?}", progress);
true
};