From 5f9451f547dbdbb15b47c9b23dd002c122db748f Mon Sep 17 00:00:00 2001 From: RumovZ Date: Mon, 2 May 2022 13:12:46 +0200 Subject: [PATCH] Add apkg import/export on backend (#1743) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add apkg export on backend * Filter out missing media-paths at write time * Make TagMatcher::new() infallible * Gather export data instead of copying directly * Revert changes to rslib/src/tags/ * Reuse filename_is_safe/check_filename_safe() * Accept func to produce MediaIter in export_apkg() * Only store file folder once in MediaIter * Use temporary tables for gathering export_apkg() now accepts a search instead of a deck id. Decks are gathered according to the matched notes' cards. * Use schedule_as_new() to reset cards * ExportData → ExchangeData * Ignore ascii case when filtering system tags * search_notes_cards_into_table → search_cards_of_notes_into_table * Start on apkg importing on backend * Fix due dates in days for apkg export * Refactor import-export/package - Move media and meta code into appropriate modules. - Normalize/check for normalization when deserializing media entries. * Add SafeMediaEntry for deserialized MediaEntries * Prepare media based on checksums - Ensure all existing media files are hashed. - Hash incoming files during preparation to detect conflicts. - Uniquify names of conflicting files with hash (not notetype id). - Mark media files as used while importing notes. - Finally copy used media. * Handle encoding in `replace_media_refs()` * Add trait to keep down cow boilerplate * Add notetypes immediately instaed of preparing * Move target_col into Context * Add notes immediately instaed of preparing * Note id, not guid of conflicting notes * Add import_decks() * decks_configs → deck_configs * Add import_deck_configs() * Add import_cards(), import_revlog() * Use dyn instead of generic for media_fn Otherwise, would have to pass None with type annotation in the default case. * Fix signature of import_apkg() * Fix search_cards_of_notes_into_table() * Test new functions in text.rs * Add roundtrip test for apkg (stub) * Keep source id of imported cards (or skip) * Keep source ids of imported revlog (or skip) * Try to keep source ids of imported notes * Make adding notetype with id undoable * Wrap apkg import in transaction * Keep source ids of imported deck configs (or skip) * Handle card due dates and original due/did * Fix importing cards/revlog Card ids are manually uniquified. * Factor out card importing * Refactor card and revlog importing * Factor out card importing Also handle missing parents . * Factor out note importing * Factor out media importing * Maybe upgrade scheduler of apkg * Fix parent deck gathering * Unconditionally import static media * Fix deck importing edge cases Test those edge cases, and add some global test helpers. * Test note importing * Let import_apkg() take a progress func * Expand roundtrip apkg test * Use fat pointer to avoid propogating generics * Fix progress_fn type * Expose apkg export/import on backend * Return note log when importing apkg * Fix archived collection name on apkg import * Add CollectionOpWithBackendProgress * Fix wrong Interrupted Exception being checked * Add ClosedCollectionOp * Add note ids to log and strip HTML * Update progress when checking incoming media too * Conditionally enable new importing in GUI * Fix all_checksums() for media import Entries of deleted files are nulled, not removed. * Make apkg exporting on backend abortable * Return number of notes imported from apkg * Fix exception printing for QueryOp as well * Add QueryOpWithBackendProgress Also support backend exporting progress. * Expose new apkg and colpkg exporting * Open transaction in insert_data() Was slowing down exporting by several orders of magnitude. * Handle zstd-compressed apkg * Add legacy arg to ExportAnkiPackage Currently not exposed on the frontend * Remove unused import in proto file * Add symlink for typechecking of import_export_pb2 * Avoid kwargs in pb message creation, so typechecking is not lost Protobuf's behaviour is rather subtle and I had to dig through the docs to figure it out: set a field on a submessage to automatically assign the submessage to the parent, or call SetInParent() to persist a default version of the field you specified. * Avoid re-exporting protobuf msgs we only use internally * Stop after one test failure mypy often fails much faster than pylint * Avoid an extra allocation when extracting media checksums * Update progress after prepare_media() finishes Otherwise the bulk of the import ends up being shown as "Checked: 0" in the progress window. * Show progress of note imports Note import is the slowest part, so showing progress here makes the UI feel more responsive. * Reset filtered decks at import time Before this change, filtered decks exported with scheduling remained filtered on import, and maybe_remove_from_filtered_deck() moved cards into them as their home deck, leading to errors during review. We may still want to provide a way to preserve filtered decks on import, but to do that we'll need to ensure we don't rewrite the home decks of cards, and we'll need to ensure the home decks are included as part of the import (or give an error if they're not). https://github.com/ankitects/anki/pull/1743/files#r839346423 * Fix a corner-case where due dates were shifted by a day This issue existed in the old Python code as well. We need to include the user's UTC offset in the exported file, or days_elapsed falls back on the v1 cutoff calculation, which may be a day earlier or later than the v2 calculation. * Log conflicting note in remapped nt case * take_fields() → into_fields() * Alias `[u8; 20]` with `Sha1Hash` * Truncate logged fields * Rework apkg note import tests - Use macros for more helpful errors. - Split monolith into unit tests. - Fix some unknown error with the previous test along the way. (Was failing after 969484de4388d225c9f17d94534b3ba0094c3568.) * Fix sorting of imported decks Also adjust the test, so it fails without the patch. It was only passing before, because the parent deck happened to come before the inconsistently capitalised child alphabetically. But we want all parent decks to be imported before their child decks, so their children can adopt their capitalisation. * target[_id]s → existing_card[_id]s * export_collection_extracting_media() → ... export_into_collection_file() * target_already_exists→card_ordinal_already_exists * Add search_cards_of_notes_into_table.sql * Imrove type of apkg export selector/limit * Remove redundant call to mod_schema() * Parent tooltips to mw * Fix a crash when truncating note text String::truncate() is a bit of a footgun, and I've hit this before too :-) * Remove ExportLimit in favour of separate classes * Remove OpWithBackendProgress and ClosedCollectionOp Backend progress logic is now in ProgressManager. QueryOp can be used for running on closed collection. Also fix aborting of colpkg exports, which slipped through in #1817. * Tidy up import log * Avoid QDialog.exec() * Default to excluding scheuling for deck list deck * Use IncrementalProgress in whole import_export code * Compare checksums when importing colpkgs * Avoid registering changes if hashes are not needed * ImportProgress::Collection → ImportProgress::File * Make downgrading apkgs depend on meta version * Generalise IncrementableProgress And use it in entire import_export code instead. * Fix type complexity lint * Take count_map for IncrementableProgress::get_inner * Replace import/export env with Shift click * Accept all args from update() for backend progress * Pass fields of ProgressUpdate explicitly * Move update_interval into IncrementableProgress * Outsource incrementing into Incrementor * Mutate ProgressUpdate in progress_update callback * Switch import/export legacy toggle to profile setting Shift would have been nice, but the existing shortcuts complicate things. If the user triggers an import with ctrl+shift+i, shift is unlikely to have been released by the time our code runs, meaning the user accidentally triggers the new code. We could potentially wait a while before bringing up the dialog, but then we're forced to guess at how long it will take the user to release the key. One alternative would be to use alt instead of shift, but then we need to trigger our shortcut when that key is pressed as well, and it could potentially cause a conflict with an add-on that already uses that combination. * Show extension in export dialog * Continue to provide separate options for schema 11+18 colpkg export * Default to colpkg export when using File>Export * Improve appearance of combo boxes when switching between apkg/colpkg + Deal with long deck names * Convert newlines to spaces when showing fields from import Ensures each imported note appears on a separate line * Don't separate total note count from the other summary lines This may come down to personal preference, but I feel the other counts are equally as important, and separating them feels like it makes it a bit easier to ignore them. * Fix 'deck not normal' error when importing a filtered deck for the 2nd time * Fix [Identical] being shown on first import * Revert "Continue to provide separate options for schema 11+18 colpkg export" This reverts commit 8f0b2c175f4794d642823b60414d142a12768441. Will use a different approach * Move legacy support into a separate exporter option; add to apkg export * Adjust 'too new' message to also apply to .apkg import case * Show a better message when attempting to import new apkg into old code Previously the user could end seeing a message like: UnicodeDecodeError: 'utf-8' codec can't decode byte 0xb5 in position 1: invalid start byte Unfortunately we can't retroactively fix this for older clients. * Hide legacy support option in older exporting screen * Reflect change from paths to fnames in type & name * Make imported decks normal at once Then skip special casing in update_deck(). Also skip updating description if new one is empty. Co-authored-by: Damien Elmes --- .bazelrc | 3 + ftl/core/exporting.ftl | 3 +- ftl/core/importing.ftl | 6 + proto/anki/import_export.proto | 37 ++ pylib/.pylintrc | 1 + pylib/anki/collection.py | 68 ++- pylib/anki/exporting.py | 2 +- pylib/anki/import_export_pb2.pyi | 1 + pylib/anki/importing/anki2.py | 4 + pylib/anki/importing/apkg.py | 8 +- qt/aqt/browser/browser.py | 11 +- qt/aqt/errors.py | 4 +- qt/aqt/exporting.py | 1 + qt/aqt/forms/exporting.ui | 29 +- qt/aqt/import_export/__init__.py | 0 qt/aqt/import_export/exporting.py | 252 ++++++++++ qt/aqt/import_export/importing.py | 150 ++++++ qt/aqt/importing.py | 80 +-- qt/aqt/main.py | 23 +- qt/aqt/operations/__init__.py | 86 +++- qt/aqt/profiles.py | 6 + qt/aqt/progress.py | 46 ++ qt/aqt/taskman.py | 23 + rslib/src/backend/generic.rs | 6 + rslib/src/backend/import_export.rs | 84 +++- rslib/src/backend/progress.rs | 4 +- rslib/src/card/undo.rs | 8 + rslib/src/deckconfig/undo.rs | 7 + rslib/src/decks/addupdate.rs | 2 +- rslib/src/import_export/gather.rs | 222 +++++++++ rslib/src/import_export/insert.rs | 62 +++ rslib/src/import_export/mod.rs | 79 ++- .../src/import_export/package/apkg/export.rs | 106 ++++ .../package/apkg/import/cards.rs | 179 +++++++ .../package/apkg/import/decks.rs | 212 ++++++++ .../package/apkg/import/media.rs | 134 +++++ .../import_export/package/apkg/import/mod.rs | 137 ++++++ .../package/apkg/import/notes.rs | 459 ++++++++++++++++++ rslib/src/import_export/package/apkg/mod.rs | 8 + rslib/src/import_export/package/apkg/tests.rs | 150 ++++++ .../import_export/package/colpkg/export.rs | 115 +++-- .../import_export/package/colpkg/import.rs | 232 +++------ .../src/import_export/package/colpkg/tests.rs | 19 +- rslib/src/import_export/package/media.rs | 174 +++++++ rslib/src/import_export/package/meta.rs | 51 +- rslib/src/import_export/package/mod.rs | 4 + rslib/src/io.rs | 34 +- rslib/src/lib.rs | 1 + rslib/src/media/changetracker.rs | 4 +- rslib/src/media/database.rs | 23 +- rslib/src/media/files.rs | 32 +- rslib/src/media/mod.rs | 33 +- rslib/src/notes/mod.rs | 6 +- rslib/src/notes/undo.rs | 12 +- rslib/src/notetype/mod.rs | 2 +- rslib/src/notetype/undo.rs | 11 + rslib/src/ops.rs | 2 + rslib/src/prelude.rs | 1 + rslib/src/revlog/undo.rs | 10 +- rslib/src/search/builder.rs | 16 + rslib/src/stats/graphs.rs | 2 +- rslib/src/storage/card/add_card_if_unique.sql | 41 ++ rslib/src/storage/card/mod.rs | 52 ++ .../card/search_cards_of_notes_into_table.sql | 7 + rslib/src/storage/deck/mod.rs | 20 + .../src/storage/deckconfig/add_if_unique.sql | 3 + rslib/src/storage/deckconfig/mod.rs | 16 + rslib/src/storage/mod.rs | 28 +- rslib/src/storage/note/add_if_unique.sql | 27 ++ rslib/src/storage/note/mod.rs | 63 ++- rslib/src/storage/notetype/mod.rs | 17 + rslib/src/storage/revlog/mod.rs | 27 +- rslib/src/tests.rs | 63 +++ rslib/src/text.rs | 215 ++++++-- 74 files changed, 3617 insertions(+), 449 deletions(-) create mode 120000 pylib/anki/import_export_pb2.pyi create mode 100644 qt/aqt/import_export/__init__.py create mode 100644 qt/aqt/import_export/exporting.py create mode 100644 qt/aqt/import_export/importing.py create mode 100644 rslib/src/import_export/gather.rs create mode 100644 rslib/src/import_export/insert.rs create mode 100644 rslib/src/import_export/package/apkg/export.rs create mode 100644 rslib/src/import_export/package/apkg/import/cards.rs create mode 100644 rslib/src/import_export/package/apkg/import/decks.rs create mode 100644 rslib/src/import_export/package/apkg/import/media.rs create mode 100644 rslib/src/import_export/package/apkg/import/mod.rs create mode 100644 rslib/src/import_export/package/apkg/import/notes.rs create mode 100644 rslib/src/import_export/package/apkg/mod.rs create mode 100644 rslib/src/import_export/package/apkg/tests.rs create mode 100644 rslib/src/import_export/package/media.rs create mode 100644 rslib/src/storage/card/add_card_if_unique.sql create mode 100644 rslib/src/storage/card/search_cards_of_notes_into_table.sql create mode 100644 rslib/src/storage/deckconfig/add_if_unique.sql create mode 100644 rslib/src/storage/note/add_if_unique.sql create mode 100644 rslib/src/tests.rs diff --git a/.bazelrc b/.bazelrc index 872241350..ba35396fe 100644 --- a/.bazelrc +++ b/.bazelrc @@ -20,6 +20,9 @@ test --aspects=@rules_rust//rust:defs.bzl%rust_clippy_aspect --output_groups=+cl # print output when test fails test --test_output=errors +# stop after one test failure +test --notest_keep_going + # don't add empty __init__.py files build --incompatible_default_to_explicit_init_py diff --git a/ftl/core/exporting.ftl b/ftl/core/exporting.ftl index b396001f3..e18b921b0 100644 --- a/ftl/core/exporting.ftl +++ b/ftl/core/exporting.ftl @@ -5,7 +5,7 @@ exporting-anki-deck-package = Anki Deck Package exporting-cards-in-plain-text = Cards in Plain Text exporting-collection = collection exporting-collection-exported = Collection exported. -exporting-colpkg-too-new = Please update to the latest Anki version, then import the .colpkg file again. +exporting-colpkg-too-new = Please update to the latest Anki version, then import the .colpkg/.apkg file again. exporting-couldnt-save-file = Couldn't save file: { $val } exporting-export = Export... exporting-export-format = Export format: @@ -14,6 +14,7 @@ exporting-include-html-and-media-references = Include HTML and media references exporting-include-media = Include media exporting-include-scheduling-information = Include scheduling information exporting-include-tags = Include tags +exporting-support-older-anki-versions = Support older Anki versions (slower/larger files) exporting-notes-in-plain-text = Notes in Plain Text exporting-selected-notes = Selected Notes exporting-card-exported = diff --git a/ftl/core/importing.ftl b/ftl/core/importing.ftl index 9ecfbb3bf..27077eabc 100644 --- a/ftl/core/importing.ftl +++ b/ftl/core/importing.ftl @@ -78,4 +78,10 @@ importing-processed-media-file = *[other] Imported { $count } media files } importing-importing-collection = Importing collection... +importing-importing-file = Importing file... importing-failed-to-import-media-file = Failed to import media file: { $debugInfo } +importing-processed-notes = + { $count -> + [one] Processed { $count } note... + *[other] Processed { $count } notes... + } diff --git a/proto/anki/import_export.proto b/proto/anki/import_export.proto index ea8bfe8ad..2b5bb74ba 100644 --- a/proto/anki/import_export.proto +++ b/proto/anki/import_export.proto @@ -5,6 +5,8 @@ syntax = "proto3"; package anki.import_export; +import "anki/collection.proto"; +import "anki/notes.proto"; import "anki/generic.proto"; service ImportExportService { @@ -12,12 +14,16 @@ service ImportExportService { returns (generic.Empty); rpc ExportCollectionPackage(ExportCollectionPackageRequest) returns (generic.Empty); + rpc ImportAnkiPackage(ImportAnkiPackageRequest) + returns (ImportAnkiPackageResponse); + rpc ExportAnkiPackage(ExportAnkiPackageRequest) returns (generic.UInt32); } message ImportCollectionPackageRequest { string col_path = 1; string backup_path = 2; string media_folder = 3; + string media_db = 4; } message ExportCollectionPackageRequest { @@ -26,6 +32,37 @@ message ExportCollectionPackageRequest { bool legacy = 3; } +message ImportAnkiPackageRequest { + string package_path = 1; +} + +message ImportAnkiPackageResponse { + message Note { + notes.NoteId id = 1; + repeated string fields = 2; + } + message Log { + repeated Note new = 1; + repeated Note updated = 2; + repeated Note duplicate = 3; + repeated Note conflicting = 4; + } + collection.OpChanges changes = 1; + Log log = 2; +} + +message ExportAnkiPackageRequest { + string out_path = 1; + bool with_scheduling = 2; + bool with_media = 3; + bool legacy = 4; + oneof selector { + generic.Empty whole_collection = 5; + int64 deck_id = 6; + notes.NoteIds note_ids = 7; + } +} + message PackageMetadata { enum Version { VERSION_UNKNOWN = 0; diff --git a/pylib/.pylintrc b/pylib/.pylintrc index c9d947569..12152de9e 100644 --- a/pylib/.pylintrc +++ b/pylib/.pylintrc @@ -22,6 +22,7 @@ ignored-classes= CustomStudyRequest, Cram, ScheduleCardsAsNewRequest, + ExportAnkiPackageRequest, [REPORTS] output-format=colorized diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 3d695b9a3..ae121c8ff 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -10,6 +10,7 @@ from anki import ( collection_pb2, config_pb2, generic_pb2, + import_export_pb2, links_pb2, search_pb2, stats_pb2, @@ -32,6 +33,7 @@ OpChangesAfterUndo = collection_pb2.OpChangesAfterUndo BrowserRow = search_pb2.BrowserRow BrowserColumns = search_pb2.BrowserColumns StripHtmlMode = card_rendering_pb2.StripHtmlRequest +ImportLogWithChanges = import_export_pb2.ImportAnkiPackageResponse import copy import os @@ -90,6 +92,19 @@ class LegacyCheckpoint: LegacyUndoResult = Union[None, LegacyCheckpoint, LegacyReviewUndo] +@dataclass +class DeckIdLimit: + deck_id: DeckId + + +@dataclass +class NoteIdsLimit: + note_ids: Sequence[NoteId] + + +ExportLimit = Union[DeckIdLimit, NoteIdsLimit, None] + + class Collection(DeprecatedNamesMixin): sched: V1Scheduler | V2Scheduler | V3Scheduler @@ -259,14 +274,6 @@ class Collection(DeprecatedNamesMixin): self._clear_caches() self.db = None - def export_collection( - self, out_path: str, include_media: bool, legacy: bool - ) -> None: - self.close_for_full_sync() - self._backend.export_collection_package( - out_path=out_path, include_media=include_media, legacy=legacy - ) - def rollback(self) -> None: self._clear_caches() self.db.rollback() @@ -321,6 +328,15 @@ class Collection(DeprecatedNamesMixin): else: return -1 + def legacy_checkpoint_pending(self) -> bool: + return ( + self._have_outstanding_checkpoint() + and time.time() - self._last_checkpoint_at < 300 + ) + + # Import/export + ########################################################################## + def create_backup( self, *, @@ -353,12 +369,40 @@ class Collection(DeprecatedNamesMixin): "Throws if backup creation failed." self._backend.await_backup_completion() - def legacy_checkpoint_pending(self) -> bool: - return ( - self._have_outstanding_checkpoint() - and time.time() - self._last_checkpoint_at < 300 + def export_collection_package( + self, out_path: str, include_media: bool, legacy: bool + ) -> None: + self.close_for_full_sync() + self._backend.export_collection_package( + out_path=out_path, include_media=include_media, legacy=legacy ) + def import_anki_package(self, path: str) -> ImportLogWithChanges: + return self._backend.import_anki_package(package_path=path) + + def export_anki_package( + self, + *, + out_path: str, + limit: ExportLimit, + with_scheduling: bool, + with_media: bool, + legacy_support: bool, + ) -> int: + request = import_export_pb2.ExportAnkiPackageRequest( + out_path=out_path, + with_scheduling=with_scheduling, + with_media=with_media, + legacy=legacy_support, + ) + if isinstance(limit, DeckIdLimit): + request.deck_id = limit.deck_id + elif isinstance(limit, NoteIdsLimit): + request.note_ids.note_ids.extend(limit.note_ids) + else: + request.whole_collection.SetInParent() + return self._backend.export_anki_package(request) + # Object helpers ########################################################################## diff --git a/pylib/anki/exporting.py b/pylib/anki/exporting.py index 8e5954d84..b17b783f6 100644 --- a/pylib/anki/exporting.py +++ b/pylib/anki/exporting.py @@ -446,7 +446,7 @@ class AnkiCollectionPackageExporter(AnkiPackageExporter): time.sleep(0.1) threading.Thread(target=progress).start() - self.col.export_collection(path, self.includeMedia, self.LEGACY) + self.col.export_collection_package(path, self.includeMedia, self.LEGACY) class AnkiCollectionPackage21bExporter(AnkiCollectionPackageExporter): diff --git a/pylib/anki/import_export_pb2.pyi b/pylib/anki/import_export_pb2.pyi new file mode 120000 index 000000000..d44638b4f --- /dev/null +++ b/pylib/anki/import_export_pb2.pyi @@ -0,0 +1 @@ +../../.bazel/bin/pylib/anki/import_export_pb2.pyi \ No newline at end of file diff --git a/pylib/anki/importing/anki2.py b/pylib/anki/importing/anki2.py index dd9ad5052..f37c089a2 100644 --- a/pylib/anki/importing/anki2.py +++ b/pylib/anki/importing/anki2.py @@ -25,6 +25,10 @@ class V2ImportIntoV1(Exception): pass +class MediaMapInvalid(Exception): + pass + + class Anki2Importer(Importer): needMapper = False diff --git a/pylib/anki/importing/apkg.py b/pylib/anki/importing/apkg.py index 049f54c37..31d1cc4fd 100644 --- a/pylib/anki/importing/apkg.py +++ b/pylib/anki/importing/apkg.py @@ -9,7 +9,7 @@ import unicodedata import zipfile from typing import Any, Optional -from anki.importing.anki2 import Anki2Importer +from anki.importing.anki2 import Anki2Importer, MediaMapInvalid from anki.utils import tmpfile @@ -36,7 +36,11 @@ class AnkiPackageImporter(Anki2Importer): # number to use during the import self.nameToNum = {} dir = self.col.media.dir() - for k, v in list(json.loads(z.read("media").decode("utf8")).items()): + try: + media_dict = json.loads(z.read("media").decode("utf8")) + except Exception as exc: + raise MediaMapInvalid() from exc + for k, v in list(media_dict.items()): path = os.path.abspath(os.path.join(dir, v)) if os.path.commonprefix([path, dir]) != dir: raise Exception("Invalid file") diff --git a/qt/aqt/browser/browser.py b/qt/aqt/browser/browser.py index ef996641c..a38b5bb95 100644 --- a/qt/aqt/browser/browser.py +++ b/qt/aqt/browser/browser.py @@ -23,7 +23,8 @@ from anki.tags import MARKED_TAG from anki.utils import is_mac from aqt import AnkiQt, gui_hooks from aqt.editor import Editor -from aqt.exporting import ExportDialog +from aqt.exporting import ExportDialog as LegacyExportDialog +from aqt.import_export.exporting import ExportDialog from aqt.operations.card import set_card_deck, set_card_flag from aqt.operations.collection import redo, undo from aqt.operations.note import remove_notes @@ -792,8 +793,12 @@ class Browser(QMainWindow): @no_arg_trigger @skip_if_selection_is_empty def _on_export_notes(self) -> None: - cids = self.selectedNotesAsCards() - ExportDialog(self.mw, cids=list(cids)) + if self.mw.pm.new_import_export(): + nids = self.selected_notes() + ExportDialog(self.mw, nids=nids) + else: + cids = self.selectedNotesAsCards() + LegacyExportDialog(self.mw, cids=list(cids)) # Flags & Marking ###################################################################### diff --git a/qt/aqt/errors.py b/qt/aqt/errors.py index 618549e1a..934bdff2d 100644 --- a/qt/aqt/errors.py +++ b/qt/aqt/errors.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Optional, TextIO, cast from markdown import markdown import aqt -from anki.errors import DocumentedError, LocalizedError +from anki.errors import DocumentedError, Interrupted, LocalizedError from aqt.qt import * from aqt.utils import showText, showWarning, supportText, tr @@ -22,7 +22,7 @@ if TYPE_CHECKING: def show_exception(*, parent: QWidget, exception: Exception) -> None: "Present a caught exception to the user using a pop-up." - if isinstance(exception, InterruptedError): + if isinstance(exception, Interrupted): # nothing to do return help_page = exception.help_page if isinstance(exception, DocumentedError) else None diff --git a/qt/aqt/exporting.py b/qt/aqt/exporting.py index 52ed11a21..370fcc85a 100644 --- a/qt/aqt/exporting.py +++ b/qt/aqt/exporting.py @@ -40,6 +40,7 @@ class ExportDialog(QDialog): self.col = mw.col.weakref() self.frm = aqt.forms.exporting.Ui_ExportDialog() self.frm.setupUi(self) + self.frm.legacy_support.setVisible(False) self.exporter: Exporter | None = None self.cids = cids disable_help_button(self) diff --git a/qt/aqt/forms/exporting.ui b/qt/aqt/forms/exporting.ui index b7b64436f..3d39e9416 100644 --- a/qt/aqt/forms/exporting.ui +++ b/qt/aqt/forms/exporting.ui @@ -6,8 +6,8 @@ 0 0 - 295 - 223 + 563 + 245 @@ -30,7 +30,14 @@ - + + + + 0 + 0 + + + @@ -40,7 +47,11 @@ - + + + 50 + + @@ -83,6 +94,16 @@ + + + + exporting_support_older_anki_versions + + + true + + + diff --git a/qt/aqt/import_export/__init__.py b/qt/aqt/import_export/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qt/aqt/import_export/exporting.py b/qt/aqt/import_export/exporting.py new file mode 100644 index 000000000..81a0c1455 --- /dev/null +++ b/qt/aqt/import_export/exporting.py @@ -0,0 +1,252 @@ +# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +from __future__ import annotations + +import os +import re +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Sequence, Type + +import aqt.forms +import aqt.main +from anki.collection import DeckIdLimit, ExportLimit, NoteIdsLimit, Progress +from anki.decks import DeckId, DeckNameId +from anki.notes import NoteId +from aqt import gui_hooks +from aqt.errors import show_exception +from aqt.operations import QueryOp +from aqt.progress import ProgressUpdate +from aqt.qt import * +from aqt.utils import ( + checkInvalidFilename, + disable_help_button, + getSaveFile, + showWarning, + tooltip, + tr, +) + + +class ExportDialog(QDialog): + def __init__( + self, + mw: aqt.main.AnkiQt, + did: DeckId | None = None, + nids: Sequence[NoteId] | None = None, + ): + QDialog.__init__(self, mw, Qt.WindowType.Window) + self.mw = mw + self.col = mw.col.weakref() + self.frm = aqt.forms.exporting.Ui_ExportDialog() + self.frm.setupUi(self) + self.exporter: Type[Exporter] = None + self.nids = nids + disable_help_button(self) + self.setup(did) + self.open() + + def setup(self, did: DeckId | None) -> None: + self.exporters: list[Type[Exporter]] = [ApkgExporter, ColpkgExporter] + self.frm.format.insertItems( + 0, [f"{e.name()} (.{e.extension})" for e in self.exporters] + ) + qconnect(self.frm.format.activated, self.exporter_changed) + if self.nids is None and not did: + # file>export defaults to colpkg + default_exporter_idx = 1 + else: + default_exporter_idx = 0 + self.frm.format.setCurrentIndex(default_exporter_idx) + self.exporter_changed(default_exporter_idx) + # deck list + if self.nids is None: + self.all_decks = self.col.decks.all_names_and_ids() + decks = [tr.exporting_all_decks()] + decks.extend(d.name for d in self.all_decks) + else: + decks = [tr.exporting_selected_notes()] + self.frm.deck.addItems(decks) + # save button + b = QPushButton(tr.exporting_export()) + self.frm.buttonBox.addButton(b, QDialogButtonBox.ButtonRole.AcceptRole) + # set default option if accessed through deck button + if did: + name = self.mw.col.decks.get(did)["name"] + index = self.frm.deck.findText(name) + self.frm.deck.setCurrentIndex(index) + self.frm.includeSched.setChecked(False) + + def exporter_changed(self, idx: int) -> None: + self.exporter = self.exporters[idx] + self.frm.includeSched.setVisible(self.exporter.show_include_scheduling) + self.frm.includeMedia.setVisible(self.exporter.show_include_media) + self.frm.includeTags.setVisible(self.exporter.show_include_tags) + self.frm.includeHTML.setVisible(self.exporter.show_include_html) + self.frm.legacy_support.setVisible(self.exporter.show_legacy_support) + self.frm.deck.setVisible(self.exporter.show_deck_list) + + def accept(self) -> None: + if not (out_path := self.get_out_path()): + return + self.exporter.export(self.mw, self.options(out_path)) + QDialog.reject(self) + + def get_out_path(self) -> str | None: + filename = self.filename() + while True: + path = getSaveFile( + parent=self, + title=tr.actions_export(), + dir_description="export", + key=self.exporter.name(), + ext=self.exporter.extension, + fname=filename, + ) + if not path: + return None + if checkInvalidFilename(os.path.basename(path), dirsep=False): + continue + path = os.path.normpath(path) + if os.path.commonprefix([self.mw.pm.base, path]) == self.mw.pm.base: + showWarning("Please choose a different export location.") + continue + break + return path + + def options(self, out_path: str) -> Options: + limit: ExportLimit = None + if self.nids: + limit = NoteIdsLimit(self.nids) + elif current_deck_id := self.current_deck_id(): + limit = DeckIdLimit(current_deck_id) + + return Options( + out_path=out_path, + include_scheduling=self.frm.includeSched.isChecked(), + include_media=self.frm.includeMedia.isChecked(), + include_tags=self.frm.includeTags.isChecked(), + include_html=self.frm.includeHTML.isChecked(), + legacy_support=self.frm.legacy_support.isChecked(), + limit=limit, + ) + + def current_deck_id(self) -> DeckId | None: + return (deck := self.current_deck()) and DeckId(deck.id) or None + + def current_deck(self) -> DeckNameId | None: + if self.exporter.show_deck_list: + if idx := self.frm.deck.currentIndex(): + return self.all_decks[idx - 1] + return None + + def filename(self) -> str: + if self.exporter.show_deck_list: + deck_name = self.frm.deck.currentText() + stem = re.sub('[\\\\/?<>:*|"^]', "_", deck_name) + else: + time_str = time.strftime("%Y-%m-%d@%H-%M-%S", time.localtime(time.time())) + stem = f"{tr.exporting_collection()}-{time_str}" + return f"{stem}.{self.exporter.extension}" + + +@dataclass +class Options: + out_path: str + include_scheduling: bool + include_media: bool + include_tags: bool + include_html: bool + legacy_support: bool + limit: ExportLimit + + +class Exporter(ABC): + extension: str + show_deck_list = False + show_include_scheduling = False + show_include_media = False + show_include_tags = False + show_include_html = False + show_legacy_support = False + + @staticmethod + @abstractmethod + def export(mw: aqt.main.AnkiQt, options: Options) -> None: + pass + + @staticmethod + @abstractmethod + def name() -> str: + pass + + +class ColpkgExporter(Exporter): + extension = "colpkg" + show_include_media = True + show_legacy_support = True + + @staticmethod + def name() -> str: + return tr.exporting_anki_collection_package() + + @staticmethod + def export(mw: aqt.main.AnkiQt, options: Options) -> None: + def on_success(_: None) -> None: + mw.reopen() + tooltip(tr.exporting_collection_exported(), parent=mw) + + def on_failure(exception: Exception) -> None: + mw.reopen() + show_exception(parent=mw, exception=exception) + + gui_hooks.collection_will_temporarily_close(mw.col) + QueryOp( + parent=mw, + op=lambda col: col.export_collection_package( + options.out_path, + include_media=options.include_media, + legacy=options.legacy_support, + ), + success=on_success, + ).with_backend_progress(export_progress_update).failure( + on_failure + ).run_in_background() + + +class ApkgExporter(Exporter): + extension = "apkg" + show_deck_list = True + show_include_scheduling = True + show_include_media = True + show_legacy_support = True + + @staticmethod + def name() -> str: + return tr.exporting_anki_deck_package() + + @staticmethod + def export(mw: aqt.main.AnkiQt, options: Options) -> None: + QueryOp( + parent=mw, + op=lambda col: col.export_anki_package( + out_path=options.out_path, + limit=options.limit, + with_scheduling=options.include_scheduling, + with_media=options.include_media, + legacy_support=options.legacy_support, + ), + success=lambda count: tooltip( + tr.exporting_note_exported(count=count), parent=mw + ), + ).with_backend_progress(export_progress_update).run_in_background() + + +def export_progress_update(progress: Progress, update: ProgressUpdate) -> None: + if not progress.HasField("exporting"): + return None + update.label = tr.exporting_exported_media_file(count=progress.exporting) + if update.user_wants_abort: + update.abort = True diff --git a/qt/aqt/import_export/importing.py b/qt/aqt/import_export/importing.py new file mode 100644 index 000000000..73a45dbd7 --- /dev/null +++ b/qt/aqt/import_export/importing.py @@ -0,0 +1,150 @@ +# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +from __future__ import annotations + +from itertools import chain + +import aqt.main +from anki.collection import Collection, ImportLogWithChanges, Progress +from anki.errors import Interrupted +from aqt.operations import CollectionOp, QueryOp +from aqt.progress import ProgressUpdate +from aqt.qt import * +from aqt.utils import askUser, getFile, showInfo, showText, showWarning, tooltip, tr + + +def import_file(mw: aqt.main.AnkiQt) -> None: + if not (path := get_file_path(mw)): + return + + filename = os.path.basename(path).lower() + if filename.endswith(".anki"): + showInfo(tr.importing_anki_files_are_from_a_very()) + elif filename.endswith(".anki2"): + showInfo(tr.importing_anki2_files_are_not_directly_importable()) + elif is_collection_package(filename): + maybe_import_collection_package(mw, path) + elif filename.endswith(".apkg"): + import_anki_package(mw, path) + else: + raise NotImplementedError + + +def get_file_path(mw: aqt.main.AnkiQt) -> str | None: + if file := getFile( + mw, + tr.actions_import(), + None, + key="import", + filter=tr.importing_packaged_anki_deckcollection_apkg_colpkg_zip(), + ): + return str(file) + return None + + +def is_collection_package(filename: str) -> bool: + return ( + filename == "collection.apkg" + or (filename.startswith("backup-") and filename.endswith(".apkg")) + or filename.endswith(".colpkg") + ) + + +def maybe_import_collection_package(mw: aqt.main.AnkiQt, path: str) -> None: + if askUser( + tr.importing_this_will_delete_your_existing_collection(), + msgfunc=QMessageBox.warning, + defaultno=True, + ): + import_collection_package(mw, path) + + +def import_collection_package(mw: aqt.main.AnkiQt, file: str) -> None: + def on_success() -> None: + mw.loadCollection() + tooltip(tr.importing_importing_complete()) + + def on_failure(err: Exception) -> None: + mw.loadCollection() + if not isinstance(err, Interrupted): + showWarning(str(err)) + + QueryOp( + parent=mw, + op=lambda _: mw.create_backup_now(), + success=lambda _: mw.unloadCollection( + lambda: import_collection_package_op(mw, file, on_success) + .failure(on_failure) + .run_in_background() + ), + ).with_progress().run_in_background() + + +def import_collection_package_op( + mw: aqt.main.AnkiQt, path: str, success: Callable[[], None] +) -> QueryOp[None]: + def op(_: Collection) -> None: + col_path = mw.pm.collectionPath() + media_folder = os.path.join(mw.pm.profileFolder(), "collection.media") + media_db = os.path.join(mw.pm.profileFolder(), "collection.media.db2") + mw.backend.import_collection_package( + col_path=col_path, + backup_path=path, + media_folder=media_folder, + media_db=media_db, + ) + + return QueryOp(parent=mw, op=op, success=lambda _: success()).with_backend_progress( + import_progress_update + ) + + +def import_anki_package(mw: aqt.main.AnkiQt, path: str) -> None: + CollectionOp( + parent=mw, + op=lambda col: col.import_anki_package(path), + ).with_backend_progress(import_progress_update).success( + show_import_log + ).run_in_background() + + +def show_import_log(log_with_changes: ImportLogWithChanges) -> None: + showText(stringify_log(log_with_changes.log), plain_text_edit=True) + + +def stringify_log(log: ImportLogWithChanges.Log) -> str: + total = len(log.conflicting) + len(log.updated) + len(log.new) + len(log.duplicate) + return "\n".join( + chain( + (tr.importing_notes_found_in_file(val=total),), + ( + template_string(val=len(row)) + for (row, template_string) in ( + (log.conflicting, tr.importing_notes_that_could_not_be_imported), + (log.updated, tr.importing_notes_updated_as_file_had_newer), + (log.new, tr.importing_notes_added_from_file), + (log.duplicate, tr.importing_notes_skipped_as_theyre_already_in), + ) + if row + ), + ("",), + *( + [f"[{action}] {', '.join(note.fields)}" for note in rows] + for (rows, action) in ( + (log.conflicting, tr.importing_skipped()), + (log.updated, tr.importing_updated()), + (log.new, tr.adding_added()), + (log.duplicate, tr.importing_identical()), + ) + ), + ) + ) + + +def import_progress_update(progress: Progress, update: ProgressUpdate) -> None: + if not progress.HasField("importing"): + return + update.label = progress.importing + if update.user_wants_abort: + update.abort = True diff --git a/qt/aqt/importing.py b/qt/aqt/importing.py index 0168d2c9c..8d2225001 100644 --- a/qt/aqt/importing.py +++ b/qt/aqt/importing.py @@ -11,11 +11,10 @@ import anki.importing as importing import aqt.deckchooser import aqt.forms import aqt.modelchooser -from anki.errors import Interrupted -from anki.importing.anki2 import V2ImportIntoV1 +from anki.importing.anki2 import MediaMapInvalid, V2ImportIntoV1 from anki.importing.apkg import AnkiPackageImporter +from aqt.import_export.importing import import_collection_package from aqt.main import AnkiQt, gui_hooks -from aqt.operations import QueryOp from aqt.qt import * from aqt.utils import ( HelpPage, @@ -389,6 +388,10 @@ def importFile(mw: AnkiQt, file: str) -> None: future.result() except zipfile.BadZipfile: showWarning(invalidZipMsg()) + except MediaMapInvalid: + showWarning( + "Unable to read file. It probably requires a newer version of Anki to import." + ) except V2ImportIntoV1: showWarning( """\ @@ -434,78 +437,11 @@ def setupApkgImport(mw: AnkiQt, importer: AnkiPackageImporter) -> bool: if not full: # adding return True - if not askUser( + if askUser( tr.importing_this_will_delete_your_existing_collection(), msgfunc=QMessageBox.warning, defaultno=True, ): - return False + import_collection_package(mw, importer.file) - full_apkg_import(mw, importer.file) return False - - -def full_apkg_import(mw: AnkiQt, file: str) -> None: - def on_done(success: bool) -> None: - mw.loadCollection() - if success: - tooltip(tr.importing_importing_complete()) - - def after_backup(created: bool) -> None: - mw.unloadCollection(lambda: replace_with_apkg(mw, file, on_done)) - - QueryOp( - parent=mw, op=lambda _: mw.create_backup_now(), success=after_backup - ).with_progress().run_in_background() - - -def replace_with_apkg( - mw: AnkiQt, filename: str, callback: Callable[[bool], None] -) -> None: - """Tries to replace the provided collection with the provided backup, - then calls the callback. True if success. - """ - if not (dialog := mw.progress.start(immediate=True)): - print("No progress dialog during import; aborting will not work") - timer = QTimer() - timer.setSingleShot(False) - timer.setInterval(100) - - def on_progress() -> None: - progress = mw.backend.latest_progress() - if not progress.HasField("importing"): - return - label = progress.importing - - try: - if dialog.wantCancel: - mw.backend.set_wants_abort() - except AttributeError: - # dialog may not be active - pass - - mw.taskman.run_on_main(lambda: mw.progress.update(label=label)) - - def do_import() -> None: - col_path = mw.pm.collectionPath() - media_folder = os.path.join(mw.pm.profileFolder(), "collection.media") - mw.backend.import_collection_package( - col_path=col_path, backup_path=filename, media_folder=media_folder - ) - - def on_done(future: Future) -> None: - mw.progress.finish() - timer.deleteLater() - - try: - future.result() - except Exception as error: - if not isinstance(error, Interrupted): - showWarning(str(error)) - callback(False) - else: - callback(True) - - qconnect(timer.timeout, on_progress) - timer.start() - mw.taskman.run_in_background(do_import, on_done) diff --git a/qt/aqt/main.py b/qt/aqt/main.py index e3b1c5a2c..3d65e7b3e 100644 --- a/qt/aqt/main.py +++ b/qt/aqt/main.py @@ -47,6 +47,8 @@ from aqt.addons import DownloadLogEntry, check_and_prompt_for_updates, show_log_ from aqt.dbcheck import check_db from aqt.emptycards import show_empty_cards from aqt.flags import FlagManager +from aqt.import_export.exporting import ExportDialog +from aqt.import_export.importing import import_collection_package_op, import_file from aqt.legacy import install_pylib_legacy from aqt.mediacheck import check_media_db from aqt.mediasync import MediaSyncer @@ -402,15 +404,12 @@ class AnkiQt(QMainWindow): ) def _openBackup(self, path: str) -> None: - def on_done(success: bool) -> None: - if success: - self.onOpenProfile(callback=lambda: self.col.mod_schema(check=False)) - - import aqt.importing - self.restoring_backup = True showInfo(tr.qt_misc_automatic_syncing_and_backups_have_been()) - aqt.importing.replace_with_apkg(self, path, on_done) + + import_collection_package_op( + self, path, success=self.onOpenProfile + ).run_in_background() def _on_downgrade(self) -> None: self.progress.start() @@ -1183,12 +1182,18 @@ title="{}" {}>{}""".format( def onImport(self) -> None: import aqt.importing - aqt.importing.onImport(self) + if self.pm.new_import_export(): + import_file(self) + else: + aqt.importing.onImport(self) def onExport(self, did: DeckId | None = None) -> None: import aqt.exporting - aqt.exporting.ExportDialog(self, did=did) + if self.pm.new_import_export(): + ExportDialog(self, did=did) + else: + aqt.exporting.ExportDialog(self, did=did) # Installing add-ons from CLI / mimetype handler ########################################################################## diff --git a/qt/aqt/operations/__init__.py b/qt/aqt/operations/__init__.py index 7d7ad8b91..761c6e7ab 100644 --- a/qt/aqt/operations/__init__.py +++ b/qt/aqt/operations/__init__.py @@ -8,16 +8,19 @@ from typing import Any, Callable, Generic, Protocol, TypeVar, Union import aqt import aqt.gui_hooks +import aqt.main from anki.collection import ( Collection, + ImportLogWithChanges, OpChanges, OpChangesAfterUndo, OpChangesWithCount, OpChangesWithId, + Progress, ) from aqt.errors import show_exception +from aqt.progress import ProgressUpdate from aqt.qt import QWidget -from aqt.utils import showWarning class HasChangesProperty(Protocol): @@ -34,6 +37,7 @@ ResultWithChanges = TypeVar( OpChangesWithCount, OpChangesWithId, OpChangesAfterUndo, + ImportLogWithChanges, HasChangesProperty, ], ) @@ -65,6 +69,7 @@ class CollectionOp(Generic[ResultWithChanges]): _success: Callable[[ResultWithChanges], Any] | None = None _failure: Callable[[Exception], Any] | None = None + _progress_update: Callable[[Progress, ProgressUpdate], None] | None = None def __init__(self, parent: QWidget, op: Callable[[Collection], ResultWithChanges]): self._parent = parent @@ -82,6 +87,12 @@ class CollectionOp(Generic[ResultWithChanges]): self._failure = failure return self + def with_backend_progress( + self, progress_update: Callable[[Progress, ProgressUpdate], None] | None + ) -> CollectionOp[ResultWithChanges]: + self._progress_update = progress_update + return self + def run_in_background(self, *, initiator: object | None = None) -> None: from aqt import mw @@ -113,12 +124,29 @@ class CollectionOp(Generic[ResultWithChanges]): if self._success: self._success(result) finally: - mw.update_undo_actions() - mw.autosave() - # fire change hooks - self._fire_change_hooks_after_op_performed(result, initiator) + self._finish_op(mw, result, initiator) - mw.taskman.with_progress(wrapped_op, wrapped_done) + self._run(mw, wrapped_op, wrapped_done) + + def _run( + self, + mw: aqt.main.AnkiQt, + op: Callable[[], ResultWithChanges], + on_done: Callable[[Future], None], + ) -> None: + if self._progress_update: + mw.taskman.with_backend_progress( + op, self._progress_update, on_done=on_done, parent=self._parent + ) + else: + mw.taskman.with_progress(op, on_done, parent=self._parent) + + def _finish_op( + self, mw: aqt.main.AnkiQt, result: ResultWithChanges, initiator: object | None + ) -> None: + mw.update_undo_actions() + mw.autosave() + self._fire_change_hooks_after_op_performed(result, initiator) def _fire_change_hooks_after_op_performed( self, @@ -168,6 +196,7 @@ class QueryOp(Generic[T]): _failure: Callable[[Exception], Any] | None = None _progress: bool | str = False + _progress_update: Callable[[Progress, ProgressUpdate], None] | None = None def __init__( self, @@ -192,6 +221,12 @@ class QueryOp(Generic[T]): self._progress = label or True return self + def with_backend_progress( + self, progress_update: Callable[[Progress, ProgressUpdate], None] | None + ) -> QueryOp[T]: + self._progress_update = progress_update + return self + def run_in_background(self) -> None: from aqt import mw @@ -201,26 +236,11 @@ class QueryOp(Generic[T]): def wrapped_op() -> T: assert mw - if self._progress: - label: str | None - if isinstance(self._progress, str): - label = self._progress - else: - label = None - - def start_progress() -> None: - assert mw - mw.progress.start(label=label) - - mw.taskman.run_on_main(start_progress) return self._op(mw.col) def wrapped_done(future: Future) -> None: assert mw - if self._progress: - mw.progress.finish() - mw._decrease_background_ops() # did something go wrong? if exception := future.exception(): @@ -228,7 +248,7 @@ class QueryOp(Generic[T]): if self._failure: self._failure(exception) else: - showWarning(str(exception), self._parent) + show_exception(parent=self._parent, exception=exception) return else: # BaseException like SystemExit; rethrow it @@ -236,4 +256,24 @@ class QueryOp(Generic[T]): self._success(future.result()) - mw.taskman.run_in_background(wrapped_op, wrapped_done) + self._run(mw, wrapped_op, wrapped_done) + + def _run( + self, + mw: aqt.main.AnkiQt, + op: Callable[[], T], + on_done: Callable[[Future], None], + ) -> None: + label = self._progress if isinstance(self._progress, str) else None + if self._progress_update: + mw.taskman.with_backend_progress( + op, + self._progress_update, + on_done=on_done, + start_label=label, + parent=self._parent, + ) + elif self._progress: + mw.taskman.with_progress(op, on_done, label=label, parent=self._parent) + else: + mw.taskman.run_in_background(op, on_done) diff --git a/qt/aqt/profiles.py b/qt/aqt/profiles.py index 1d8ce51f5..6d00fb6a3 100644 --- a/qt/aqt/profiles.py +++ b/qt/aqt/profiles.py @@ -538,6 +538,12 @@ create table if not exists profiles def dark_mode_widgets(self) -> bool: return self.meta.get("dark_mode_widgets", False) + def new_import_export(self) -> bool: + return self.meta.get("new_import_export", False) + + def set_new_import_export(self, enabled: bool) -> None: + self.meta["new_import_export"] = enabled + # Profile-specific ###################################################################### diff --git a/qt/aqt/progress.py b/qt/aqt/progress.py index 3aea4f652..3a1b355a9 100644 --- a/qt/aqt/progress.py +++ b/qt/aqt/progress.py @@ -3,9 +3,11 @@ from __future__ import annotations import time +from dataclasses import dataclass import aqt.forms from anki._legacy import print_deprecation_warning +from anki.collection import Progress from aqt.qt import * from aqt.utils import disable_help_button, tr @@ -23,6 +25,7 @@ class ProgressManager: self._busy_cursor_timer: QTimer | None = None self._win: ProgressDialog | None = None self._levels = 0 + self._backend_timer: QTimer | None = None # Safer timers ########################################################################## @@ -166,6 +169,34 @@ class ProgressManager: qconnect(self._show_timer.timeout, self._on_show_timer) return self._win + def start_with_backend_updates( + self, + progress_update: Callable[[Progress, ProgressUpdate], None], + start_label: str | None = None, + parent: QWidget | None = None, + ) -> None: + self._backend_timer = QTimer() + self._backend_timer.setSingleShot(False) + self._backend_timer.setInterval(100) + + if not (dialog := self.start(immediate=True, label=start_label, parent=parent)): + print("Progress dialog already running; aborting will not work") + + def on_progress() -> None: + assert self.mw + + user_wants_abort = dialog and dialog.wantCancel or False + update = ProgressUpdate(user_wants_abort=user_wants_abort) + progress = self.mw.backend.latest_progress() + progress_update(progress, update) + if update.abort: + self.mw.backend.set_wants_abort() + if update.has_update(): + self.update(label=update.label, value=update.value, max=update.max) + + qconnect(self._backend_timer.timeout, on_progress) + self._backend_timer.start() + def update( self, label: str | None = None, @@ -204,6 +235,9 @@ class ProgressManager: if self._show_timer: self._show_timer.stop() self._show_timer = None + if self._backend_timer: + self._backend_timer.deleteLater() + self._backend_timer = None def clear(self) -> None: "Restore the interface after an error." @@ -295,3 +329,15 @@ class ProgressDialog(QDialog): if evt.key() == Qt.Key.Key_Escape: evt.ignore() self.wantCancel = True + + +@dataclass +class ProgressUpdate: + label: str | None = None + value: int | None = None + max: int | None = None + user_wants_abort: bool = False + abort: bool = False + + def has_update(self) -> bool: + return self.label is not None or self.value is not None or self.max is not None diff --git a/qt/aqt/taskman.py b/qt/aqt/taskman.py index b763f96d4..283153ee1 100644 --- a/qt/aqt/taskman.py +++ b/qt/aqt/taskman.py @@ -15,6 +15,8 @@ from threading import Lock from typing import Any, Callable import aqt +from anki.collection import Progress +from aqt.progress import ProgressUpdate from aqt.qt import * Closure = Callable[[], None] @@ -89,6 +91,27 @@ class TaskManager(QObject): self.run_in_background(task, wrapped_done) + def with_backend_progress( + self, + task: Callable, + progress_update: Callable[[Progress, ProgressUpdate], None], + on_done: Callable[[Future], None] | None = None, + parent: QWidget | None = None, + start_label: str | None = None, + ) -> None: + self.mw.progress.start_with_backend_updates( + progress_update, + parent=parent, + start_label=start_label, + ) + + def wrapped_done(fut: Future) -> None: + self.mw.progress.finish() + if on_done: + on_done(fut) + + self.run_in_background(task, wrapped_done) + def _on_closures_pending(self) -> None: """Run any pending closures. This runs in the main thread.""" with self._closures_lock: diff --git a/rslib/src/backend/generic.rs b/rslib/src/backend/generic.rs index 6813c141a..b86751bcb 100644 --- a/rslib/src/backend/generic.rs +++ b/rslib/src/backend/generic.rs @@ -69,6 +69,12 @@ impl From for NoteId { } } +impl From for pb::NoteId { + fn from(nid: NoteId) -> Self { + pb::NoteId { nid: nid.0 } + } +} + impl From for NotetypeId { fn from(ntid: pb::NotetypeId) -> Self { NotetypeId(ntid.ntid) diff --git a/rslib/src/backend/import_export.rs b/rslib/src/backend/import_export.rs index 47d99a456..457e40e08 100644 --- a/rslib/src/backend/import_export.rs +++ b/rslib/src/backend/import_export.rs @@ -1,12 +1,18 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::path::Path; + use super::{progress::Progress, Backend}; pub(super) use crate::backend_proto::importexport_service::Service as ImportExportService; use crate::{ - backend_proto::{self as pb}, - import_export::{package::import_colpkg, ImportProgress}, + backend_proto::{self as pb, export_anki_package_request::Selector}, + import_export::{ + package::{import_colpkg, NoteLog}, + ImportProgress, + }, prelude::*, + search::SearchNode, }; impl ImportExportService for Backend { @@ -38,30 +44,68 @@ impl ImportExportService for Backend { import_colpkg( &input.backup_path, &input.col_path, - &input.media_folder, + Path::new(&input.media_folder), + Path::new(&input.media_db), self.import_progress_fn(), + &self.log, ) .map(Into::into) } + + fn import_anki_package( + &self, + input: pb::ImportAnkiPackageRequest, + ) -> Result { + self.with_col(|col| col.import_apkg(&input.package_path, self.import_progress_fn())) + .map(Into::into) + } + + fn export_anki_package(&self, input: pb::ExportAnkiPackageRequest) -> Result { + let selector = input + .selector + .ok_or_else(|| AnkiError::invalid_input("missing oneof"))?; + self.with_col(|col| { + col.export_apkg( + &input.out_path, + SearchNode::from_selector(selector), + input.with_scheduling, + input.with_media, + input.legacy, + None, + self.export_progress_fn(), + ) + }) + .map(Into::into) + } } -impl Backend { - fn import_progress_fn(&self) -> impl FnMut(ImportProgress) -> Result<()> { - let mut handler = self.new_progress_handler(); - move |progress| { - let throttle = matches!(progress, ImportProgress::Media(_)); - if handler.update(Progress::Import(progress), throttle) { - Ok(()) - } else { - Err(AnkiError::Interrupted) - } - } - } - - fn export_progress_fn(&self) -> impl FnMut(usize) { - let mut handler = self.new_progress_handler(); - move |media_files| { - handler.update(Progress::Export(media_files), true); +impl SearchNode { + fn from_selector(selector: Selector) -> Self { + match selector { + Selector::WholeCollection(_) => Self::WholeCollection, + Selector::DeckId(did) => Self::from_deck_id(did, true), + Selector::NoteIds(nids) => Self::from_note_ids(nids.note_ids), + } + } +} + +impl Backend { + fn import_progress_fn(&self) -> impl FnMut(ImportProgress, bool) -> bool { + let mut handler = self.new_progress_handler(); + move |progress, throttle| handler.update(Progress::Import(progress), throttle) + } + + fn export_progress_fn(&self) -> impl FnMut(usize, bool) -> bool { + let mut handler = self.new_progress_handler(); + move |progress, throttle| handler.update(Progress::Export(progress), throttle) + } +} + +impl From> for pb::ImportAnkiPackageResponse { + fn from(output: OpOutput) -> Self { + Self { + changes: Some(output.changes.into()), + log: Some(output.output), } } } diff --git a/rslib/src/backend/progress.rs b/rslib/src/backend/progress.rs index a21bede23..733d06a67 100644 --- a/rslib/src/backend/progress.rs +++ b/rslib/src/backend/progress.rs @@ -108,8 +108,10 @@ pub(super) fn progress_to_proto(progress: Option, tr: &I18n) -> pb::Pr } Progress::Import(progress) => pb::progress::Value::Importing( match progress { - ImportProgress::Collection => tr.importing_importing_collection(), + ImportProgress::File => tr.importing_importing_file(), ImportProgress::Media(n) => tr.importing_processed_media_file(n), + ImportProgress::MediaCheck(n) => tr.media_check_checked(n), + ImportProgress::Notes(n) => tr.importing_processed_notes(n), } .into(), ), diff --git a/rslib/src/card/undo.rs b/rslib/src/card/undo.rs index c0bb0a035..e83f096f7 100644 --- a/rslib/src/card/undo.rs +++ b/rslib/src/card/undo.rs @@ -35,6 +35,14 @@ impl Collection { Ok(()) } + pub(crate) fn add_card_if_unique_undoable(&mut self, card: &Card) -> Result { + let added = self.storage.add_card_if_unique(card)?; + if added { + self.save_undo(UndoableCardChange::Added(Box::new(card.clone()))); + } + Ok(added) + } + pub(super) fn update_card_undoable(&mut self, card: &mut Card, original: Card) -> Result<()> { if card.id.0 == 0 { return Err(AnkiError::invalid_input("card id not set")); diff --git a/rslib/src/deckconfig/undo.rs b/rslib/src/deckconfig/undo.rs index cf14c8586..0a0ed7a2d 100644 --- a/rslib/src/deckconfig/undo.rs +++ b/rslib/src/deckconfig/undo.rs @@ -44,6 +44,13 @@ impl Collection { Ok(()) } + pub(crate) fn add_deck_config_if_unique_undoable(&mut self, config: &DeckConfig) -> Result<()> { + if self.storage.add_deck_conf_if_unique(config)? { + self.save_undo(UndoableDeckConfigChange::Added(Box::new(config.clone()))); + } + Ok(()) + } + pub(super) fn update_deck_config_undoable( &mut self, config: &DeckConfig, diff --git a/rslib/src/decks/addupdate.rs b/rslib/src/decks/addupdate.rs index cf2975478..159c57abd 100644 --- a/rslib/src/decks/addupdate.rs +++ b/rslib/src/decks/addupdate.rs @@ -148,7 +148,7 @@ impl Collection { Ok(()) } - fn first_existing_parent( + pub(crate) fn first_existing_parent( &self, machine_name: &str, recursion_level: usize, diff --git a/rslib/src/import_export/gather.rs b/rslib/src/import_export/gather.rs new file mode 100644 index 000000000..af3b71006 --- /dev/null +++ b/rslib/src/import_export/gather.rs @@ -0,0 +1,222 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::collections::{HashMap, HashSet}; + +use itertools::Itertools; + +use crate::{ + decks::immediate_parent_name, + io::filename_is_safe, + latex::extract_latex, + prelude::*, + revlog::RevlogEntry, + text::{extract_media_refs, extract_underscored_css_imports, extract_underscored_references}, +}; + +#[derive(Debug, Default)] +pub(super) struct ExchangeData { + pub(super) decks: Vec, + pub(super) notes: Vec, + pub(super) cards: Vec, + pub(super) notetypes: Vec, + pub(super) revlog: Vec, + pub(super) deck_configs: Vec, + pub(super) media_filenames: HashSet, + pub(super) days_elapsed: u32, + pub(super) creation_utc_offset: Option, +} + +impl ExchangeData { + pub(super) fn gather_data( + &mut self, + col: &mut Collection, + search: impl TryIntoSearch, + with_scheduling: bool, + ) -> Result<()> { + self.days_elapsed = col.timing_today()?.days_elapsed; + self.creation_utc_offset = col.get_creation_utc_offset(); + self.notes = col.gather_notes(search)?; + self.cards = col.gather_cards()?; + self.decks = col.gather_decks()?; + self.notetypes = col.gather_notetypes()?; + + if with_scheduling { + self.revlog = col.gather_revlog()?; + self.deck_configs = col.gather_deck_configs(&self.decks)?; + } else { + self.remove_scheduling_information(col); + }; + + col.storage.clear_searched_notes_table()?; + col.storage.clear_searched_cards_table() + } + + pub(super) fn gather_media_names(&mut self) { + let mut inserter = |name: String| { + if filename_is_safe(&name) { + self.media_filenames.insert(name); + } + }; + let svg_getter = svg_getter(&self.notetypes); + for note in self.notes.iter() { + gather_media_names_from_note(note, &mut inserter, &svg_getter); + } + for notetype in self.notetypes.iter() { + gather_media_names_from_notetype(notetype, &mut inserter); + } + } + + fn remove_scheduling_information(&mut self, col: &Collection) { + self.remove_system_tags(); + self.reset_deck_config_ids(); + self.reset_cards(col); + } + + fn remove_system_tags(&mut self) { + const SYSTEM_TAGS: [&str; 2] = ["marked", "leech"]; + for note in self.notes.iter_mut() { + note.tags = std::mem::take(&mut note.tags) + .into_iter() + .filter(|tag| !SYSTEM_TAGS.iter().any(|s| tag.eq_ignore_ascii_case(s))) + .collect(); + } + } + + fn reset_deck_config_ids(&mut self) { + for deck in self.decks.iter_mut() { + if let Ok(normal_mut) = deck.normal_mut() { + normal_mut.config_id = 1; + } else { + // filtered decks are reset at import time for legacy reasons + } + } + } + + fn reset_cards(&mut self, col: &Collection) { + let mut position = col.get_next_card_position(); + for card in self.cards.iter_mut() { + // schedule_as_new() removes cards from filtered decks, but we want to + // leave cards in their current deck, which gets converted to a regular + // deck on import + let deck_id = card.deck_id; + if card.schedule_as_new(position, true, true) { + position += 1; + } + card.flags = 0; + card.deck_id = deck_id; + } + } +} + +fn gather_media_names_from_note( + note: &Note, + inserter: &mut impl FnMut(String), + svg_getter: &impl Fn(NotetypeId) -> bool, +) { + for field in note.fields() { + for media_ref in extract_media_refs(field) { + inserter(media_ref.fname_decoded.to_string()); + } + + for latex in extract_latex(field, svg_getter(note.notetype_id)).1 { + inserter(latex.fname); + } + } +} + +fn gather_media_names_from_notetype(notetype: &Notetype, inserter: &mut impl FnMut(String)) { + for name in extract_underscored_css_imports(¬etype.config.css) { + inserter(name.to_string()); + } + for template in ¬etype.templates { + for template_side in [&template.config.q_format, &template.config.a_format] { + for name in extract_underscored_references(template_side) { + inserter(name.to_string()); + } + } + } +} + +fn svg_getter(notetypes: &[Notetype]) -> impl Fn(NotetypeId) -> bool { + let svg_map: HashMap = notetypes + .iter() + .map(|nt| (nt.id, nt.config.latex_svg)) + .collect(); + move |nt_id| svg_map.get(&nt_id).copied().unwrap_or_default() +} + +impl Collection { + fn gather_notes(&mut self, search: impl TryIntoSearch) -> Result> { + self.search_notes_into_table(search)?; + self.storage.all_searched_notes() + } + + fn gather_cards(&mut self) -> Result> { + self.storage.search_cards_of_notes_into_table()?; + self.storage.all_searched_cards() + } + + fn gather_decks(&mut self) -> Result> { + let decks = self.storage.get_decks_for_search_cards()?; + let parents = self.get_parent_decks(&decks)?; + Ok(decks + .into_iter() + .filter(|deck| deck.id != DeckId(1)) + .chain(parents) + .collect()) + } + + fn get_parent_decks(&mut self, decks: &[Deck]) -> Result> { + let mut parent_names: HashSet = decks + .iter() + .map(|deck| deck.name.as_native_str().to_owned()) + .collect(); + let mut parents = Vec::new(); + for deck in decks { + self.add_parent_decks(deck.name.as_native_str(), &mut parent_names, &mut parents)?; + } + Ok(parents) + } + + fn add_parent_decks( + &mut self, + name: &str, + parent_names: &mut HashSet, + parents: &mut Vec, + ) -> Result<()> { + if let Some(parent_name) = immediate_parent_name(name) { + if parent_names.insert(parent_name.to_owned()) { + parents.push( + self.storage + .get_deck_by_name(parent_name)? + .ok_or(AnkiError::DatabaseCheckRequired)?, + ); + self.add_parent_decks(parent_name, parent_names, parents)?; + } + } + Ok(()) + } + + fn gather_notetypes(&mut self) -> Result> { + self.storage.get_notetypes_for_search_notes() + } + + fn gather_revlog(&mut self) -> Result> { + self.storage.get_revlog_entries_for_searched_cards() + } + + fn gather_deck_configs(&mut self, decks: &[Deck]) -> Result> { + decks + .iter() + .filter_map(|deck| deck.config_id()) + .unique() + .filter(|config_id| *config_id != DeckConfigId(1)) + .map(|config_id| { + self.storage + .get_deck_config(config_id)? + .ok_or(AnkiError::NotFound) + }) + .collect() + } +} diff --git a/rslib/src/import_export/insert.rs b/rslib/src/import_export/insert.rs new file mode 100644 index 000000000..21ab1be1c --- /dev/null +++ b/rslib/src/import_export/insert.rs @@ -0,0 +1,62 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use super::gather::ExchangeData; +use crate::{prelude::*, revlog::RevlogEntry}; + +impl Collection { + pub(super) fn insert_data(&mut self, data: &ExchangeData) -> Result<()> { + self.transact_no_undo(|col| { + col.insert_decks(&data.decks)?; + col.insert_notes(&data.notes)?; + col.insert_cards(&data.cards)?; + col.insert_notetypes(&data.notetypes)?; + col.insert_revlog(&data.revlog)?; + col.insert_deck_configs(&data.deck_configs) + }) + } + + fn insert_decks(&self, decks: &[Deck]) -> Result<()> { + for deck in decks { + self.storage.add_or_update_deck_with_existing_id(deck)?; + } + Ok(()) + } + + fn insert_notes(&self, notes: &[Note]) -> Result<()> { + for note in notes { + self.storage.add_or_update_note(note)?; + } + Ok(()) + } + + fn insert_cards(&self, cards: &[Card]) -> Result<()> { + for card in cards { + self.storage.add_or_update_card(card)?; + } + Ok(()) + } + + fn insert_notetypes(&self, notetypes: &[Notetype]) -> Result<()> { + for notetype in notetypes { + self.storage + .add_or_update_notetype_with_existing_id(notetype)?; + } + Ok(()) + } + + fn insert_revlog(&self, revlog: &[RevlogEntry]) -> Result<()> { + for entry in revlog { + self.storage.add_revlog_entry(entry, false)?; + } + Ok(()) + } + + fn insert_deck_configs(&self, configs: &[DeckConfig]) -> Result<()> { + for config in configs { + self.storage + .add_or_update_deck_config_with_existing_id(config)?; + } + Ok(()) + } +} diff --git a/rslib/src/import_export/mod.rs b/rslib/src/import_export/mod.rs index 994d93101..b1e2944af 100644 --- a/rslib/src/import_export/mod.rs +++ b/rslib/src/import_export/mod.rs @@ -1,10 +1,87 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +mod gather; +mod insert; pub mod package; +use std::marker::PhantomData; + +use crate::prelude::*; + #[derive(Debug, Clone, Copy, PartialEq)] pub enum ImportProgress { - Collection, + File, Media(usize), + MediaCheck(usize), + Notes(usize), +} + +/// Wrapper around a progress function, usually passed by the [crate::backend::Backend], +/// to make repeated calls more ergonomic. +pub(crate) struct IncrementableProgress

(Box bool>); + +impl

IncrementableProgress

{ + /// `progress_fn: (progress, throttle) -> should_continue` + pub(crate) fn new(progress_fn: impl 'static + FnMut(P, bool) -> bool) -> Self { + Self(Box::new(progress_fn)) + } + + /// Returns an [Incrementor] with an `increment()` function for use in loops. + pub(crate) fn incrementor<'inc, 'progress: 'inc, 'map: 'inc>( + &'progress mut self, + mut count_map: impl 'map + FnMut(usize) -> P, + ) -> Incrementor<'inc, impl FnMut(usize) -> Result<()> + 'inc> { + Incrementor::new(move |u| self.update(count_map(u), true)) + } + + /// Manually triggers an update. + /// Returns [AnkiError::Interrupted] if the operation should be cancelled. + pub(crate) fn call(&mut self, progress: P) -> Result<()> { + self.update(progress, false) + } + + fn update(&mut self, progress: P, throttle: bool) -> Result<()> { + if (self.0)(progress, throttle) { + Ok(()) + } else { + Err(AnkiError::Interrupted) + } + } + + /// Stopgap for returning a progress fn compliant with the media code. + pub(crate) fn media_db_fn( + &mut self, + count_map: impl 'static + Fn(usize) -> P, + ) -> Result bool + '_> { + Ok(move |count| (self.0)(count_map(count), true)) + } +} + +pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> { + update_fn: F, + count: usize, + update_interval: usize, + _phantom: PhantomData<&'f ()>, +} + +impl<'f, F: 'f + FnMut(usize) -> Result<()>> Incrementor<'f, F> { + fn new(update_fn: F) -> Self { + Self { + update_fn, + count: 0, + update_interval: 17, + _phantom: PhantomData, + } + } + + /// Increments the progress counter, periodically triggering an update. + /// Returns [AnkiError::Interrupted] if the operation should be cancelled. + pub(crate) fn increment(&mut self) -> Result<()> { + self.count += 1; + if self.count % self.update_interval != 0 { + return Ok(()); + } + (self.update_fn)(self.count) + } } diff --git a/rslib/src/import_export/package/apkg/export.rs b/rslib/src/import_export/package/apkg/export.rs new file mode 100644 index 000000000..9b797893d --- /dev/null +++ b/rslib/src/import_export/package/apkg/export.rs @@ -0,0 +1,106 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; + +use tempfile::NamedTempFile; + +use crate::{ + collection::CollectionBuilder, + import_export::{ + gather::ExchangeData, + package::{ + colpkg::export::{export_collection, MediaIter}, + Meta, + }, + IncrementableProgress, + }, + io::{atomic_rename, tempfile_in_parent_of}, + prelude::*, +}; + +impl Collection { + /// Returns number of exported notes. + #[allow(clippy::too_many_arguments)] + pub fn export_apkg( + &mut self, + out_path: impl AsRef, + search: impl TryIntoSearch, + with_scheduling: bool, + with_media: bool, + legacy: bool, + media_fn: Option) -> MediaIter>>, + progress_fn: impl 'static + FnMut(usize, bool) -> bool, + ) -> Result { + let mut progress = IncrementableProgress::new(progress_fn); + let temp_apkg = tempfile_in_parent_of(out_path.as_ref())?; + let mut temp_col = NamedTempFile::new()?; + let temp_col_path = temp_col + .path() + .to_str() + .ok_or_else(|| AnkiError::IoError("tempfile with non-unicode name".into()))?; + let meta = if legacy { + Meta::new_legacy() + } else { + Meta::new() + }; + let data = self.export_into_collection_file( + &meta, + temp_col_path, + search, + with_scheduling, + with_media, + )?; + + let media = if let Some(media_fn) = media_fn { + media_fn(data.media_filenames) + } else { + MediaIter::from_file_list(data.media_filenames, self.media_folder.clone()) + }; + let col_size = temp_col.as_file().metadata()?.len() as usize; + + export_collection( + meta, + temp_apkg.path(), + &mut temp_col, + col_size, + media, + &self.tr, + &mut progress, + )?; + atomic_rename(temp_apkg, out_path.as_ref(), true)?; + Ok(data.notes.len()) + } + + fn export_into_collection_file( + &mut self, + meta: &Meta, + path: &str, + search: impl TryIntoSearch, + with_scheduling: bool, + with_media: bool, + ) -> Result { + let mut data = ExchangeData::default(); + data.gather_data(self, search, with_scheduling)?; + if with_media { + data.gather_media_names(); + } + + let mut temp_col = Collection::new_minimal(path)?; + temp_col.insert_data(&data)?; + temp_col.set_creation_stamp(self.storage.creation_stamp()?)?; + temp_col.set_creation_utc_offset(data.creation_utc_offset)?; + temp_col.close(Some(meta.schema_version()))?; + + Ok(data) + } + + fn new_minimal(path: impl Into) -> Result { + let col = CollectionBuilder::new(path).build()?; + col.storage.db.execute_batch("DELETE FROM notetypes")?; + Ok(col) + } +} diff --git a/rslib/src/import_export/package/apkg/import/cards.rs b/rslib/src/import_export/package/apkg/import/cards.rs new file mode 100644 index 000000000..6a87510c1 --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/cards.rs @@ -0,0 +1,179 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + collections::{HashMap, HashSet}, + mem, +}; + +use super::Context; +use crate::{ + card::{CardQueue, CardType}, + config::SchedulerVersion, + prelude::*, + revlog::RevlogEntry, +}; + +type CardAsNidAndOrd = (NoteId, u16); + +struct CardContext<'a> { + target_col: &'a mut Collection, + usn: Usn, + + imported_notes: &'a HashMap, + remapped_decks: &'a HashMap, + + /// The number of days the source collection is ahead of the target collection + collection_delta: i32, + scheduler_version: SchedulerVersion, + existing_cards: HashSet, + existing_card_ids: HashSet, + + imported_cards: HashMap, +} + +impl<'c> CardContext<'c> { + fn new<'a: 'c>( + usn: Usn, + days_elapsed: u32, + target_col: &'a mut Collection, + imported_notes: &'a HashMap, + imported_decks: &'a HashMap, + ) -> Result { + let existing_cards = target_col.storage.all_cards_as_nid_and_ord()?; + let collection_delta = target_col.collection_delta(days_elapsed)?; + let scheduler_version = target_col.scheduler_info()?.version; + let existing_card_ids = target_col.storage.get_all_card_ids()?; + Ok(Self { + target_col, + usn, + imported_notes, + remapped_decks: imported_decks, + existing_cards, + collection_delta, + scheduler_version, + existing_card_ids, + imported_cards: HashMap::new(), + }) + } +} + +impl Collection { + /// How much `days_elapsed` is ahead of this collection. + fn collection_delta(&mut self, days_elapsed: u32) -> Result { + Ok(days_elapsed as i32 - self.timing_today()?.days_elapsed as i32) + } +} + +impl Context<'_> { + pub(super) fn import_cards_and_revlog( + &mut self, + imported_notes: &HashMap, + imported_decks: &HashMap, + ) -> Result<()> { + let mut ctx = CardContext::new( + self.usn, + self.data.days_elapsed, + self.target_col, + imported_notes, + imported_decks, + )?; + ctx.import_cards(mem::take(&mut self.data.cards))?; + ctx.import_revlog(mem::take(&mut self.data.revlog)) + } +} + +impl CardContext<'_> { + fn import_cards(&mut self, mut cards: Vec) -> Result<()> { + for card in &mut cards { + if self.map_to_imported_note(card) && !self.card_ordinal_already_exists(card) { + self.add_card(card)?; + } + // TODO: could update existing card + } + Ok(()) + } + + fn import_revlog(&mut self, revlog: Vec) -> Result<()> { + for mut entry in revlog { + if let Some(cid) = self.imported_cards.get(&entry.cid) { + entry.cid = *cid; + entry.usn = self.usn; + self.target_col.add_revlog_entry_if_unique_undoable(entry)?; + } + } + Ok(()) + } + + fn map_to_imported_note(&self, card: &mut Card) -> bool { + if let Some(nid) = self.imported_notes.get(&card.note_id) { + card.note_id = *nid; + true + } else { + false + } + } + + fn card_ordinal_already_exists(&self, card: &Card) -> bool { + self.existing_cards + .contains(&(card.note_id, card.template_idx)) + } + + fn add_card(&mut self, card: &mut Card) -> Result<()> { + card.usn = self.usn; + self.remap_deck_id(card); + card.shift_collection_relative_dates(self.collection_delta); + card.maybe_remove_from_filtered_deck(self.scheduler_version); + let old_id = self.uniquify_card_id(card); + + self.target_col.add_card_if_unique_undoable(card)?; + self.existing_card_ids.insert(card.id); + self.imported_cards.insert(old_id, card.id); + + Ok(()) + } + + fn uniquify_card_id(&mut self, card: &mut Card) -> CardId { + let original = card.id; + while self.existing_card_ids.contains(&card.id) { + card.id.0 += 999; + } + original + } + + fn remap_deck_id(&self, card: &mut Card) { + if let Some(did) = self.remapped_decks.get(&card.deck_id) { + card.deck_id = *did; + } + } +} + +impl Card { + /// `delta` is the number days the card's source collection is ahead of the + /// target collection. + fn shift_collection_relative_dates(&mut self, delta: i32) { + if self.due_in_days_since_collection_creation() { + self.due -= delta; + } + if self.original_due_in_days_since_collection_creation() && self.original_due != 0 { + self.original_due -= delta; + } + } + + fn due_in_days_since_collection_creation(&self) -> bool { + matches!(self.queue, CardQueue::Review | CardQueue::DayLearn) + || self.ctype == CardType::Review + } + + fn original_due_in_days_since_collection_creation(&self) -> bool { + self.ctype == CardType::Review + } + + fn maybe_remove_from_filtered_deck(&mut self, version: SchedulerVersion) { + if self.is_filtered() { + // instead of moving between decks, the deck is converted to a regular one + self.original_deck_id = self.deck_id; + self.remove_from_filtered_deck_restoring_queue(version); + } + } +} diff --git a/rslib/src/import_export/package/apkg/import/decks.rs b/rslib/src/import_export/package/apkg/import/decks.rs new file mode 100644 index 000000000..c9717963f --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/decks.rs @@ -0,0 +1,212 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{collections::HashMap, mem}; + +use super::Context; +use crate::{decks::NormalDeck, prelude::*}; + +struct DeckContext<'d> { + target_col: &'d mut Collection, + usn: Usn, + renamed_parents: Vec<(String, String)>, + imported_decks: HashMap, + unique_suffix: String, +} + +impl<'d> DeckContext<'d> { + fn new<'a: 'd>(target_col: &'a mut Collection, usn: Usn) -> Self { + Self { + target_col, + usn, + renamed_parents: Vec::new(), + imported_decks: HashMap::new(), + unique_suffix: TimestampSecs::now().to_string(), + } + } +} + +impl Context<'_> { + pub(super) fn import_decks_and_configs(&mut self) -> Result> { + let mut ctx = DeckContext::new(self.target_col, self.usn); + ctx.import_deck_configs(mem::take(&mut self.data.deck_configs))?; + ctx.import_decks(mem::take(&mut self.data.decks))?; + Ok(ctx.imported_decks) + } +} + +impl DeckContext<'_> { + fn import_deck_configs(&mut self, mut configs: Vec) -> Result<()> { + for config in &mut configs { + config.usn = self.usn; + self.target_col.add_deck_config_if_unique_undoable(config)?; + } + Ok(()) + } + + fn import_decks(&mut self, mut decks: Vec) -> Result<()> { + // ensure parents are seen before children + decks.sort_unstable_by_key(|deck| deck.level()); + for deck in &mut decks { + self.prepare_deck(deck); + self.import_deck(deck)?; + } + Ok(()) + } + + fn prepare_deck(&mut self, deck: &mut Deck) { + self.maybe_reparent(deck); + if deck.is_filtered() { + deck.kind = DeckKind::Normal(NormalDeck { + config_id: 1, + ..Default::default() + }); + } + } + + fn import_deck(&mut self, deck: &mut Deck) -> Result<()> { + if let Some(original) = self.get_deck_by_name(deck)? { + if original.is_filtered() { + self.uniquify_name(deck); + self.add_deck(deck) + } else { + self.update_deck(deck, original) + } + } else { + self.ensure_valid_first_existing_parent(deck)?; + self.add_deck(deck) + } + } + + fn maybe_reparent(&self, deck: &mut Deck) { + if let Some(new_name) = self.reparented_name(deck.name.as_native_str()) { + deck.name = NativeDeckName::from_native_str(new_name); + } + } + + fn reparented_name(&self, name: &str) -> Option { + self.renamed_parents + .iter() + .find_map(|(old_parent, new_parent)| { + name.starts_with(old_parent) + .then(|| name.replacen(old_parent, new_parent, 1)) + }) + } + + fn get_deck_by_name(&mut self, deck: &Deck) -> Result> { + self.target_col + .storage + .get_deck_by_name(deck.name.as_native_str()) + } + + fn uniquify_name(&mut self, deck: &mut Deck) { + let old_parent = format!("{}\x1f", deck.name.as_native_str()); + deck.uniquify_name(&self.unique_suffix); + let new_parent = format!("{}\x1f", deck.name.as_native_str()); + self.renamed_parents.push((old_parent, new_parent)); + } + + fn add_deck(&mut self, deck: &mut Deck) -> Result<()> { + let old_id = mem::take(&mut deck.id); + self.target_col.add_deck_inner(deck, self.usn)?; + self.imported_decks.insert(old_id, deck.id); + Ok(()) + } + + /// Caller must ensure decks are normal. + fn update_deck(&mut self, deck: &Deck, original: Deck) -> Result<()> { + let mut new_deck = original.clone(); + new_deck.normal_mut()?.update_with_other(deck.normal()?); + self.imported_decks.insert(deck.id, new_deck.id); + self.target_col + .update_deck_inner(&mut new_deck, original, self.usn) + } + + fn ensure_valid_first_existing_parent(&mut self, deck: &mut Deck) -> Result<()> { + if let Some(ancestor) = self + .target_col + .first_existing_parent(deck.name.as_native_str(), 0)? + { + if ancestor.is_filtered() { + self.add_unique_default_deck(ancestor.name.as_native_str())?; + self.maybe_reparent(deck); + } + } + Ok(()) + } + + fn add_unique_default_deck(&mut self, name: &str) -> Result<()> { + let mut deck = Deck::new_normal(); + deck.name = NativeDeckName::from_native_str(name); + self.uniquify_name(&mut deck); + self.target_col.add_deck_inner(&mut deck, self.usn) + } +} + +impl Deck { + fn uniquify_name(&mut self, suffix: &str) { + let new_name = format!("{} {}", self.name.as_native_str(), suffix); + self.name = NativeDeckName::from_native_str(new_name); + } + + fn level(&self) -> usize { + self.name.components().count() + } +} + +impl NormalDeck { + fn update_with_other(&mut self, other: &Self) { + if !other.description.is_empty() { + self.markdown_description = other.markdown_description; + self.description = other.description.clone(); + } + if other.config_id != 1 { + self.config_id = other.config_id; + } + } +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use super::*; + use crate::{collection::open_test_collection, tests::new_deck_with_machine_name}; + + #[test] + fn parents() { + let mut col = open_test_collection(); + + col.add_deck_with_machine_name("filtered", true); + col.add_deck_with_machine_name("PARENT", false); + + let mut ctx = DeckContext::new(&mut col, Usn(1)); + ctx.unique_suffix = "★".to_string(); + + let imports = vec![ + new_deck_with_machine_name("unknown parent\x1fchild", false), + new_deck_with_machine_name("filtered\x1fchild", false), + new_deck_with_machine_name("parent\x1fchild", false), + new_deck_with_machine_name("NEW PARENT\x1fchild", false), + new_deck_with_machine_name("new parent", false), + ]; + ctx.import_decks(imports).unwrap(); + let existing_decks: HashSet<_> = ctx + .target_col + .get_all_deck_names(true) + .unwrap() + .into_iter() + .map(|(_, name)| name) + .collect(); + + // missing parents get created + assert!(existing_decks.contains("unknown parent")); + // ... and uniquified if their existing counterparts are filtered + assert!(existing_decks.contains("filtered ★")); + assert!(existing_decks.contains("filtered ★::child")); + // the case of existing parents is matched + assert!(existing_decks.contains("PARENT::child")); + // the case of imported parents is matched, regardless of pass order + assert!(existing_decks.contains("new parent::child")); + } +} diff --git a/rslib/src/import_export/package/apkg/import/media.rs b/rslib/src/import_export/package/apkg/import/media.rs new file mode 100644 index 000000000..bffd28f2f --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/media.rs @@ -0,0 +1,134 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{collections::HashMap, fs::File, mem}; + +use zip::ZipArchive; + +use super::Context; +use crate::{ + import_export::{ + package::{ + media::{extract_media_entries, SafeMediaEntry}, + Meta, + }, + ImportProgress, IncrementableProgress, + }, + media::{ + files::{add_hash_suffix_to_file_stem, sha1_of_reader}, + MediaManager, + }, + prelude::*, +}; + +/// Map of source media files, that do not already exist in the target. +#[derive(Default)] +pub(super) struct MediaUseMap { + /// original, normalized filename → (refererenced on import material, + /// entry with possibly remapped filename) + checked: HashMap, + /// Static files (latex, underscored). Usage is not tracked, and if the name + /// already exists in the target, it is skipped regardless of content equality. + unchecked: Vec, +} + +impl Context<'_> { + pub(super) fn prepare_media(&mut self) -> Result { + let db_progress_fn = self.progress.media_db_fn(ImportProgress::MediaCheck)?; + let existing_sha1s = self.target_col.all_existing_sha1s(db_progress_fn)?; + prepare_media( + &self.meta, + &mut self.archive, + &existing_sha1s, + &mut self.progress, + ) + } + + pub(super) fn copy_media(&mut self, media_map: &mut MediaUseMap) -> Result<()> { + let mut incrementor = self.progress.incrementor(ImportProgress::Media); + for entry in media_map.used_entries() { + incrementor.increment()?; + entry.copy_from_archive(&mut self.archive, &self.target_col.media_folder)?; + } + Ok(()) + } +} + +impl Collection { + fn all_existing_sha1s( + &mut self, + progress_fn: impl FnMut(usize) -> bool, + ) -> Result> { + let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; + mgr.all_checksums(progress_fn, &self.log) + } +} + +fn prepare_media( + meta: &Meta, + archive: &mut ZipArchive, + existing_sha1s: &HashMap, + progress: &mut IncrementableProgress, +) -> Result { + let mut media_map = MediaUseMap::default(); + let mut incrementor = progress.incrementor(ImportProgress::MediaCheck); + + for mut entry in extract_media_entries(meta, archive)? { + incrementor.increment()?; + + if entry.is_static() { + if !existing_sha1s.contains_key(&entry.name) { + media_map.unchecked.push(entry); + } + } else if let Some(other_sha1) = existing_sha1s.get(&entry.name) { + entry.with_hash_from_archive(archive)?; + if entry.sha1 != *other_sha1 { + let original_name = entry.uniquify_name(); + media_map.add_checked(original_name, entry); + } + } else { + media_map.add_checked(entry.name.clone(), entry); + } + } + Ok(media_map) +} + +impl MediaUseMap { + pub(super) fn add_checked(&mut self, filename: impl Into, entry: SafeMediaEntry) { + self.checked.insert(filename.into(), (false, entry)); + } + + pub(super) fn use_entry(&mut self, filename: &str) -> Option<&SafeMediaEntry> { + self.checked.get_mut(filename).map(|(used, entry)| { + *used = true; + &*entry + }) + } + + pub(super) fn used_entries(&self) -> impl Iterator { + self.checked + .values() + .filter_map(|(used, entry)| used.then(|| entry)) + .chain(self.unchecked.iter()) + } +} + +impl SafeMediaEntry { + fn with_hash_from_archive(&mut self, archive: &mut ZipArchive) -> Result<()> { + if self.sha1 == [0; 20] { + let mut reader = self.fetch_file(archive)?; + self.sha1 = sha1_of_reader(&mut reader)?; + } + Ok(()) + } + + /// Requires sha1 to be set. Returns old file name. + fn uniquify_name(&mut self) -> String { + let new_name = add_hash_suffix_to_file_stem(&self.name, &self.sha1); + mem::replace(&mut self.name, new_name) + } + + fn is_static(&self) -> bool { + self.name.starts_with('_') || self.name.starts_with("latex-") + } +} diff --git a/rslib/src/import_export/package/apkg/import/mod.rs b/rslib/src/import_export/package/apkg/import/mod.rs new file mode 100644 index 000000000..810325b7e --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/mod.rs @@ -0,0 +1,137 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +mod cards; +mod decks; +mod media; +mod notes; + +use std::{fs::File, io, path::Path}; + +pub(crate) use notes::NoteMeta; +use rusqlite::OptionalExtension; +use tempfile::NamedTempFile; +use zip::ZipArchive; +use zstd::stream::copy_decode; + +use crate::{ + collection::CollectionBuilder, + import_export::{ + gather::ExchangeData, + package::{Meta, NoteLog}, + ImportProgress, IncrementableProgress, + }, + prelude::*, + search::SearchNode, +}; + +struct Context<'a> { + target_col: &'a mut Collection, + archive: ZipArchive, + meta: Meta, + data: ExchangeData, + usn: Usn, + progress: IncrementableProgress, +} + +impl Collection { + pub fn import_apkg( + &mut self, + path: impl AsRef, + progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + ) -> Result> { + let file = File::open(path)?; + let archive = ZipArchive::new(file)?; + + self.transact(Op::Import, |col| { + let mut ctx = Context::new(archive, col, progress_fn)?; + ctx.import() + }) + } +} + +impl<'a> Context<'a> { + fn new( + mut archive: ZipArchive, + target_col: &'a mut Collection, + progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + ) -> Result { + let progress = IncrementableProgress::new(progress_fn); + let meta = Meta::from_archive(&mut archive)?; + let data = ExchangeData::gather_from_archive( + &mut archive, + &meta, + SearchNode::WholeCollection, + true, + )?; + let usn = target_col.usn()?; + Ok(Self { + target_col, + archive, + meta, + data, + usn, + progress, + }) + } + + fn import(&mut self) -> Result { + let mut media_map = self.prepare_media()?; + self.progress.call(ImportProgress::File)?; + let note_imports = self.import_notes_and_notetypes(&mut media_map)?; + let imported_decks = self.import_decks_and_configs()?; + self.import_cards_and_revlog(¬e_imports.id_map, &imported_decks)?; + self.copy_media(&mut media_map)?; + Ok(note_imports.log) + } +} + +impl ExchangeData { + fn gather_from_archive( + archive: &mut ZipArchive, + meta: &Meta, + search: impl TryIntoSearch, + with_scheduling: bool, + ) -> Result { + let tempfile = collection_to_tempfile(meta, archive)?; + let mut col = CollectionBuilder::new(tempfile.path()).build()?; + col.maybe_upgrade_scheduler()?; + + let mut data = ExchangeData::default(); + data.gather_data(&mut col, search, with_scheduling)?; + + Ok(data) + } +} + +fn collection_to_tempfile(meta: &Meta, archive: &mut ZipArchive) -> Result { + let mut zip_file = archive.by_name(meta.collection_filename())?; + let mut tempfile = NamedTempFile::new()?; + if meta.zstd_compressed() { + copy_decode(zip_file, &mut tempfile) + } else { + io::copy(&mut zip_file, &mut tempfile).map(|_| ()) + } + .map_err(|err| AnkiError::file_io_error(err, tempfile.path()))?; + + Ok(tempfile) +} + +impl Collection { + fn maybe_upgrade_scheduler(&mut self) -> Result<()> { + if self.scheduling_included()? { + self.upgrade_to_v2_scheduler()?; + } + Ok(()) + } + + fn scheduling_included(&mut self) -> Result { + const SQL: &str = "SELECT 1 FROM cards WHERE queue != 0"; + Ok(self + .storage + .db + .query_row(SQL, [], |_| Ok(())) + .optional()? + .is_some()) + } +} diff --git a/rslib/src/import_export/package/apkg/import/notes.rs b/rslib/src/import_export/package/apkg/import/notes.rs new file mode 100644 index 000000000..4fb8431b9 --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/notes.rs @@ -0,0 +1,459 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + mem, + sync::Arc, +}; + +use sha1::Sha1; + +use super::{media::MediaUseMap, Context}; +use crate::{ + import_export::{ + package::{media::safe_normalized_file_name, LogNote, NoteLog}, + ImportProgress, IncrementableProgress, + }, + prelude::*, + text::{ + newlines_to_spaces, replace_media_refs, strip_html_preserving_media_filenames, + truncate_to_char_boundary, CowMapping, + }, +}; + +struct NoteContext<'a> { + target_col: &'a mut Collection, + usn: Usn, + normalize_notes: bool, + remapped_notetypes: HashMap, + target_guids: HashMap, + target_ids: HashSet, + media_map: &'a mut MediaUseMap, + imports: NoteImports, +} + +#[derive(Debug, Default)] +pub(super) struct NoteImports { + pub(super) id_map: HashMap, + /// All notes from the source collection as [Vec]s of their fields, and grouped + /// by import result kind. + pub(super) log: NoteLog, +} + +impl NoteImports { + fn log_new(&mut self, note: Note, source_id: NoteId) { + self.id_map.insert(source_id, note.id); + self.log.new.push(note.into_log_note()); + } + + fn log_updated(&mut self, note: Note, source_id: NoteId) { + self.id_map.insert(source_id, note.id); + self.log.updated.push(note.into_log_note()); + } + + fn log_duplicate(&mut self, mut note: Note, target_id: NoteId) { + self.id_map.insert(note.id, target_id); + // id is for looking up note in *target* collection + note.id = target_id; + self.log.duplicate.push(note.into_log_note()); + } + + fn log_conflicting(&mut self, note: Note) { + self.log.conflicting.push(note.into_log_note()); + } +} + +impl Note { + fn into_log_note(self) -> LogNote { + LogNote { + id: Some(self.id.into()), + fields: self + .into_fields() + .into_iter() + .map(|field| { + let mut reduced = strip_html_preserving_media_filenames(&field) + .map_cow(newlines_to_spaces) + .get_owned() + .unwrap_or(field); + truncate_to_char_boundary(&mut reduced, 80); + reduced + }) + .collect(), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct NoteMeta { + id: NoteId, + mtime: TimestampSecs, + notetype_id: NotetypeId, +} + +impl NoteMeta { + pub(crate) fn new(id: NoteId, mtime: TimestampSecs, notetype_id: NotetypeId) -> Self { + Self { + id, + mtime, + notetype_id, + } + } +} + +impl Context<'_> { + pub(super) fn import_notes_and_notetypes( + &mut self, + media_map: &mut MediaUseMap, + ) -> Result { + let mut ctx = NoteContext::new(self.usn, self.target_col, media_map)?; + ctx.import_notetypes(mem::take(&mut self.data.notetypes))?; + ctx.import_notes(mem::take(&mut self.data.notes), &mut self.progress)?; + Ok(ctx.imports) + } +} + +impl<'n> NoteContext<'n> { + fn new<'a: 'n>( + usn: Usn, + target_col: &'a mut Collection, + media_map: &'a mut MediaUseMap, + ) -> Result { + let target_guids = target_col.storage.note_guid_map()?; + let normalize_notes = target_col.get_config_bool(BoolKey::NormalizeNoteText); + let target_ids = target_col.storage.get_all_note_ids()?; + Ok(Self { + target_col, + usn, + normalize_notes, + remapped_notetypes: HashMap::new(), + target_guids, + target_ids, + imports: NoteImports::default(), + media_map, + }) + } + + fn import_notetypes(&mut self, mut notetypes: Vec) -> Result<()> { + for notetype in &mut notetypes { + if let Some(existing) = self.target_col.storage.get_notetype(notetype.id)? { + self.merge_or_remap_notetype(notetype, existing)?; + } else { + self.add_notetype(notetype)?; + } + } + Ok(()) + } + + fn merge_or_remap_notetype( + &mut self, + incoming: &mut Notetype, + existing: Notetype, + ) -> Result<()> { + if incoming.schema_hash() == existing.schema_hash() { + if incoming.mtime_secs > existing.mtime_secs { + self.update_notetype(incoming, existing)?; + } + } else { + self.add_notetype_with_remapped_id(incoming)?; + } + Ok(()) + } + + fn add_notetype(&mut self, notetype: &mut Notetype) -> Result<()> { + notetype.prepare_for_update(None, true)?; + self.target_col + .ensure_notetype_name_unique(notetype, self.usn)?; + notetype.usn = self.usn; + self.target_col + .add_notetype_with_unique_id_undoable(notetype) + } + + fn update_notetype(&mut self, notetype: &mut Notetype, original: Notetype) -> Result<()> { + notetype.usn = self.usn; + self.target_col + .add_or_update_notetype_with_existing_id_inner(notetype, Some(original), self.usn, true) + } + + fn add_notetype_with_remapped_id(&mut self, notetype: &mut Notetype) -> Result<()> { + let old_id = std::mem::take(&mut notetype.id); + notetype.usn = self.usn; + self.target_col + .add_notetype_inner(notetype, self.usn, true)?; + self.remapped_notetypes.insert(old_id, notetype.id); + Ok(()) + } + + fn import_notes( + &mut self, + notes: Vec, + progress: &mut IncrementableProgress, + ) -> Result<()> { + let mut incrementor = progress.incrementor(ImportProgress::Notes); + + for mut note in notes { + incrementor.increment()?; + if let Some(notetype_id) = self.remapped_notetypes.get(¬e.notetype_id) { + if self.target_guids.contains_key(¬e.guid) { + self.imports.log_conflicting(note); + } else { + note.notetype_id = *notetype_id; + self.add_note(note)?; + } + } else if let Some(&meta) = self.target_guids.get(¬e.guid) { + self.maybe_update_note(note, meta)?; + } else { + self.add_note(note)?; + } + } + Ok(()) + } + + fn add_note(&mut self, mut note: Note) -> Result<()> { + self.munge_media(&mut note)?; + self.target_col.canonify_note_tags(&mut note, self.usn)?; + let notetype = self.get_expected_notetype(note.notetype_id)?; + note.prepare_for_update(¬etype, self.normalize_notes)?; + note.usn = self.usn; + let old_id = self.uniquify_note_id(&mut note); + + self.target_col.add_note_only_with_id_undoable(&mut note)?; + self.target_ids.insert(note.id); + self.imports.log_new(note, old_id); + + Ok(()) + } + + fn uniquify_note_id(&mut self, note: &mut Note) -> NoteId { + let original = note.id; + while self.target_ids.contains(¬e.id) { + note.id.0 += 999; + } + original + } + + fn get_expected_notetype(&mut self, ntid: NotetypeId) -> Result> { + self.target_col + .get_notetype(ntid)? + .ok_or(AnkiError::NotFound) + } + + fn get_expected_note(&mut self, nid: NoteId) -> Result { + self.target_col + .storage + .get_note(nid)? + .ok_or(AnkiError::NotFound) + } + + fn maybe_update_note(&mut self, note: Note, meta: NoteMeta) -> Result<()> { + if meta.mtime < note.mtime { + if meta.notetype_id == note.notetype_id { + self.update_note(note, meta.id)?; + } else { + self.imports.log_conflicting(note); + } + } else { + self.imports.log_duplicate(note, meta.id); + } + Ok(()) + } + + fn update_note(&mut self, mut note: Note, target_id: NoteId) -> Result<()> { + let source_id = note.id; + note.id = target_id; + self.munge_media(&mut note)?; + let original = self.get_expected_note(note.id)?; + let notetype = self.get_expected_notetype(note.notetype_id)?; + self.target_col.update_note_inner_without_cards( + &mut note, + &original, + ¬etype, + self.usn, + true, + self.normalize_notes, + true, + )?; + self.imports.log_updated(note, source_id); + Ok(()) + } + + fn munge_media(&mut self, note: &mut Note) -> Result<()> { + for field in note.fields_mut() { + if let Some(new_field) = self.replace_media_refs(field) { + *field = new_field; + }; + } + Ok(()) + } + + fn replace_media_refs(&mut self, field: &mut String) -> Option { + replace_media_refs(field, |name| { + if let Ok(normalized) = safe_normalized_file_name(name) { + if let Some(entry) = self.media_map.use_entry(&normalized) { + if entry.name != name { + // name is not normalized, and/or remapped + return Some(entry.name.clone()); + } + } else if let Cow::Owned(s) = normalized { + // no entry; might be a reference to an existing file, so ensure normalization + return Some(s); + } + } + None + }) + } +} + +impl Notetype { + fn schema_hash(&self) -> Sha1Hash { + let mut hasher = Sha1::new(); + for field in &self.fields { + hasher.update(field.name.as_bytes()); + } + for template in &self.templates { + hasher.update(template.name.as_bytes()); + } + hasher.digest().bytes() + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{collection::open_test_collection, import_export::package::media::SafeMediaEntry}; + + /// Import [Note] into [Collection], optionally taking a [MediaUseMap], + /// or a [Notetype] remapping. + macro_rules! import_note { + ($col:expr, $note:expr, $old_notetype:expr => $new_notetype:expr) => {{ + let mut media_map = MediaUseMap::default(); + let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut media_map).unwrap(); + ctx.remapped_notetypes.insert($old_notetype, $new_notetype); + let mut progress = IncrementableProgress::new(|_, _| true); + ctx.import_notes(vec![$note], &mut progress).unwrap(); + ctx.imports.log + }}; + ($col:expr, $note:expr, $media_map:expr) => {{ + let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut $media_map).unwrap(); + let mut progress = IncrementableProgress::new(|_, _| true); + ctx.import_notes(vec![$note], &mut progress).unwrap(); + ctx.imports.log + }}; + ($col:expr, $note:expr) => {{ + let mut media_map = MediaUseMap::default(); + import_note!($col, $note, media_map) + }}; + } + + /// Assert that exactly one [Note] is logged, and that with the given state and fields. + macro_rules! assert_note_logged { + ($log:expr, $state:ident, $fields:expr) => { + assert_eq!($log.$state.pop().unwrap().fields, $fields); + assert!($log.new.is_empty()); + assert!($log.updated.is_empty()); + assert!($log.duplicate.is_empty()); + assert!($log.conflicting.is_empty()); + }; + } + + impl Collection { + fn note_id_for_guid(&self, guid: &str) -> NoteId { + self.storage + .db + .query_row("SELECT id FROM notes WHERE guid = ?", [guid], |r| r.get(0)) + .unwrap() + } + } + + #[test] + fn should_add_note_with_new_id_if_guid_is_unique_and_id_is_not() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.guid = "other".to_string(); + let original_id = note.id; + + let mut log = import_note!(col, note); + assert_ne!(col.note_id_for_guid("other"), original_id); + assert_note_logged!(log, new, &["", ""]); + } + + #[test] + fn should_skip_note_if_guid_already_exists_with_newer_mtime() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.mtime.0 -= 1; + note.fields_mut()[0] = "outdated".to_string(); + + let mut log = import_note!(col, note); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, duplicate, &["outdated", ""]); + } + + #[test] + fn should_update_note_if_guid_already_exists_with_different_id() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.id.0 = 42; + note.mtime.0 += 1; + note.fields_mut()[0] = "updated".to_string(); + + let mut log = import_note!(col, note); + assert_eq!(col.get_all_notes()[0].fields()[0], "updated"); + assert_note_logged!(log, updated, &["updated", ""]); + } + + #[test] + fn should_ignore_note_if_guid_already_exists_with_different_notetype() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.notetype_id.0 = 42; + note.mtime.0 += 1; + note.fields_mut()[0] = "updated".to_string(); + + let mut log = import_note!(col, note); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, conflicting, &["updated", ""]); + } + + #[test] + fn should_add_note_with_remapped_notetype_if_in_notetype_map() { + let mut col = open_test_collection(); + let basic_ntid = col.get_notetype_by_name("basic").unwrap().unwrap().id; + let mut note = col.new_note("basic"); + note.notetype_id.0 = 123; + + let mut log = import_note!(col, note, NotetypeId(123) => basic_ntid); + assert_eq!(col.get_all_notes()[0].notetype_id, basic_ntid); + assert_note_logged!(log, new, &["", ""]); + } + + #[test] + fn should_ignore_note_if_guid_already_exists_and_notetype_is_remapped() { + let mut col = open_test_collection(); + let basic_ntid = col.get_notetype_by_name("basic").unwrap().unwrap().id; + let mut note = col.add_new_note("basic"); + note.notetype_id.0 = 123; + note.mtime.0 += 1; + note.fields_mut()[0] = "updated".to_string(); + + let mut log = import_note!(col, note, NotetypeId(123) => basic_ntid); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, conflicting, &["updated", ""]); + } + + #[test] + fn should_add_note_with_remapped_media_reference_in_field_if_in_media_map() { + let mut col = open_test_collection(); + let mut note = col.new_note("basic"); + note.fields_mut()[0] = "".to_string(); + + let mut media_map = MediaUseMap::default(); + let entry = SafeMediaEntry::from_legacy(("0", "bar.jpg".to_string())).unwrap(); + media_map.add_checked("foo.jpg", entry); + + let mut log = import_note!(col, note, media_map); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, new, &[" bar.jpg ", ""]); + } +} diff --git a/rslib/src/import_export/package/apkg/mod.rs b/rslib/src/import_export/package/apkg/mod.rs new file mode 100644 index 000000000..0ac21fac7 --- /dev/null +++ b/rslib/src/import_export/package/apkg/mod.rs @@ -0,0 +1,8 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +mod export; +mod import; +mod tests; + +pub(crate) use import::NoteMeta; diff --git a/rslib/src/import_export/package/apkg/tests.rs b/rslib/src/import_export/package/apkg/tests.rs new file mode 100644 index 000000000..812f6a963 --- /dev/null +++ b/rslib/src/import_export/package/apkg/tests.rs @@ -0,0 +1,150 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +#![cfg(test)] + +use std::{collections::HashSet, fs::File, io::Write}; + +use crate::{ + media::files::sha1_of_data, prelude::*, search::SearchNode, tests::open_fs_test_collection, +}; + +const SAMPLE_JPG: &str = "sample.jpg"; +const SAMPLE_MP3: &str = "sample.mp3"; +const SAMPLE_JS: &str = "_sample.js"; +const JPG_DATA: &[u8] = b"1"; +const MP3_DATA: &[u8] = b"2"; +const JS_DATA: &[u8] = b"3"; +const OTHER_MP3_DATA: &[u8] = b"4"; + +#[test] +fn roundtrip() { + let (mut src_col, src_tempdir) = open_fs_test_collection("src"); + let (mut target_col, _target_tempdir) = open_fs_test_collection("target"); + let apkg_path = src_tempdir.path().join("test.apkg"); + + let (main_deck, sibling_deck) = src_col.add_sample_decks(); + let notetype = src_col.add_sample_notetype(); + let note = src_col.add_sample_note(&main_deck, &sibling_deck, ¬etype); + src_col.add_sample_media(); + target_col.add_conflicting_media(); + + src_col + .export_apkg( + &apkg_path, + SearchNode::from_deck_name("parent::sample"), + true, + true, + true, + None, + |_, _| true, + ) + .unwrap(); + target_col.import_apkg(&apkg_path, |_, _| true).unwrap(); + + target_col.assert_decks(); + target_col.assert_notetype(¬etype); + target_col.assert_note_and_media(¬e); + + target_col.undo().unwrap(); + target_col.assert_empty(); +} + +impl Collection { + fn add_sample_decks(&mut self) -> (Deck, Deck) { + let sample = self.add_named_deck("parent\x1fsample"); + self.add_named_deck("parent\x1fsample\x1fchild"); + let siblings = self.add_named_deck("siblings"); + + (sample, siblings) + } + + fn add_named_deck(&mut self, name: &str) -> Deck { + let mut deck = Deck::new_normal(); + deck.name = NativeDeckName::from_native_str(name); + self.add_deck(&mut deck).unwrap(); + deck + } + + fn add_sample_notetype(&mut self) -> Notetype { + let mut nt = Notetype { + name: "sample".into(), + ..Default::default() + }; + nt.add_field("sample"); + nt.add_template("sample1", "{{sample}}", ""); + nt.add_template("sample2", "{{sample}}2", ""); + self.add_notetype(&mut nt, true).unwrap(); + nt + } + + fn add_sample_note( + &mut self, + main_deck: &Deck, + sibling_decks: &Deck, + notetype: &Notetype, + ) -> Note { + let mut sample = notetype.new_note(); + sample.fields_mut()[0] = format!(" [sound:{SAMPLE_MP3}]"); + sample.tags = vec!["sample".into()]; + self.add_note(&mut sample, main_deck.id).unwrap(); + + let card = self + .storage + .get_card_by_ordinal(sample.id, 1) + .unwrap() + .unwrap(); + self.set_deck(&[card.id], sibling_decks.id).unwrap(); + + sample + } + + fn add_sample_media(&self) { + self.add_media(&[ + (SAMPLE_JPG, JPG_DATA), + (SAMPLE_MP3, MP3_DATA), + (SAMPLE_JS, JS_DATA), + ]); + } + + fn add_conflicting_media(&mut self) { + let mut file = File::create(self.media_folder.join(SAMPLE_MP3)).unwrap(); + file.write_all(OTHER_MP3_DATA).unwrap(); + } + + fn assert_decks(&mut self) { + let existing_decks: HashSet<_> = self + .get_all_deck_names(true) + .unwrap() + .into_iter() + .map(|(_, name)| name) + .collect(); + for deck in ["parent", "parent::sample", "siblings"] { + assert!(existing_decks.contains(deck)); + } + assert!(!existing_decks.contains("parent::sample::child")); + } + + fn assert_notetype(&mut self, notetype: &Notetype) { + assert!(self.get_notetype(notetype.id).unwrap().is_some()); + } + + fn assert_note_and_media(&mut self, note: &Note) { + let sha1 = sha1_of_data(MP3_DATA); + let new_mp3_name = format!("sample-{}.mp3", hex::encode(&sha1)); + + for file in [SAMPLE_JPG, SAMPLE_JS, &new_mp3_name] { + assert!(self.media_folder.join(file).exists()) + } + + let imported_note = self.storage.get_note(note.id).unwrap().unwrap(); + assert!(imported_note.fields()[0].contains(&new_mp3_name)); + } + + fn assert_empty(&self) { + assert!(self.get_all_deck_names(true).unwrap().is_empty()); + assert!(self.storage.get_all_note_ids().unwrap().is_empty()); + assert!(self.storage.get_all_card_ids().unwrap().is_empty()); + assert!(self.storage.all_tags().unwrap().is_empty()); + } +} diff --git a/rslib/src/import_export/package/colpkg/export.rs b/rslib/src/import_export/package/colpkg/export.rs index 6dc5de69f..d4fc0352d 100644 --- a/rslib/src/import_export/package/colpkg/export.rs +++ b/rslib/src/import_export/package/colpkg/export.rs @@ -4,7 +4,8 @@ use std::{ borrow::Cow, collections::HashMap, - fs::{DirEntry, File}, + ffi::OsStr, + fs::File, io::{self, Read, Write}, path::{Path, PathBuf}, }; @@ -21,6 +22,7 @@ use zstd::{ use super::super::{MediaEntries, MediaEntry, Meta, Version}; use crate::{ collection::CollectionBuilder, + import_export::IncrementableProgress, io::{atomic_rename, read_dir_files, tempfile_in_parent_of}, media::files::filename_if_normalized, prelude::*, @@ -38,8 +40,9 @@ impl Collection { out_path: impl AsRef, include_media: bool, legacy: bool, - progress_fn: impl FnMut(usize), + progress_fn: impl 'static + FnMut(usize, bool) -> bool, ) -> Result<()> { + let mut progress = IncrementableProgress::new(progress_fn); let colpkg_name = out_path.as_ref(); let temp_colpkg = tempfile_in_parent_of(colpkg_name)?; let src_path = self.col_path.clone(); @@ -61,19 +64,48 @@ impl Collection { src_media_folder, legacy, &tr, - progress_fn, + &mut progress, )?; atomic_rename(temp_colpkg, colpkg_name, true) } } +pub struct MediaIter(Box>>); + +impl MediaIter { + /// Iterator over all files in the given path, without traversing subfolders. + pub fn from_folder(path: &Path) -> Result { + Ok(Self(Box::new( + read_dir_files(path)?.map(|res| res.map(|entry| entry.path())), + ))) + } + + /// Iterator over all given files in the given folder. + /// Missing files are silently ignored. + pub fn from_file_list( + list: impl IntoIterator + 'static, + folder: PathBuf, + ) -> Self { + Self(Box::new( + list.into_iter() + .map(move |file| folder.join(file)) + .filter(|path| path.exists()) + .map(Ok), + )) + } + + pub fn empty() -> Self { + Self(Box::new(std::iter::empty())) + } +} + fn export_collection_file( out_path: impl AsRef, col_path: impl AsRef, media_dir: Option, legacy: bool, tr: &I18n, - progress_fn: impl FnMut(usize), + progress: &mut IncrementableProgress, ) -> Result<()> { let meta = if legacy { Meta::new_legacy() @@ -82,15 +114,13 @@ fn export_collection_file( }; let mut col_file = File::open(col_path)?; let col_size = col_file.metadata()?.len() as usize; - export_collection( - meta, - out_path, - &mut col_file, - col_size, - media_dir, - tr, - progress_fn, - ) + let media = if let Some(path) = media_dir { + MediaIter::from_folder(&path)? + } else { + MediaIter::empty() + }; + + export_collection(meta, out_path, &mut col_file, col_size, media, tr, progress) } /// Write copied collection data without any media. @@ -105,20 +135,20 @@ pub(crate) fn export_colpkg_from_data( out_path, &mut col_data, col_size, - None, + MediaIter::empty(), tr, - |_| (), + &mut IncrementableProgress::new(|_, _| true), ) } -fn export_collection( +pub(crate) fn export_collection( meta: Meta, out_path: impl AsRef, col: &mut impl Read, col_size: usize, - media_dir: Option, + media: MediaIter, tr: &I18n, - progress_fn: impl FnMut(usize), + progress: &mut IncrementableProgress, ) -> Result<()> { let out_file = File::create(&out_path)?; let mut zip = ZipWriter::new(out_file); @@ -129,7 +159,7 @@ fn export_collection( zip.write_all(&meta_bytes)?; write_collection(&meta, &mut zip, col, col_size)?; write_dummy_collection(&mut zip, tr)?; - write_media(&meta, &mut zip, media_dir, progress_fn)?; + write_media(&meta, &mut zip, media, progress)?; zip.finish()?; Ok(()) @@ -203,17 +233,12 @@ fn zstd_copy(reader: &mut impl Read, writer: &mut impl Write, size: usize) -> Re fn write_media( meta: &Meta, zip: &mut ZipWriter, - media_dir: Option, - progress_fn: impl FnMut(usize), + media: MediaIter, + progress: &mut IncrementableProgress, ) -> Result<()> { let mut media_entries = vec![]; - - if let Some(media_dir) = media_dir { - write_media_files(meta, zip, &media_dir, &mut media_entries, progress_fn)?; - } - + write_media_files(meta, zip, media, &mut media_entries, progress)?; write_media_map(meta, media_entries, zip)?; - Ok(()) } @@ -251,19 +276,23 @@ fn write_media_map( fn write_media_files( meta: &Meta, zip: &mut ZipWriter, - dir: &Path, + media: MediaIter, media_entries: &mut Vec, - mut progress_fn: impl FnMut(usize), + progress: &mut IncrementableProgress, ) -> Result<()> { let mut copier = MediaCopier::new(meta); - for (index, entry) in read_dir_files(dir)?.enumerate() { - progress_fn(index); + let mut incrementor = progress.incrementor(|u| u); + for (index, res) in media.0.enumerate() { + incrementor.increment()?; + let path = res?; zip.start_file(index.to_string(), file_options_stored())?; - let entry = entry?; - let name = normalized_unicode_file_name(&entry)?; - let mut file = File::open(entry.path())?; + let mut file = File::open(&path)?; + let file_name = path + .file_name() + .ok_or_else(|| AnkiError::invalid_input("not a file path"))?; + let name = normalized_unicode_file_name(file_name)?; let (size, sha1) = copier.copy(&mut file, zip)?; media_entries.push(MediaEntry::new(name, size, sha1)); @@ -272,23 +301,11 @@ fn write_media_files( Ok(()) } -impl MediaEntry { - fn new(name: impl Into, size: impl TryInto, sha1: impl Into>) -> Self { - MediaEntry { - name: name.into(), - size: size.try_into().unwrap_or_default(), - sha1: sha1.into(), - legacy_zip_filename: None, - } - } -} - -fn normalized_unicode_file_name(entry: &DirEntry) -> Result { - let filename = entry.file_name(); +fn normalized_unicode_file_name(filename: &OsStr) -> Result { let filename = filename.to_str().ok_or_else(|| { AnkiError::IoError(format!( "non-unicode file name: {}", - entry.file_name().to_string_lossy() + filename.to_string_lossy() )) })?; filename_if_normalized(filename) @@ -324,7 +341,7 @@ impl MediaCopier { &mut self, reader: &mut impl Read, writer: &mut impl Write, - ) -> Result<(usize, [u8; 20])> { + ) -> Result<(usize, Sha1Hash)> { let mut size = 0; let mut hasher = Sha1::new(); let mut buf = [0; 64 * 1024]; diff --git a/rslib/src/import_export/package/colpkg/import.rs b/rslib/src/import_export/package/colpkg/import.rs index 3cdffcba7..0cc38f791 100644 --- a/rslib/src/import_export/package/colpkg/import.rs +++ b/rslib/src/import_export/package/colpkg/import.rs @@ -2,64 +2,39 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html use std::{ - borrow::Cow, - collections::HashMap, - fs::{self, File}, - io::{self, Read, Write}, - path::{Component, Path, PathBuf}, + fs::File, + io::{self, Write}, + path::{Path, PathBuf}, }; -use prost::Message; use zip::{read::ZipFile, ZipArchive}; use zstd::{self, stream::copy_decode}; -use super::super::Version; use crate::{ collection::CollectionBuilder, error::ImportError, import_export::{ - package::{MediaEntries, MediaEntry, Meta}, - ImportProgress, + package::{ + media::{extract_media_entries, SafeMediaEntry}, + Meta, + }, + ImportProgress, IncrementableProgress, }, io::{atomic_rename, tempfile_in_parent_of}, - media::files::normalize_filename, + media::MediaManager, prelude::*, }; -impl Meta { - /// Extracts meta data from an archive and checks if its version is supported. - pub(super) fn from_archive(archive: &mut ZipArchive) -> Result { - let meta_bytes = archive.by_name("meta").ok().and_then(|mut meta_file| { - let mut buf = vec![]; - meta_file.read_to_end(&mut buf).ok()?; - Some(buf) - }); - let meta = if let Some(bytes) = meta_bytes { - let meta: Meta = Message::decode(&*bytes)?; - if meta.version() == Version::Unknown { - return Err(AnkiError::ImportError(ImportError::TooNew)); - } - meta - } else { - Meta { - version: if archive.by_name("collection.anki21").is_ok() { - Version::Legacy2 - } else { - Version::Legacy1 - } as i32, - } - }; - Ok(meta) - } -} - pub fn import_colpkg( colpkg_path: &str, target_col_path: &str, - target_media_folder: &str, - mut progress_fn: impl FnMut(ImportProgress) -> Result<()>, + target_media_folder: &Path, + media_db: &Path, + progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + log: &Logger, ) -> Result<()> { - progress_fn(ImportProgress::Collection)?; + let mut progress = IncrementableProgress::new(progress_fn); + progress.call(ImportProgress::File)?; let col_path = PathBuf::from(target_col_path); let mut tempfile = tempfile_in_parent_of(&col_path)?; @@ -68,12 +43,18 @@ pub fn import_colpkg( let meta = Meta::from_archive(&mut archive)?; copy_collection(&mut archive, &mut tempfile, &meta)?; - progress_fn(ImportProgress::Collection)?; + progress.call(ImportProgress::File)?; check_collection_and_mod_schema(tempfile.path())?; - progress_fn(ImportProgress::Collection)?; + progress.call(ImportProgress::File)?; - let media_folder = Path::new(target_media_folder); - restore_media(&meta, progress_fn, &mut archive, media_folder)?; + restore_media( + &meta, + &mut progress, + &mut archive, + target_media_folder, + media_db, + log, + )?; atomic_rename(tempfile, &col_path, true) } @@ -96,31 +77,25 @@ fn check_collection_and_mod_schema(col_path: &Path) -> Result<()> { fn restore_media( meta: &Meta, - mut progress_fn: impl FnMut(ImportProgress) -> Result<()>, + progress: &mut IncrementableProgress, archive: &mut ZipArchive, media_folder: &Path, + media_db: &Path, + log: &Logger, ) -> Result<()> { let media_entries = extract_media_entries(meta, archive)?; + if media_entries.is_empty() { + return Ok(()); + } + std::fs::create_dir_all(media_folder)?; + let media_manager = MediaManager::new(media_folder, media_db)?; + let mut media_comparer = MediaComparer::new(meta, progress, &media_manager, log)?; - for (entry_idx, entry) in media_entries.iter().enumerate() { - if entry_idx % 10 == 0 { - progress_fn(ImportProgress::Media(entry_idx))?; - } - - let zip_filename = entry - .legacy_zip_filename - .map(|n| n as usize) - .unwrap_or(entry_idx) - .to_string(); - - if let Ok(mut zip_file) = archive.by_name(&zip_filename) { - maybe_restore_media_file(meta, media_folder, entry, &mut zip_file)?; - } else { - return Err(AnkiError::invalid_input(&format!( - "{zip_filename} missing from archive" - ))); - } + let mut incrementor = progress.incrementor(ImportProgress::Media); + for mut entry in media_entries { + incrementor.increment()?; + maybe_restore_media_file(meta, media_folder, archive, &mut entry, &mut media_comparer)?; } Ok(()) @@ -129,13 +104,19 @@ fn restore_media( fn maybe_restore_media_file( meta: &Meta, media_folder: &Path, - entry: &MediaEntry, - zip_file: &mut ZipFile, + archive: &mut ZipArchive, + entry: &mut SafeMediaEntry, + media_comparer: &mut MediaComparer, ) -> Result<()> { - let file_path = entry.safe_normalized_file_path(meta, media_folder)?; - let already_exists = entry.is_equal_to(meta, zip_file, &file_path); + let file_path = entry.file_path(media_folder); + let mut zip_file = entry.fetch_file(archive)?; + if meta.media_list_is_hashmap() { + entry.size = zip_file.size() as u32; + } + + let already_exists = media_comparer.entry_is_equal_to(entry, &file_path)?; if !already_exists { - restore_media_file(meta, zip_file, &file_path)?; + restore_media_file(meta, &mut zip_file, &file_path)?; }; Ok(()) @@ -154,79 +135,6 @@ fn restore_media_file(meta: &Meta, zip_file: &mut ZipFile, path: &Path) -> Resul atomic_rename(tempfile, path, false) } -impl MediaEntry { - fn safe_normalized_file_path(&self, meta: &Meta, media_folder: &Path) -> Result { - check_filename_safe(&self.name)?; - let normalized = maybe_normalizing(&self.name, meta.strict_media_checks())?; - Ok(media_folder.join(normalized.as_ref())) - } - - fn is_equal_to(&self, meta: &Meta, self_zipped: &ZipFile, other_path: &Path) -> bool { - // TODO: checks hashs (https://github.com/ankitects/anki/pull/1723#discussion_r829653147) - let self_size = if meta.media_list_is_hashmap() { - self_zipped.size() - } else { - self.size as u64 - }; - fs::metadata(other_path) - .map(|metadata| metadata.len() as u64 == self_size) - .unwrap_or_default() - } -} - -/// - If strict is true, return an error if not normalized. -/// - If false, return the normalized version. -fn maybe_normalizing(name: &str, strict: bool) -> Result> { - let normalized = normalize_filename(name); - if strict && matches!(normalized, Cow::Owned(_)) { - // exporting code should have checked this - Err(AnkiError::ImportError(ImportError::Corrupt)) - } else { - Ok(normalized) - } -} - -/// Return an error if name contains any path separators. -fn check_filename_safe(name: &str) -> Result<()> { - let mut components = Path::new(name).components(); - let first_element_normal = components - .next() - .map(|component| matches!(component, Component::Normal(_))) - .unwrap_or_default(); - if !first_element_normal || components.next().is_some() { - Err(AnkiError::ImportError(ImportError::Corrupt)) - } else { - Ok(()) - } -} - -fn extract_media_entries(meta: &Meta, archive: &mut ZipArchive) -> Result> { - let mut file = archive.by_name("media")?; - let mut buf = Vec::new(); - if meta.zstd_compressed() { - copy_decode(file, &mut buf)?; - } else { - io::copy(&mut file, &mut buf)?; - } - if meta.media_list_is_hashmap() { - let map: HashMap<&str, String> = serde_json::from_slice(&buf)?; - map.into_iter() - .map(|(idx_str, name)| { - let idx: u32 = idx_str.parse()?; - Ok(MediaEntry { - name, - size: 0, - sha1: vec![], - legacy_zip_filename: Some(idx), - }) - }) - .collect() - } else { - let entries: MediaEntries = Message::decode(&*buf)?; - Ok(entries.entries) - } -} - fn copy_collection( archive: &mut ZipArchive, writer: &mut impl Write, @@ -244,29 +152,31 @@ fn copy_collection( Ok(()) } -#[cfg(test)] -mod test { - use super::*; +type GetChecksumFn<'a> = dyn FnMut(&str) -> Result> + 'a; - #[test] - fn path_traversal() { - assert!(check_filename_safe("foo").is_ok(),); +struct MediaComparer<'a>(Option>>); - assert!(check_filename_safe("..").is_err()); - assert!(check_filename_safe("foo/bar").is_err()); - assert!(check_filename_safe("/foo").is_err()); - assert!(check_filename_safe("../foo").is_err()); +impl<'a> MediaComparer<'a> { + fn new( + meta: &Meta, + progress: &mut IncrementableProgress, + media_manager: &'a MediaManager, + log: &Logger, + ) -> Result { + Ok(Self(if meta.media_list_is_hashmap() { + None + } else { + let mut db_progress_fn = progress.media_db_fn(ImportProgress::MediaCheck)?; + media_manager.register_changes(&mut db_progress_fn, log)?; + Some(Box::new(media_manager.checksum_getter())) + })) + } - if cfg!(windows) { - assert!(check_filename_safe("foo\\bar").is_err()); - assert!(check_filename_safe("c:\\foo").is_err()); - assert!(check_filename_safe("\\foo").is_err()); + fn entry_is_equal_to(&mut self, entry: &SafeMediaEntry, other_path: &Path) -> Result { + if let Some(ref mut get_checksum) = self.0 { + Ok(entry.has_checksum_equal_to(get_checksum)?) + } else { + Ok(entry.has_size_equal_to(other_path)) } } - - #[test] - fn normalization() { - assert_eq!(&maybe_normalizing("con", false).unwrap(), "con_"); - assert!(&maybe_normalizing("con", true).is_err()); - } } diff --git a/rslib/src/import_export/package/colpkg/tests.rs b/rslib/src/import_export/package/colpkg/tests.rs index 08c84012f..b07ef084c 100644 --- a/rslib/src/import_export/package/colpkg/tests.rs +++ b/rslib/src/import_export/package/colpkg/tests.rs @@ -8,8 +8,8 @@ use std::path::Path; use tempfile::tempdir; use crate::{ - collection::CollectionBuilder, import_export::package::import_colpkg, media::MediaManager, - prelude::*, + collection::CollectionBuilder, import_export::package::import_colpkg, log::terminal, + media::MediaManager, prelude::*, }; fn collection_with_media(dir: &Path, name: &str) -> Result { @@ -41,19 +41,26 @@ fn roundtrip() -> Result<()> { // export to a file let col = collection_with_media(dir, name)?; let colpkg_name = dir.join(format!("{name}.colpkg")); - col.export_colpkg(&colpkg_name, true, legacy, |_| ())?; + col.export_colpkg(&colpkg_name, true, legacy, |_, _| true)?; + // import into a new collection let anki2_name = dir .join(format!("{name}.anki2")) .to_string_lossy() .into_owned(); let import_media_dir = dir.join(format!("{name}.media")); + std::fs::create_dir_all(&import_media_dir)?; + let import_media_db = dir.join(format!("{name}.mdb")); + MediaManager::new(&import_media_dir, &import_media_db)?; import_colpkg( &colpkg_name.to_string_lossy(), &anki2_name, - import_media_dir.to_str().unwrap(), - |_| Ok(()), + &import_media_dir, + &import_media_db, + |_, _| true, + &terminal(), )?; + // confirm collection imported let col = CollectionBuilder::new(&anki2_name).build()?; assert_eq!( @@ -82,7 +89,7 @@ fn normalization_check_on_export() -> Result<()> { // manually write a file in the wrong encoding. std::fs::write(col.media_folder.join("ぱぱ.jpg"), "nfd encoding")?; assert_eq!( - col.export_colpkg(&colpkg_name, true, false, |_| ()) + col.export_colpkg(&colpkg_name, true, false, |_, _| true,) .unwrap_err(), AnkiError::MediaCheckRequired ); diff --git a/rslib/src/import_export/package/media.rs b/rslib/src/import_export/package/media.rs new file mode 100644 index 000000000..3613ffa72 --- /dev/null +++ b/rslib/src/import_export/package/media.rs @@ -0,0 +1,174 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + borrow::Cow, + collections::HashMap, + fs::{self, File}, + io, + path::{Path, PathBuf}, +}; + +use prost::Message; +use tempfile::NamedTempFile; +use zip::{read::ZipFile, ZipArchive}; +use zstd::stream::copy_decode; + +use super::{MediaEntries, MediaEntry, Meta}; +use crate::{ + error::ImportError, + io::{atomic_rename, filename_is_safe}, + media::files::normalize_filename, + prelude::*, +}; + +/// Like [MediaEntry], but with a safe filename and set zip filename. +pub(super) struct SafeMediaEntry { + pub(super) name: String, + pub(super) size: u32, + pub(super) sha1: Sha1Hash, + pub(super) index: usize, +} + +impl MediaEntry { + pub(super) fn new( + name: impl Into, + size: impl TryInto, + sha1: impl Into>, + ) -> Self { + MediaEntry { + name: name.into(), + size: size.try_into().unwrap_or_default(), + sha1: sha1.into(), + legacy_zip_filename: None, + } + } +} + +impl SafeMediaEntry { + pub(super) fn from_entry(enumerated: (usize, MediaEntry)) -> Result { + let (index, entry) = enumerated; + if let Ok(sha1) = entry.sha1.try_into() { + if !matches!(safe_normalized_file_name(&entry.name)?, Cow::Owned(_)) { + return Ok(Self { + name: entry.name, + size: entry.size, + sha1, + index, + }); + } + } + Err(AnkiError::ImportError(ImportError::Corrupt)) + } + + pub(super) fn from_legacy(legacy_entry: (&str, String)) -> Result { + let zip_filename: usize = legacy_entry.0.parse()?; + let name = match safe_normalized_file_name(&legacy_entry.1)? { + Cow::Owned(new_name) => new_name, + Cow::Borrowed(_) => legacy_entry.1, + }; + Ok(Self { + name, + size: 0, + sha1: [0; 20], + index: zip_filename, + }) + } + + pub(super) fn file_path(&self, media_folder: &Path) -> PathBuf { + media_folder.join(&self.name) + } + + pub(super) fn fetch_file<'a>(&self, archive: &'a mut ZipArchive) -> Result> { + archive + .by_name(&self.index.to_string()) + .map_err(|_| AnkiError::invalid_input(&format!("{} missing from archive", self.index))) + } + + pub(super) fn has_checksum_equal_to( + &self, + get_checksum: &mut impl FnMut(&str) -> Result>, + ) -> Result { + get_checksum(&self.name).map(|opt| opt.map_or(false, |sha1| sha1 == self.sha1)) + } + + pub(super) fn has_size_equal_to(&self, other_path: &Path) -> bool { + fs::metadata(other_path).map_or(false, |metadata| metadata.len() == self.size as u64) + } + + pub(super) fn copy_from_archive( + &self, + archive: &mut ZipArchive, + target_folder: &Path, + ) -> Result<()> { + let mut file = self.fetch_file(archive)?; + let mut tempfile = NamedTempFile::new_in(target_folder)?; + io::copy(&mut file, &mut tempfile)?; + atomic_rename(tempfile, &self.file_path(target_folder), false) + } +} + +pub(super) fn extract_media_entries( + meta: &Meta, + archive: &mut ZipArchive, +) -> Result> { + let media_list_data = get_media_list_data(archive, meta)?; + if meta.media_list_is_hashmap() { + let map: HashMap<&str, String> = serde_json::from_slice(&media_list_data)?; + map.into_iter().map(SafeMediaEntry::from_legacy).collect() + } else { + MediaEntries::decode_safe_entries(&media_list_data) + } +} + +pub(super) fn safe_normalized_file_name(name: &str) -> Result> { + if !filename_is_safe(name) { + Err(AnkiError::ImportError(ImportError::Corrupt)) + } else { + Ok(normalize_filename(name)) + } +} + +fn get_media_list_data(archive: &mut ZipArchive, meta: &Meta) -> Result> { + let mut file = archive.by_name("media")?; + let mut buf = Vec::new(); + if meta.zstd_compressed() { + copy_decode(file, &mut buf)?; + } else { + io::copy(&mut file, &mut buf)?; + } + Ok(buf) +} + +impl MediaEntries { + fn decode_safe_entries(buf: &[u8]) -> Result> { + let entries: Self = Message::decode(buf)?; + entries + .entries + .into_iter() + .enumerate() + .map(SafeMediaEntry::from_entry) + .collect() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn normalization() { + // legacy entries get normalized on deserialisation + let entry = SafeMediaEntry::from_legacy(("1", "con".to_owned())).unwrap(); + assert_eq!(entry.name, "con_"); + + // new-style entries should have been normalized on export + let mut entries = Vec::new(); + MediaEntries { + entries: vec![MediaEntry::new("con", 0, Vec::new())], + } + .encode(&mut entries) + .unwrap(); + assert!(MediaEntries::decode_safe_entries(&entries).is_err()); + } +} diff --git a/rslib/src/import_export/package/meta.rs b/rslib/src/import_export/package/meta.rs index c2ac4e80c..62e8dbda4 100644 --- a/rslib/src/import_export/package/meta.rs +++ b/rslib/src/import_export/package/meta.rs @@ -1,7 +1,13 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::{fs::File, io::Read}; + +use prost::Message; +use zip::ZipArchive; + pub(super) use crate::backend_proto::{package_metadata::Version, PackageMetadata as Meta}; +use crate::{error::ImportError, prelude::*, storage::SchemaVersion}; impl Version { pub(super) fn collection_filename(&self) -> &'static str { @@ -12,6 +18,16 @@ impl Version { Version::Latest => "collection.anki21b", } } + + /// Latest schema version that is supported by all clients supporting + /// this package version. + pub(super) fn schema_version(&self) -> SchemaVersion { + match self { + Version::Unknown => unreachable!(), + Version::Legacy1 | Version::Legacy2 => SchemaVersion::V11, + Version::Latest => SchemaVersion::V18, + } + } } impl Meta { @@ -27,10 +43,41 @@ impl Meta { } } + /// Extracts meta data from an archive and checks if its version is supported. + pub(super) fn from_archive(archive: &mut ZipArchive) -> Result { + let meta_bytes = archive.by_name("meta").ok().and_then(|mut meta_file| { + let mut buf = vec![]; + meta_file.read_to_end(&mut buf).ok()?; + Some(buf) + }); + let meta = if let Some(bytes) = meta_bytes { + let meta: Meta = Message::decode(&*bytes)?; + if meta.version() == Version::Unknown { + return Err(AnkiError::ImportError(ImportError::TooNew)); + } + meta + } else { + Meta { + version: if archive.by_name("collection.anki21").is_ok() { + Version::Legacy2 + } else { + Version::Legacy1 + } as i32, + } + }; + Ok(meta) + } + pub(super) fn collection_filename(&self) -> &'static str { self.version().collection_filename() } + /// Latest schema version that is supported by all clients supporting + /// this package version. + pub(super) fn schema_version(&self) -> SchemaVersion { + self.version().schema_version() + } + pub(super) fn zstd_compressed(&self) -> bool { !self.is_legacy() } @@ -39,10 +86,6 @@ impl Meta { self.is_legacy() } - pub(super) fn strict_media_checks(&self) -> bool { - !self.is_legacy() - } - fn is_legacy(&self) -> bool { matches!(self.version(), Version::Legacy1 | Version::Legacy2) } diff --git a/rslib/src/import_export/package/mod.rs b/rslib/src/import_export/package/mod.rs index 66d3ca14e..f999d84ce 100644 --- a/rslib/src/import_export/package/mod.rs +++ b/rslib/src/import_export/package/mod.rs @@ -1,11 +1,15 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +mod apkg; mod colpkg; +mod media; mod meta; +pub(crate) use apkg::NoteMeta; pub(crate) use colpkg::export::export_colpkg_from_data; pub use colpkg::import::import_colpkg; pub(self) use meta::{Meta, Version}; +pub use crate::backend_proto::import_anki_package_response::{Log as NoteLog, Note as LogNote}; pub(self) use crate::backend_proto::{media_entries::MediaEntry, MediaEntries}; diff --git a/rslib/src/io.rs b/rslib/src/io.rs index 8a68ffdb3..fc86c5db6 100644 --- a/rslib/src/io.rs +++ b/rslib/src/io.rs @@ -1,7 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use std::path::Path; +use std::path::{Component, Path}; use tempfile::NamedTempFile; @@ -42,6 +42,17 @@ pub(crate) fn read_dir_files(path: impl AsRef) -> std::io::Result bool { + let mut components = Path::new(name).components(); + let first_element_normal = components + .next() + .map(|component| matches!(component, Component::Normal(_))) + .unwrap_or_default(); + + first_element_normal && components.next().is_none() +} + pub(crate) struct ReadDirFiles(std::fs::ReadDir); impl Iterator for ReadDirFiles { @@ -60,3 +71,24 @@ impl Iterator for ReadDirFiles { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn path_traversal() { + assert!(filename_is_safe("foo")); + + assert!(!filename_is_safe("..")); + assert!(!filename_is_safe("foo/bar")); + assert!(!filename_is_safe("/foo")); + assert!(!filename_is_safe("../foo")); + + if cfg!(windows) { + assert!(!filename_is_safe("foo\\bar")); + assert!(!filename_is_safe("c:\\foo")); + assert!(!filename_is_safe("\\foo")); + } + } +} diff --git a/rslib/src/lib.rs b/rslib/src/lib.rs index 9e03edc38..dc43db87e 100644 --- a/rslib/src/lib.rs +++ b/rslib/src/lib.rs @@ -40,6 +40,7 @@ mod sync; pub mod tags; pub mod template; pub mod template_filters; +pub(crate) mod tests; pub mod text; pub mod timestamp; pub mod types; diff --git a/rslib/src/media/changetracker.rs b/rslib/src/media/changetracker.rs index 6db3e94f9..dcf04da85 100644 --- a/rslib/src/media/changetracker.rs +++ b/rslib/src/media/changetracker.rs @@ -4,7 +4,6 @@ use std::{collections::HashMap, path::Path, time}; use crate::{ - error::{AnkiError, Result}, log::{debug, Logger}, media::{ database::{MediaDatabaseContext, MediaEntry}, @@ -13,11 +12,12 @@ use crate::{ NONSYNCABLE_FILENAME, }, }, + prelude::*, }; struct FilesystemEntry { fname: String, - sha1: Option<[u8; 20]>, + sha1: Option, mtime: i64, is_new: bool, } diff --git a/rslib/src/media/database.rs b/rslib/src/media/database.rs index 858035c33..927c1e6fd 100644 --- a/rslib/src/media/database.rs +++ b/rslib/src/media/database.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, path::Path}; use rusqlite::{params, Connection, OptionalExtension, Row, Statement}; -use crate::error::Result; +use crate::prelude::*; fn trace(s: &str) { println!("sql: {}", s) @@ -47,7 +47,7 @@ fn initial_db_setup(db: &mut Connection) -> Result<()> { pub struct MediaEntry { pub fname: String, /// If None, file has been deleted - pub sha1: Option<[u8; 20]>, + pub sha1: Option, // Modification time; 0 if deleted pub mtime: i64, /// True if changed since last sync @@ -222,6 +222,14 @@ delete from media where fname=?" Ok(map?) } + /// Returns all filenames and checksums, where the checksum is not null. + pub(super) fn all_checksums(&mut self) -> Result> { + self.db + .prepare("SELECT fname, csum FROM media WHERE csum IS NOT NULL")? + .query_and_then([], row_to_name_and_checksum)? + .collect() + } + pub(super) fn force_resync(&mut self) -> Result<()> { self.db .execute_batch("delete from media; update meta set lastUsn = 0, dirMod = 0") @@ -231,7 +239,7 @@ delete from media where fname=?" fn row_to_entry(row: &Row) -> rusqlite::Result { // map the string checksum into bytes - let sha1_str: Option = row.get(1)?; + let sha1_str = row.get_ref(1)?.as_str_or_null()?; let sha1_array = if let Some(s) = sha1_str { let mut arr = [0; 20]; match hex::decode_to_slice(s, arr.as_mut()) { @@ -250,6 +258,15 @@ fn row_to_entry(row: &Row) -> rusqlite::Result { }) } +fn row_to_name_and_checksum(row: &Row) -> Result<(String, Sha1Hash)> { + let file_name = row.get(0)?; + let sha1_str: String = row.get(1)?; + let mut sha1 = [0; 20]; + hex::decode_to_slice(sha1_str, &mut sha1) + .map_err(|_| AnkiError::invalid_input(format!("bad media checksum: {file_name}")))?; + Ok((file_name, sha1)) +} + #[cfg(test)] mod test { use tempfile::NamedTempFile; diff --git a/rslib/src/media/files.rs b/rslib/src/media/files.rs index b56b1baa3..077e4ce93 100644 --- a/rslib/src/media/files.rs +++ b/rslib/src/media/files.rs @@ -15,10 +15,7 @@ use sha1::Sha1; use unic_ucd_category::GeneralCategory; use unicode_normalization::{is_nfc, UnicodeNormalization}; -use crate::{ - error::{AnkiError, Result}, - log::{debug, Logger}, -}; +use crate::prelude::*; /// The maximum length we allow a filename to be. When combined /// with the rest of the path, the full path needs to be under ~240 chars @@ -164,7 +161,7 @@ pub fn add_data_to_folder_uniquely<'a, P>( folder: P, desired_name: &'a str, data: &[u8], - sha1: [u8; 20], + sha1: Sha1Hash, ) -> io::Result> where P: AsRef, @@ -194,7 +191,7 @@ where } /// Convert foo.jpg into foo-abcde12345679.jpg -fn add_hash_suffix_to_file_stem(fname: &str, hash: &[u8; 20]) -> String { +pub(crate) fn add_hash_suffix_to_file_stem(fname: &str, hash: &Sha1Hash) -> String { // when appending a hash to make unique, it will be 40 bytes plus the hyphen. let max_len = MAX_FILENAME_LENGTH - 40 - 1; @@ -244,18 +241,18 @@ fn split_and_truncate_filename(fname: &str, max_bytes: usize) -> (&str, &str) { }; // cap extension to 10 bytes so stem_len can't be negative - ext = truncate_to_char_boundary(ext, 10); + ext = truncated_to_char_boundary(ext, 10); // cap stem, allowing for the . and a trailing _ let stem_len = max_bytes - ext.len() - 2; - stem = truncate_to_char_boundary(stem, stem_len); + stem = truncated_to_char_boundary(stem, stem_len); (stem, ext) } -/// Trim a string on a valid UTF8 boundary. +/// Return a substring on a valid UTF8 boundary. /// Based on a funtion in the Rust stdlib. -fn truncate_to_char_boundary(s: &str, mut max: usize) -> &str { +fn truncated_to_char_boundary(s: &str, mut max: usize) -> &str { if max >= s.len() { s } else { @@ -267,7 +264,7 @@ fn truncate_to_char_boundary(s: &str, mut max: usize) -> &str { } /// Return the SHA1 of a file if it exists, or None. -fn existing_file_sha1(path: &Path) -> io::Result> { +fn existing_file_sha1(path: &Path) -> io::Result> { match sha1_of_file(path) { Ok(o) => Ok(Some(o)), Err(e) => { @@ -281,12 +278,17 @@ fn existing_file_sha1(path: &Path) -> io::Result> { } /// Return the SHA1 of a file, failing if it doesn't exist. -pub(crate) fn sha1_of_file(path: &Path) -> io::Result<[u8; 20]> { +pub(crate) fn sha1_of_file(path: &Path) -> io::Result { let mut file = fs::File::open(path)?; + sha1_of_reader(&mut file) +} + +/// Return the SHA1 of a stream. +pub(crate) fn sha1_of_reader(reader: &mut impl Read) -> io::Result { let mut hasher = Sha1::new(); let mut buf = [0; 64 * 1024]; loop { - match file.read(&mut buf) { + match reader.read(&mut buf) { Ok(0) => break, Ok(n) => hasher.update(&buf[0..n]), Err(e) => { @@ -302,7 +304,7 @@ pub(crate) fn sha1_of_file(path: &Path) -> io::Result<[u8; 20]> { } /// Return the SHA1 of provided data. -pub(crate) fn sha1_of_data(data: &[u8]) -> [u8; 20] { +pub(crate) fn sha1_of_data(data: &[u8]) -> Sha1Hash { let mut hasher = Sha1::new(); hasher.update(data); hasher.digest().bytes() @@ -371,7 +373,7 @@ pub(super) fn trash_folder(media_folder: &Path) -> Result { pub(super) struct AddedFile { pub fname: String, - pub sha1: [u8; 20], + pub sha1: Sha1Hash, pub mtime: i64, pub renamed_from: Option, } diff --git a/rslib/src/media/mod.rs b/rslib/src/media/mod.rs index 351befa1f..de3fb73e6 100644 --- a/rslib/src/media/mod.rs +++ b/rslib/src/media/mod.rs @@ -3,19 +3,21 @@ use std::{ borrow::Cow, + collections::HashMap, path::{Path, PathBuf}, }; use rusqlite::Connection; use slog::Logger; +use self::changetracker::ChangeTracker; use crate::{ - error::Result, media::{ database::{open_or_create, MediaDatabaseContext, MediaEntry}, files::{add_data_to_folder_uniquely, mtime_as_i64, remove_files, sha1_of_data}, sync::{MediaSyncProgress, MediaSyncer}, }, + prelude::*, }; pub mod changetracker; @@ -24,6 +26,8 @@ pub mod database; pub mod files; pub mod sync; +pub type Sha1Hash = [u8; 20]; + pub struct MediaManager { db: Connection, media_folder: PathBuf, @@ -153,4 +157,31 @@ impl MediaManager { pub fn dbctx(&self) -> MediaDatabaseContext { MediaDatabaseContext::new(&self.db) } + + pub fn all_checksums( + &self, + progress: impl FnMut(usize) -> bool, + log: &Logger, + ) -> Result> { + let mut dbctx = self.dbctx(); + ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut dbctx)?; + dbctx.all_checksums() + } + + pub fn checksum_getter(&self) -> impl FnMut(&str) -> Result> + '_ { + let mut dbctx = self.dbctx(); + move |fname: &str| { + dbctx + .get_entry(fname) + .map(|opt| opt.and_then(|entry| entry.sha1)) + } + } + + pub fn register_changes( + &self, + progress: &mut impl FnMut(usize) -> bool, + log: &Logger, + ) -> Result<()> { + ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut self.dbctx()) + } } diff --git a/rslib/src/notes/mod.rs b/rslib/src/notes/mod.rs index 5eed822f4..ed8ec29fa 100644 --- a/rslib/src/notes/mod.rs +++ b/rslib/src/notes/mod.rs @@ -55,6 +55,10 @@ impl Note { &self.fields } + pub fn into_fields(self) -> Vec { + self.fields + } + pub fn set_field(&mut self, idx: usize, text: impl Into) -> Result<()> { if idx >= self.fields.len() { return Err(AnkiError::invalid_input( @@ -320,7 +324,7 @@ fn invalid_char_for_field(c: char) -> bool { } impl Collection { - fn canonify_note_tags(&mut self, note: &mut Note, usn: Usn) -> Result<()> { + pub(crate) fn canonify_note_tags(&mut self, note: &mut Note, usn: Usn) -> Result<()> { if !note.tags.is_empty() { let tags = std::mem::take(&mut note.tags); note.tags = self.canonify_tags(tags, usn)?.0; diff --git a/rslib/src/notes/undo.rs b/rslib/src/notes/undo.rs index 709778810..89e22dc48 100644 --- a/rslib/src/notes/undo.rs +++ b/rslib/src/notes/undo.rs @@ -83,13 +83,23 @@ impl Collection { } /// Add a note, not adding any cards. - pub(super) fn add_note_only_undoable(&mut self, note: &mut Note) -> Result<(), AnkiError> { + pub(crate) fn add_note_only_undoable(&mut self, note: &mut Note) -> Result<(), AnkiError> { self.storage.add_note(note)?; self.save_undo(UndoableNoteChange::Added(Box::new(note.clone()))); Ok(()) } + /// Add a note, not adding any cards. Caller guarantees id is unique. + pub(crate) fn add_note_only_with_id_undoable(&mut self, note: &mut Note) -> Result<()> { + if self.storage.add_note_if_unique(note)? { + self.save_undo(UndoableNoteChange::Added(Box::new(note.clone()))); + Ok(()) + } else { + Err(AnkiError::invalid_input("note id existed")) + } + } + pub(crate) fn update_note_tags_undoable( &mut self, tags: &NoteTags, diff --git a/rslib/src/notetype/mod.rs b/rslib/src/notetype/mod.rs index eb00fc10d..84ebb3e24 100644 --- a/rslib/src/notetype/mod.rs +++ b/rslib/src/notetype/mod.rs @@ -648,7 +648,7 @@ impl Collection { /// - Caller must set notetype as modified if appropriate. /// - This only supports undo when an existing notetype is passed in. - fn add_or_update_notetype_with_existing_id_inner( + pub(crate) fn add_or_update_notetype_with_existing_id_inner( &mut self, notetype: &mut Notetype, original: Option, diff --git a/rslib/src/notetype/undo.rs b/rslib/src/notetype/undo.rs index 74f6d34a6..5780ed9b2 100644 --- a/rslib/src/notetype/undo.rs +++ b/rslib/src/notetype/undo.rs @@ -42,6 +42,17 @@ impl Collection { Ok(()) } + /// Caller must ensure [NotetypeId] is unique. + pub(crate) fn add_notetype_with_unique_id_undoable( + &mut self, + notetype: &Notetype, + ) -> Result<()> { + self.storage + .add_or_update_notetype_with_existing_id(notetype)?; + self.save_undo(UndoableNotetypeChange::Added(Box::new(notetype.clone()))); + Ok(()) + } + pub(super) fn update_notetype_undoable( &mut self, notetype: &Notetype, diff --git a/rslib/src/ops.rs b/rslib/src/ops.rs index 5daa2c582..c3d8f4a28 100644 --- a/rslib/src/ops.rs +++ b/rslib/src/ops.rs @@ -17,6 +17,7 @@ pub enum Op { CreateCustomStudy, EmptyFilteredDeck, FindAndReplace, + Import, RebuildFilteredDeck, RemoveDeck, RemoveNote, @@ -55,6 +56,7 @@ impl Op { Op::AnswerCard => tr.actions_answer_card(), Op::Bury => tr.studying_bury(), Op::CreateCustomStudy => tr.actions_custom_study(), + Op::Import => tr.actions_import(), Op::RemoveDeck => tr.decks_delete_deck(), Op::RemoveNote => tr.studying_delete_note(), Op::RenameDeck => tr.actions_rename_deck(), diff --git a/rslib/src/prelude.rs b/rslib/src/prelude.rs index 721aaaf52..042080415 100644 --- a/rslib/src/prelude.rs +++ b/rslib/src/prelude.rs @@ -12,6 +12,7 @@ pub use crate::{ decks::{Deck, DeckId, DeckKind, NativeDeckName}, error::{AnkiError, Result}, i18n::I18n, + media::Sha1Hash, notes::{Note, NoteId}, notetype::{Notetype, NotetypeId}, ops::{Op, OpChanges, OpOutput}, diff --git a/rslib/src/revlog/undo.rs b/rslib/src/revlog/undo.rs index 24ea6f61a..71602f0cc 100644 --- a/rslib/src/revlog/undo.rs +++ b/rslib/src/revlog/undo.rs @@ -28,9 +28,17 @@ impl Collection { /// Add the provided revlog entry, modifying the ID if it is not unique. pub(crate) fn add_revlog_entry_undoable(&mut self, mut entry: RevlogEntry) -> Result { - entry.id = self.storage.add_revlog_entry(&entry, true)?; + entry.id = self.storage.add_revlog_entry(&entry, true)?.unwrap(); let id = entry.id; self.save_undo(UndoableRevlogChange::Added(Box::new(entry))); Ok(id) } + + /// Add the provided revlog entry, if its ID is unique. + pub(crate) fn add_revlog_entry_if_unique_undoable(&mut self, entry: RevlogEntry) -> Result<()> { + if self.storage.add_revlog_entry(&entry, false)?.is_some() { + self.save_undo(UndoableRevlogChange::Added(Box::new(entry))); + } + Ok(()) + } } diff --git a/rslib/src/search/builder.rs b/rslib/src/search/builder.rs index c51c50ee9..dab57edfa 100644 --- a/rslib/src/search/builder.rs +++ b/rslib/src/search/builder.rs @@ -134,6 +134,14 @@ impl Default for SearchBuilder { } impl SearchNode { + pub fn from_deck_id(did: impl Into, with_children: bool) -> Self { + if with_children { + Self::DeckIdWithChildren(did.into()) + } else { + Self::DeckIdWithoutChildren(did.into()) + } + } + /// Construct [SearchNode] from an unescaped deck name. pub fn from_deck_name(name: &str) -> Self { Self::Deck(escape_anki_wildcards_for_search_node(name)) @@ -158,6 +166,14 @@ impl SearchNode { name, ))) } + + pub fn from_note_ids, N: Into>(ids: I) -> Self { + Self::NoteIds(ids.into_iter().map(Into::into).join(",")) + } + + pub fn from_card_ids, C: Into>(ids: I) -> Self { + Self::CardIds(ids.into_iter().map(Into::into).join(",")) + } } impl> From for Node { diff --git a/rslib/src/stats/graphs.rs b/rslib/src/stats/graphs.rs index 0a158eb66..eb5236e66 100644 --- a/rslib/src/stats/graphs.rs +++ b/rslib/src/stats/graphs.rs @@ -38,7 +38,7 @@ impl Collection { self.storage.get_all_revlog_entries(revlog_start)? } else { self.storage - .get_revlog_entries_for_searched_cards(revlog_start)? + .get_pb_revlog_entries_for_searched_cards(revlog_start)? }; self.storage.clear_searched_cards_table()?; diff --git a/rslib/src/storage/card/add_card_if_unique.sql b/rslib/src/storage/card/add_card_if_unique.sql new file mode 100644 index 000000000..80b3a6394 --- /dev/null +++ b/rslib/src/storage/card/add_card_if_unique.sql @@ -0,0 +1,41 @@ +INSERT + OR IGNORE INTO cards ( + id, + nid, + did, + ord, + mod, + usn, + type, + queue, + due, + ivl, + factor, + reps, + lapses, + left, + odue, + odid, + flags, + data + ) +VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ? + ) \ No newline at end of file diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index f62053ae8..96b314486 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -145,6 +145,34 @@ impl super::SqliteStorage { Ok(()) } + /// Add card if id is unique. True if card was added. + pub(crate) fn add_card_if_unique(&self, card: &Card) -> Result { + self.db + .prepare_cached(include_str!("add_card_if_unique.sql"))? + .execute(params![ + card.id, + card.note_id, + card.deck_id, + card.template_idx, + card.mtime, + card.usn, + card.ctype as u8, + card.queue as i8, + card.due, + card.interval, + card.ease_factor, + card.reps, + card.lapses, + card.remaining_steps, + card.original_due, + card.original_deck_id, + card.flags, + CardData::from_card(card), + ]) + .map(|n_rows| n_rows == 1) + .map_err(Into::into) + } + /// Add or update card, using the provided ID. Used for syncing & undoing. pub(crate) fn add_or_update_card(&self, card: &Card) -> Result<()> { let mut stmt = self.db.prepare_cached(include_str!("add_or_update.sql"))?; @@ -400,6 +428,20 @@ impl super::SqliteStorage { .collect() } + pub(crate) fn get_all_card_ids(&self) -> Result> { + self.db + .prepare("SELECT id FROM cards")? + .query_and_then([], |row| Ok(row.get(0)?))? + .collect() + } + + pub(crate) fn all_cards_as_nid_and_ord(&self) -> Result> { + self.db + .prepare("SELECT nid, ord FROM cards")? + .query_and_then([], |r| Ok((NoteId(r.get(0)?), r.get(1)?)))? + .collect() + } + pub(crate) fn card_ids_of_notes(&self, nids: &[NoteId]) -> Result> { let mut stmt = self .db @@ -455,6 +497,16 @@ impl super::SqliteStorage { Ok(nids) } + /// Place the ids of cards with notes in 'search_nids' into 'search_cids'. + /// Returns number of added cards. + pub(crate) fn search_cards_of_notes_into_table(&self) -> Result { + self.setup_searched_cards_table()?; + self.db + .prepare(include_str!("search_cards_of_notes_into_table.sql"))? + .execute([]) + .map_err(Into::into) + } + pub(crate) fn all_searched_cards(&self) -> Result> { self.db .prepare_cached(concat!( diff --git a/rslib/src/storage/card/search_cards_of_notes_into_table.sql b/rslib/src/storage/card/search_cards_of_notes_into_table.sql new file mode 100644 index 000000000..0387fd566 --- /dev/null +++ b/rslib/src/storage/card/search_cards_of_notes_into_table.sql @@ -0,0 +1,7 @@ +INSERT INTO search_cids +SELECT id +FROM cards +WHERE nid IN ( + SELECT nid + FROM search_nids + ) \ No newline at end of file diff --git a/rslib/src/storage/deck/mod.rs b/rslib/src/storage/deck/mod.rs index da1c5d0f7..f99467143 100644 --- a/rslib/src/storage/deck/mod.rs +++ b/rslib/src/storage/deck/mod.rs @@ -74,6 +74,15 @@ impl SqliteStorage { .transpose() } + pub(crate) fn get_deck_by_name(&self, machine_name: &str) -> Result> { + self.db + .prepare_cached(concat!(include_str!("get_deck.sql"), " WHERE name = ?"))? + .query_and_then([machine_name], row_to_deck)? + .next() + .transpose() + .map_err(Into::into) + } + pub(crate) fn get_all_decks(&self) -> Result> { self.db .prepare(include_str!("get_deck.sql"))? @@ -111,6 +120,17 @@ impl SqliteStorage { .map_err(Into::into) } + pub(crate) fn get_decks_for_search_cards(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get_deck.sql"), + " WHERE id IN (SELECT DISTINCT did FROM cards WHERE id IN", + " (SELECT cid FROM search_cids))", + ))? + .query_and_then([], row_to_deck)? + .collect() + } + // caller should ensure name unique pub(crate) fn add_deck(&self, deck: &mut Deck) -> Result<()> { assert!(deck.id.0 == 0); diff --git a/rslib/src/storage/deckconfig/add_if_unique.sql b/rslib/src/storage/deckconfig/add_if_unique.sql new file mode 100644 index 000000000..516024466 --- /dev/null +++ b/rslib/src/storage/deckconfig/add_if_unique.sql @@ -0,0 +1,3 @@ +INSERT + OR IGNORE INTO deck_config (id, name, mtime_secs, usn, config) +VALUES (?, ?, ?, ?, ?); \ No newline at end of file diff --git a/rslib/src/storage/deckconfig/mod.rs b/rslib/src/storage/deckconfig/mod.rs index aeffc0151..c7b31dc92 100644 --- a/rslib/src/storage/deckconfig/mod.rs +++ b/rslib/src/storage/deckconfig/mod.rs @@ -67,6 +67,22 @@ impl SqliteStorage { Ok(()) } + pub(crate) fn add_deck_conf_if_unique(&self, conf: &DeckConfig) -> Result { + let mut conf_bytes = vec![]; + conf.inner.encode(&mut conf_bytes)?; + self.db + .prepare_cached(include_str!("add_if_unique.sql"))? + .execute(params![ + conf.id, + conf.name, + conf.mtime_secs, + conf.usn, + conf_bytes, + ]) + .map(|added| added == 1) + .map_err(Into::into) + } + pub(crate) fn update_deck_conf(&self, conf: &DeckConfig) -> Result<()> { let mut conf_bytes = vec![]; conf.inner.encode(&mut conf_bytes)?; diff --git a/rslib/src/storage/mod.rs b/rslib/src/storage/mod.rs index f2099d5b7..5ab5aa42d 100644 --- a/rslib/src/storage/mod.rs +++ b/rslib/src/storage/mod.rs @@ -34,9 +34,10 @@ impl SchemaVersion { } /// Write a list of IDs as '(x,y,...)' into the provided string. -pub(crate) fn ids_to_string(buf: &mut String, ids: &[T]) +pub(crate) fn ids_to_string(buf: &mut String, ids: I) where - T: std::fmt::Display, + D: std::fmt::Display, + I: IntoIterator, { buf.push('('); write_comma_separated_ids(buf, ids); @@ -44,15 +45,18 @@ where } /// Write a list of Ids as 'x,y,...' into the provided string. -pub(crate) fn write_comma_separated_ids(buf: &mut String, ids: &[T]) +pub(crate) fn write_comma_separated_ids(buf: &mut String, ids: I) where - T: std::fmt::Display, + D: std::fmt::Display, + I: IntoIterator, { - if !ids.is_empty() { - for id in ids.iter().skip(1) { - write!(buf, "{},", id).unwrap(); - } - write!(buf, "{}", ids[0]).unwrap(); + let mut trailing_sep = false; + for id in ids { + write!(buf, "{},", id).unwrap(); + trailing_sep = true; + } + if trailing_sep { + buf.pop(); } } @@ -73,17 +77,17 @@ mod test { #[test] fn ids_string() { let mut s = String::new(); - ids_to_string::(&mut s, &[]); + ids_to_string(&mut s, &[0; 0]); assert_eq!(s, "()"); s.clear(); ids_to_string(&mut s, &[7]); assert_eq!(s, "(7)"); s.clear(); ids_to_string(&mut s, &[7, 6]); - assert_eq!(s, "(6,7)"); + assert_eq!(s, "(7,6)"); s.clear(); ids_to_string(&mut s, &[7, 6, 5]); - assert_eq!(s, "(6,5,7)"); + assert_eq!(s, "(7,6,5)"); s.clear(); } } diff --git a/rslib/src/storage/note/add_if_unique.sql b/rslib/src/storage/note/add_if_unique.sql new file mode 100644 index 000000000..1dd408bbe --- /dev/null +++ b/rslib/src/storage/note/add_if_unique.sql @@ -0,0 +1,27 @@ +INSERT + OR IGNORE INTO notes ( + id, + guid, + mid, + mod, + usn, + tags, + flds, + sfld, + csum, + flags, + data + ) +VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + 0, + "" + ) \ No newline at end of file diff --git a/rslib/src/storage/note/mod.rs b/rslib/src/storage/note/mod.rs index 4c7c20bc8..72d88e4e0 100644 --- a/rslib/src/storage/note/mod.rs +++ b/rslib/src/storage/note/mod.rs @@ -1,12 +1,13 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use rusqlite::{params, Row}; use crate::{ error::Result, + import_export::package::NoteMeta, notes::{Note, NoteId, NoteTags}, notetype::NotetypeId, tags::{join_tags, split_tags}, @@ -41,6 +42,13 @@ impl super::SqliteStorage { .transpose() } + pub fn get_all_note_ids(&self) -> Result> { + self.db + .prepare("SELECT id FROM notes")? + .query_and_then([], |row| Ok(row.get(0)?))? + .collect() + } + /// If fields have been modified, caller must call note.prepare_for_update() prior to calling this. pub(crate) fn update_note(&self, note: &Note) -> Result<()> { assert!(note.id.0 != 0); @@ -77,6 +85,24 @@ impl super::SqliteStorage { Ok(()) } + pub(crate) fn add_note_if_unique(&self, note: &Note) -> Result { + self.db + .prepare_cached(include_str!("add_if_unique.sql"))? + .execute(params![ + note.id, + note.guid, + note.notetype_id, + note.mtime, + note.usn, + join_tags(¬e.tags), + join_fields(note.fields()), + note.sort_field.as_ref().unwrap(), + note.checksum.unwrap(), + ]) + .map(|added| added == 1) + .map_err(Into::into) + } + /// Add or update the provided note, preserving ID. Used by the syncing code. pub(crate) fn add_or_update_note(&self, note: &Note) -> Result<()> { let mut stmt = self.db.prepare_cached(include_str!("add_or_update.sql"))?; @@ -210,6 +236,16 @@ impl super::SqliteStorage { Ok(()) } + pub(crate) fn all_searched_notes(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get.sql"), + " WHERE id IN (SELECT nid FROM search_nids)" + ))? + .query_and_then([], |r| row_to_note(r).map_err(Into::into))? + .collect() + } + pub(crate) fn get_note_tags_by_predicate(&mut self, want: F) -> Result> where F: Fn(&str) -> bool, @@ -259,6 +295,24 @@ impl super::SqliteStorage { Ok(()) } + + pub(crate) fn note_guid_map(&mut self) -> Result> { + self.db + .prepare("SELECT guid, id, mod, mid FROM notes")? + .query_and_then([], row_to_note_meta)? + .collect() + } + + #[cfg(test)] + pub(crate) fn get_all_notes(&mut self) -> Vec { + self.db + .prepare("SELECT * FROM notes") + .unwrap() + .query_and_then([], row_to_note) + .unwrap() + .collect::>() + .unwrap() + } } fn row_to_note(row: &Row) -> Result { @@ -285,3 +339,10 @@ fn row_to_note_tags(row: &Row) -> Result { tags: row.get(3)?, }) } + +fn row_to_note_meta(row: &Row) -> Result<(String, NoteMeta)> { + Ok(( + row.get(0)?, + NoteMeta::new(row.get(1)?, row.get(2)?, row.get(3)?), + )) +} diff --git a/rslib/src/storage/notetype/mod.rs b/rslib/src/storage/notetype/mod.rs index 94b551a0a..3ada340a8 100644 --- a/rslib/src/storage/notetype/mod.rs +++ b/rslib/src/storage/notetype/mod.rs @@ -99,6 +99,23 @@ impl SqliteStorage { .map_err(Into::into) } + pub(crate) fn get_notetypes_for_search_notes(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get_notetype.sql"), + " WHERE id IN (SELECT DISTINCT mid FROM notes WHERE id IN", + " (SELECT nid FROM search_nids))", + ))? + .query_and_then([], |r| { + row_to_notetype_core(r).and_then(|mut nt| { + nt.fields = self.get_notetype_fields(nt.id)?; + nt.templates = self.get_notetype_templates(nt.id)?; + Ok(nt) + }) + })? + .collect() + } + pub fn get_all_notetype_names(&self) -> Result> { self.db .prepare_cached(include_str!("get_notetype_names.sql"))? diff --git a/rslib/src/storage/revlog/mod.rs b/rslib/src/storage/revlog/mod.rs index 4284b25be..d18d321e4 100644 --- a/rslib/src/storage/revlog/mod.rs +++ b/rslib/src/storage/revlog/mod.rs @@ -61,16 +61,19 @@ impl SqliteStorage { Ok(()) } - /// Returns the used id, which may differ if `ensure_unique` is true. + /// Adds the entry, if its id is unique. If it is not, and `uniquify` is true, + /// adds it with a new id. Returns the added id. + /// (I.e., the option is safe to unwrap, if `uniquify` is true.) pub(crate) fn add_revlog_entry( &self, entry: &RevlogEntry, - ensure_unique: bool, - ) -> Result { - self.db + uniquify: bool, + ) -> Result> { + let added = self + .db .prepare_cached(include_str!("add.sql"))? .execute(params![ - ensure_unique, + uniquify, entry.id, entry.cid, entry.usn, @@ -81,7 +84,7 @@ impl SqliteStorage { entry.taken_millis, entry.review_kind as u8 ])?; - Ok(RevlogId(self.db.last_insert_rowid())) + Ok((added > 0).then(|| RevlogId(self.db.last_insert_rowid()))) } pub(crate) fn get_revlog_entry(&self, id: RevlogId) -> Result> { @@ -107,7 +110,7 @@ impl SqliteStorage { .collect() } - pub(crate) fn get_revlog_entries_for_searched_cards( + pub(crate) fn get_pb_revlog_entries_for_searched_cards( &self, after: TimestampSecs, ) -> Result> { @@ -120,6 +123,16 @@ impl SqliteStorage { .collect() } + pub(crate) fn get_revlog_entries_for_searched_cards(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get.sql"), + " where cid in (select cid from search_cids)" + ))? + .query_and_then([], row_to_revlog_entry)? + .collect() + } + /// This includes entries from deleted cards. pub(crate) fn get_all_revlog_entries( &self, diff --git a/rslib/src/tests.rs b/rslib/src/tests.rs new file mode 100644 index 000000000..1880fb059 --- /dev/null +++ b/rslib/src/tests.rs @@ -0,0 +1,63 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +#![cfg(test)] + +use tempfile::{tempdir, TempDir}; + +use crate::{collection::CollectionBuilder, media::MediaManager, prelude::*}; + +pub(crate) fn open_fs_test_collection(name: &str) -> (Collection, TempDir) { + let tempdir = tempdir().unwrap(); + let dir = tempdir.path(); + let media_folder = dir.join(format!("{name}.media")); + std::fs::create_dir(&media_folder).unwrap(); + let col = CollectionBuilder::new(dir.join(format!("{name}.anki2"))) + .set_media_paths(media_folder, dir.join(format!("{name}.mdb"))) + .build() + .unwrap(); + (col, tempdir) +} + +impl Collection { + pub(crate) fn add_media(&self, media: &[(&str, &[u8])]) { + let mgr = MediaManager::new(&self.media_folder, &self.media_db).unwrap(); + let mut ctx = mgr.dbctx(); + for (name, data) in media { + mgr.add_file(&mut ctx, name, data).unwrap(); + } + } + + pub(crate) fn new_note(&mut self, notetype: &str) -> Note { + self.get_notetype_by_name(notetype) + .unwrap() + .unwrap() + .new_note() + } + + pub(crate) fn add_new_note(&mut self, notetype: &str) -> Note { + let mut note = self.new_note(notetype); + self.add_note(&mut note, DeckId(1)).unwrap(); + note + } + + pub(crate) fn get_all_notes(&mut self) -> Vec { + self.storage.get_all_notes() + } + + pub(crate) fn add_deck_with_machine_name(&mut self, name: &str, filtered: bool) -> Deck { + let mut deck = new_deck_with_machine_name(name, filtered); + self.add_deck_inner(&mut deck, Usn(1)).unwrap(); + deck + } +} + +pub(crate) fn new_deck_with_machine_name(name: &str, filtered: bool) -> Deck { + let mut deck = if filtered { + Deck::new_filtered() + } else { + Deck::new_normal() + }; + deck.name = NativeDeckName::from_native_str(name); + deck +} diff --git a/rslib/src/text.rs b/rslib/src/text.rs index c285f87f4..d8d1d0df8 100644 --- a/rslib/src/text.rs +++ b/rslib/src/text.rs @@ -31,6 +31,31 @@ impl Trimming for Cow<'_, str> { } } +pub(crate) trait CowMapping<'a, B: ?Sized + 'a + ToOwned> { + /// Returns [self] + /// - unchanged, if the given function returns [Cow::Borrowed] + /// - with the new value, if the given function returns [Cow::Owned] + fn map_cow(self, f: impl FnOnce(&B) -> Cow) -> Self; + fn get_owned(self) -> Option; +} + +impl<'a, B: ?Sized + 'a + ToOwned> CowMapping<'a, B> for Cow<'a, B> { + fn map_cow(self, f: impl FnOnce(&B) -> Cow) -> Self { + if let Cow::Owned(o) = f(&self) { + Cow::Owned(o) + } else { + self + } + } + + fn get_owned(self) -> Option { + match self { + Cow::Borrowed(_) => None, + Cow::Owned(s) => Some(s), + } + } +} + #[derive(Debug, PartialEq)] pub enum AvTag { SoundOrVideo(String), @@ -115,34 +140,48 @@ lazy_static! { | \[\[type:[^]]+\]\] ").unwrap(); + + /// Files included in CSS with a leading underscore. + static ref UNDERSCORED_CSS_IMPORTS: Regex = Regex::new( + r#"(?xi) + (?:@import\s+ # import statement with a bare + "(_[^"]*.css)" # double quoted + | # or + '(_[^']*.css)' # single quoted css filename + ) + | # or + (?:url\(\s* # a url function with a + "(_[^"]+)" # double quoted + | # or + '(_[^']+)' # single quoted + | # or + (_.+) # unquoted filename + \s*\)) + "#).unwrap(); + + /// Strings, src and data attributes with a leading underscore. + static ref UNDERSCORED_REFERENCES: Regex = Regex::new( + r#"(?x) + "(_[^"]+)" # double quoted + | # or + '(_[^']+)' # single quoted string + | # or + \b(?:src|data) # a 'src' or 'data' attribute + = # followed by + (_[^ >]+) # an unquoted value + "#).unwrap(); } pub fn html_to_text_line(html: &str) -> Cow { - let mut out: Cow = html.into(); - if let Cow::Owned(o) = PERSISTENT_HTML_SPACERS.replace_all(&out, " ") { - out = o.into(); - } - if let Cow::Owned(o) = UNPRINTABLE_TAGS.replace_all(&out, "") { - out = o.into(); - } - if let Cow::Owned(o) = strip_html_preserving_media_filenames(&out) { - out = o.into(); - } - out.trim() + PERSISTENT_HTML_SPACERS + .replace_all(html, " ") + .map_cow(|s| UNPRINTABLE_TAGS.replace_all(s, "")) + .map_cow(strip_html_preserving_media_filenames) + .trim() } pub fn strip_html(html: &str) -> Cow { - let mut out: Cow = html.into(); - - if let Cow::Owned(o) = strip_html_preserving_entities(html) { - out = o.into(); - } - - if let Cow::Owned(o) = decode_entities(out.as_ref()) { - out = o.into(); - } - - out + strip_html_preserving_entities(html).map_cow(decode_entities) } pub fn strip_html_preserving_entities(html: &str) -> Cow { @@ -161,18 +200,29 @@ pub fn decode_entities(html: &str) -> Cow { } } +pub(crate) fn newlines_to_spaces(text: &str) -> Cow { + if text.contains('\n') { + text.replace('\n', " ").into() + } else { + text.into() + } +} + pub fn strip_html_for_tts(html: &str) -> Cow { - let mut out: Cow = html.into(); + HTML_LINEBREAK_TAGS + .replace_all(html, " ") + .map_cow(strip_html) +} - if let Cow::Owned(o) = HTML_LINEBREAK_TAGS.replace_all(html, " ") { - out = o.into(); +/// Truncate a String on a valid UTF8 boundary. +pub(crate) fn truncate_to_char_boundary(s: &mut String, mut max: usize) { + if max >= s.len() { + return; } - - if let Cow::Owned(o) = strip_html(out.as_ref()) { - out = o.into(); + while !s.is_char_boundary(max) { + max -= 1; } - - out + s.truncate(max); } #[derive(Debug)] @@ -216,6 +266,61 @@ pub(crate) fn extract_media_refs(text: &str) -> Vec { out } +/// Calls `replacer` for every media reference in `text`, and optionally +/// replaces it with something else. [None] if no reference was found. +pub(crate) fn replace_media_refs( + text: &str, + mut replacer: impl FnMut(&str) -> Option, +) -> Option { + let mut rep = |caps: &Captures| { + let whole_match = caps.get(0).unwrap().as_str(); + let old_name = caps.iter().skip(1).find_map(|g| g).unwrap().as_str(); + let old_name_decoded = decode_entities(old_name); + + if let Some(mut new_name) = replacer(&old_name_decoded) { + if matches!(old_name_decoded, Cow::Owned(_)) { + new_name = htmlescape::encode_minimal(&new_name); + } + whole_match.replace(old_name, &new_name) + } else { + whole_match.to_owned() + } + }; + + HTML_MEDIA_TAGS + .replace_all(text, &mut rep) + .map_cow(|s| AV_TAGS.replace_all(s, &mut rep)) + .get_owned() +} + +pub(crate) fn extract_underscored_css_imports(text: &str) -> Vec<&str> { + UNDERSCORED_CSS_IMPORTS + .captures_iter(text) + .map(|caps| { + caps.get(1) + .or_else(|| caps.get(2)) + .or_else(|| caps.get(3)) + .or_else(|| caps.get(4)) + .or_else(|| caps.get(5)) + .unwrap() + .as_str() + }) + .collect() +} + +pub(crate) fn extract_underscored_references(text: &str) -> Vec<&str> { + UNDERSCORED_REFERENCES + .captures_iter(text) + .map(|caps| { + caps.get(1) + .or_else(|| caps.get(2)) + .or_else(|| caps.get(3)) + .unwrap() + .as_str() + }) + .collect() +} + pub fn strip_html_preserving_media_filenames(html: &str) -> Cow { let without_fnames = HTML_MEDIA_TAGS.replace_all(html, r" ${1}${2}${3} "); let without_html = strip_html(&without_fnames); @@ -463,4 +568,54 @@ mod test { assert!(!is_glob(r"\\\_")); assert!(glob_matcher(r"foo\*bar*")("foo*bar123")); } + + #[test] + fn extracting() { + assert_eq!( + extract_underscored_css_imports(concat!( + "@IMPORT '_foo.css'\n", + "@import \"_bar.css\"\n", + "@import '_baz.css'\n", + "@import 'nope.css'\n", + "url(_foo.css)\n", + "URL(\"_bar.css\")\n", + "@import url('_baz.css')\n", + "url('nope.css')\n", + )), + vec!["_foo.css", "_bar.css", "_baz.css", "_foo.css", "_bar.css", "_baz.css",] + ); + assert_eq!( + extract_underscored_references(concat!( + "", + "", + "\"_baz.js\"", + "\"nope.js\"", + "", + "", + "'_baz.js'", + )), + vec!["_foo.jpg", "_bar", "_baz.js", "_foo.jpg", "_bar", "_baz.js",] + ); + } + + #[test] + fn replacing() { + assert_eq!( + &replace_media_refs("[sound:bar.mp3]", |s| { + (s != "baz.jpg").then(|| "spam".to_string()) + }) + .unwrap(), + "[sound:spam]", + ); + } + + #[test] + fn truncate() { + let mut s = "日本語".to_string(); + truncate_to_char_boundary(&mut s, 6); + assert_eq!(&s, "日本"); + let mut s = "日本語".to_string(); + truncate_to_char_boundary(&mut s, 1); + assert_eq!(&s, ""); + } }