From f20b5b8db6f2be4605a78593050b9b05ee48edcc Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Sat, 1 Feb 2020 19:05:22 +1000 Subject: [PATCH] media sync working, but unpolished --- proto/backend.proto | 3 + rslib/Cargo.toml | 15 +- rslib/src/backend.rs | 5 +- rslib/src/err.rs | 33 ++ rslib/src/media/database.rs | 19 +- rslib/src/media/files.rs | 84 ++++- rslib/src/media/mod.rs | 47 +-- rslib/src/media/sync.rs | 636 ++++++++++++++++++++++++++++++++++++ 8 files changed, 782 insertions(+), 60 deletions(-) create mode 100644 rslib/src/media/sync.rs diff --git a/proto/backend.proto b/proto/backend.proto index b5dfe9562..a51b1f3c2 100644 --- a/proto/backend.proto +++ b/proto/backend.proto @@ -53,6 +53,9 @@ message BackendError { StringError template_parse = 2; StringError io_error = 3; StringError db_error = 4; + StringError network_error = 5; + Empty ankiweb_auth_failed = 6; + StringError ankiweb_misc_error = 7; } } diff --git a/rslib/Cargo.toml b/rslib/Cargo.toml index 972d1dd3d..b18299111 100644 --- a/rslib/Cargo.toml +++ b/rslib/Cargo.toml @@ -8,8 +8,8 @@ license = "AGPL-3.0-or-later" [dependencies] nom = "5.0.1" failure = "0.1.6" -prost = "0.5.0" -bytes = "0.4" +prost = "0.6.1" +bytes = "0.5.4" chrono = "0.4.10" lazy_static = "1.4.0" regex = "1.3.3" @@ -19,7 +19,16 @@ htmlescape = "0.3.1" sha1 = "0.6.0" unicode-normalization = "0.1.12" tempfile = "3.1.0" -rusqlite = "0.21.0" +rusqlite = { version = "0.21.0", features = ["trace"] } +reqwest = { version = "0.10.1", features = ["json"] } +serde = "1.0.104" +serde_json = "1.0.45" +tokio = "0.2.11" +serde_derive = "1.0.104" +env_logger = "0.7.1" +zip = "0.5.4" +log = "0.4.8" +serde_tuple = "0.4.0" [build-dependencies] prost-build = "0.5.0" diff --git a/rslib/src/backend.rs b/rslib/src/backend.rs index 30c63db1d..c60b201f6 100644 --- a/rslib/src/backend.rs +++ b/rslib/src/backend.rs @@ -3,7 +3,7 @@ use crate::backend_proto as pt; use crate::backend_proto::backend_input::Value; -use crate::backend_proto::RenderedTemplateReplacement; +use crate::backend_proto::{Empty, RenderedTemplateReplacement}; use crate::cloze::expand_clozes_to_reveal_latex; use crate::err::{AnkiError, Result}; use crate::media::MediaManager; @@ -34,6 +34,9 @@ impl std::convert::From for pt::BackendError { }, AnkiError::IOError { info } => V::IoError(pt::StringError { info }), AnkiError::DBError { info } => V::DbError(pt::StringError { info }), + AnkiError::NetworkError { info } => V::NetworkError(pt::StringError { info }), + AnkiError::AnkiWebAuthenticationFailed => V::AnkiwebAuthFailed(Empty {}), + AnkiError::AnkiWebMiscError { info } => V::AnkiwebMiscError(pt::StringError { info }), }; pt::BackendError { value: Some(value) } diff --git a/rslib/src/err.rs b/rslib/src/err.rs index 63bd0a931..68afb506a 100644 --- a/rslib/src/err.rs +++ b/rslib/src/err.rs @@ -19,6 +19,15 @@ pub enum AnkiError { #[fail(display = "DB error: {}", info)] DBError { info: String }, + + #[fail(display = "Network error: {}", info)] + NetworkError { info: String }, + + #[fail(display = "AnkiWeb authentication failed.")] + AnkiWebAuthenticationFailed, + + #[fail(display = "AnkiWeb error: {}", info)] + AnkiWebMiscError { info: String }, } // error helpers @@ -65,3 +74,27 @@ impl From for AnkiError { } } } + +impl From for AnkiError { + fn from(err: reqwest::Error) -> Self { + AnkiError::NetworkError { + info: format!("{:?}", err), + } + } +} + +impl From for AnkiError { + fn from(err: zip::result::ZipError) -> Self { + AnkiError::AnkiWebMiscError { + info: format!("{:?}", err), + } + } +} + +impl From for AnkiError { + fn from(err: serde_json::Error) -> Self { + AnkiError::AnkiWebMiscError { + info: format!("{:?}", err), + } + } +} diff --git a/rslib/src/media/database.rs b/rslib/src/media/database.rs index ae808dac0..78f39380e 100644 --- a/rslib/src/media/database.rs +++ b/rslib/src/media/database.rs @@ -2,13 +2,22 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html use crate::err::Result; +use log::debug; use rusqlite::{params, Connection, OptionalExtension, Statement, NO_PARAMS}; use std::collections::HashMap; use std::path::Path; +fn trace(s: &str) { + debug!("sql: {}", s) +} + pub(super) fn open_or_create>(path: P) -> Result { let mut db = Connection::open(path)?; + if std::env::var("TRACESQL").is_ok() { + db.trace(Some(trace)); + } + db.pragma_update(None, "page_size", &4096)?; db.pragma_update(None, "legacy_file_format", &false)?; db.pragma_update(None, "journal", &"wal")?; @@ -218,16 +227,6 @@ delete from media where fname=?" .map_err(Into::into) } - pub(super) fn changes_pending(&mut self) -> Result { - self.db - .query_row( - "select count(*) from media where dirty=1", - NO_PARAMS, - |row| Ok(row.get(0)?), - ) - .map_err(Into::into) - } - pub(super) fn count(&mut self) -> Result { self.db .query_row( diff --git a/rslib/src/media/files.rs b/rslib/src/media/files.rs index 285a9b989..aec113554 100644 --- a/rslib/src/media/files.rs +++ b/rslib/src/media/files.rs @@ -1,14 +1,16 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use crate::err::{AnkiError, Result}; use lazy_static::lazy_static; +use log::debug; use regex::Regex; use sha1::Sha1; use std::borrow::Cow; use std::io::Read; use std::path::Path; use std::{fs, io, time}; -use unicode_normalization::{is_nfc_quick, IsNormalized, UnicodeNormalization}; +use unicode_normalization::{is_nfc, UnicodeNormalization}; /// The maximum length we allow a filename to be. When combined /// with the rest of the path, the full path needs to be under ~240 chars @@ -61,10 +63,10 @@ fn disallowed_char(char: char) -> bool { /// - Any problem characters are removed. /// - Windows device names like CON and PRN have '_' appended /// - The filename is limited to 120 bytes. -fn normalize_filename(fname: &str) -> Cow { +pub(crate) fn normalize_filename(fname: &str) -> Cow { let mut output = Cow::Borrowed(fname); - if is_nfc_quick(output.chars()) != IsNormalized::Yes { + if !is_nfc(output.as_ref()) { output = output.chars().nfc().collect::().into(); } @@ -232,6 +234,76 @@ pub(super) fn mtime_as_i64>(path: P) -> io::Result { .as_secs() as i64) } +pub(super) fn remove_files(media_folder: &Path, files: &[&String]) -> Result<()> { + for &file in files { + debug!("removing {}", file); + let path = media_folder.join(file.as_str()); + fs::remove_file(path).map_err(|e| AnkiError::IOError { + info: format!("Error removing {}: {}", file, e), + })?; + } + + Ok(()) +} + +pub(super) struct AddedFile { + pub fname: String, + pub sha1: [u8; 20], + pub mtime: i64, + pub renamed_from: Option, +} + +/// Add a file received from AnkiWeb into the media folder. +/// +/// Because AnkiWeb did not previously enforce file name limits and invalid +/// characters, we'll need to rename the file if it is not valid. +pub(super) fn add_file_from_ankiweb( + media_folder: &Path, + fname: &str, + data: &[u8], +) -> Result { + let normalized = normalize_filename(fname); + + let path = media_folder.join(normalized.as_ref()); + fs::write(&path, data)?; + + let sha1 = sha1_of_data(data); + + let mtime = mtime_as_i64(path)?; + + // fixme: could we use the info sent from the server for the hash instead + // of hashing it here and returning hash? + + let renamed_from = if let Cow::Borrowed(_) = normalized { + None + } else { + Some(fname.to_string()) + }; + + Ok(AddedFile { + fname: normalized.to_string(), + sha1, + mtime, + renamed_from, + }) +} + +pub(super) fn data_for_file(media_folder: &Path, fname: &str) -> Result>> { + let mut file = match fs::File::open(&media_folder.join(fname)) { + Ok(file) => file, + Err(e) => { + if e.kind() == io::ErrorKind::NotFound { + return Ok(None); + } else { + return Err(e.into()); + } + } + }; + let mut buf = vec![]; + file.read_to_end(&mut buf)?; + Ok(Some(buf)) +} + #[cfg(test)] mod test { use crate::media::files::{ @@ -242,7 +314,7 @@ mod test { use tempfile::tempdir; #[test] - fn test_normalize() { + fn normalize() { assert_eq!(normalize_filename("foo.jpg"), Cow::Borrowed("foo.jpg")); assert_eq!( normalize_filename("con.jpg[]><:\"/?*^\\|\0\r\n").as_ref(), @@ -257,7 +329,7 @@ mod test { } #[test] - fn test_add_hash_suffix() { + fn add_hash_suffix() { let hash = sha1_of_data("hello".as_bytes()); assert_eq!( add_hash_suffix_to_file_stem("test.jpg", &hash).as_str(), @@ -266,7 +338,7 @@ mod test { } #[test] - fn test_adding() { + fn adding() { let dir = tempdir().unwrap(); let dpath = dir.path(); diff --git a/rslib/src/media/mod.rs b/rslib/src/media/mod.rs index 8ac46a616..1f64ebc09 100644 --- a/rslib/src/media/mod.rs +++ b/rslib/src/media/mod.rs @@ -15,6 +15,7 @@ use std::time; pub mod database; pub mod files; +pub mod sync; pub struct MediaManager { db: Connection, @@ -94,7 +95,7 @@ impl MediaManager { } /// Note any added/changed/deleted files. - pub fn register_changes(&mut self) -> Result<()> { + fn register_changes(&mut self) -> Result<()> { self.transact(|ctx| { // folder mtime unchanged? let dirmod = mtime_as_i64(&self.media_folder)?; @@ -119,45 +120,11 @@ impl MediaManager { }) } - // syncDelete - pub fn remove_entry(&mut self, fname: &str) -> Result<()> { - self.transact(|ctx| ctx.remove_entry(fname)) - } - // forceResync pub fn clear(&mut self) -> Result<()> { self.transact(|ctx| ctx.clear()) } - // lastUsn - pub fn get_last_usn(&mut self) -> Result { - self.query(|ctx| Ok(ctx.get_meta()?.last_sync_usn)) - } - - // setLastUsn - pub fn set_last_usn(&mut self, usn: i32) -> Result<()> { - self.transact(|ctx| { - let mut meta = ctx.get_meta()?; - meta.last_sync_usn = usn; - ctx.set_meta(&meta) - }) - } - - // dirtyCount - pub fn changes_pending(&mut self) -> Result { - self.query(|ctx| ctx.changes_pending()) - } - - // mediaCount - pub fn count(&mut self) -> Result { - self.query(|ctx| ctx.count()) - } - - // mediaChangesZip - pub fn get_pending_uploads(&mut self, max_entries: u32) -> Result> { - self.query(|ctx| ctx.get_pending_uploads(max_entries)) - } - // db helpers pub(super) fn query(&self, func: F) -> Result @@ -334,7 +301,7 @@ mod test { let mut entry = mgr.transact(|ctx| { assert_eq!(ctx.count()?, 1); - assert_eq!(ctx.changes_pending()?, 1); + assert!(!ctx.get_pending_uploads(1)?.is_empty()); let mut entry = ctx.get_entry("file.jpg")?.unwrap(); assert_eq!( entry, @@ -354,7 +321,7 @@ mod test { // mark it as unmodified entry.sync_required = false; ctx.set_entry(&entry)?; - assert_eq!(ctx.changes_pending()?, 0); + assert!(ctx.get_pending_uploads(1)?.is_empty()); // modify it fs::write(&f1, "hello1")?; @@ -369,7 +336,7 @@ mod test { mgr.transact(|ctx| { assert_eq!(ctx.count()?, 1); - assert_eq!(ctx.changes_pending()?, 1); + assert!(!ctx.get_pending_uploads(1)?.is_empty()); assert_eq!( ctx.get_entry("file.jpg")?.unwrap(), MediaEntry { @@ -388,7 +355,7 @@ mod test { // mark it as unmodified entry.sync_required = false; ctx.set_entry(&entry)?; - assert_eq!(ctx.changes_pending()?, 0); + assert!(ctx.get_pending_uploads(1)?.is_empty()); Ok(()) })?; @@ -401,7 +368,7 @@ mod test { mgr.query(|ctx| { assert_eq!(ctx.count()?, 0); - assert_eq!(ctx.changes_pending()?, 1); + assert!(!ctx.get_pending_uploads(1)?.is_empty()); assert_eq!( ctx.get_entry("file.jpg")?.unwrap(), MediaEntry { diff --git a/rslib/src/media/sync.rs b/rslib/src/media/sync.rs new file mode 100644 index 000000000..3e2df7316 --- /dev/null +++ b/rslib/src/media/sync.rs @@ -0,0 +1,636 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use crate::err::{AnkiError, Result}; +use crate::media::database::{MediaDatabaseContext, MediaEntry}; +use crate::media::files::{ + add_file_from_ankiweb, data_for_file, normalize_filename, remove_files, AddedFile, +}; +use crate::media::MediaManager; +use bytes::Bytes; +use log::debug; +use reqwest; +use reqwest::{multipart, Client, Response, StatusCode}; +use serde_derive::{Deserialize, Serialize}; +use serde_tuple::Serialize_tuple; +use std::borrow::Cow; +use std::collections::HashMap; +use std::io::{Read, Write}; +use std::path::Path; +use std::{io, time}; + +// fixme: callback using PyEval_SaveThread(); +// fixme: runCommand() could be releasing GIL, but perhaps overkill for all commands? +// fixme: sync url +// fixme: version string +// fixme: shards + +// fixme: refactor into a struct + +static SYNC_URL: &str = "https://sync.ankiweb.net/msync/"; + +static SYNC_MAX_FILES: usize = 25; +static SYNC_MAX_BYTES: usize = (2.5 * 1024.0 * 1024.0) as usize; + +#[allow(clippy::useless_let_if_seq)] +pub async fn sync_media(mgr: &mut MediaManager, hkey: &str) -> Result<()> { + // make sure media DB is up to date + mgr.register_changes()?; + + let client_usn = mgr.query(|ctx| Ok(ctx.get_meta()?.last_sync_usn))?; + + let client = Client::builder() + .connect_timeout(time::Duration::from_secs(30)) + .build()?; + + debug!("beginning media sync"); + let (sync_key, server_usn) = sync_begin(&client, hkey).await?; + debug!("server usn was {}", server_usn); + + let mut actions_performed = false; + + // need to fetch changes from server? + if client_usn != server_usn { + debug!("differs from local usn {}, fetching changes", client_usn); + fetch_changes(mgr, &client, &sync_key, client_usn).await?; + actions_performed = true; + } + + // need to send changes to server? + let changes_pending = mgr.query(|ctx| Ok(!ctx.get_pending_uploads(1)?.is_empty()))?; + if changes_pending { + send_changes(mgr, &client, &sync_key).await?; + actions_performed = true; + } + + if actions_performed { + finalize_sync(mgr, &client, &sync_key).await?; + } + + debug!("media sync complete"); + + Ok(()) +} + +#[derive(Debug, Deserialize)] +struct SyncBeginResult { + data: Option, + err: String, +} + +#[derive(Debug, Deserialize)] +struct SyncBeginResponse { + #[serde(rename = "sk")] + sync_key: String, + usn: i32, +} + +fn rewrite_forbidden(err: reqwest::Error) -> AnkiError { + if err.is_status() && err.status().unwrap() == StatusCode::FORBIDDEN { + AnkiError::AnkiWebAuthenticationFailed + } else { + err.into() + } +} + +async fn sync_begin(client: &Client, hkey: &str) -> Result<(String, i32)> { + let url = format!("{}/begin", SYNC_URL); + + let resp = client + .get(&url) + .query(&[("k", hkey), ("v", "ankidesktop,2.1.19,mac")]) + .send() + .await? + .error_for_status() + .map_err(rewrite_forbidden)?; + + let reply: SyncBeginResult = resp.json().await?; + + if let Some(data) = reply.data { + Ok((data.sync_key, data.usn)) + } else { + Err(AnkiError::AnkiWebMiscError { info: reply.err }) + } +} + +async fn fetch_changes( + mgr: &mut MediaManager, + client: &Client, + skey: &str, + client_usn: i32, +) -> Result<()> { + let mut last_usn = client_usn; + loop { + debug!("fetching record batch starting from usn {}", last_usn); + let batch = fetch_record_batch(client, skey, last_usn).await?; + if batch.is_empty() { + debug!("empty batch, done"); + break; + } + last_usn = batch.last().unwrap().usn; + + let (to_download, to_delete, to_remove_pending) = determine_required_changes(mgr, &batch)?; + + // do file removal and additions first + remove_files(mgr.media_folder.as_path(), to_delete.as_slice())?; + let downloaded = download_files( + mgr.media_folder.as_path(), + client, + skey, + to_download.as_slice(), + ) + .await?; + + // then update the DB + mgr.transact(|ctx| { + record_removals(ctx, &to_delete)?; + record_additions(ctx, downloaded)?; + record_clean(ctx, &to_remove_pending)?; + Ok(()) + })?; + } + Ok(()) +} + +#[derive(Debug, Clone, Copy)] +enum LocalState { + NotInDB, + InDBNotPending, + InDBAndPending, +} + +#[derive(PartialEq, Debug)] +enum RequiredChange { + // no also covers the case where we'll later upload + None, + Download, + Delete, + RemovePending, +} + +fn determine_required_change( + local_sha1: &str, + remote_sha1: &str, + local_state: LocalState, +) -> RequiredChange { + use LocalState as L; + use RequiredChange as R; + + match (local_sha1, remote_sha1, local_state) { + // both deleted, not in local DB + ("", "", L::NotInDB) => R::None, + // both deleted, in local DB + ("", "", _) => R::Delete, + // added on server, add even if local deletion pending + ("", _, _) => R::Download, + // deleted on server but added locally; upload later + (_, "", L::InDBAndPending) => R::None, + // deleted on server and not pending sync + (_, "", _) => R::Delete, + // if pending but the same as server, don't need to upload + (lsum, rsum, L::InDBAndPending) if lsum == rsum => R::RemovePending, + (lsum, rsum, _) => { + if lsum == rsum { + // not pending and same as server, nothing to do + R::None + } else { + // differs from server, favour server + R::Download + } + } + } +} + +/// Get a list of server filenames and the actions required on them. +/// Returns filenames in (to_download, to_delete). +fn determine_required_changes<'a>( + mgr: &mut MediaManager, + records: &'a [ServerMediaRecord], +) -> Result<(Vec<&'a String>, Vec<&'a String>, Vec<&'a String>)> { + mgr.query(|ctx| { + let mut to_download = vec![]; + let mut to_delete = vec![]; + let mut to_remove_pending = vec![]; + + for remote in records { + let (local_sha1, local_state) = match ctx.get_entry(&remote.fname)? { + Some(entry) => ( + match entry.sha1 { + Some(arr) => hex::encode(arr), + None => "".to_string(), + }, + if entry.sync_required { + LocalState::InDBAndPending + } else { + LocalState::InDBNotPending + }, + ), + None => ("".to_string(), LocalState::NotInDB), + }; + + let req_change = determine_required_change(&local_sha1, &remote.sha1, local_state); + debug!( + "for {}, lsha={} rsha={} lstate={:?} -> {:?}", + remote.fname, + local_sha1.chars().take(8).collect::(), + remote.sha1.chars().take(8).collect::(), + local_state, + req_change + ); + match req_change { + RequiredChange::Download => to_download.push(&remote.fname), + RequiredChange::Delete => to_delete.push(&remote.fname), + RequiredChange::RemovePending => to_remove_pending.push(&remote.fname), + RequiredChange::None => (), + }; + } + + Ok((to_download, to_delete, to_remove_pending)) + }) +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct RecordBatchRequest { + last_usn: i32, +} + +#[derive(Debug, Deserialize)] +struct RecordBatchResult { + data: Option>, + err: String, +} + +#[derive(Debug, Deserialize)] +struct ServerMediaRecord { + fname: String, + usn: i32, + sha1: String, +} + +async fn ankiweb_json_request( + client: &Client, + url: &str, + json: &T, + skey: &str, +) -> Result +where + T: serde::Serialize, +{ + let req_json = serde_json::to_string(json)?; + let part = multipart::Part::text(req_json); + ankiweb_request(client, url, part, skey).await +} + +async fn ankiweb_bytes_request( + client: &Client, + url: &str, + bytes: Vec, + skey: &str, +) -> Result { + let part = multipart::Part::bytes(bytes); + ankiweb_request(client, url, part, skey).await +} + +async fn ankiweb_request( + client: &Client, + url: &str, + data_part: multipart::Part, + skey: &str, +) -> Result { + let data_part = data_part.file_name("data"); + + let form = multipart::Form::new() + .part("data", data_part) + .text("sk", skey.to_string()); + + client + .post(url) + .multipart(form) + .send() + .await? + .error_for_status() + .map_err(rewrite_forbidden) +} + +async fn fetch_record_batch( + client: &Client, + skey: &str, + last_usn: i32, +) -> Result> { + let url = format!("{}/mediaChanges", SYNC_URL); + + let req = RecordBatchRequest { last_usn }; + let resp = ankiweb_json_request(client, &url, &req, skey).await?; + let res: RecordBatchResult = resp.json().await?; + + if let Some(batch) = res.data { + Ok(batch) + } else { + Err(AnkiError::AnkiWebMiscError { info: res.err }) + } +} + +async fn download_files( + media_folder: &Path, + client: &Client, + skey: &str, + mut fnames: &[&String], +) -> Result> { + let mut downloaded = vec![]; + while !fnames.is_empty() { + let batch: Vec<_> = fnames + .iter() + .take(SYNC_MAX_FILES) + .map(ToOwned::to_owned) + .collect(); + let zip_data = fetch_zip(client, skey, batch.as_slice()).await?; + let download_batch = extract_into_media_folder(media_folder, zip_data)?.into_iter(); + let len = download_batch.len(); + fnames = &fnames[len..]; + downloaded.extend(download_batch); + } + + Ok(downloaded) +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct ZipRequest<'a> { + files: &'a [&'a String], +} + +async fn fetch_zip(client: &Client, skey: &str, files: &[&String]) -> Result { + let url = format!("{}/downloadFiles", SYNC_URL); + + debug!("requesting files: {:?}", files); + + let req = ZipRequest { files }; + let resp = ankiweb_json_request(client, &url, &req, skey).await?; + resp.bytes().await.map_err(Into::into) +} + +fn extract_into_media_folder(media_folder: &Path, zip: Bytes) -> Result> { + let reader = io::Cursor::new(zip); + let mut zip = zip::ZipArchive::new(reader)?; + + let meta_file = zip.by_name("_meta")?; + let fmap: HashMap = serde_json::from_reader(meta_file)?; + let mut output = Vec::with_capacity(fmap.len()); + + for i in 0..zip.len() { + let mut file = zip.by_index(i)?; + let name = file.name(); + if name == "_meta" { + continue; + } + + let real_name = fmap.get(name).ok_or(AnkiError::AnkiWebMiscError { + info: "malformed zip received".into(), + })?; + + let mut data = Vec::with_capacity(file.size() as usize); + file.read_to_end(&mut data)?; + + debug!("writing {}", real_name); + + let added = add_file_from_ankiweb(media_folder, real_name, &data)?; + + output.push(added); + } + + Ok(output) +} + +fn record_removals(ctx: &mut MediaDatabaseContext, removals: &[&String]) -> Result<()> { + for &fname in removals { + debug!("marking removed: {}", fname); + ctx.remove_entry(fname)?; + } + + Ok(()) +} + +fn record_additions(ctx: &mut MediaDatabaseContext, additions: Vec) -> Result<()> { + for file in additions { + let entry = MediaEntry { + fname: file.fname.to_string(), + sha1: Some(file.sha1), + mtime: file.mtime, + sync_required: false, + }; + debug!( + "marking added: {} {}", + entry.fname, + hex::encode(entry.sha1.as_ref().unwrap()) + ); + ctx.set_entry(&entry)?; + } + + Ok(()) +} + +fn record_clean(ctx: &mut MediaDatabaseContext, clean: &[&String]) -> Result<()> { + for &fname in clean { + if let Some(mut entry) = ctx.get_entry(fname)? { + if entry.sync_required { + entry.sync_required = false; + debug!("marking clean: {}", entry.fname); + ctx.set_entry(&entry)?; + } + } + } + + Ok(()) +} + +async fn send_changes(mgr: &mut MediaManager, client: &Client, skey: &str) -> Result<()> { + loop { + let pending: Vec = mgr.query(|ctx: &mut MediaDatabaseContext| { + ctx.get_pending_uploads(SYNC_MAX_FILES as u32) + })?; + if pending.is_empty() { + break; + } + + let zip_data = zip_files(&mgr.media_folder, &pending)?; + send_zip_data(client, skey, zip_data).await?; + + let fnames: Vec<_> = pending.iter().map(|e| &e.fname).collect(); + mgr.transact(|ctx| record_clean(ctx, fnames.as_slice()))?; + } + + Ok(()) +} + +#[derive(Serialize_tuple)] +struct UploadEntry<'a> { + fname: &'a str, + in_zip_name: Option, +} + +fn zip_files(media_folder: &Path, files: &[MediaEntry]) -> Result> { + let buf = vec![]; + + let w = std::io::Cursor::new(buf); + let mut zip = zip::ZipWriter::new(w); + + let options = + zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored); + + let mut accumulated_size = 0; + let mut entries = vec![]; + + for (idx, file) in files.iter().enumerate() { + if accumulated_size > SYNC_MAX_BYTES { + break; + } + + let normalized = normalize_filename(&file.fname); + if let Cow::Owned(_) = normalized { + // fixme: non-string err, or should ignore instead + return Err(AnkiError::AnkiWebMiscError { + info: "Invalid filename found. Please use the Check Media function.".to_owned(), + }); + } + + let file_data = data_for_file(media_folder, &file.fname)?; + + if let Some(data) = &file_data { + if data.is_empty() { + // fixme: should ignore these, not error + return Err(AnkiError::AnkiWebMiscError { + info: "0 byte file found".to_owned(), + }); + } + accumulated_size += data.len(); + zip.start_file(format!("{}", idx), options)?; + zip.write_all(data)?; + } + + debug!( + "will upload {} as {}", + file.fname, + if file_data.is_some() { + "addition " + } else { + "removal" + } + ); + + entries.push(UploadEntry { + fname: &file.fname, + in_zip_name: if file_data.is_some() { + Some(format!("{}", idx)) + } else { + None + }, + }); + } + + let meta = serde_json::to_string(&entries)?; + zip.start_file("_meta", options)?; + zip.write_all(meta.as_bytes())?; + + let w = zip.finish()?; + + Ok(w.into_inner()) +} + +async fn send_zip_data(client: &Client, skey: &str, data: Vec) -> Result<()> { + let url = format!("{}/uploadChanges", SYNC_URL); + + ankiweb_bytes_request(client, &url, data, skey).await?; + + Ok(()) +} + +#[derive(Serialize)] +struct FinalizeRequest { + local: u32, +} + +#[derive(Debug, Deserialize)] +struct FinalizeResponse { + data: Option, + err: String, +} + +async fn finalize_sync(mgr: &mut MediaManager, client: &Client, skey: &str) -> Result<()> { + let url = format!("{}/mediaSanity", SYNC_URL); + let local = mgr.query(|ctx| ctx.count())?; + + let obj = FinalizeRequest { local }; + let resp = ankiweb_json_request(client, &url, &obj, skey).await?; + let resp: FinalizeResponse = resp.json().await?; + + if let Some(data) = resp.data { + if data == "OK" { + Ok(()) + } else { + // fixme: force resync + Err(AnkiError::AnkiWebMiscError { + info: "resync required ".into(), + }) + } + } else { + Err(AnkiError::AnkiWebMiscError { + info: format!("finalize failed: {}", resp.err), + }) + } +} + +#[cfg(test)] +mod test { + use crate::err::Result; + use crate::media::sync::{determine_required_change, sync_media, LocalState, RequiredChange}; + use crate::media::MediaManager; + use tempfile::tempdir; + use tokio::runtime::Runtime; + + async fn test_sync(hkey: &str) -> Result<()> { + let dir = tempdir()?; + let media_dir = dir.path().join("media"); + std::fs::create_dir(&media_dir)?; + let media_db = dir.path().join("media.db"); + + std::fs::write(media_dir.join("test.file").as_path(), "hello")?; + + let mut mgr = MediaManager::new(&media_dir, &media_db)?; + + sync_media(&mut mgr, hkey).await?; + + Ok(()) + } + + #[test] + fn sync() { + env_logger::init(); + + let hkey = match std::env::var("TEST_HKEY") { + Ok(s) => s, + Err(_) => { + return; + } + }; + + let mut rt = Runtime::new().unwrap(); + rt.block_on(test_sync(&hkey)).unwrap() + } + + #[test] + fn required_change() { + use determine_required_change as d; + use LocalState as L; + use RequiredChange as R; + assert_eq!(d("", "", L::NotInDB), R::None); + assert_eq!(d("", "", L::InDBNotPending), R::Delete); + assert_eq!(d("", "1", L::InDBAndPending), R::Download); + assert_eq!(d("1", "", L::InDBAndPending), R::None); + assert_eq!(d("1", "", L::InDBNotPending), R::Delete); + assert_eq!(d("1", "1", L::InDBNotPending), R::None); + assert_eq!(d("1", "1", L::InDBAndPending), R::RemovePending); + assert_eq!(d("a", "b", L::InDBAndPending), R::Download); + assert_eq!(d("a", "b", L::InDBNotPending), R::Download); + } +}