diff --git a/proto/backend.proto b/proto/backend.proto index 6b4cddc6f..5466814a3 100644 --- a/proto/backend.proto +++ b/proto/backend.proto @@ -26,6 +26,7 @@ message BackendInput { ExtractAVTagsIn extract_av_tags = 24; string expand_clozes_to_reveal_latex = 25; AddFileToMediaFolderIn add_file_to_media_folder = 26; + SyncMediaIn sync_media = 27; } } @@ -42,6 +43,7 @@ message BackendOutput { ExtractAVTagsOut extract_av_tags = 24; string expand_clozes_to_reveal_latex = 25; string add_file_to_media_folder = 26; + Empty sync_media = 27; BackendError error = 2047; } @@ -50,7 +52,7 @@ message BackendOutput { message BackendError { oneof value { StringError invalid_input = 1; - StringError template_parse = 2; + TemplateParseError template_parse = 2; StringError io_error = 3; StringError db_error = 4; StringError network_error = 5; @@ -61,11 +63,35 @@ message BackendError { } } +message Progress { + oneof value { + MediaSyncProgress media_sync = 1; + } +} + message StringError { string info = 1; +} + +message TemplateParseError { + string info = 1; bool q_side = 2; } +message MediaSyncProgress { + oneof value { + uint32 downloaded_changes = 1; + uint32 downloaded_files = 2; + MediaSyncUploadProgress uploaded = 3; + uint32 removed_files = 4; + } +} + +message MediaSyncUploadProgress { + uint32 files = 1; + uint32 deletions = 2; +} + message TemplateRequirementsIn { repeated string template_front = 1; map field_names_to_ordinals = 2; @@ -189,4 +215,11 @@ message TTSTag { message AddFileToMediaFolderIn { string desired_name = 1; bytes data = 2; +} + +message SyncMediaIn { + string hkey = 1; + string media_folder = 2; + string media_db = 3; + string endpoint = 4; } \ No newline at end of file diff --git a/pylib/anki/hooks.py b/pylib/anki/hooks.py index d46465e4e..09498dece 100644 --- a/pylib/anki/hooks.py +++ b/pylib/anki/hooks.py @@ -360,6 +360,33 @@ class _NotesWillBeDeletedHook: notes_will_be_deleted = _NotesWillBeDeletedHook() +class _RustProgressCallbackFilter: + """Warning: this is called on a background thread.""" + + _hooks: List[Callable[[bool, "anki.rsbackend.Progress"], bool]] = [] + + def append(self, cb: Callable[[bool, "anki.rsbackend.Progress"], bool]) -> None: + """(proceed: bool, progress: anki.rsbackend.Progress)""" + self._hooks.append(cb) + + def remove(self, cb: Callable[[bool, "anki.rsbackend.Progress"], bool]) -> None: + if cb in self._hooks: + self._hooks.remove(cb) + + def __call__(self, proceed: bool, progress: anki.rsbackend.Progress) -> bool: + for filter in self._hooks: + try: + proceed = filter(proceed, progress) + except: + # if the hook fails, remove it + self._hooks.remove(filter) + raise + return proceed + + +rust_progress_callback = _RustProgressCallbackFilter() + + class _Schedv2DidAnswerReviewCardHook: _hooks: List[Callable[["anki.cards.Card", int, bool], None]] = [] diff --git a/pylib/anki/rsbackend.py b/pylib/anki/rsbackend.py index d9dcd477e..92b5cd9df 100644 --- a/pylib/anki/rsbackend.py +++ b/pylib/anki/rsbackend.py @@ -1,32 +1,78 @@ # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html # pylint: skip-file - +import enum from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, NewType, NoReturn, Optional, Tuple, Union import ankirspy # pytype: disable=import-error import anki.backend_pb2 as pb import anki.buildinfo +from anki import hooks from anki.models import AllTemplateReqs from anki.sound import AVTag, SoundOrVideoTag, TTSTag +from anki.types import assert_impossible_literal assert ankirspy.buildhash() == anki.buildinfo.buildhash SchedTimingToday = pb.SchedTimingTodayOut -class BackendException(Exception): +class Interrupted(Exception): + pass + + +class StringError(Exception): def __str__(self) -> str: - err: pb.BackendError = self.args[0] # pylint: disable=unsubscriptable-object - kind = err.WhichOneof("value") - if kind == "invalid_input": - return f"invalid input: {err.invalid_input.info}" - elif kind == "template_parse": - return err.template_parse.info - else: - return f"unhandled error: {err}" + return self.args[0] # pylint: disable=unsubscriptable-object + + +class NetworkError(StringError): + pass + + +class IOError(StringError): + pass + + +class DBError(StringError): + pass + + +class TemplateError(StringError): + def q_side(self) -> bool: + return self.args[1] + + +class AnkiWebError(StringError): + pass + + +class AnkiWebAuthFailed(Exception): + pass + + +def proto_exception_to_native(err: pb.BackendError) -> Exception: + val = err.WhichOneof("value") + if val == "interrupted": + return Interrupted() + elif val == "network_error": + return NetworkError(err.network_error.info) + elif val == "io_error": + return IOError(err.io_error.info) + elif val == "db_error": + return DBError(err.db_error.info) + elif val == "template_parse": + return TemplateError(err.template_parse.info, err.template_parse.q_side) + elif val == "invalid_input": + return StringError(err.invalid_input.info) + elif val == "ankiweb_auth_failed": + return AnkiWebAuthFailed() + elif val == "ankiweb_misc_error": + return AnkiWebError(err.ankiweb_misc_error.info) + else: + assert_impossible_literal(val) def proto_template_reqs_to_legacy( @@ -71,6 +117,45 @@ class TemplateReplacement: TemplateReplacementList = List[Union[str, TemplateReplacement]] +@dataclass +class MediaSyncDownloadedChanges: + changes: int + + +@dataclass +class MediaSyncDownloadedFiles: + files: int + + +@dataclass +class MediaSyncUploaded: + files: int + deletions: int + + +@dataclass +class MediaSyncRemovedFiles: + files: int + + +MediaSyncProgress = Union[ + MediaSyncDownloadedChanges, + MediaSyncDownloadedFiles, + MediaSyncUploaded, + MediaSyncRemovedFiles, +] + + +class ProgressKind(enum.Enum): + MediaSyncProgress = 0 + + +@dataclass +class Progress: + kind: ProgressKind + val: Union[MediaSyncProgress] + + def proto_replacement_list_to_native( nodes: List[pb.RenderedTemplateNode], ) -> TemplateReplacementList: @@ -89,6 +174,36 @@ def proto_replacement_list_to_native( return results +def proto_progress_to_native(progress: pb.Progress) -> Progress: + kind = progress.WhichOneof("value") + if kind == "media_sync": + ikind = progress.media_sync.WhichOneof("value") + pkind = ProgressKind.MediaSyncProgress + if ikind == "downloaded_changes": + return Progress( + kind=pkind, + val=MediaSyncDownloadedChanges(progress.media_sync.downloaded_changes), + ) + elif ikind == "downloaded_files": + return Progress( + kind=pkind, + val=MediaSyncDownloadedFiles(progress.media_sync.downloaded_files), + ) + elif ikind == "uploaded": + up = progress.media_sync.uploaded + return Progress( + kind=pkind, + val=MediaSyncUploaded(files=up.files, deletions=up.deletions), + ) + elif ikind == "removed_files": + return Progress( + kind=pkind, val=MediaSyncRemovedFiles(progress.media_sync.removed_files) + ) + else: + assert_impossible_literal(ikind) + assert_impossible_literal(kind) + + class RustBackend: def __init__(self, col_path: str, media_folder_path: str, media_db_path: str): init_msg = pb.BackendInit( @@ -97,15 +212,24 @@ class RustBackend: media_db_path=media_db_path, ) self._backend = ankirspy.open_backend(init_msg.SerializeToString()) + self._backend.set_progress_callback(self._on_progress) - def _run_command(self, input: pb.BackendInput) -> pb.BackendOutput: + def _on_progress(self, progress_bytes: bytes) -> bool: + progress = pb.Progress() + progress.ParseFromString(progress_bytes) + native_progress = proto_progress_to_native(progress) + return hooks.rust_progress_callback(True, native_progress) + + def _run_command( + self, input: pb.BackendInput, release_gil: bool = False + ) -> pb.BackendOutput: input_bytes = input.SerializeToString() - output_bytes = self._backend.command(input_bytes) + output_bytes = self._backend.command(input_bytes, release_gil) output = pb.BackendOutput() output.ParseFromString(output_bytes) kind = output.WhichOneof("value") if kind == "error": - raise BackendException(output.error) + raise proto_exception_to_native(output.error) else: return output @@ -195,3 +319,18 @@ class RustBackend: ) ) ).add_file_to_media_folder + + def sync_media( + self, hkey: str, media_folder: str, media_db: str, endpoint: str + ) -> None: + self._run_command( + pb.BackendInput( + sync_media=pb.SyncMediaIn( + hkey=hkey, + media_folder=media_folder, + media_db=media_db, + endpoint=endpoint, + ) + ), + release_gil=True, + ) diff --git a/pylib/anki/template.py b/pylib/anki/template.py index 0e08af26a..3ba3afe80 100644 --- a/pylib/anki/template.py +++ b/pylib/anki/template.py @@ -120,10 +120,8 @@ def render_card( # render try: output = render_card_from_context(ctx) - except anki.rsbackend.BackendException as e: - # fixme: specific exception in 2.1.21 - err = e.args[0].template_parse # pylint: disable=no-member - if err.q_side: + except anki.rsbackend.TemplateError as e: + if e.q_side(): side = _("Front") else: side = _("Back") diff --git a/pylib/anki/types.py b/pylib/anki/types.py new file mode 100644 index 000000000..28445aebf --- /dev/null +++ b/pylib/anki/types.py @@ -0,0 +1,16 @@ +import enum +from typing import Any, NoReturn + + +class _Impossible(enum.Enum): + pass + + +def assert_impossible(arg: NoReturn) -> NoReturn: + raise Exception(f"unexpected arg received: {type(arg)} {arg}") + + +# mypy is not yet smart enough to do exhaustiveness checking on literal types, +# so this will fail at runtime instead of typecheck time :-( +def assert_impossible_literal(arg: Any) -> NoReturn: + raise Exception(f"unexpected arg received: {type(arg)} {arg}") diff --git a/pylib/tools/genhooks.py b/pylib/tools/genhooks.py index 9e124be77..1a4fa57fe 100644 --- a/pylib/tools/genhooks.py +++ b/pylib/tools/genhooks.py @@ -50,6 +50,12 @@ hooks = [ ), Hook(name="sync_stage_did_change", args=["stage: str"], legacy_hook="sync"), Hook(name="sync_progress_did_change", args=["msg: str"], legacy_hook="syncMsg"), + Hook( + name="rust_progress_callback", + args=["proceed: bool", "progress: anki.rsbackend.Progress"], + return_type="bool", + doc="Warning: this is called on a background thread.", + ), Hook( name="tag_added", args=["tag: str"], legacy_hook="newTag", legacy_no_args=True, ), diff --git a/qt/aqt/gui_hooks.py b/qt/aqt/gui_hooks.py index 026f0bc2f..3fa21a96f 100644 --- a/qt/aqt/gui_hooks.py +++ b/qt/aqt/gui_hooks.py @@ -697,6 +697,30 @@ class _EditorWillUseFontForFieldFilter: editor_will_use_font_for_field = _EditorWillUseFontForFieldFilter() +class _MediaSyncDidProgressHook: + _hooks: List[Callable[["aqt.mediasync.LogEntryWithTime"], None]] = [] + + def append(self, cb: Callable[["aqt.mediasync.LogEntryWithTime"], None]) -> None: + """(entry: aqt.mediasync.LogEntryWithTime)""" + self._hooks.append(cb) + + def remove(self, cb: Callable[["aqt.mediasync.LogEntryWithTime"], None]) -> None: + if cb in self._hooks: + self._hooks.remove(cb) + + def __call__(self, entry: aqt.mediasync.LogEntryWithTime) -> None: + for hook in self._hooks: + try: + hook(entry) + except: + # if the hook fails, remove it + self._hooks.remove(hook) + raise + + +media_sync_did_progress = _MediaSyncDidProgressHook() + + class _OverviewDidRefreshHook: """Allow to update the overview window. E.g. add the deck name in the title.""" diff --git a/qt/aqt/mediasync.py b/qt/aqt/mediasync.py new file mode 100644 index 000000000..aeba367d2 --- /dev/null +++ b/qt/aqt/mediasync.py @@ -0,0 +1,206 @@ +# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +import time +from concurrent.futures import Future +from copy import copy +from dataclasses import dataclass +from typing import List, Optional, Union + +import anki +import aqt +from anki import hooks +from anki.lang import _ +from anki.media import media_paths_from_col_path +from anki.rsbackend import ( + Interrupted, + MediaSyncDownloadedChanges, + MediaSyncDownloadedFiles, + MediaSyncProgress, + MediaSyncRemovedFiles, + MediaSyncUploaded, + Progress, + ProgressKind, +) +from anki.types import assert_impossible +from anki.utils import intTime +from aqt import gui_hooks +from aqt.qt import QDialog, QDialogButtonBox, QPushButton, QWidget +from aqt.taskman import TaskManager + + +@dataclass +class MediaSyncState: + downloaded_changes: int = 0 + downloaded_files: int = 0 + uploaded_files: int = 0 + uploaded_removals: int = 0 + removed_files: int = 0 + + +# fixme: make sure we don't run twice +# fixme: handle auth errors +# fixme: handle network errors +# fixme: show progress in UI +# fixme: abort when closing collection/app +# fixme: handle no hkey +# fixme: shards +# fixme: dialog should be a singleton +# fixme: abort button should not be default + + +class SyncBegun: + pass + + +class SyncEnded: + pass + + +class SyncAborted: + pass + + +LogEntry = Union[MediaSyncState, SyncBegun, SyncEnded, SyncAborted] + + +@dataclass +class LogEntryWithTime: + time: int + entry: LogEntry + + +class MediaSyncer: + def __init__(self, taskman: TaskManager): + self._taskman = taskman + self._sync_state: Optional[MediaSyncState] = None + self._log: List[LogEntryWithTime] = [] + self._want_stop = False + hooks.rust_progress_callback.append(self._on_rust_progress) + + def _on_rust_progress(self, proceed: bool, progress: Progress) -> bool: + if progress.kind != ProgressKind.MediaSyncProgress: + return proceed + + self._update_state(progress.val) + self._log_and_notify(copy(self._sync_state)) + + if self._want_stop: + return False + else: + return proceed + + def _update_state(self, progress: MediaSyncProgress) -> None: + if isinstance(progress, MediaSyncDownloadedChanges): + self._sync_state.downloaded_changes += progress.changes + elif isinstance(progress, MediaSyncDownloadedFiles): + self._sync_state.downloaded_files += progress.files + elif isinstance(progress, MediaSyncUploaded): + self._sync_state.uploaded_files += progress.files + self._sync_state.uploaded_removals += progress.deletions + elif isinstance(progress, MediaSyncRemovedFiles): + self._sync_state.removed_files += progress.files + + def start( + self, col: anki.storage._Collection, hkey: str, shard: Optional[int] + ) -> None: + "Start media syncing in the background, if it's not already running." + if self._sync_state is not None: + return + + self._log_and_notify(SyncBegun()) + self._sync_state = MediaSyncState() + self._want_stop = False + + if shard is not None: + shard_str = str(shard) + else: + shard_str = "" + endpoint = f"https://sync{shard_str}ankiweb.net" + + (media_folder, media_db) = media_paths_from_col_path(col.path) + + def run() -> None: + col.backend.sync_media(hkey, media_folder, media_db, endpoint) + + self._taskman.run_in_background(run, self._on_finished) + + def _log_and_notify(self, entry: LogEntry) -> None: + entry_with_time = LogEntryWithTime(time=intTime(), entry=entry) + self._log.append(entry_with_time) + self._taskman.run_on_main( + lambda: gui_hooks.media_sync_did_progress(entry_with_time) + ) + + def _on_finished(self, future: Future) -> None: + self._sync_state = None + + exc = future.exception() + if exc is not None: + if isinstance(exc, Interrupted): + self._log_and_notify(SyncAborted()) + else: + raise exc + else: + self._log_and_notify(SyncEnded()) + + def entries(self) -> List[LogEntryWithTime]: + return self._log + + def abort(self) -> None: + self._want_stop = True + + +class MediaSyncDialog(QDialog): + def __init__(self, parent: QWidget, syncer: MediaSyncer) -> None: + super().__init__(parent) + self._syncer = syncer + self.form = aqt.forms.synclog.Ui_Dialog() + self.form.setupUi(self) + self.abort_button = QPushButton(_("Abort")) + self.abort_button.clicked.connect(self._on_abort) # type: ignore + self.form.buttonBox.addButton(self.abort_button, QDialogButtonBox.ActionRole) + + gui_hooks.media_sync_did_progress.append(self._on_log_entry) + + self.form.plainTextEdit.setPlainText( + "\n".join(self._entry_to_text(x) for x in syncer.entries()) + ) + + def _on_abort(self, *args) -> None: + self.form.plainTextEdit.appendPlainText( + self._time_and_text(intTime(), _("Aborting...")) + ) + self._syncer.abort() + self.abort_button.setHidden(True) + + def _time_and_text(self, stamp: int, text: str) -> str: + asctime = time.asctime(time.localtime(stamp)) + return f"{asctime}: {text}" + + def _entry_to_text(self, entry: LogEntryWithTime): + if isinstance(entry.entry, SyncBegun): + txt = _("Sync starting...") + elif isinstance(entry.entry, SyncEnded): + txt = _("Sync complete.") + elif isinstance(entry.entry, SyncAborted): + txt = _("Aborted.") + elif isinstance(entry.entry, MediaSyncState): + txt = self._logentry_to_text(entry.entry) + else: + assert_impossible(entry.entry) + return self._time_and_text(entry.time, txt) + + def _logentry_to_text(self, e: MediaSyncState) -> str: + return _( + "Added: %(a_up)s ↑, %(a_dwn)s ↓, Removed: %(r_up)s ↑, %(r_dwn)s ↓, Checked: %(chk)s" + ) % dict( + a_up=e.uploaded_files, + a_dwn=e.downloaded_files, + r_up=e.uploaded_removals, + r_dwn=e.removed_files, + chk=e.downloaded_changes, + ) + + def _on_log_entry(self, entry: LogEntryWithTime): + self.form.plainTextEdit.appendPlainText(self._entry_to_text(entry)) diff --git a/qt/aqt/profiles.py b/qt/aqt/profiles.py index 9b7f0c26c..50179c9f6 100644 --- a/qt/aqt/profiles.py +++ b/qt/aqt/profiles.py @@ -11,7 +11,7 @@ import locale import pickle import random import shutil -from typing import Any, Dict +from typing import Any, Dict, Optional from send2trash import send2trash @@ -502,7 +502,7 @@ please see: def set_night_mode(self, on: bool) -> None: self.meta["night_mode"] = on - # Profile-specific options + # Profile-specific ###################################################################### def interrupt_audio(self) -> bool: @@ -512,6 +512,9 @@ please see: self.profile["interrupt_audio"] = val aqt.sound.av_player.interrupt_current_audio = val + def sync_key(self) -> Optional[str]: + return self.profile.get("syncKey") + ###################################################################### def apply_profile_options(self) -> None: diff --git a/qt/designer/synclog.ui b/qt/designer/synclog.ui new file mode 100644 index 000000000..adb8120a9 --- /dev/null +++ b/qt/designer/synclog.ui @@ -0,0 +1,74 @@ + + + Dialog + + + + 0 + 0 + 557 + 295 + + + + Sync + + + + + + true + + + + + + + + + + Qt::Horizontal + + + QDialogButtonBox::Close + + + + + + + + + buttonBox + accepted() + Dialog + accept() + + + 248 + 254 + + + 157 + 274 + + + + + buttonBox + rejected() + Dialog + reject() + + + 316 + 260 + + + 286 + 274 + + + + + diff --git a/qt/tools/genhooks_gui.py b/qt/tools/genhooks_gui.py index bd8226132..4ffa76c12 100644 --- a/qt/tools/genhooks_gui.py +++ b/qt/tools/genhooks_gui.py @@ -266,6 +266,9 @@ hooks = [ return_type="str", legacy_hook="setupStyle", ), + Hook( + name="media_sync_did_progress", args=["entry: aqt.mediasync.LogEntryWithTime"], + ), # Adding cards ################### Hook( diff --git a/rslib/src/backend.rs b/rslib/src/backend.rs index c4e929555..b5139b6fe 100644 --- a/rslib/src/backend.rs +++ b/rslib/src/backend.rs @@ -3,9 +3,10 @@ use crate::backend_proto as pt; use crate::backend_proto::backend_input::Value; -use crate::backend_proto::{Empty, RenderedTemplateReplacement}; +use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn}; use crate::cloze::expand_clozes_to_reveal_latex; use crate::err::{AnkiError, Result}; +use crate::media::sync::{sync_media, Progress as MediaSyncProgress}; use crate::media::MediaManager; use crate::sched::{local_minutes_west_for_stamp, sched_timing_today}; use crate::template::{ @@ -16,11 +17,19 @@ use crate::text::{extract_av_tags, strip_av_tags, AVTag}; use prost::Message; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; +use tokio::runtime::Runtime; + +pub type ProtoProgressCallback = Box) -> bool + Send>; pub struct Backend { #[allow(dead_code)] col_path: PathBuf, media_manager: Option, + progress_callback: Option, +} + +enum Progress { + MediaSync(MediaSyncProgress), } /// Convert an Anki error to a protobuf error. @@ -77,6 +86,7 @@ impl Backend { Ok(Backend { col_path: col_path.into(), media_manager, + progress_callback: None, }) } @@ -142,9 +152,26 @@ impl Backend { Value::AddFileToMediaFolder(input) => { OValue::AddFileToMediaFolder(self.add_file_to_media_folder(input)?) } + Value::SyncMedia(input) => { + self.sync_media(input)?; + OValue::SyncMedia(Empty {}) + } }) } + fn fire_progress_callback(&self, progress: Progress) -> bool { + if let Some(cb) = &self.progress_callback { + let bytes = progress_to_proto_bytes(progress); + cb(bytes) + } else { + true + } + } + + pub fn set_progress_callback(&mut self, progress_cb: Option) { + self.progress_callback = progress_cb; + } + fn template_requirements( &self, input: pt::TemplateRequirementsIn, @@ -263,6 +290,17 @@ impl Backend { .add_file(&input.desired_name, &input.data)? .into()) } + + fn sync_media(&self, input: SyncMediaIn) -> Result<()> { + let mut mgr = MediaManager::new(&input.media_folder, &input.media_db)?; + + let callback = |progress: MediaSyncProgress| { + self.fire_progress_callback(Progress::MediaSync(progress)) + }; + + let mut rt = Runtime::new().unwrap(); + rt.block_on(sync_media(&mut mgr, &input.hkey, callback)) + } } fn ords_hash_to_set(ords: HashSet) -> Vec { @@ -292,3 +330,28 @@ fn rendered_node_to_proto(node: RenderedNode) -> pt::rendered_template_node::Val }), } } + +fn progress_to_proto_bytes(progress: Progress) -> Vec { + let proto = pt::Progress { + value: Some(match progress { + Progress::MediaSync(progress) => { + use pt::media_sync_progress::Value as V; + use MediaSyncProgress as P; + let val = match progress { + P::DownloadedChanges(n) => V::DownloadedChanges(n as u32), + P::DownloadedFiles(n) => V::DownloadedFiles(n as u32), + P::Uploaded { files, deletions } => V::Uploaded(pt::MediaSyncUploadProgress { + files: files as u32, + deletions: deletions as u32, + }), + P::RemovedFiles(n) => V::RemovedFiles(n as u32), + }; + pt::progress::Value::MediaSync(pt::MediaSyncProgress { value: Some(val) }) + } + }), + }; + + let mut buf = vec![]; + proto.encode(&mut buf).expect("encode failed"); + buf +} diff --git a/rslib/src/err.rs b/rslib/src/err.rs index a45dcf114..36006f53f 100644 --- a/rslib/src/err.rs +++ b/rslib/src/err.rs @@ -80,9 +80,11 @@ impl From for AnkiError { impl From for AnkiError { fn from(err: reqwest::Error) -> Self { - AnkiError::NetworkError { - info: format!("{:?}", err), - } + let url = err.url().map(|url| url.as_str()).unwrap_or(""); + let str_err = format!("{}", err); + // strip url from error to avoid exposing keys + let str_err = str_err.replace(url, ""); + AnkiError::NetworkError { info: str_err } } } diff --git a/rslib/src/media/sync.rs b/rslib/src/media/sync.rs index 37092be0c..563df2952 100644 --- a/rslib/src/media/sync.rs +++ b/rslib/src/media/sync.rs @@ -2,7 +2,7 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html use crate::err::{AnkiError, Result}; -use crate::media::database::{MediaDatabaseContext, MediaEntry}; +use crate::media::database::{MediaDatabaseContext, MediaDatabaseMetadata, MediaEntry}; use crate::media::files::{ add_file_from_ankiweb, data_for_file, normalize_filename, remove_files, AddedFile, }; @@ -95,8 +95,8 @@ where } } - async fn fetch_changes(&mut self, client_usn: i32) -> Result<()> { - let mut last_usn = client_usn; + async fn fetch_changes(&mut self, mut meta: MediaDatabaseMetadata) -> Result<()> { + let mut last_usn = meta.last_sync_usn; loop { debug!("fetching record batch starting from usn {}", last_usn); @@ -140,6 +140,11 @@ where record_removals(ctx, &to_delete)?; record_additions(ctx, downloaded)?; record_clean(ctx, &to_remove_pending)?; + + // update usn + meta.last_sync_usn = last_usn; + ctx.set_meta(&meta)?; + Ok(()) })?; } @@ -214,7 +219,8 @@ where // make sure media DB is up to date register_changes(&mut sctx.ctx, mgr.media_folder.as_path())?; - let client_usn = sctx.ctx.get_meta()?.last_sync_usn; + let meta = sctx.ctx.get_meta()?; + let client_usn = meta.last_sync_usn; debug!("beginning media sync"); let (sync_key, server_usn) = sctx.sync_begin(hkey).await?; @@ -226,7 +232,7 @@ where // need to fetch changes from server? if client_usn != server_usn { debug!("differs from local usn {}, fetching changes", client_usn); - sctx.fetch_changes(client_usn).await?; + sctx.fetch_changes(meta).await?; actions_performed = true; } diff --git a/rspy/Cargo.toml b/rspy/Cargo.toml index fad342b71..46e17ddc2 100644 --- a/rspy/Cargo.toml +++ b/rspy/Cargo.toml @@ -6,6 +6,9 @@ authors = ["Ankitects Pty Ltd and contributors"] [dependencies] anki = { path = "../rslib" } +log = "0.4.8" +env_logger = "0.7.1" +tokio = "0.2.11" [dependencies.pyo3] version = "0.8.0" diff --git a/rspy/src/lib.rs b/rspy/src/lib.rs index 135c1f531..79456e084 100644 --- a/rspy/src/lib.rs +++ b/rspy/src/lib.rs @@ -1,4 +1,5 @@ use anki::backend::{init_backend, Backend as RustBackend}; +use log::error; use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::{exceptions, wrap_pyfunction}; @@ -23,11 +24,45 @@ fn open_backend(init_msg: &PyBytes) -> PyResult { #[pymethods] impl Backend { - fn command(&mut self, py: Python, input: &PyBytes) -> PyObject { - let out_bytes = self.backend.run_command_bytes(input.as_bytes()); + fn command(&mut self, py: Python, input: &PyBytes, release_gil: bool) -> PyObject { + let in_bytes = input.as_bytes(); + let out_bytes = if release_gil { + py.allow_threads(move || self.backend.run_command_bytes(in_bytes)) + } else { + self.backend.run_command_bytes(in_bytes) + }; let out_obj = PyBytes::new(py, &out_bytes); out_obj.into() } + + fn set_progress_callback(&mut self, callback: PyObject) { + if callback.is_none() { + self.backend.set_progress_callback(None); + } else { + let func = move |bytes: Vec| { + let gil = Python::acquire_gil(); + let py = gil.python(); + let out_bytes = PyBytes::new(py, &bytes); + let out_obj: PyObject = out_bytes.into(); + let res: PyObject = match callback.call1(py, (out_obj,)) { + Ok(res) => res, + Err(e) => { + error!("error calling callback:"); + e.print(py); + return false; + } + }; + match res.extract(py) { + Ok(cont) => cont, + Err(e) => { + error!("callback did not return bool: {:?}", e); + return false; + } + } + }; + self.backend.set_progress_callback(Some(Box::new(func))); + } + } } #[pymodule] @@ -36,5 +71,7 @@ fn ankirspy(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(buildhash)).unwrap(); m.add_wrapped(wrap_pyfunction!(open_backend)).unwrap(); + env_logger::init(); + Ok(()) }