diff --git a/rslib/src/import_export/package/apkg/import/media.rs b/rslib/src/import_export/package/apkg/import/media.rs index 3a294107e..b1350c5f8 100644 --- a/rslib/src/import_export/package/apkg/import/media.rs +++ b/rslib/src/import_export/package/apkg/import/media.rs @@ -8,13 +8,13 @@ use zip::ZipArchive; use super::Context; use crate::{ import_export::{ - package::media::{extract_media_entries, SafeMediaEntry}, + package::{ + colpkg::export::MediaCopier, + media::{extract_media_entries, SafeMediaEntry}, + }, ImportProgress, IncrementableProgress, }, - media::{ - files::{add_hash_suffix_to_file_stem, sha1_of_reader}, - MediaManager, - }, + media::files::{add_hash_suffix_to_file_stem, sha1_of_reader}, prelude::*, }; @@ -37,7 +37,9 @@ impl Context<'_> { } let db_progress_fn = self.progress.media_db_fn(ImportProgress::MediaCheck)?; - let existing_sha1s = self.target_col.all_existing_sha1s(db_progress_fn)?; + let existing_sha1s = self + .media_manager + .all_checksums(db_progress_fn, &self.target_col.log)?; prepare_media( media_entries, @@ -49,21 +51,21 @@ impl Context<'_> { pub(super) fn copy_media(&mut self, media_map: &mut MediaUseMap) -> Result<()> { let mut incrementor = self.progress.incrementor(ImportProgress::Media); - for entry in media_map.used_entries() { - incrementor.increment()?; - entry.copy_from_archive(&mut self.archive, &self.target_col.media_folder)?; - } - Ok(()) - } -} - -impl Collection { - fn all_existing_sha1s( - &mut self, - progress_fn: impl FnMut(usize) -> bool, - ) -> Result> { - let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; - mgr.all_checksums(progress_fn, &self.log) + let mut dbctx = self.media_manager.dbctx(); + let mut copier = MediaCopier::new(false); + self.media_manager.transact(&mut dbctx, |dbctx| { + for entry in media_map.used_entries() { + incrementor.increment()?; + entry.copy_and_ensure_sha1_set( + &mut self.archive, + &self.target_col.media_folder, + &mut copier, + )?; + self.media_manager + .add_entry(dbctx, &entry.name, entry.sha1.unwrap())?; + } + Ok(()) + }) } } @@ -84,8 +86,8 @@ fn prepare_media( media_map.unchecked.push(entry); } } else if let Some(other_sha1) = existing_sha1s.get(&entry.name) { - entry.with_hash_from_archive(archive)?; - if entry.sha1 != *other_sha1 { + entry.ensure_sha1_set(archive)?; + if entry.sha1.unwrap() != *other_sha1 { let original_name = entry.uniquify_name(); media_map.add_checked(original_name, entry); } @@ -108,26 +110,26 @@ impl MediaUseMap { }) } - pub(super) fn used_entries(&self) -> impl Iterator { + pub(super) fn used_entries(&mut self) -> impl Iterator { self.checked - .values() + .values_mut() .filter_map(|(used, entry)| used.then(|| entry)) - .chain(self.unchecked.iter()) + .chain(self.unchecked.iter_mut()) } } impl SafeMediaEntry { - fn with_hash_from_archive(&mut self, archive: &mut ZipArchive) -> Result<()> { - if self.sha1 == [0; 20] { + fn ensure_sha1_set(&mut self, archive: &mut ZipArchive) -> Result<()> { + if self.sha1.is_none() { let mut reader = self.fetch_file(archive)?; - self.sha1 = sha1_of_reader(&mut reader)?; + self.sha1 = Some(sha1_of_reader(&mut reader)?); } Ok(()) } /// Requires sha1 to be set. Returns old file name. fn uniquify_name(&mut self) -> String { - let new_name = add_hash_suffix_to_file_stem(&self.name, &self.sha1); + let new_name = add_hash_suffix_to_file_stem(&self.name, &self.sha1.expect("sha1 not set")); mem::replace(&mut self.name, new_name) } diff --git a/rslib/src/import_export/package/apkg/import/mod.rs b/rslib/src/import_export/package/apkg/import/mod.rs index 9581c7909..e0cc2b2d8 100644 --- a/rslib/src/import_export/package/apkg/import/mod.rs +++ b/rslib/src/import_export/package/apkg/import/mod.rs @@ -19,12 +19,14 @@ use crate::{ import_export::{ gather::ExchangeData, package::Meta, ImportProgress, IncrementableProgress, NoteLog, }, + media::MediaManager, prelude::*, search::SearchNode, }; struct Context<'a> { target_col: &'a mut Collection, + media_manager: MediaManager, archive: ZipArchive, meta: Meta, data: ExchangeData, @@ -56,6 +58,7 @@ impl<'a> Context<'a> { ) -> Result { let mut progress = IncrementableProgress::new(progress_fn); progress.call(ImportProgress::Extracting)?; + let media_manager = MediaManager::new(&target_col.media_folder, &target_col.media_db)?; let meta = Meta::from_archive(&mut archive)?; let data = ExchangeData::gather_from_archive( &mut archive, @@ -67,6 +70,7 @@ impl<'a> Context<'a> { let usn = target_col.usn()?; Ok(Self { target_col, + media_manager, archive, meta, data, diff --git a/rslib/src/import_export/package/apkg/tests.rs b/rslib/src/import_export/package/apkg/tests.rs index 812f6a963..304839f4e 100644 --- a/rslib/src/import_export/package/apkg/tests.rs +++ b/rslib/src/import_export/package/apkg/tests.rs @@ -6,7 +6,10 @@ use std::{collections::HashSet, fs::File, io::Write}; use crate::{ - media::files::sha1_of_data, prelude::*, search::SearchNode, tests::open_fs_test_collection, + media::{files::sha1_of_data, MediaManager}, + prelude::*, + search::SearchNode, + tests::open_fs_test_collection, }; const SAMPLE_JPG: &str = "sample.jpg"; @@ -132,9 +135,13 @@ impl Collection { fn assert_note_and_media(&mut self, note: &Note) { let sha1 = sha1_of_data(MP3_DATA); let new_mp3_name = format!("sample-{}.mp3", hex::encode(&sha1)); + let csums = MediaManager::new(&self.media_folder, &self.media_db) + .unwrap() + .all_checksums_as_is(); for file in [SAMPLE_JPG, SAMPLE_JS, &new_mp3_name] { - assert!(self.media_folder.join(file).exists()) + assert!(self.media_folder.join(file).exists()); + assert!(*csums.get(file).unwrap() != [0; 20]); } let imported_note = self.storage.get_note(note.id).unwrap().unwrap(); diff --git a/rslib/src/import_export/package/colpkg/export.rs b/rslib/src/import_export/package/colpkg/export.rs index 02f5bf5c0..79f383f5e 100644 --- a/rslib/src/import_export/package/colpkg/export.rs +++ b/rslib/src/import_export/package/colpkg/export.rs @@ -282,7 +282,7 @@ fn write_media_files( media_entries: &mut Vec, progress: &mut IncrementableProgress, ) -> Result<()> { - let mut copier = MediaCopier::new(meta); + let mut copier = MediaCopier::new(meta.zstd_compressed()); let mut incrementor = progress.incrementor(ExportProgress::Media); for (index, res) in media.0.enumerate() { incrementor.increment()?; @@ -315,18 +315,20 @@ fn normalized_unicode_file_name(filename: &OsStr) -> Result { .ok_or(AnkiError::MediaCheckRequired) } -/// Copies and hashes while encoding according to the targeted version. +/// Copies and hashes while optionally encoding. /// If compressing, the encoder is reused to optimize for repeated calls. -struct MediaCopier { +pub(crate) struct MediaCopier { encoding: bool, encoder: Option>, + buf: [u8; 64 * 1024], } impl MediaCopier { - fn new(meta: &Meta) -> Self { + pub(crate) fn new(encoding: bool) -> Self { Self { - encoding: meta.zstd_compressed(), + encoding, encoder: None, + buf: [0; 64 * 1024], } } @@ -339,25 +341,25 @@ impl MediaCopier { } /// Returns size and sha1 hash of the copied data. - fn copy( + pub(crate) fn copy( &mut self, reader: &mut impl Read, writer: &mut impl Write, ) -> Result<(usize, Sha1Hash)> { let mut size = 0; let mut hasher = Sha1::new(); - let mut buf = [0; 64 * 1024]; + self.buf = [0; 64 * 1024]; let mut wrapped_writer = MaybeEncodedWriter::new(writer, self.encoder()); loop { - let count = match reader.read(&mut buf) { + let count = match reader.read(&mut self.buf) { Ok(0) => break, Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, result => result?, }; size += count; - hasher.update(&buf[..count]); - wrapped_writer.write(&buf[..count])?; + hasher.update(&self.buf[..count]); + wrapped_writer.write(&self.buf[..count])?; } self.encoder = wrapped_writer.finish()?; @@ -410,7 +412,7 @@ mod test { let bytes_hash = sha1_of_data(b"foo"); for meta in [Meta::new_legacy(), Meta::new()] { - let mut writer = MediaCopier::new(&meta); + let mut writer = MediaCopier::new(meta.zstd_compressed()); let mut buf = Vec::new(); let (size, hash) = writer.copy(&mut bytes.as_slice(), &mut buf).unwrap(); diff --git a/rslib/src/import_export/package/media.rs b/rslib/src/import_export/package/media.rs index 3613ffa72..0b66e24dd 100644 --- a/rslib/src/import_export/package/media.rs +++ b/rslib/src/import_export/package/media.rs @@ -14,7 +14,7 @@ use tempfile::NamedTempFile; use zip::{read::ZipFile, ZipArchive}; use zstd::stream::copy_decode; -use super::{MediaEntries, MediaEntry, Meta}; +use super::{colpkg::export::MediaCopier, MediaEntries, MediaEntry, Meta}; use crate::{ error::ImportError, io::{atomic_rename, filename_is_safe}, @@ -26,7 +26,7 @@ use crate::{ pub(super) struct SafeMediaEntry { pub(super) name: String, pub(super) size: u32, - pub(super) sha1: Sha1Hash, + pub(super) sha1: Option, pub(super) index: usize, } @@ -53,7 +53,7 @@ impl SafeMediaEntry { return Ok(Self { name: entry.name, size: entry.size, - sha1, + sha1: Some(sha1), index, }); } @@ -70,7 +70,7 @@ impl SafeMediaEntry { Ok(Self { name, size: 0, - sha1: [0; 20], + sha1: None, index: zip_filename, }) } @@ -89,21 +89,29 @@ impl SafeMediaEntry { &self, get_checksum: &mut impl FnMut(&str) -> Result>, ) -> Result { - get_checksum(&self.name).map(|opt| opt.map_or(false, |sha1| sha1 == self.sha1)) + get_checksum(&self.name) + .map(|opt| opt.map_or(false, |sha1| sha1 == self.sha1.expect("sha1 not set"))) } pub(super) fn has_size_equal_to(&self, other_path: &Path) -> bool { fs::metadata(other_path).map_or(false, |metadata| metadata.len() == self.size as u64) } - pub(super) fn copy_from_archive( - &self, + /// Copy the archived file to the target folder, setting its hash if necessary. + pub(super) fn copy_and_ensure_sha1_set( + &mut self, archive: &mut ZipArchive, target_folder: &Path, + copier: &mut MediaCopier, ) -> Result<()> { let mut file = self.fetch_file(archive)?; let mut tempfile = NamedTempFile::new_in(target_folder)?; - io::copy(&mut file, &mut tempfile)?; + if self.sha1.is_none() { + let (_, sha1) = copier.copy(&mut file, &mut tempfile)?; + self.sha1 = Some(sha1); + } else { + io::copy(&mut file, &mut tempfile)?; + } atomic_rename(tempfile, &self.file_path(target_folder), false) } } diff --git a/rslib/src/media/mod.rs b/rslib/src/media/mod.rs index de3fb73e6..2cc7044d5 100644 --- a/rslib/src/media/mod.rs +++ b/rslib/src/media/mod.rs @@ -52,31 +52,23 @@ impl MediaManager { /// appended to the name. /// /// Also notes the file in the media database. - #[allow(clippy::match_like_matches_macro)] pub fn add_file<'a>( &self, ctx: &mut MediaDatabaseContext, desired_name: &'a str, data: &[u8], ) -> Result> { - let pre_add_folder_mtime = mtime_as_i64(&self.media_folder)?; - - // add file to folder let data_hash = sha1_of_data(data); - let chosen_fname = - add_data_to_folder_uniquely(&self.media_folder, desired_name, data, data_hash)?; - let file_mtime = mtime_as_i64(self.media_folder.join(chosen_fname.as_ref()))?; - let post_add_folder_mtime = mtime_as_i64(&self.media_folder)?; - // add to the media DB - ctx.transact(|ctx| { + self.transact(ctx, |ctx| { + let chosen_fname = + add_data_to_folder_uniquely(&self.media_folder, desired_name, data, data_hash)?; + let file_mtime = mtime_as_i64(self.media_folder.join(chosen_fname.as_ref()))?; + let existing_entry = ctx.get_entry(&chosen_fname)?; let new_sha1 = Some(data_hash); - let entry_update_required = match existing_entry { - Some(existing) if existing.sha1 == new_sha1 => false, - _ => true, - }; + let entry_update_required = existing_entry.map(|e| e.sha1 != new_sha1).unwrap_or(true); if entry_update_required { ctx.set_entry(&MediaEntry { @@ -87,34 +79,16 @@ impl MediaManager { })?; } - let mut meta = ctx.get_meta()?; - if meta.folder_mtime == pre_add_folder_mtime { - // if media db was in sync with folder prior to this add, - // we can keep it in sync - meta.folder_mtime = post_add_folder_mtime; - ctx.set_meta(&meta)?; - } else { - // otherwise, leave it alone so that other pending changes - // get picked up later - } - - Ok(()) - })?; - - Ok(chosen_fname) + Ok(chosen_fname) + }) } pub fn remove_files(&self, ctx: &mut MediaDatabaseContext, filenames: &[S]) -> Result<()> where S: AsRef + std::fmt::Debug, { - let pre_remove_folder_mtime = mtime_as_i64(&self.media_folder)?; - - remove_files(&self.media_folder, filenames)?; - - let post_remove_folder_mtime = mtime_as_i64(&self.media_folder)?; - - ctx.transact(|ctx| { + self.transact(ctx, |ctx| { + remove_files(&self.media_folder, filenames)?; for fname in filenames { if let Some(mut entry) = ctx.get_entry(fname.as_ref())? { entry.sha1 = None; @@ -123,19 +97,50 @@ impl MediaManager { ctx.set_entry(&entry)?; } } + Ok(()) + }) + } + + /// Opens a transaction and manages folder mtime, so user should perform not + /// only db ops, but also all file ops inside the closure. + pub(crate) fn transact( + &self, + ctx: &mut MediaDatabaseContext, + func: impl FnOnce(&mut MediaDatabaseContext) -> Result, + ) -> Result { + let start_folder_mtime = mtime_as_i64(&self.media_folder)?; + ctx.transact(|ctx| { + let out = func(ctx)?; let mut meta = ctx.get_meta()?; - if meta.folder_mtime == pre_remove_folder_mtime { + if meta.folder_mtime == start_folder_mtime { // if media db was in sync with folder prior to this add, // we can keep it in sync - meta.folder_mtime = post_remove_folder_mtime; + meta.folder_mtime = mtime_as_i64(&self.media_folder)?; ctx.set_meta(&meta)?; } else { // otherwise, leave it alone so that other pending changes // get picked up later } - Ok(()) + Ok(out) + }) + } + + /// Set entry for a newly added file. Caller must ensure transaction. + pub(crate) fn add_entry( + &self, + ctx: &mut MediaDatabaseContext, + fname: impl Into, + sha1: [u8; 20], + ) -> Result<()> { + let fname = fname.into(); + let mtime = mtime_as_i64(self.media_folder.join(&fname))?; + ctx.set_entry(&MediaEntry { + fname, + mtime, + sha1: Some(sha1), + sync_required: true, }) } @@ -185,3 +190,16 @@ impl MediaManager { ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut self.dbctx()) } } + +#[cfg(test)] +mod test { + use super::*; + + impl MediaManager { + /// All checksums without registering changes first. + pub(crate) fn all_checksums_as_is(&self) -> HashMap { + let mut dbctx = self.dbctx(); + dbctx.all_checksums().unwrap() + } + } +}