diff --git a/.bazelrc b/.bazelrc index 872241350..ba35396fe 100644 --- a/.bazelrc +++ b/.bazelrc @@ -20,6 +20,9 @@ test --aspects=@rules_rust//rust:defs.bzl%rust_clippy_aspect --output_groups=+cl # print output when test fails test --test_output=errors +# stop after one test failure +test --notest_keep_going + # don't add empty __init__.py files build --incompatible_default_to_explicit_init_py diff --git a/ftl/core/exporting.ftl b/ftl/core/exporting.ftl index b396001f3..e18b921b0 100644 --- a/ftl/core/exporting.ftl +++ b/ftl/core/exporting.ftl @@ -5,7 +5,7 @@ exporting-anki-deck-package = Anki Deck Package exporting-cards-in-plain-text = Cards in Plain Text exporting-collection = collection exporting-collection-exported = Collection exported. -exporting-colpkg-too-new = Please update to the latest Anki version, then import the .colpkg file again. +exporting-colpkg-too-new = Please update to the latest Anki version, then import the .colpkg/.apkg file again. exporting-couldnt-save-file = Couldn't save file: { $val } exporting-export = Export... exporting-export-format = Export format: @@ -14,6 +14,7 @@ exporting-include-html-and-media-references = Include HTML and media references exporting-include-media = Include media exporting-include-scheduling-information = Include scheduling information exporting-include-tags = Include tags +exporting-support-older-anki-versions = Support older Anki versions (slower/larger files) exporting-notes-in-plain-text = Notes in Plain Text exporting-selected-notes = Selected Notes exporting-card-exported = diff --git a/ftl/core/importing.ftl b/ftl/core/importing.ftl index 9ecfbb3bf..27077eabc 100644 --- a/ftl/core/importing.ftl +++ b/ftl/core/importing.ftl @@ -78,4 +78,10 @@ importing-processed-media-file = *[other] Imported { $count } media files } importing-importing-collection = Importing collection... +importing-importing-file = Importing file... importing-failed-to-import-media-file = Failed to import media file: { $debugInfo } +importing-processed-notes = + { $count -> + [one] Processed { $count } note... + *[other] Processed { $count } notes... + } diff --git a/proto/anki/import_export.proto b/proto/anki/import_export.proto index ea8bfe8ad..2b5bb74ba 100644 --- a/proto/anki/import_export.proto +++ b/proto/anki/import_export.proto @@ -5,6 +5,8 @@ syntax = "proto3"; package anki.import_export; +import "anki/collection.proto"; +import "anki/notes.proto"; import "anki/generic.proto"; service ImportExportService { @@ -12,12 +14,16 @@ service ImportExportService { returns (generic.Empty); rpc ExportCollectionPackage(ExportCollectionPackageRequest) returns (generic.Empty); + rpc ImportAnkiPackage(ImportAnkiPackageRequest) + returns (ImportAnkiPackageResponse); + rpc ExportAnkiPackage(ExportAnkiPackageRequest) returns (generic.UInt32); } message ImportCollectionPackageRequest { string col_path = 1; string backup_path = 2; string media_folder = 3; + string media_db = 4; } message ExportCollectionPackageRequest { @@ -26,6 +32,37 @@ message ExportCollectionPackageRequest { bool legacy = 3; } +message ImportAnkiPackageRequest { + string package_path = 1; +} + +message ImportAnkiPackageResponse { + message Note { + notes.NoteId id = 1; + repeated string fields = 2; + } + message Log { + repeated Note new = 1; + repeated Note updated = 2; + repeated Note duplicate = 3; + repeated Note conflicting = 4; + } + collection.OpChanges changes = 1; + Log log = 2; +} + +message ExportAnkiPackageRequest { + string out_path = 1; + bool with_scheduling = 2; + bool with_media = 3; + bool legacy = 4; + oneof selector { + generic.Empty whole_collection = 5; + int64 deck_id = 6; + notes.NoteIds note_ids = 7; + } +} + message PackageMetadata { enum Version { VERSION_UNKNOWN = 0; diff --git a/pylib/.pylintrc b/pylib/.pylintrc index c9d947569..12152de9e 100644 --- a/pylib/.pylintrc +++ b/pylib/.pylintrc @@ -22,6 +22,7 @@ ignored-classes= CustomStudyRequest, Cram, ScheduleCardsAsNewRequest, + ExportAnkiPackageRequest, [REPORTS] output-format=colorized diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 3d695b9a3..ae121c8ff 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -10,6 +10,7 @@ from anki import ( collection_pb2, config_pb2, generic_pb2, + import_export_pb2, links_pb2, search_pb2, stats_pb2, @@ -32,6 +33,7 @@ OpChangesAfterUndo = collection_pb2.OpChangesAfterUndo BrowserRow = search_pb2.BrowserRow BrowserColumns = search_pb2.BrowserColumns StripHtmlMode = card_rendering_pb2.StripHtmlRequest +ImportLogWithChanges = import_export_pb2.ImportAnkiPackageResponse import copy import os @@ -90,6 +92,19 @@ class LegacyCheckpoint: LegacyUndoResult = Union[None, LegacyCheckpoint, LegacyReviewUndo] +@dataclass +class DeckIdLimit: + deck_id: DeckId + + +@dataclass +class NoteIdsLimit: + note_ids: Sequence[NoteId] + + +ExportLimit = Union[DeckIdLimit, NoteIdsLimit, None] + + class Collection(DeprecatedNamesMixin): sched: V1Scheduler | V2Scheduler | V3Scheduler @@ -259,14 +274,6 @@ class Collection(DeprecatedNamesMixin): self._clear_caches() self.db = None - def export_collection( - self, out_path: str, include_media: bool, legacy: bool - ) -> None: - self.close_for_full_sync() - self._backend.export_collection_package( - out_path=out_path, include_media=include_media, legacy=legacy - ) - def rollback(self) -> None: self._clear_caches() self.db.rollback() @@ -321,6 +328,15 @@ class Collection(DeprecatedNamesMixin): else: return -1 + def legacy_checkpoint_pending(self) -> bool: + return ( + self._have_outstanding_checkpoint() + and time.time() - self._last_checkpoint_at < 300 + ) + + # Import/export + ########################################################################## + def create_backup( self, *, @@ -353,12 +369,40 @@ class Collection(DeprecatedNamesMixin): "Throws if backup creation failed." self._backend.await_backup_completion() - def legacy_checkpoint_pending(self) -> bool: - return ( - self._have_outstanding_checkpoint() - and time.time() - self._last_checkpoint_at < 300 + def export_collection_package( + self, out_path: str, include_media: bool, legacy: bool + ) -> None: + self.close_for_full_sync() + self._backend.export_collection_package( + out_path=out_path, include_media=include_media, legacy=legacy ) + def import_anki_package(self, path: str) -> ImportLogWithChanges: + return self._backend.import_anki_package(package_path=path) + + def export_anki_package( + self, + *, + out_path: str, + limit: ExportLimit, + with_scheduling: bool, + with_media: bool, + legacy_support: bool, + ) -> int: + request = import_export_pb2.ExportAnkiPackageRequest( + out_path=out_path, + with_scheduling=with_scheduling, + with_media=with_media, + legacy=legacy_support, + ) + if isinstance(limit, DeckIdLimit): + request.deck_id = limit.deck_id + elif isinstance(limit, NoteIdsLimit): + request.note_ids.note_ids.extend(limit.note_ids) + else: + request.whole_collection.SetInParent() + return self._backend.export_anki_package(request) + # Object helpers ########################################################################## diff --git a/pylib/anki/exporting.py b/pylib/anki/exporting.py index 8e5954d84..b17b783f6 100644 --- a/pylib/anki/exporting.py +++ b/pylib/anki/exporting.py @@ -446,7 +446,7 @@ class AnkiCollectionPackageExporter(AnkiPackageExporter): time.sleep(0.1) threading.Thread(target=progress).start() - self.col.export_collection(path, self.includeMedia, self.LEGACY) + self.col.export_collection_package(path, self.includeMedia, self.LEGACY) class AnkiCollectionPackage21bExporter(AnkiCollectionPackageExporter): diff --git a/pylib/anki/import_export_pb2.pyi b/pylib/anki/import_export_pb2.pyi new file mode 120000 index 000000000..d44638b4f --- /dev/null +++ b/pylib/anki/import_export_pb2.pyi @@ -0,0 +1 @@ +../../.bazel/bin/pylib/anki/import_export_pb2.pyi \ No newline at end of file diff --git a/pylib/anki/importing/anki2.py b/pylib/anki/importing/anki2.py index dd9ad5052..f37c089a2 100644 --- a/pylib/anki/importing/anki2.py +++ b/pylib/anki/importing/anki2.py @@ -25,6 +25,10 @@ class V2ImportIntoV1(Exception): pass +class MediaMapInvalid(Exception): + pass + + class Anki2Importer(Importer): needMapper = False diff --git a/pylib/anki/importing/apkg.py b/pylib/anki/importing/apkg.py index 049f54c37..31d1cc4fd 100644 --- a/pylib/anki/importing/apkg.py +++ b/pylib/anki/importing/apkg.py @@ -9,7 +9,7 @@ import unicodedata import zipfile from typing import Any, Optional -from anki.importing.anki2 import Anki2Importer +from anki.importing.anki2 import Anki2Importer, MediaMapInvalid from anki.utils import tmpfile @@ -36,7 +36,11 @@ class AnkiPackageImporter(Anki2Importer): # number to use during the import self.nameToNum = {} dir = self.col.media.dir() - for k, v in list(json.loads(z.read("media").decode("utf8")).items()): + try: + media_dict = json.loads(z.read("media").decode("utf8")) + except Exception as exc: + raise MediaMapInvalid() from exc + for k, v in list(media_dict.items()): path = os.path.abspath(os.path.join(dir, v)) if os.path.commonprefix([path, dir]) != dir: raise Exception("Invalid file") diff --git a/qt/aqt/browser/browser.py b/qt/aqt/browser/browser.py index ef996641c..a38b5bb95 100644 --- a/qt/aqt/browser/browser.py +++ b/qt/aqt/browser/browser.py @@ -23,7 +23,8 @@ from anki.tags import MARKED_TAG from anki.utils import is_mac from aqt import AnkiQt, gui_hooks from aqt.editor import Editor -from aqt.exporting import ExportDialog +from aqt.exporting import ExportDialog as LegacyExportDialog +from aqt.import_export.exporting import ExportDialog from aqt.operations.card import set_card_deck, set_card_flag from aqt.operations.collection import redo, undo from aqt.operations.note import remove_notes @@ -792,8 +793,12 @@ class Browser(QMainWindow): @no_arg_trigger @skip_if_selection_is_empty def _on_export_notes(self) -> None: - cids = self.selectedNotesAsCards() - ExportDialog(self.mw, cids=list(cids)) + if self.mw.pm.new_import_export(): + nids = self.selected_notes() + ExportDialog(self.mw, nids=nids) + else: + cids = self.selectedNotesAsCards() + LegacyExportDialog(self.mw, cids=list(cids)) # Flags & Marking ###################################################################### diff --git a/qt/aqt/errors.py b/qt/aqt/errors.py index 618549e1a..934bdff2d 100644 --- a/qt/aqt/errors.py +++ b/qt/aqt/errors.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Optional, TextIO, cast from markdown import markdown import aqt -from anki.errors import DocumentedError, LocalizedError +from anki.errors import DocumentedError, Interrupted, LocalizedError from aqt.qt import * from aqt.utils import showText, showWarning, supportText, tr @@ -22,7 +22,7 @@ if TYPE_CHECKING: def show_exception(*, parent: QWidget, exception: Exception) -> None: "Present a caught exception to the user using a pop-up." - if isinstance(exception, InterruptedError): + if isinstance(exception, Interrupted): # nothing to do return help_page = exception.help_page if isinstance(exception, DocumentedError) else None diff --git a/qt/aqt/exporting.py b/qt/aqt/exporting.py index 52ed11a21..370fcc85a 100644 --- a/qt/aqt/exporting.py +++ b/qt/aqt/exporting.py @@ -40,6 +40,7 @@ class ExportDialog(QDialog): self.col = mw.col.weakref() self.frm = aqt.forms.exporting.Ui_ExportDialog() self.frm.setupUi(self) + self.frm.legacy_support.setVisible(False) self.exporter: Exporter | None = None self.cids = cids disable_help_button(self) diff --git a/qt/aqt/forms/exporting.ui b/qt/aqt/forms/exporting.ui index b7b64436f..3d39e9416 100644 --- a/qt/aqt/forms/exporting.ui +++ b/qt/aqt/forms/exporting.ui @@ -6,8 +6,8 @@ 0 0 - 295 - 223 + 563 + 245 @@ -30,7 +30,14 @@ - + + + + 0 + 0 + + + @@ -40,7 +47,11 @@ - + + + 50 + + @@ -83,6 +94,16 @@ + + + + exporting_support_older_anki_versions + + + true + + + diff --git a/qt/aqt/import_export/__init__.py b/qt/aqt/import_export/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qt/aqt/import_export/exporting.py b/qt/aqt/import_export/exporting.py new file mode 100644 index 000000000..81a0c1455 --- /dev/null +++ b/qt/aqt/import_export/exporting.py @@ -0,0 +1,252 @@ +# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +from __future__ import annotations + +import os +import re +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Sequence, Type + +import aqt.forms +import aqt.main +from anki.collection import DeckIdLimit, ExportLimit, NoteIdsLimit, Progress +from anki.decks import DeckId, DeckNameId +from anki.notes import NoteId +from aqt import gui_hooks +from aqt.errors import show_exception +from aqt.operations import QueryOp +from aqt.progress import ProgressUpdate +from aqt.qt import * +from aqt.utils import ( + checkInvalidFilename, + disable_help_button, + getSaveFile, + showWarning, + tooltip, + tr, +) + + +class ExportDialog(QDialog): + def __init__( + self, + mw: aqt.main.AnkiQt, + did: DeckId | None = None, + nids: Sequence[NoteId] | None = None, + ): + QDialog.__init__(self, mw, Qt.WindowType.Window) + self.mw = mw + self.col = mw.col.weakref() + self.frm = aqt.forms.exporting.Ui_ExportDialog() + self.frm.setupUi(self) + self.exporter: Type[Exporter] = None + self.nids = nids + disable_help_button(self) + self.setup(did) + self.open() + + def setup(self, did: DeckId | None) -> None: + self.exporters: list[Type[Exporter]] = [ApkgExporter, ColpkgExporter] + self.frm.format.insertItems( + 0, [f"{e.name()} (.{e.extension})" for e in self.exporters] + ) + qconnect(self.frm.format.activated, self.exporter_changed) + if self.nids is None and not did: + # file>export defaults to colpkg + default_exporter_idx = 1 + else: + default_exporter_idx = 0 + self.frm.format.setCurrentIndex(default_exporter_idx) + self.exporter_changed(default_exporter_idx) + # deck list + if self.nids is None: + self.all_decks = self.col.decks.all_names_and_ids() + decks = [tr.exporting_all_decks()] + decks.extend(d.name for d in self.all_decks) + else: + decks = [tr.exporting_selected_notes()] + self.frm.deck.addItems(decks) + # save button + b = QPushButton(tr.exporting_export()) + self.frm.buttonBox.addButton(b, QDialogButtonBox.ButtonRole.AcceptRole) + # set default option if accessed through deck button + if did: + name = self.mw.col.decks.get(did)["name"] + index = self.frm.deck.findText(name) + self.frm.deck.setCurrentIndex(index) + self.frm.includeSched.setChecked(False) + + def exporter_changed(self, idx: int) -> None: + self.exporter = self.exporters[idx] + self.frm.includeSched.setVisible(self.exporter.show_include_scheduling) + self.frm.includeMedia.setVisible(self.exporter.show_include_media) + self.frm.includeTags.setVisible(self.exporter.show_include_tags) + self.frm.includeHTML.setVisible(self.exporter.show_include_html) + self.frm.legacy_support.setVisible(self.exporter.show_legacy_support) + self.frm.deck.setVisible(self.exporter.show_deck_list) + + def accept(self) -> None: + if not (out_path := self.get_out_path()): + return + self.exporter.export(self.mw, self.options(out_path)) + QDialog.reject(self) + + def get_out_path(self) -> str | None: + filename = self.filename() + while True: + path = getSaveFile( + parent=self, + title=tr.actions_export(), + dir_description="export", + key=self.exporter.name(), + ext=self.exporter.extension, + fname=filename, + ) + if not path: + return None + if checkInvalidFilename(os.path.basename(path), dirsep=False): + continue + path = os.path.normpath(path) + if os.path.commonprefix([self.mw.pm.base, path]) == self.mw.pm.base: + showWarning("Please choose a different export location.") + continue + break + return path + + def options(self, out_path: str) -> Options: + limit: ExportLimit = None + if self.nids: + limit = NoteIdsLimit(self.nids) + elif current_deck_id := self.current_deck_id(): + limit = DeckIdLimit(current_deck_id) + + return Options( + out_path=out_path, + include_scheduling=self.frm.includeSched.isChecked(), + include_media=self.frm.includeMedia.isChecked(), + include_tags=self.frm.includeTags.isChecked(), + include_html=self.frm.includeHTML.isChecked(), + legacy_support=self.frm.legacy_support.isChecked(), + limit=limit, + ) + + def current_deck_id(self) -> DeckId | None: + return (deck := self.current_deck()) and DeckId(deck.id) or None + + def current_deck(self) -> DeckNameId | None: + if self.exporter.show_deck_list: + if idx := self.frm.deck.currentIndex(): + return self.all_decks[idx - 1] + return None + + def filename(self) -> str: + if self.exporter.show_deck_list: + deck_name = self.frm.deck.currentText() + stem = re.sub('[\\\\/?<>:*|"^]', "_", deck_name) + else: + time_str = time.strftime("%Y-%m-%d@%H-%M-%S", time.localtime(time.time())) + stem = f"{tr.exporting_collection()}-{time_str}" + return f"{stem}.{self.exporter.extension}" + + +@dataclass +class Options: + out_path: str + include_scheduling: bool + include_media: bool + include_tags: bool + include_html: bool + legacy_support: bool + limit: ExportLimit + + +class Exporter(ABC): + extension: str + show_deck_list = False + show_include_scheduling = False + show_include_media = False + show_include_tags = False + show_include_html = False + show_legacy_support = False + + @staticmethod + @abstractmethod + def export(mw: aqt.main.AnkiQt, options: Options) -> None: + pass + + @staticmethod + @abstractmethod + def name() -> str: + pass + + +class ColpkgExporter(Exporter): + extension = "colpkg" + show_include_media = True + show_legacy_support = True + + @staticmethod + def name() -> str: + return tr.exporting_anki_collection_package() + + @staticmethod + def export(mw: aqt.main.AnkiQt, options: Options) -> None: + def on_success(_: None) -> None: + mw.reopen() + tooltip(tr.exporting_collection_exported(), parent=mw) + + def on_failure(exception: Exception) -> None: + mw.reopen() + show_exception(parent=mw, exception=exception) + + gui_hooks.collection_will_temporarily_close(mw.col) + QueryOp( + parent=mw, + op=lambda col: col.export_collection_package( + options.out_path, + include_media=options.include_media, + legacy=options.legacy_support, + ), + success=on_success, + ).with_backend_progress(export_progress_update).failure( + on_failure + ).run_in_background() + + +class ApkgExporter(Exporter): + extension = "apkg" + show_deck_list = True + show_include_scheduling = True + show_include_media = True + show_legacy_support = True + + @staticmethod + def name() -> str: + return tr.exporting_anki_deck_package() + + @staticmethod + def export(mw: aqt.main.AnkiQt, options: Options) -> None: + QueryOp( + parent=mw, + op=lambda col: col.export_anki_package( + out_path=options.out_path, + limit=options.limit, + with_scheduling=options.include_scheduling, + with_media=options.include_media, + legacy_support=options.legacy_support, + ), + success=lambda count: tooltip( + tr.exporting_note_exported(count=count), parent=mw + ), + ).with_backend_progress(export_progress_update).run_in_background() + + +def export_progress_update(progress: Progress, update: ProgressUpdate) -> None: + if not progress.HasField("exporting"): + return None + update.label = tr.exporting_exported_media_file(count=progress.exporting) + if update.user_wants_abort: + update.abort = True diff --git a/qt/aqt/import_export/importing.py b/qt/aqt/import_export/importing.py new file mode 100644 index 000000000..73a45dbd7 --- /dev/null +++ b/qt/aqt/import_export/importing.py @@ -0,0 +1,150 @@ +# Copyright: Ankitects Pty Ltd and contributors +# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +from __future__ import annotations + +from itertools import chain + +import aqt.main +from anki.collection import Collection, ImportLogWithChanges, Progress +from anki.errors import Interrupted +from aqt.operations import CollectionOp, QueryOp +from aqt.progress import ProgressUpdate +from aqt.qt import * +from aqt.utils import askUser, getFile, showInfo, showText, showWarning, tooltip, tr + + +def import_file(mw: aqt.main.AnkiQt) -> None: + if not (path := get_file_path(mw)): + return + + filename = os.path.basename(path).lower() + if filename.endswith(".anki"): + showInfo(tr.importing_anki_files_are_from_a_very()) + elif filename.endswith(".anki2"): + showInfo(tr.importing_anki2_files_are_not_directly_importable()) + elif is_collection_package(filename): + maybe_import_collection_package(mw, path) + elif filename.endswith(".apkg"): + import_anki_package(mw, path) + else: + raise NotImplementedError + + +def get_file_path(mw: aqt.main.AnkiQt) -> str | None: + if file := getFile( + mw, + tr.actions_import(), + None, + key="import", + filter=tr.importing_packaged_anki_deckcollection_apkg_colpkg_zip(), + ): + return str(file) + return None + + +def is_collection_package(filename: str) -> bool: + return ( + filename == "collection.apkg" + or (filename.startswith("backup-") and filename.endswith(".apkg")) + or filename.endswith(".colpkg") + ) + + +def maybe_import_collection_package(mw: aqt.main.AnkiQt, path: str) -> None: + if askUser( + tr.importing_this_will_delete_your_existing_collection(), + msgfunc=QMessageBox.warning, + defaultno=True, + ): + import_collection_package(mw, path) + + +def import_collection_package(mw: aqt.main.AnkiQt, file: str) -> None: + def on_success() -> None: + mw.loadCollection() + tooltip(tr.importing_importing_complete()) + + def on_failure(err: Exception) -> None: + mw.loadCollection() + if not isinstance(err, Interrupted): + showWarning(str(err)) + + QueryOp( + parent=mw, + op=lambda _: mw.create_backup_now(), + success=lambda _: mw.unloadCollection( + lambda: import_collection_package_op(mw, file, on_success) + .failure(on_failure) + .run_in_background() + ), + ).with_progress().run_in_background() + + +def import_collection_package_op( + mw: aqt.main.AnkiQt, path: str, success: Callable[[], None] +) -> QueryOp[None]: + def op(_: Collection) -> None: + col_path = mw.pm.collectionPath() + media_folder = os.path.join(mw.pm.profileFolder(), "collection.media") + media_db = os.path.join(mw.pm.profileFolder(), "collection.media.db2") + mw.backend.import_collection_package( + col_path=col_path, + backup_path=path, + media_folder=media_folder, + media_db=media_db, + ) + + return QueryOp(parent=mw, op=op, success=lambda _: success()).with_backend_progress( + import_progress_update + ) + + +def import_anki_package(mw: aqt.main.AnkiQt, path: str) -> None: + CollectionOp( + parent=mw, + op=lambda col: col.import_anki_package(path), + ).with_backend_progress(import_progress_update).success( + show_import_log + ).run_in_background() + + +def show_import_log(log_with_changes: ImportLogWithChanges) -> None: + showText(stringify_log(log_with_changes.log), plain_text_edit=True) + + +def stringify_log(log: ImportLogWithChanges.Log) -> str: + total = len(log.conflicting) + len(log.updated) + len(log.new) + len(log.duplicate) + return "\n".join( + chain( + (tr.importing_notes_found_in_file(val=total),), + ( + template_string(val=len(row)) + for (row, template_string) in ( + (log.conflicting, tr.importing_notes_that_could_not_be_imported), + (log.updated, tr.importing_notes_updated_as_file_had_newer), + (log.new, tr.importing_notes_added_from_file), + (log.duplicate, tr.importing_notes_skipped_as_theyre_already_in), + ) + if row + ), + ("",), + *( + [f"[{action}] {', '.join(note.fields)}" for note in rows] + for (rows, action) in ( + (log.conflicting, tr.importing_skipped()), + (log.updated, tr.importing_updated()), + (log.new, tr.adding_added()), + (log.duplicate, tr.importing_identical()), + ) + ), + ) + ) + + +def import_progress_update(progress: Progress, update: ProgressUpdate) -> None: + if not progress.HasField("importing"): + return + update.label = progress.importing + if update.user_wants_abort: + update.abort = True diff --git a/qt/aqt/importing.py b/qt/aqt/importing.py index 0168d2c9c..8d2225001 100644 --- a/qt/aqt/importing.py +++ b/qt/aqt/importing.py @@ -11,11 +11,10 @@ import anki.importing as importing import aqt.deckchooser import aqt.forms import aqt.modelchooser -from anki.errors import Interrupted -from anki.importing.anki2 import V2ImportIntoV1 +from anki.importing.anki2 import MediaMapInvalid, V2ImportIntoV1 from anki.importing.apkg import AnkiPackageImporter +from aqt.import_export.importing import import_collection_package from aqt.main import AnkiQt, gui_hooks -from aqt.operations import QueryOp from aqt.qt import * from aqt.utils import ( HelpPage, @@ -389,6 +388,10 @@ def importFile(mw: AnkiQt, file: str) -> None: future.result() except zipfile.BadZipfile: showWarning(invalidZipMsg()) + except MediaMapInvalid: + showWarning( + "Unable to read file. It probably requires a newer version of Anki to import." + ) except V2ImportIntoV1: showWarning( """\ @@ -434,78 +437,11 @@ def setupApkgImport(mw: AnkiQt, importer: AnkiPackageImporter) -> bool: if not full: # adding return True - if not askUser( + if askUser( tr.importing_this_will_delete_your_existing_collection(), msgfunc=QMessageBox.warning, defaultno=True, ): - return False + import_collection_package(mw, importer.file) - full_apkg_import(mw, importer.file) return False - - -def full_apkg_import(mw: AnkiQt, file: str) -> None: - def on_done(success: bool) -> None: - mw.loadCollection() - if success: - tooltip(tr.importing_importing_complete()) - - def after_backup(created: bool) -> None: - mw.unloadCollection(lambda: replace_with_apkg(mw, file, on_done)) - - QueryOp( - parent=mw, op=lambda _: mw.create_backup_now(), success=after_backup - ).with_progress().run_in_background() - - -def replace_with_apkg( - mw: AnkiQt, filename: str, callback: Callable[[bool], None] -) -> None: - """Tries to replace the provided collection with the provided backup, - then calls the callback. True if success. - """ - if not (dialog := mw.progress.start(immediate=True)): - print("No progress dialog during import; aborting will not work") - timer = QTimer() - timer.setSingleShot(False) - timer.setInterval(100) - - def on_progress() -> None: - progress = mw.backend.latest_progress() - if not progress.HasField("importing"): - return - label = progress.importing - - try: - if dialog.wantCancel: - mw.backend.set_wants_abort() - except AttributeError: - # dialog may not be active - pass - - mw.taskman.run_on_main(lambda: mw.progress.update(label=label)) - - def do_import() -> None: - col_path = mw.pm.collectionPath() - media_folder = os.path.join(mw.pm.profileFolder(), "collection.media") - mw.backend.import_collection_package( - col_path=col_path, backup_path=filename, media_folder=media_folder - ) - - def on_done(future: Future) -> None: - mw.progress.finish() - timer.deleteLater() - - try: - future.result() - except Exception as error: - if not isinstance(error, Interrupted): - showWarning(str(error)) - callback(False) - else: - callback(True) - - qconnect(timer.timeout, on_progress) - timer.start() - mw.taskman.run_in_background(do_import, on_done) diff --git a/qt/aqt/main.py b/qt/aqt/main.py index e3b1c5a2c..3d65e7b3e 100644 --- a/qt/aqt/main.py +++ b/qt/aqt/main.py @@ -47,6 +47,8 @@ from aqt.addons import DownloadLogEntry, check_and_prompt_for_updates, show_log_ from aqt.dbcheck import check_db from aqt.emptycards import show_empty_cards from aqt.flags import FlagManager +from aqt.import_export.exporting import ExportDialog +from aqt.import_export.importing import import_collection_package_op, import_file from aqt.legacy import install_pylib_legacy from aqt.mediacheck import check_media_db from aqt.mediasync import MediaSyncer @@ -402,15 +404,12 @@ class AnkiQt(QMainWindow): ) def _openBackup(self, path: str) -> None: - def on_done(success: bool) -> None: - if success: - self.onOpenProfile(callback=lambda: self.col.mod_schema(check=False)) - - import aqt.importing - self.restoring_backup = True showInfo(tr.qt_misc_automatic_syncing_and_backups_have_been()) - aqt.importing.replace_with_apkg(self, path, on_done) + + import_collection_package_op( + self, path, success=self.onOpenProfile + ).run_in_background() def _on_downgrade(self) -> None: self.progress.start() @@ -1183,12 +1182,18 @@ title="{}" {}>{}""".format( def onImport(self) -> None: import aqt.importing - aqt.importing.onImport(self) + if self.pm.new_import_export(): + import_file(self) + else: + aqt.importing.onImport(self) def onExport(self, did: DeckId | None = None) -> None: import aqt.exporting - aqt.exporting.ExportDialog(self, did=did) + if self.pm.new_import_export(): + ExportDialog(self, did=did) + else: + aqt.exporting.ExportDialog(self, did=did) # Installing add-ons from CLI / mimetype handler ########################################################################## diff --git a/qt/aqt/operations/__init__.py b/qt/aqt/operations/__init__.py index 7d7ad8b91..761c6e7ab 100644 --- a/qt/aqt/operations/__init__.py +++ b/qt/aqt/operations/__init__.py @@ -8,16 +8,19 @@ from typing import Any, Callable, Generic, Protocol, TypeVar, Union import aqt import aqt.gui_hooks +import aqt.main from anki.collection import ( Collection, + ImportLogWithChanges, OpChanges, OpChangesAfterUndo, OpChangesWithCount, OpChangesWithId, + Progress, ) from aqt.errors import show_exception +from aqt.progress import ProgressUpdate from aqt.qt import QWidget -from aqt.utils import showWarning class HasChangesProperty(Protocol): @@ -34,6 +37,7 @@ ResultWithChanges = TypeVar( OpChangesWithCount, OpChangesWithId, OpChangesAfterUndo, + ImportLogWithChanges, HasChangesProperty, ], ) @@ -65,6 +69,7 @@ class CollectionOp(Generic[ResultWithChanges]): _success: Callable[[ResultWithChanges], Any] | None = None _failure: Callable[[Exception], Any] | None = None + _progress_update: Callable[[Progress, ProgressUpdate], None] | None = None def __init__(self, parent: QWidget, op: Callable[[Collection], ResultWithChanges]): self._parent = parent @@ -82,6 +87,12 @@ class CollectionOp(Generic[ResultWithChanges]): self._failure = failure return self + def with_backend_progress( + self, progress_update: Callable[[Progress, ProgressUpdate], None] | None + ) -> CollectionOp[ResultWithChanges]: + self._progress_update = progress_update + return self + def run_in_background(self, *, initiator: object | None = None) -> None: from aqt import mw @@ -113,12 +124,29 @@ class CollectionOp(Generic[ResultWithChanges]): if self._success: self._success(result) finally: - mw.update_undo_actions() - mw.autosave() - # fire change hooks - self._fire_change_hooks_after_op_performed(result, initiator) + self._finish_op(mw, result, initiator) - mw.taskman.with_progress(wrapped_op, wrapped_done) + self._run(mw, wrapped_op, wrapped_done) + + def _run( + self, + mw: aqt.main.AnkiQt, + op: Callable[[], ResultWithChanges], + on_done: Callable[[Future], None], + ) -> None: + if self._progress_update: + mw.taskman.with_backend_progress( + op, self._progress_update, on_done=on_done, parent=self._parent + ) + else: + mw.taskman.with_progress(op, on_done, parent=self._parent) + + def _finish_op( + self, mw: aqt.main.AnkiQt, result: ResultWithChanges, initiator: object | None + ) -> None: + mw.update_undo_actions() + mw.autosave() + self._fire_change_hooks_after_op_performed(result, initiator) def _fire_change_hooks_after_op_performed( self, @@ -168,6 +196,7 @@ class QueryOp(Generic[T]): _failure: Callable[[Exception], Any] | None = None _progress: bool | str = False + _progress_update: Callable[[Progress, ProgressUpdate], None] | None = None def __init__( self, @@ -192,6 +221,12 @@ class QueryOp(Generic[T]): self._progress = label or True return self + def with_backend_progress( + self, progress_update: Callable[[Progress, ProgressUpdate], None] | None + ) -> QueryOp[T]: + self._progress_update = progress_update + return self + def run_in_background(self) -> None: from aqt import mw @@ -201,26 +236,11 @@ class QueryOp(Generic[T]): def wrapped_op() -> T: assert mw - if self._progress: - label: str | None - if isinstance(self._progress, str): - label = self._progress - else: - label = None - - def start_progress() -> None: - assert mw - mw.progress.start(label=label) - - mw.taskman.run_on_main(start_progress) return self._op(mw.col) def wrapped_done(future: Future) -> None: assert mw - if self._progress: - mw.progress.finish() - mw._decrease_background_ops() # did something go wrong? if exception := future.exception(): @@ -228,7 +248,7 @@ class QueryOp(Generic[T]): if self._failure: self._failure(exception) else: - showWarning(str(exception), self._parent) + show_exception(parent=self._parent, exception=exception) return else: # BaseException like SystemExit; rethrow it @@ -236,4 +256,24 @@ class QueryOp(Generic[T]): self._success(future.result()) - mw.taskman.run_in_background(wrapped_op, wrapped_done) + self._run(mw, wrapped_op, wrapped_done) + + def _run( + self, + mw: aqt.main.AnkiQt, + op: Callable[[], T], + on_done: Callable[[Future], None], + ) -> None: + label = self._progress if isinstance(self._progress, str) else None + if self._progress_update: + mw.taskman.with_backend_progress( + op, + self._progress_update, + on_done=on_done, + start_label=label, + parent=self._parent, + ) + elif self._progress: + mw.taskman.with_progress(op, on_done, label=label, parent=self._parent) + else: + mw.taskman.run_in_background(op, on_done) diff --git a/qt/aqt/profiles.py b/qt/aqt/profiles.py index 1d8ce51f5..6d00fb6a3 100644 --- a/qt/aqt/profiles.py +++ b/qt/aqt/profiles.py @@ -538,6 +538,12 @@ create table if not exists profiles def dark_mode_widgets(self) -> bool: return self.meta.get("dark_mode_widgets", False) + def new_import_export(self) -> bool: + return self.meta.get("new_import_export", False) + + def set_new_import_export(self, enabled: bool) -> None: + self.meta["new_import_export"] = enabled + # Profile-specific ###################################################################### diff --git a/qt/aqt/progress.py b/qt/aqt/progress.py index 3aea4f652..3a1b355a9 100644 --- a/qt/aqt/progress.py +++ b/qt/aqt/progress.py @@ -3,9 +3,11 @@ from __future__ import annotations import time +from dataclasses import dataclass import aqt.forms from anki._legacy import print_deprecation_warning +from anki.collection import Progress from aqt.qt import * from aqt.utils import disable_help_button, tr @@ -23,6 +25,7 @@ class ProgressManager: self._busy_cursor_timer: QTimer | None = None self._win: ProgressDialog | None = None self._levels = 0 + self._backend_timer: QTimer | None = None # Safer timers ########################################################################## @@ -166,6 +169,34 @@ class ProgressManager: qconnect(self._show_timer.timeout, self._on_show_timer) return self._win + def start_with_backend_updates( + self, + progress_update: Callable[[Progress, ProgressUpdate], None], + start_label: str | None = None, + parent: QWidget | None = None, + ) -> None: + self._backend_timer = QTimer() + self._backend_timer.setSingleShot(False) + self._backend_timer.setInterval(100) + + if not (dialog := self.start(immediate=True, label=start_label, parent=parent)): + print("Progress dialog already running; aborting will not work") + + def on_progress() -> None: + assert self.mw + + user_wants_abort = dialog and dialog.wantCancel or False + update = ProgressUpdate(user_wants_abort=user_wants_abort) + progress = self.mw.backend.latest_progress() + progress_update(progress, update) + if update.abort: + self.mw.backend.set_wants_abort() + if update.has_update(): + self.update(label=update.label, value=update.value, max=update.max) + + qconnect(self._backend_timer.timeout, on_progress) + self._backend_timer.start() + def update( self, label: str | None = None, @@ -204,6 +235,9 @@ class ProgressManager: if self._show_timer: self._show_timer.stop() self._show_timer = None + if self._backend_timer: + self._backend_timer.deleteLater() + self._backend_timer = None def clear(self) -> None: "Restore the interface after an error." @@ -295,3 +329,15 @@ class ProgressDialog(QDialog): if evt.key() == Qt.Key.Key_Escape: evt.ignore() self.wantCancel = True + + +@dataclass +class ProgressUpdate: + label: str | None = None + value: int | None = None + max: int | None = None + user_wants_abort: bool = False + abort: bool = False + + def has_update(self) -> bool: + return self.label is not None or self.value is not None or self.max is not None diff --git a/qt/aqt/taskman.py b/qt/aqt/taskman.py index b763f96d4..283153ee1 100644 --- a/qt/aqt/taskman.py +++ b/qt/aqt/taskman.py @@ -15,6 +15,8 @@ from threading import Lock from typing import Any, Callable import aqt +from anki.collection import Progress +from aqt.progress import ProgressUpdate from aqt.qt import * Closure = Callable[[], None] @@ -89,6 +91,27 @@ class TaskManager(QObject): self.run_in_background(task, wrapped_done) + def with_backend_progress( + self, + task: Callable, + progress_update: Callable[[Progress, ProgressUpdate], None], + on_done: Callable[[Future], None] | None = None, + parent: QWidget | None = None, + start_label: str | None = None, + ) -> None: + self.mw.progress.start_with_backend_updates( + progress_update, + parent=parent, + start_label=start_label, + ) + + def wrapped_done(fut: Future) -> None: + self.mw.progress.finish() + if on_done: + on_done(fut) + + self.run_in_background(task, wrapped_done) + def _on_closures_pending(self) -> None: """Run any pending closures. This runs in the main thread.""" with self._closures_lock: diff --git a/rslib/src/backend/generic.rs b/rslib/src/backend/generic.rs index 6813c141a..b86751bcb 100644 --- a/rslib/src/backend/generic.rs +++ b/rslib/src/backend/generic.rs @@ -69,6 +69,12 @@ impl From for NoteId { } } +impl From for pb::NoteId { + fn from(nid: NoteId) -> Self { + pb::NoteId { nid: nid.0 } + } +} + impl From for NotetypeId { fn from(ntid: pb::NotetypeId) -> Self { NotetypeId(ntid.ntid) diff --git a/rslib/src/backend/import_export.rs b/rslib/src/backend/import_export.rs index 47d99a456..457e40e08 100644 --- a/rslib/src/backend/import_export.rs +++ b/rslib/src/backend/import_export.rs @@ -1,12 +1,18 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::path::Path; + use super::{progress::Progress, Backend}; pub(super) use crate::backend_proto::importexport_service::Service as ImportExportService; use crate::{ - backend_proto::{self as pb}, - import_export::{package::import_colpkg, ImportProgress}, + backend_proto::{self as pb, export_anki_package_request::Selector}, + import_export::{ + package::{import_colpkg, NoteLog}, + ImportProgress, + }, prelude::*, + search::SearchNode, }; impl ImportExportService for Backend { @@ -38,30 +44,68 @@ impl ImportExportService for Backend { import_colpkg( &input.backup_path, &input.col_path, - &input.media_folder, + Path::new(&input.media_folder), + Path::new(&input.media_db), self.import_progress_fn(), + &self.log, ) .map(Into::into) } + + fn import_anki_package( + &self, + input: pb::ImportAnkiPackageRequest, + ) -> Result { + self.with_col(|col| col.import_apkg(&input.package_path, self.import_progress_fn())) + .map(Into::into) + } + + fn export_anki_package(&self, input: pb::ExportAnkiPackageRequest) -> Result { + let selector = input + .selector + .ok_or_else(|| AnkiError::invalid_input("missing oneof"))?; + self.with_col(|col| { + col.export_apkg( + &input.out_path, + SearchNode::from_selector(selector), + input.with_scheduling, + input.with_media, + input.legacy, + None, + self.export_progress_fn(), + ) + }) + .map(Into::into) + } } -impl Backend { - fn import_progress_fn(&self) -> impl FnMut(ImportProgress) -> Result<()> { - let mut handler = self.new_progress_handler(); - move |progress| { - let throttle = matches!(progress, ImportProgress::Media(_)); - if handler.update(Progress::Import(progress), throttle) { - Ok(()) - } else { - Err(AnkiError::Interrupted) - } - } - } - - fn export_progress_fn(&self) -> impl FnMut(usize) { - let mut handler = self.new_progress_handler(); - move |media_files| { - handler.update(Progress::Export(media_files), true); +impl SearchNode { + fn from_selector(selector: Selector) -> Self { + match selector { + Selector::WholeCollection(_) => Self::WholeCollection, + Selector::DeckId(did) => Self::from_deck_id(did, true), + Selector::NoteIds(nids) => Self::from_note_ids(nids.note_ids), + } + } +} + +impl Backend { + fn import_progress_fn(&self) -> impl FnMut(ImportProgress, bool) -> bool { + let mut handler = self.new_progress_handler(); + move |progress, throttle| handler.update(Progress::Import(progress), throttle) + } + + fn export_progress_fn(&self) -> impl FnMut(usize, bool) -> bool { + let mut handler = self.new_progress_handler(); + move |progress, throttle| handler.update(Progress::Export(progress), throttle) + } +} + +impl From> for pb::ImportAnkiPackageResponse { + fn from(output: OpOutput) -> Self { + Self { + changes: Some(output.changes.into()), + log: Some(output.output), } } } diff --git a/rslib/src/backend/progress.rs b/rslib/src/backend/progress.rs index a21bede23..733d06a67 100644 --- a/rslib/src/backend/progress.rs +++ b/rslib/src/backend/progress.rs @@ -108,8 +108,10 @@ pub(super) fn progress_to_proto(progress: Option, tr: &I18n) -> pb::Pr } Progress::Import(progress) => pb::progress::Value::Importing( match progress { - ImportProgress::Collection => tr.importing_importing_collection(), + ImportProgress::File => tr.importing_importing_file(), ImportProgress::Media(n) => tr.importing_processed_media_file(n), + ImportProgress::MediaCheck(n) => tr.media_check_checked(n), + ImportProgress::Notes(n) => tr.importing_processed_notes(n), } .into(), ), diff --git a/rslib/src/card/undo.rs b/rslib/src/card/undo.rs index c0bb0a035..e83f096f7 100644 --- a/rslib/src/card/undo.rs +++ b/rslib/src/card/undo.rs @@ -35,6 +35,14 @@ impl Collection { Ok(()) } + pub(crate) fn add_card_if_unique_undoable(&mut self, card: &Card) -> Result { + let added = self.storage.add_card_if_unique(card)?; + if added { + self.save_undo(UndoableCardChange::Added(Box::new(card.clone()))); + } + Ok(added) + } + pub(super) fn update_card_undoable(&mut self, card: &mut Card, original: Card) -> Result<()> { if card.id.0 == 0 { return Err(AnkiError::invalid_input("card id not set")); diff --git a/rslib/src/deckconfig/undo.rs b/rslib/src/deckconfig/undo.rs index cf14c8586..0a0ed7a2d 100644 --- a/rslib/src/deckconfig/undo.rs +++ b/rslib/src/deckconfig/undo.rs @@ -44,6 +44,13 @@ impl Collection { Ok(()) } + pub(crate) fn add_deck_config_if_unique_undoable(&mut self, config: &DeckConfig) -> Result<()> { + if self.storage.add_deck_conf_if_unique(config)? { + self.save_undo(UndoableDeckConfigChange::Added(Box::new(config.clone()))); + } + Ok(()) + } + pub(super) fn update_deck_config_undoable( &mut self, config: &DeckConfig, diff --git a/rslib/src/decks/addupdate.rs b/rslib/src/decks/addupdate.rs index cf2975478..159c57abd 100644 --- a/rslib/src/decks/addupdate.rs +++ b/rslib/src/decks/addupdate.rs @@ -148,7 +148,7 @@ impl Collection { Ok(()) } - fn first_existing_parent( + pub(crate) fn first_existing_parent( &self, machine_name: &str, recursion_level: usize, diff --git a/rslib/src/import_export/gather.rs b/rslib/src/import_export/gather.rs new file mode 100644 index 000000000..af3b71006 --- /dev/null +++ b/rslib/src/import_export/gather.rs @@ -0,0 +1,222 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::collections::{HashMap, HashSet}; + +use itertools::Itertools; + +use crate::{ + decks::immediate_parent_name, + io::filename_is_safe, + latex::extract_latex, + prelude::*, + revlog::RevlogEntry, + text::{extract_media_refs, extract_underscored_css_imports, extract_underscored_references}, +}; + +#[derive(Debug, Default)] +pub(super) struct ExchangeData { + pub(super) decks: Vec, + pub(super) notes: Vec, + pub(super) cards: Vec, + pub(super) notetypes: Vec, + pub(super) revlog: Vec, + pub(super) deck_configs: Vec, + pub(super) media_filenames: HashSet, + pub(super) days_elapsed: u32, + pub(super) creation_utc_offset: Option, +} + +impl ExchangeData { + pub(super) fn gather_data( + &mut self, + col: &mut Collection, + search: impl TryIntoSearch, + with_scheduling: bool, + ) -> Result<()> { + self.days_elapsed = col.timing_today()?.days_elapsed; + self.creation_utc_offset = col.get_creation_utc_offset(); + self.notes = col.gather_notes(search)?; + self.cards = col.gather_cards()?; + self.decks = col.gather_decks()?; + self.notetypes = col.gather_notetypes()?; + + if with_scheduling { + self.revlog = col.gather_revlog()?; + self.deck_configs = col.gather_deck_configs(&self.decks)?; + } else { + self.remove_scheduling_information(col); + }; + + col.storage.clear_searched_notes_table()?; + col.storage.clear_searched_cards_table() + } + + pub(super) fn gather_media_names(&mut self) { + let mut inserter = |name: String| { + if filename_is_safe(&name) { + self.media_filenames.insert(name); + } + }; + let svg_getter = svg_getter(&self.notetypes); + for note in self.notes.iter() { + gather_media_names_from_note(note, &mut inserter, &svg_getter); + } + for notetype in self.notetypes.iter() { + gather_media_names_from_notetype(notetype, &mut inserter); + } + } + + fn remove_scheduling_information(&mut self, col: &Collection) { + self.remove_system_tags(); + self.reset_deck_config_ids(); + self.reset_cards(col); + } + + fn remove_system_tags(&mut self) { + const SYSTEM_TAGS: [&str; 2] = ["marked", "leech"]; + for note in self.notes.iter_mut() { + note.tags = std::mem::take(&mut note.tags) + .into_iter() + .filter(|tag| !SYSTEM_TAGS.iter().any(|s| tag.eq_ignore_ascii_case(s))) + .collect(); + } + } + + fn reset_deck_config_ids(&mut self) { + for deck in self.decks.iter_mut() { + if let Ok(normal_mut) = deck.normal_mut() { + normal_mut.config_id = 1; + } else { + // filtered decks are reset at import time for legacy reasons + } + } + } + + fn reset_cards(&mut self, col: &Collection) { + let mut position = col.get_next_card_position(); + for card in self.cards.iter_mut() { + // schedule_as_new() removes cards from filtered decks, but we want to + // leave cards in their current deck, which gets converted to a regular + // deck on import + let deck_id = card.deck_id; + if card.schedule_as_new(position, true, true) { + position += 1; + } + card.flags = 0; + card.deck_id = deck_id; + } + } +} + +fn gather_media_names_from_note( + note: &Note, + inserter: &mut impl FnMut(String), + svg_getter: &impl Fn(NotetypeId) -> bool, +) { + for field in note.fields() { + for media_ref in extract_media_refs(field) { + inserter(media_ref.fname_decoded.to_string()); + } + + for latex in extract_latex(field, svg_getter(note.notetype_id)).1 { + inserter(latex.fname); + } + } +} + +fn gather_media_names_from_notetype(notetype: &Notetype, inserter: &mut impl FnMut(String)) { + for name in extract_underscored_css_imports(¬etype.config.css) { + inserter(name.to_string()); + } + for template in ¬etype.templates { + for template_side in [&template.config.q_format, &template.config.a_format] { + for name in extract_underscored_references(template_side) { + inserter(name.to_string()); + } + } + } +} + +fn svg_getter(notetypes: &[Notetype]) -> impl Fn(NotetypeId) -> bool { + let svg_map: HashMap = notetypes + .iter() + .map(|nt| (nt.id, nt.config.latex_svg)) + .collect(); + move |nt_id| svg_map.get(&nt_id).copied().unwrap_or_default() +} + +impl Collection { + fn gather_notes(&mut self, search: impl TryIntoSearch) -> Result> { + self.search_notes_into_table(search)?; + self.storage.all_searched_notes() + } + + fn gather_cards(&mut self) -> Result> { + self.storage.search_cards_of_notes_into_table()?; + self.storage.all_searched_cards() + } + + fn gather_decks(&mut self) -> Result> { + let decks = self.storage.get_decks_for_search_cards()?; + let parents = self.get_parent_decks(&decks)?; + Ok(decks + .into_iter() + .filter(|deck| deck.id != DeckId(1)) + .chain(parents) + .collect()) + } + + fn get_parent_decks(&mut self, decks: &[Deck]) -> Result> { + let mut parent_names: HashSet = decks + .iter() + .map(|deck| deck.name.as_native_str().to_owned()) + .collect(); + let mut parents = Vec::new(); + for deck in decks { + self.add_parent_decks(deck.name.as_native_str(), &mut parent_names, &mut parents)?; + } + Ok(parents) + } + + fn add_parent_decks( + &mut self, + name: &str, + parent_names: &mut HashSet, + parents: &mut Vec, + ) -> Result<()> { + if let Some(parent_name) = immediate_parent_name(name) { + if parent_names.insert(parent_name.to_owned()) { + parents.push( + self.storage + .get_deck_by_name(parent_name)? + .ok_or(AnkiError::DatabaseCheckRequired)?, + ); + self.add_parent_decks(parent_name, parent_names, parents)?; + } + } + Ok(()) + } + + fn gather_notetypes(&mut self) -> Result> { + self.storage.get_notetypes_for_search_notes() + } + + fn gather_revlog(&mut self) -> Result> { + self.storage.get_revlog_entries_for_searched_cards() + } + + fn gather_deck_configs(&mut self, decks: &[Deck]) -> Result> { + decks + .iter() + .filter_map(|deck| deck.config_id()) + .unique() + .filter(|config_id| *config_id != DeckConfigId(1)) + .map(|config_id| { + self.storage + .get_deck_config(config_id)? + .ok_or(AnkiError::NotFound) + }) + .collect() + } +} diff --git a/rslib/src/import_export/insert.rs b/rslib/src/import_export/insert.rs new file mode 100644 index 000000000..21ab1be1c --- /dev/null +++ b/rslib/src/import_export/insert.rs @@ -0,0 +1,62 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use super::gather::ExchangeData; +use crate::{prelude::*, revlog::RevlogEntry}; + +impl Collection { + pub(super) fn insert_data(&mut self, data: &ExchangeData) -> Result<()> { + self.transact_no_undo(|col| { + col.insert_decks(&data.decks)?; + col.insert_notes(&data.notes)?; + col.insert_cards(&data.cards)?; + col.insert_notetypes(&data.notetypes)?; + col.insert_revlog(&data.revlog)?; + col.insert_deck_configs(&data.deck_configs) + }) + } + + fn insert_decks(&self, decks: &[Deck]) -> Result<()> { + for deck in decks { + self.storage.add_or_update_deck_with_existing_id(deck)?; + } + Ok(()) + } + + fn insert_notes(&self, notes: &[Note]) -> Result<()> { + for note in notes { + self.storage.add_or_update_note(note)?; + } + Ok(()) + } + + fn insert_cards(&self, cards: &[Card]) -> Result<()> { + for card in cards { + self.storage.add_or_update_card(card)?; + } + Ok(()) + } + + fn insert_notetypes(&self, notetypes: &[Notetype]) -> Result<()> { + for notetype in notetypes { + self.storage + .add_or_update_notetype_with_existing_id(notetype)?; + } + Ok(()) + } + + fn insert_revlog(&self, revlog: &[RevlogEntry]) -> Result<()> { + for entry in revlog { + self.storage.add_revlog_entry(entry, false)?; + } + Ok(()) + } + + fn insert_deck_configs(&self, configs: &[DeckConfig]) -> Result<()> { + for config in configs { + self.storage + .add_or_update_deck_config_with_existing_id(config)?; + } + Ok(()) + } +} diff --git a/rslib/src/import_export/mod.rs b/rslib/src/import_export/mod.rs index 994d93101..b1e2944af 100644 --- a/rslib/src/import_export/mod.rs +++ b/rslib/src/import_export/mod.rs @@ -1,10 +1,87 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +mod gather; +mod insert; pub mod package; +use std::marker::PhantomData; + +use crate::prelude::*; + #[derive(Debug, Clone, Copy, PartialEq)] pub enum ImportProgress { - Collection, + File, Media(usize), + MediaCheck(usize), + Notes(usize), +} + +/// Wrapper around a progress function, usually passed by the [crate::backend::Backend], +/// to make repeated calls more ergonomic. +pub(crate) struct IncrementableProgress

(Box bool>); + +impl

IncrementableProgress

{ + /// `progress_fn: (progress, throttle) -> should_continue` + pub(crate) fn new(progress_fn: impl 'static + FnMut(P, bool) -> bool) -> Self { + Self(Box::new(progress_fn)) + } + + /// Returns an [Incrementor] with an `increment()` function for use in loops. + pub(crate) fn incrementor<'inc, 'progress: 'inc, 'map: 'inc>( + &'progress mut self, + mut count_map: impl 'map + FnMut(usize) -> P, + ) -> Incrementor<'inc, impl FnMut(usize) -> Result<()> + 'inc> { + Incrementor::new(move |u| self.update(count_map(u), true)) + } + + /// Manually triggers an update. + /// Returns [AnkiError::Interrupted] if the operation should be cancelled. + pub(crate) fn call(&mut self, progress: P) -> Result<()> { + self.update(progress, false) + } + + fn update(&mut self, progress: P, throttle: bool) -> Result<()> { + if (self.0)(progress, throttle) { + Ok(()) + } else { + Err(AnkiError::Interrupted) + } + } + + /// Stopgap for returning a progress fn compliant with the media code. + pub(crate) fn media_db_fn( + &mut self, + count_map: impl 'static + Fn(usize) -> P, + ) -> Result bool + '_> { + Ok(move |count| (self.0)(count_map(count), true)) + } +} + +pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> { + update_fn: F, + count: usize, + update_interval: usize, + _phantom: PhantomData<&'f ()>, +} + +impl<'f, F: 'f + FnMut(usize) -> Result<()>> Incrementor<'f, F> { + fn new(update_fn: F) -> Self { + Self { + update_fn, + count: 0, + update_interval: 17, + _phantom: PhantomData, + } + } + + /// Increments the progress counter, periodically triggering an update. + /// Returns [AnkiError::Interrupted] if the operation should be cancelled. + pub(crate) fn increment(&mut self) -> Result<()> { + self.count += 1; + if self.count % self.update_interval != 0 { + return Ok(()); + } + (self.update_fn)(self.count) + } } diff --git a/rslib/src/import_export/package/apkg/export.rs b/rslib/src/import_export/package/apkg/export.rs new file mode 100644 index 000000000..9b797893d --- /dev/null +++ b/rslib/src/import_export/package/apkg/export.rs @@ -0,0 +1,106 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; + +use tempfile::NamedTempFile; + +use crate::{ + collection::CollectionBuilder, + import_export::{ + gather::ExchangeData, + package::{ + colpkg::export::{export_collection, MediaIter}, + Meta, + }, + IncrementableProgress, + }, + io::{atomic_rename, tempfile_in_parent_of}, + prelude::*, +}; + +impl Collection { + /// Returns number of exported notes. + #[allow(clippy::too_many_arguments)] + pub fn export_apkg( + &mut self, + out_path: impl AsRef, + search: impl TryIntoSearch, + with_scheduling: bool, + with_media: bool, + legacy: bool, + media_fn: Option) -> MediaIter>>, + progress_fn: impl 'static + FnMut(usize, bool) -> bool, + ) -> Result { + let mut progress = IncrementableProgress::new(progress_fn); + let temp_apkg = tempfile_in_parent_of(out_path.as_ref())?; + let mut temp_col = NamedTempFile::new()?; + let temp_col_path = temp_col + .path() + .to_str() + .ok_or_else(|| AnkiError::IoError("tempfile with non-unicode name".into()))?; + let meta = if legacy { + Meta::new_legacy() + } else { + Meta::new() + }; + let data = self.export_into_collection_file( + &meta, + temp_col_path, + search, + with_scheduling, + with_media, + )?; + + let media = if let Some(media_fn) = media_fn { + media_fn(data.media_filenames) + } else { + MediaIter::from_file_list(data.media_filenames, self.media_folder.clone()) + }; + let col_size = temp_col.as_file().metadata()?.len() as usize; + + export_collection( + meta, + temp_apkg.path(), + &mut temp_col, + col_size, + media, + &self.tr, + &mut progress, + )?; + atomic_rename(temp_apkg, out_path.as_ref(), true)?; + Ok(data.notes.len()) + } + + fn export_into_collection_file( + &mut self, + meta: &Meta, + path: &str, + search: impl TryIntoSearch, + with_scheduling: bool, + with_media: bool, + ) -> Result { + let mut data = ExchangeData::default(); + data.gather_data(self, search, with_scheduling)?; + if with_media { + data.gather_media_names(); + } + + let mut temp_col = Collection::new_minimal(path)?; + temp_col.insert_data(&data)?; + temp_col.set_creation_stamp(self.storage.creation_stamp()?)?; + temp_col.set_creation_utc_offset(data.creation_utc_offset)?; + temp_col.close(Some(meta.schema_version()))?; + + Ok(data) + } + + fn new_minimal(path: impl Into) -> Result { + let col = CollectionBuilder::new(path).build()?; + col.storage.db.execute_batch("DELETE FROM notetypes")?; + Ok(col) + } +} diff --git a/rslib/src/import_export/package/apkg/import/cards.rs b/rslib/src/import_export/package/apkg/import/cards.rs new file mode 100644 index 000000000..6a87510c1 --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/cards.rs @@ -0,0 +1,179 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + collections::{HashMap, HashSet}, + mem, +}; + +use super::Context; +use crate::{ + card::{CardQueue, CardType}, + config::SchedulerVersion, + prelude::*, + revlog::RevlogEntry, +}; + +type CardAsNidAndOrd = (NoteId, u16); + +struct CardContext<'a> { + target_col: &'a mut Collection, + usn: Usn, + + imported_notes: &'a HashMap, + remapped_decks: &'a HashMap, + + /// The number of days the source collection is ahead of the target collection + collection_delta: i32, + scheduler_version: SchedulerVersion, + existing_cards: HashSet, + existing_card_ids: HashSet, + + imported_cards: HashMap, +} + +impl<'c> CardContext<'c> { + fn new<'a: 'c>( + usn: Usn, + days_elapsed: u32, + target_col: &'a mut Collection, + imported_notes: &'a HashMap, + imported_decks: &'a HashMap, + ) -> Result { + let existing_cards = target_col.storage.all_cards_as_nid_and_ord()?; + let collection_delta = target_col.collection_delta(days_elapsed)?; + let scheduler_version = target_col.scheduler_info()?.version; + let existing_card_ids = target_col.storage.get_all_card_ids()?; + Ok(Self { + target_col, + usn, + imported_notes, + remapped_decks: imported_decks, + existing_cards, + collection_delta, + scheduler_version, + existing_card_ids, + imported_cards: HashMap::new(), + }) + } +} + +impl Collection { + /// How much `days_elapsed` is ahead of this collection. + fn collection_delta(&mut self, days_elapsed: u32) -> Result { + Ok(days_elapsed as i32 - self.timing_today()?.days_elapsed as i32) + } +} + +impl Context<'_> { + pub(super) fn import_cards_and_revlog( + &mut self, + imported_notes: &HashMap, + imported_decks: &HashMap, + ) -> Result<()> { + let mut ctx = CardContext::new( + self.usn, + self.data.days_elapsed, + self.target_col, + imported_notes, + imported_decks, + )?; + ctx.import_cards(mem::take(&mut self.data.cards))?; + ctx.import_revlog(mem::take(&mut self.data.revlog)) + } +} + +impl CardContext<'_> { + fn import_cards(&mut self, mut cards: Vec) -> Result<()> { + for card in &mut cards { + if self.map_to_imported_note(card) && !self.card_ordinal_already_exists(card) { + self.add_card(card)?; + } + // TODO: could update existing card + } + Ok(()) + } + + fn import_revlog(&mut self, revlog: Vec) -> Result<()> { + for mut entry in revlog { + if let Some(cid) = self.imported_cards.get(&entry.cid) { + entry.cid = *cid; + entry.usn = self.usn; + self.target_col.add_revlog_entry_if_unique_undoable(entry)?; + } + } + Ok(()) + } + + fn map_to_imported_note(&self, card: &mut Card) -> bool { + if let Some(nid) = self.imported_notes.get(&card.note_id) { + card.note_id = *nid; + true + } else { + false + } + } + + fn card_ordinal_already_exists(&self, card: &Card) -> bool { + self.existing_cards + .contains(&(card.note_id, card.template_idx)) + } + + fn add_card(&mut self, card: &mut Card) -> Result<()> { + card.usn = self.usn; + self.remap_deck_id(card); + card.shift_collection_relative_dates(self.collection_delta); + card.maybe_remove_from_filtered_deck(self.scheduler_version); + let old_id = self.uniquify_card_id(card); + + self.target_col.add_card_if_unique_undoable(card)?; + self.existing_card_ids.insert(card.id); + self.imported_cards.insert(old_id, card.id); + + Ok(()) + } + + fn uniquify_card_id(&mut self, card: &mut Card) -> CardId { + let original = card.id; + while self.existing_card_ids.contains(&card.id) { + card.id.0 += 999; + } + original + } + + fn remap_deck_id(&self, card: &mut Card) { + if let Some(did) = self.remapped_decks.get(&card.deck_id) { + card.deck_id = *did; + } + } +} + +impl Card { + /// `delta` is the number days the card's source collection is ahead of the + /// target collection. + fn shift_collection_relative_dates(&mut self, delta: i32) { + if self.due_in_days_since_collection_creation() { + self.due -= delta; + } + if self.original_due_in_days_since_collection_creation() && self.original_due != 0 { + self.original_due -= delta; + } + } + + fn due_in_days_since_collection_creation(&self) -> bool { + matches!(self.queue, CardQueue::Review | CardQueue::DayLearn) + || self.ctype == CardType::Review + } + + fn original_due_in_days_since_collection_creation(&self) -> bool { + self.ctype == CardType::Review + } + + fn maybe_remove_from_filtered_deck(&mut self, version: SchedulerVersion) { + if self.is_filtered() { + // instead of moving between decks, the deck is converted to a regular one + self.original_deck_id = self.deck_id; + self.remove_from_filtered_deck_restoring_queue(version); + } + } +} diff --git a/rslib/src/import_export/package/apkg/import/decks.rs b/rslib/src/import_export/package/apkg/import/decks.rs new file mode 100644 index 000000000..c9717963f --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/decks.rs @@ -0,0 +1,212 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{collections::HashMap, mem}; + +use super::Context; +use crate::{decks::NormalDeck, prelude::*}; + +struct DeckContext<'d> { + target_col: &'d mut Collection, + usn: Usn, + renamed_parents: Vec<(String, String)>, + imported_decks: HashMap, + unique_suffix: String, +} + +impl<'d> DeckContext<'d> { + fn new<'a: 'd>(target_col: &'a mut Collection, usn: Usn) -> Self { + Self { + target_col, + usn, + renamed_parents: Vec::new(), + imported_decks: HashMap::new(), + unique_suffix: TimestampSecs::now().to_string(), + } + } +} + +impl Context<'_> { + pub(super) fn import_decks_and_configs(&mut self) -> Result> { + let mut ctx = DeckContext::new(self.target_col, self.usn); + ctx.import_deck_configs(mem::take(&mut self.data.deck_configs))?; + ctx.import_decks(mem::take(&mut self.data.decks))?; + Ok(ctx.imported_decks) + } +} + +impl DeckContext<'_> { + fn import_deck_configs(&mut self, mut configs: Vec) -> Result<()> { + for config in &mut configs { + config.usn = self.usn; + self.target_col.add_deck_config_if_unique_undoable(config)?; + } + Ok(()) + } + + fn import_decks(&mut self, mut decks: Vec) -> Result<()> { + // ensure parents are seen before children + decks.sort_unstable_by_key(|deck| deck.level()); + for deck in &mut decks { + self.prepare_deck(deck); + self.import_deck(deck)?; + } + Ok(()) + } + + fn prepare_deck(&mut self, deck: &mut Deck) { + self.maybe_reparent(deck); + if deck.is_filtered() { + deck.kind = DeckKind::Normal(NormalDeck { + config_id: 1, + ..Default::default() + }); + } + } + + fn import_deck(&mut self, deck: &mut Deck) -> Result<()> { + if let Some(original) = self.get_deck_by_name(deck)? { + if original.is_filtered() { + self.uniquify_name(deck); + self.add_deck(deck) + } else { + self.update_deck(deck, original) + } + } else { + self.ensure_valid_first_existing_parent(deck)?; + self.add_deck(deck) + } + } + + fn maybe_reparent(&self, deck: &mut Deck) { + if let Some(new_name) = self.reparented_name(deck.name.as_native_str()) { + deck.name = NativeDeckName::from_native_str(new_name); + } + } + + fn reparented_name(&self, name: &str) -> Option { + self.renamed_parents + .iter() + .find_map(|(old_parent, new_parent)| { + name.starts_with(old_parent) + .then(|| name.replacen(old_parent, new_parent, 1)) + }) + } + + fn get_deck_by_name(&mut self, deck: &Deck) -> Result> { + self.target_col + .storage + .get_deck_by_name(deck.name.as_native_str()) + } + + fn uniquify_name(&mut self, deck: &mut Deck) { + let old_parent = format!("{}\x1f", deck.name.as_native_str()); + deck.uniquify_name(&self.unique_suffix); + let new_parent = format!("{}\x1f", deck.name.as_native_str()); + self.renamed_parents.push((old_parent, new_parent)); + } + + fn add_deck(&mut self, deck: &mut Deck) -> Result<()> { + let old_id = mem::take(&mut deck.id); + self.target_col.add_deck_inner(deck, self.usn)?; + self.imported_decks.insert(old_id, deck.id); + Ok(()) + } + + /// Caller must ensure decks are normal. + fn update_deck(&mut self, deck: &Deck, original: Deck) -> Result<()> { + let mut new_deck = original.clone(); + new_deck.normal_mut()?.update_with_other(deck.normal()?); + self.imported_decks.insert(deck.id, new_deck.id); + self.target_col + .update_deck_inner(&mut new_deck, original, self.usn) + } + + fn ensure_valid_first_existing_parent(&mut self, deck: &mut Deck) -> Result<()> { + if let Some(ancestor) = self + .target_col + .first_existing_parent(deck.name.as_native_str(), 0)? + { + if ancestor.is_filtered() { + self.add_unique_default_deck(ancestor.name.as_native_str())?; + self.maybe_reparent(deck); + } + } + Ok(()) + } + + fn add_unique_default_deck(&mut self, name: &str) -> Result<()> { + let mut deck = Deck::new_normal(); + deck.name = NativeDeckName::from_native_str(name); + self.uniquify_name(&mut deck); + self.target_col.add_deck_inner(&mut deck, self.usn) + } +} + +impl Deck { + fn uniquify_name(&mut self, suffix: &str) { + let new_name = format!("{} {}", self.name.as_native_str(), suffix); + self.name = NativeDeckName::from_native_str(new_name); + } + + fn level(&self) -> usize { + self.name.components().count() + } +} + +impl NormalDeck { + fn update_with_other(&mut self, other: &Self) { + if !other.description.is_empty() { + self.markdown_description = other.markdown_description; + self.description = other.description.clone(); + } + if other.config_id != 1 { + self.config_id = other.config_id; + } + } +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use super::*; + use crate::{collection::open_test_collection, tests::new_deck_with_machine_name}; + + #[test] + fn parents() { + let mut col = open_test_collection(); + + col.add_deck_with_machine_name("filtered", true); + col.add_deck_with_machine_name("PARENT", false); + + let mut ctx = DeckContext::new(&mut col, Usn(1)); + ctx.unique_suffix = "★".to_string(); + + let imports = vec![ + new_deck_with_machine_name("unknown parent\x1fchild", false), + new_deck_with_machine_name("filtered\x1fchild", false), + new_deck_with_machine_name("parent\x1fchild", false), + new_deck_with_machine_name("NEW PARENT\x1fchild", false), + new_deck_with_machine_name("new parent", false), + ]; + ctx.import_decks(imports).unwrap(); + let existing_decks: HashSet<_> = ctx + .target_col + .get_all_deck_names(true) + .unwrap() + .into_iter() + .map(|(_, name)| name) + .collect(); + + // missing parents get created + assert!(existing_decks.contains("unknown parent")); + // ... and uniquified if their existing counterparts are filtered + assert!(existing_decks.contains("filtered ★")); + assert!(existing_decks.contains("filtered ★::child")); + // the case of existing parents is matched + assert!(existing_decks.contains("PARENT::child")); + // the case of imported parents is matched, regardless of pass order + assert!(existing_decks.contains("new parent::child")); + } +} diff --git a/rslib/src/import_export/package/apkg/import/media.rs b/rslib/src/import_export/package/apkg/import/media.rs new file mode 100644 index 000000000..bffd28f2f --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/media.rs @@ -0,0 +1,134 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{collections::HashMap, fs::File, mem}; + +use zip::ZipArchive; + +use super::Context; +use crate::{ + import_export::{ + package::{ + media::{extract_media_entries, SafeMediaEntry}, + Meta, + }, + ImportProgress, IncrementableProgress, + }, + media::{ + files::{add_hash_suffix_to_file_stem, sha1_of_reader}, + MediaManager, + }, + prelude::*, +}; + +/// Map of source media files, that do not already exist in the target. +#[derive(Default)] +pub(super) struct MediaUseMap { + /// original, normalized filename → (refererenced on import material, + /// entry with possibly remapped filename) + checked: HashMap, + /// Static files (latex, underscored). Usage is not tracked, and if the name + /// already exists in the target, it is skipped regardless of content equality. + unchecked: Vec, +} + +impl Context<'_> { + pub(super) fn prepare_media(&mut self) -> Result { + let db_progress_fn = self.progress.media_db_fn(ImportProgress::MediaCheck)?; + let existing_sha1s = self.target_col.all_existing_sha1s(db_progress_fn)?; + prepare_media( + &self.meta, + &mut self.archive, + &existing_sha1s, + &mut self.progress, + ) + } + + pub(super) fn copy_media(&mut self, media_map: &mut MediaUseMap) -> Result<()> { + let mut incrementor = self.progress.incrementor(ImportProgress::Media); + for entry in media_map.used_entries() { + incrementor.increment()?; + entry.copy_from_archive(&mut self.archive, &self.target_col.media_folder)?; + } + Ok(()) + } +} + +impl Collection { + fn all_existing_sha1s( + &mut self, + progress_fn: impl FnMut(usize) -> bool, + ) -> Result> { + let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; + mgr.all_checksums(progress_fn, &self.log) + } +} + +fn prepare_media( + meta: &Meta, + archive: &mut ZipArchive, + existing_sha1s: &HashMap, + progress: &mut IncrementableProgress, +) -> Result { + let mut media_map = MediaUseMap::default(); + let mut incrementor = progress.incrementor(ImportProgress::MediaCheck); + + for mut entry in extract_media_entries(meta, archive)? { + incrementor.increment()?; + + if entry.is_static() { + if !existing_sha1s.contains_key(&entry.name) { + media_map.unchecked.push(entry); + } + } else if let Some(other_sha1) = existing_sha1s.get(&entry.name) { + entry.with_hash_from_archive(archive)?; + if entry.sha1 != *other_sha1 { + let original_name = entry.uniquify_name(); + media_map.add_checked(original_name, entry); + } + } else { + media_map.add_checked(entry.name.clone(), entry); + } + } + Ok(media_map) +} + +impl MediaUseMap { + pub(super) fn add_checked(&mut self, filename: impl Into, entry: SafeMediaEntry) { + self.checked.insert(filename.into(), (false, entry)); + } + + pub(super) fn use_entry(&mut self, filename: &str) -> Option<&SafeMediaEntry> { + self.checked.get_mut(filename).map(|(used, entry)| { + *used = true; + &*entry + }) + } + + pub(super) fn used_entries(&self) -> impl Iterator { + self.checked + .values() + .filter_map(|(used, entry)| used.then(|| entry)) + .chain(self.unchecked.iter()) + } +} + +impl SafeMediaEntry { + fn with_hash_from_archive(&mut self, archive: &mut ZipArchive) -> Result<()> { + if self.sha1 == [0; 20] { + let mut reader = self.fetch_file(archive)?; + self.sha1 = sha1_of_reader(&mut reader)?; + } + Ok(()) + } + + /// Requires sha1 to be set. Returns old file name. + fn uniquify_name(&mut self) -> String { + let new_name = add_hash_suffix_to_file_stem(&self.name, &self.sha1); + mem::replace(&mut self.name, new_name) + } + + fn is_static(&self) -> bool { + self.name.starts_with('_') || self.name.starts_with("latex-") + } +} diff --git a/rslib/src/import_export/package/apkg/import/mod.rs b/rslib/src/import_export/package/apkg/import/mod.rs new file mode 100644 index 000000000..810325b7e --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/mod.rs @@ -0,0 +1,137 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +mod cards; +mod decks; +mod media; +mod notes; + +use std::{fs::File, io, path::Path}; + +pub(crate) use notes::NoteMeta; +use rusqlite::OptionalExtension; +use tempfile::NamedTempFile; +use zip::ZipArchive; +use zstd::stream::copy_decode; + +use crate::{ + collection::CollectionBuilder, + import_export::{ + gather::ExchangeData, + package::{Meta, NoteLog}, + ImportProgress, IncrementableProgress, + }, + prelude::*, + search::SearchNode, +}; + +struct Context<'a> { + target_col: &'a mut Collection, + archive: ZipArchive, + meta: Meta, + data: ExchangeData, + usn: Usn, + progress: IncrementableProgress, +} + +impl Collection { + pub fn import_apkg( + &mut self, + path: impl AsRef, + progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + ) -> Result> { + let file = File::open(path)?; + let archive = ZipArchive::new(file)?; + + self.transact(Op::Import, |col| { + let mut ctx = Context::new(archive, col, progress_fn)?; + ctx.import() + }) + } +} + +impl<'a> Context<'a> { + fn new( + mut archive: ZipArchive, + target_col: &'a mut Collection, + progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + ) -> Result { + let progress = IncrementableProgress::new(progress_fn); + let meta = Meta::from_archive(&mut archive)?; + let data = ExchangeData::gather_from_archive( + &mut archive, + &meta, + SearchNode::WholeCollection, + true, + )?; + let usn = target_col.usn()?; + Ok(Self { + target_col, + archive, + meta, + data, + usn, + progress, + }) + } + + fn import(&mut self) -> Result { + let mut media_map = self.prepare_media()?; + self.progress.call(ImportProgress::File)?; + let note_imports = self.import_notes_and_notetypes(&mut media_map)?; + let imported_decks = self.import_decks_and_configs()?; + self.import_cards_and_revlog(¬e_imports.id_map, &imported_decks)?; + self.copy_media(&mut media_map)?; + Ok(note_imports.log) + } +} + +impl ExchangeData { + fn gather_from_archive( + archive: &mut ZipArchive, + meta: &Meta, + search: impl TryIntoSearch, + with_scheduling: bool, + ) -> Result { + let tempfile = collection_to_tempfile(meta, archive)?; + let mut col = CollectionBuilder::new(tempfile.path()).build()?; + col.maybe_upgrade_scheduler()?; + + let mut data = ExchangeData::default(); + data.gather_data(&mut col, search, with_scheduling)?; + + Ok(data) + } +} + +fn collection_to_tempfile(meta: &Meta, archive: &mut ZipArchive) -> Result { + let mut zip_file = archive.by_name(meta.collection_filename())?; + let mut tempfile = NamedTempFile::new()?; + if meta.zstd_compressed() { + copy_decode(zip_file, &mut tempfile) + } else { + io::copy(&mut zip_file, &mut tempfile).map(|_| ()) + } + .map_err(|err| AnkiError::file_io_error(err, tempfile.path()))?; + + Ok(tempfile) +} + +impl Collection { + fn maybe_upgrade_scheduler(&mut self) -> Result<()> { + if self.scheduling_included()? { + self.upgrade_to_v2_scheduler()?; + } + Ok(()) + } + + fn scheduling_included(&mut self) -> Result { + const SQL: &str = "SELECT 1 FROM cards WHERE queue != 0"; + Ok(self + .storage + .db + .query_row(SQL, [], |_| Ok(())) + .optional()? + .is_some()) + } +} diff --git a/rslib/src/import_export/package/apkg/import/notes.rs b/rslib/src/import_export/package/apkg/import/notes.rs new file mode 100644 index 000000000..4fb8431b9 --- /dev/null +++ b/rslib/src/import_export/package/apkg/import/notes.rs @@ -0,0 +1,459 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + mem, + sync::Arc, +}; + +use sha1::Sha1; + +use super::{media::MediaUseMap, Context}; +use crate::{ + import_export::{ + package::{media::safe_normalized_file_name, LogNote, NoteLog}, + ImportProgress, IncrementableProgress, + }, + prelude::*, + text::{ + newlines_to_spaces, replace_media_refs, strip_html_preserving_media_filenames, + truncate_to_char_boundary, CowMapping, + }, +}; + +struct NoteContext<'a> { + target_col: &'a mut Collection, + usn: Usn, + normalize_notes: bool, + remapped_notetypes: HashMap, + target_guids: HashMap, + target_ids: HashSet, + media_map: &'a mut MediaUseMap, + imports: NoteImports, +} + +#[derive(Debug, Default)] +pub(super) struct NoteImports { + pub(super) id_map: HashMap, + /// All notes from the source collection as [Vec]s of their fields, and grouped + /// by import result kind. + pub(super) log: NoteLog, +} + +impl NoteImports { + fn log_new(&mut self, note: Note, source_id: NoteId) { + self.id_map.insert(source_id, note.id); + self.log.new.push(note.into_log_note()); + } + + fn log_updated(&mut self, note: Note, source_id: NoteId) { + self.id_map.insert(source_id, note.id); + self.log.updated.push(note.into_log_note()); + } + + fn log_duplicate(&mut self, mut note: Note, target_id: NoteId) { + self.id_map.insert(note.id, target_id); + // id is for looking up note in *target* collection + note.id = target_id; + self.log.duplicate.push(note.into_log_note()); + } + + fn log_conflicting(&mut self, note: Note) { + self.log.conflicting.push(note.into_log_note()); + } +} + +impl Note { + fn into_log_note(self) -> LogNote { + LogNote { + id: Some(self.id.into()), + fields: self + .into_fields() + .into_iter() + .map(|field| { + let mut reduced = strip_html_preserving_media_filenames(&field) + .map_cow(newlines_to_spaces) + .get_owned() + .unwrap_or(field); + truncate_to_char_boundary(&mut reduced, 80); + reduced + }) + .collect(), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct NoteMeta { + id: NoteId, + mtime: TimestampSecs, + notetype_id: NotetypeId, +} + +impl NoteMeta { + pub(crate) fn new(id: NoteId, mtime: TimestampSecs, notetype_id: NotetypeId) -> Self { + Self { + id, + mtime, + notetype_id, + } + } +} + +impl Context<'_> { + pub(super) fn import_notes_and_notetypes( + &mut self, + media_map: &mut MediaUseMap, + ) -> Result { + let mut ctx = NoteContext::new(self.usn, self.target_col, media_map)?; + ctx.import_notetypes(mem::take(&mut self.data.notetypes))?; + ctx.import_notes(mem::take(&mut self.data.notes), &mut self.progress)?; + Ok(ctx.imports) + } +} + +impl<'n> NoteContext<'n> { + fn new<'a: 'n>( + usn: Usn, + target_col: &'a mut Collection, + media_map: &'a mut MediaUseMap, + ) -> Result { + let target_guids = target_col.storage.note_guid_map()?; + let normalize_notes = target_col.get_config_bool(BoolKey::NormalizeNoteText); + let target_ids = target_col.storage.get_all_note_ids()?; + Ok(Self { + target_col, + usn, + normalize_notes, + remapped_notetypes: HashMap::new(), + target_guids, + target_ids, + imports: NoteImports::default(), + media_map, + }) + } + + fn import_notetypes(&mut self, mut notetypes: Vec) -> Result<()> { + for notetype in &mut notetypes { + if let Some(existing) = self.target_col.storage.get_notetype(notetype.id)? { + self.merge_or_remap_notetype(notetype, existing)?; + } else { + self.add_notetype(notetype)?; + } + } + Ok(()) + } + + fn merge_or_remap_notetype( + &mut self, + incoming: &mut Notetype, + existing: Notetype, + ) -> Result<()> { + if incoming.schema_hash() == existing.schema_hash() { + if incoming.mtime_secs > existing.mtime_secs { + self.update_notetype(incoming, existing)?; + } + } else { + self.add_notetype_with_remapped_id(incoming)?; + } + Ok(()) + } + + fn add_notetype(&mut self, notetype: &mut Notetype) -> Result<()> { + notetype.prepare_for_update(None, true)?; + self.target_col + .ensure_notetype_name_unique(notetype, self.usn)?; + notetype.usn = self.usn; + self.target_col + .add_notetype_with_unique_id_undoable(notetype) + } + + fn update_notetype(&mut self, notetype: &mut Notetype, original: Notetype) -> Result<()> { + notetype.usn = self.usn; + self.target_col + .add_or_update_notetype_with_existing_id_inner(notetype, Some(original), self.usn, true) + } + + fn add_notetype_with_remapped_id(&mut self, notetype: &mut Notetype) -> Result<()> { + let old_id = std::mem::take(&mut notetype.id); + notetype.usn = self.usn; + self.target_col + .add_notetype_inner(notetype, self.usn, true)?; + self.remapped_notetypes.insert(old_id, notetype.id); + Ok(()) + } + + fn import_notes( + &mut self, + notes: Vec, + progress: &mut IncrementableProgress, + ) -> Result<()> { + let mut incrementor = progress.incrementor(ImportProgress::Notes); + + for mut note in notes { + incrementor.increment()?; + if let Some(notetype_id) = self.remapped_notetypes.get(¬e.notetype_id) { + if self.target_guids.contains_key(¬e.guid) { + self.imports.log_conflicting(note); + } else { + note.notetype_id = *notetype_id; + self.add_note(note)?; + } + } else if let Some(&meta) = self.target_guids.get(¬e.guid) { + self.maybe_update_note(note, meta)?; + } else { + self.add_note(note)?; + } + } + Ok(()) + } + + fn add_note(&mut self, mut note: Note) -> Result<()> { + self.munge_media(&mut note)?; + self.target_col.canonify_note_tags(&mut note, self.usn)?; + let notetype = self.get_expected_notetype(note.notetype_id)?; + note.prepare_for_update(¬etype, self.normalize_notes)?; + note.usn = self.usn; + let old_id = self.uniquify_note_id(&mut note); + + self.target_col.add_note_only_with_id_undoable(&mut note)?; + self.target_ids.insert(note.id); + self.imports.log_new(note, old_id); + + Ok(()) + } + + fn uniquify_note_id(&mut self, note: &mut Note) -> NoteId { + let original = note.id; + while self.target_ids.contains(¬e.id) { + note.id.0 += 999; + } + original + } + + fn get_expected_notetype(&mut self, ntid: NotetypeId) -> Result> { + self.target_col + .get_notetype(ntid)? + .ok_or(AnkiError::NotFound) + } + + fn get_expected_note(&mut self, nid: NoteId) -> Result { + self.target_col + .storage + .get_note(nid)? + .ok_or(AnkiError::NotFound) + } + + fn maybe_update_note(&mut self, note: Note, meta: NoteMeta) -> Result<()> { + if meta.mtime < note.mtime { + if meta.notetype_id == note.notetype_id { + self.update_note(note, meta.id)?; + } else { + self.imports.log_conflicting(note); + } + } else { + self.imports.log_duplicate(note, meta.id); + } + Ok(()) + } + + fn update_note(&mut self, mut note: Note, target_id: NoteId) -> Result<()> { + let source_id = note.id; + note.id = target_id; + self.munge_media(&mut note)?; + let original = self.get_expected_note(note.id)?; + let notetype = self.get_expected_notetype(note.notetype_id)?; + self.target_col.update_note_inner_without_cards( + &mut note, + &original, + ¬etype, + self.usn, + true, + self.normalize_notes, + true, + )?; + self.imports.log_updated(note, source_id); + Ok(()) + } + + fn munge_media(&mut self, note: &mut Note) -> Result<()> { + for field in note.fields_mut() { + if let Some(new_field) = self.replace_media_refs(field) { + *field = new_field; + }; + } + Ok(()) + } + + fn replace_media_refs(&mut self, field: &mut String) -> Option { + replace_media_refs(field, |name| { + if let Ok(normalized) = safe_normalized_file_name(name) { + if let Some(entry) = self.media_map.use_entry(&normalized) { + if entry.name != name { + // name is not normalized, and/or remapped + return Some(entry.name.clone()); + } + } else if let Cow::Owned(s) = normalized { + // no entry; might be a reference to an existing file, so ensure normalization + return Some(s); + } + } + None + }) + } +} + +impl Notetype { + fn schema_hash(&self) -> Sha1Hash { + let mut hasher = Sha1::new(); + for field in &self.fields { + hasher.update(field.name.as_bytes()); + } + for template in &self.templates { + hasher.update(template.name.as_bytes()); + } + hasher.digest().bytes() + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{collection::open_test_collection, import_export::package::media::SafeMediaEntry}; + + /// Import [Note] into [Collection], optionally taking a [MediaUseMap], + /// or a [Notetype] remapping. + macro_rules! import_note { + ($col:expr, $note:expr, $old_notetype:expr => $new_notetype:expr) => {{ + let mut media_map = MediaUseMap::default(); + let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut media_map).unwrap(); + ctx.remapped_notetypes.insert($old_notetype, $new_notetype); + let mut progress = IncrementableProgress::new(|_, _| true); + ctx.import_notes(vec![$note], &mut progress).unwrap(); + ctx.imports.log + }}; + ($col:expr, $note:expr, $media_map:expr) => {{ + let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut $media_map).unwrap(); + let mut progress = IncrementableProgress::new(|_, _| true); + ctx.import_notes(vec![$note], &mut progress).unwrap(); + ctx.imports.log + }}; + ($col:expr, $note:expr) => {{ + let mut media_map = MediaUseMap::default(); + import_note!($col, $note, media_map) + }}; + } + + /// Assert that exactly one [Note] is logged, and that with the given state and fields. + macro_rules! assert_note_logged { + ($log:expr, $state:ident, $fields:expr) => { + assert_eq!($log.$state.pop().unwrap().fields, $fields); + assert!($log.new.is_empty()); + assert!($log.updated.is_empty()); + assert!($log.duplicate.is_empty()); + assert!($log.conflicting.is_empty()); + }; + } + + impl Collection { + fn note_id_for_guid(&self, guid: &str) -> NoteId { + self.storage + .db + .query_row("SELECT id FROM notes WHERE guid = ?", [guid], |r| r.get(0)) + .unwrap() + } + } + + #[test] + fn should_add_note_with_new_id_if_guid_is_unique_and_id_is_not() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.guid = "other".to_string(); + let original_id = note.id; + + let mut log = import_note!(col, note); + assert_ne!(col.note_id_for_guid("other"), original_id); + assert_note_logged!(log, new, &["", ""]); + } + + #[test] + fn should_skip_note_if_guid_already_exists_with_newer_mtime() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.mtime.0 -= 1; + note.fields_mut()[0] = "outdated".to_string(); + + let mut log = import_note!(col, note); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, duplicate, &["outdated", ""]); + } + + #[test] + fn should_update_note_if_guid_already_exists_with_different_id() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.id.0 = 42; + note.mtime.0 += 1; + note.fields_mut()[0] = "updated".to_string(); + + let mut log = import_note!(col, note); + assert_eq!(col.get_all_notes()[0].fields()[0], "updated"); + assert_note_logged!(log, updated, &["updated", ""]); + } + + #[test] + fn should_ignore_note_if_guid_already_exists_with_different_notetype() { + let mut col = open_test_collection(); + let mut note = col.add_new_note("basic"); + note.notetype_id.0 = 42; + note.mtime.0 += 1; + note.fields_mut()[0] = "updated".to_string(); + + let mut log = import_note!(col, note); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, conflicting, &["updated", ""]); + } + + #[test] + fn should_add_note_with_remapped_notetype_if_in_notetype_map() { + let mut col = open_test_collection(); + let basic_ntid = col.get_notetype_by_name("basic").unwrap().unwrap().id; + let mut note = col.new_note("basic"); + note.notetype_id.0 = 123; + + let mut log = import_note!(col, note, NotetypeId(123) => basic_ntid); + assert_eq!(col.get_all_notes()[0].notetype_id, basic_ntid); + assert_note_logged!(log, new, &["", ""]); + } + + #[test] + fn should_ignore_note_if_guid_already_exists_and_notetype_is_remapped() { + let mut col = open_test_collection(); + let basic_ntid = col.get_notetype_by_name("basic").unwrap().unwrap().id; + let mut note = col.add_new_note("basic"); + note.notetype_id.0 = 123; + note.mtime.0 += 1; + note.fields_mut()[0] = "updated".to_string(); + + let mut log = import_note!(col, note, NotetypeId(123) => basic_ntid); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, conflicting, &["updated", ""]); + } + + #[test] + fn should_add_note_with_remapped_media_reference_in_field_if_in_media_map() { + let mut col = open_test_collection(); + let mut note = col.new_note("basic"); + note.fields_mut()[0] = "".to_string(); + + let mut media_map = MediaUseMap::default(); + let entry = SafeMediaEntry::from_legacy(("0", "bar.jpg".to_string())).unwrap(); + media_map.add_checked("foo.jpg", entry); + + let mut log = import_note!(col, note, media_map); + assert_eq!(col.get_all_notes()[0].fields()[0], ""); + assert_note_logged!(log, new, &[" bar.jpg ", ""]); + } +} diff --git a/rslib/src/import_export/package/apkg/mod.rs b/rslib/src/import_export/package/apkg/mod.rs new file mode 100644 index 000000000..0ac21fac7 --- /dev/null +++ b/rslib/src/import_export/package/apkg/mod.rs @@ -0,0 +1,8 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +mod export; +mod import; +mod tests; + +pub(crate) use import::NoteMeta; diff --git a/rslib/src/import_export/package/apkg/tests.rs b/rslib/src/import_export/package/apkg/tests.rs new file mode 100644 index 000000000..812f6a963 --- /dev/null +++ b/rslib/src/import_export/package/apkg/tests.rs @@ -0,0 +1,150 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +#![cfg(test)] + +use std::{collections::HashSet, fs::File, io::Write}; + +use crate::{ + media::files::sha1_of_data, prelude::*, search::SearchNode, tests::open_fs_test_collection, +}; + +const SAMPLE_JPG: &str = "sample.jpg"; +const SAMPLE_MP3: &str = "sample.mp3"; +const SAMPLE_JS: &str = "_sample.js"; +const JPG_DATA: &[u8] = b"1"; +const MP3_DATA: &[u8] = b"2"; +const JS_DATA: &[u8] = b"3"; +const OTHER_MP3_DATA: &[u8] = b"4"; + +#[test] +fn roundtrip() { + let (mut src_col, src_tempdir) = open_fs_test_collection("src"); + let (mut target_col, _target_tempdir) = open_fs_test_collection("target"); + let apkg_path = src_tempdir.path().join("test.apkg"); + + let (main_deck, sibling_deck) = src_col.add_sample_decks(); + let notetype = src_col.add_sample_notetype(); + let note = src_col.add_sample_note(&main_deck, &sibling_deck, ¬etype); + src_col.add_sample_media(); + target_col.add_conflicting_media(); + + src_col + .export_apkg( + &apkg_path, + SearchNode::from_deck_name("parent::sample"), + true, + true, + true, + None, + |_, _| true, + ) + .unwrap(); + target_col.import_apkg(&apkg_path, |_, _| true).unwrap(); + + target_col.assert_decks(); + target_col.assert_notetype(¬etype); + target_col.assert_note_and_media(¬e); + + target_col.undo().unwrap(); + target_col.assert_empty(); +} + +impl Collection { + fn add_sample_decks(&mut self) -> (Deck, Deck) { + let sample = self.add_named_deck("parent\x1fsample"); + self.add_named_deck("parent\x1fsample\x1fchild"); + let siblings = self.add_named_deck("siblings"); + + (sample, siblings) + } + + fn add_named_deck(&mut self, name: &str) -> Deck { + let mut deck = Deck::new_normal(); + deck.name = NativeDeckName::from_native_str(name); + self.add_deck(&mut deck).unwrap(); + deck + } + + fn add_sample_notetype(&mut self) -> Notetype { + let mut nt = Notetype { + name: "sample".into(), + ..Default::default() + }; + nt.add_field("sample"); + nt.add_template("sample1", "{{sample}}", ""); + nt.add_template("sample2", "{{sample}}2", ""); + self.add_notetype(&mut nt, true).unwrap(); + nt + } + + fn add_sample_note( + &mut self, + main_deck: &Deck, + sibling_decks: &Deck, + notetype: &Notetype, + ) -> Note { + let mut sample = notetype.new_note(); + sample.fields_mut()[0] = format!(" [sound:{SAMPLE_MP3}]"); + sample.tags = vec!["sample".into()]; + self.add_note(&mut sample, main_deck.id).unwrap(); + + let card = self + .storage + .get_card_by_ordinal(sample.id, 1) + .unwrap() + .unwrap(); + self.set_deck(&[card.id], sibling_decks.id).unwrap(); + + sample + } + + fn add_sample_media(&self) { + self.add_media(&[ + (SAMPLE_JPG, JPG_DATA), + (SAMPLE_MP3, MP3_DATA), + (SAMPLE_JS, JS_DATA), + ]); + } + + fn add_conflicting_media(&mut self) { + let mut file = File::create(self.media_folder.join(SAMPLE_MP3)).unwrap(); + file.write_all(OTHER_MP3_DATA).unwrap(); + } + + fn assert_decks(&mut self) { + let existing_decks: HashSet<_> = self + .get_all_deck_names(true) + .unwrap() + .into_iter() + .map(|(_, name)| name) + .collect(); + for deck in ["parent", "parent::sample", "siblings"] { + assert!(existing_decks.contains(deck)); + } + assert!(!existing_decks.contains("parent::sample::child")); + } + + fn assert_notetype(&mut self, notetype: &Notetype) { + assert!(self.get_notetype(notetype.id).unwrap().is_some()); + } + + fn assert_note_and_media(&mut self, note: &Note) { + let sha1 = sha1_of_data(MP3_DATA); + let new_mp3_name = format!("sample-{}.mp3", hex::encode(&sha1)); + + for file in [SAMPLE_JPG, SAMPLE_JS, &new_mp3_name] { + assert!(self.media_folder.join(file).exists()) + } + + let imported_note = self.storage.get_note(note.id).unwrap().unwrap(); + assert!(imported_note.fields()[0].contains(&new_mp3_name)); + } + + fn assert_empty(&self) { + assert!(self.get_all_deck_names(true).unwrap().is_empty()); + assert!(self.storage.get_all_note_ids().unwrap().is_empty()); + assert!(self.storage.get_all_card_ids().unwrap().is_empty()); + assert!(self.storage.all_tags().unwrap().is_empty()); + } +} diff --git a/rslib/src/import_export/package/colpkg/export.rs b/rslib/src/import_export/package/colpkg/export.rs index 6dc5de69f..d4fc0352d 100644 --- a/rslib/src/import_export/package/colpkg/export.rs +++ b/rslib/src/import_export/package/colpkg/export.rs @@ -4,7 +4,8 @@ use std::{ borrow::Cow, collections::HashMap, - fs::{DirEntry, File}, + ffi::OsStr, + fs::File, io::{self, Read, Write}, path::{Path, PathBuf}, }; @@ -21,6 +22,7 @@ use zstd::{ use super::super::{MediaEntries, MediaEntry, Meta, Version}; use crate::{ collection::CollectionBuilder, + import_export::IncrementableProgress, io::{atomic_rename, read_dir_files, tempfile_in_parent_of}, media::files::filename_if_normalized, prelude::*, @@ -38,8 +40,9 @@ impl Collection { out_path: impl AsRef, include_media: bool, legacy: bool, - progress_fn: impl FnMut(usize), + progress_fn: impl 'static + FnMut(usize, bool) -> bool, ) -> Result<()> { + let mut progress = IncrementableProgress::new(progress_fn); let colpkg_name = out_path.as_ref(); let temp_colpkg = tempfile_in_parent_of(colpkg_name)?; let src_path = self.col_path.clone(); @@ -61,19 +64,48 @@ impl Collection { src_media_folder, legacy, &tr, - progress_fn, + &mut progress, )?; atomic_rename(temp_colpkg, colpkg_name, true) } } +pub struct MediaIter(Box>>); + +impl MediaIter { + /// Iterator over all files in the given path, without traversing subfolders. + pub fn from_folder(path: &Path) -> Result { + Ok(Self(Box::new( + read_dir_files(path)?.map(|res| res.map(|entry| entry.path())), + ))) + } + + /// Iterator over all given files in the given folder. + /// Missing files are silently ignored. + pub fn from_file_list( + list: impl IntoIterator + 'static, + folder: PathBuf, + ) -> Self { + Self(Box::new( + list.into_iter() + .map(move |file| folder.join(file)) + .filter(|path| path.exists()) + .map(Ok), + )) + } + + pub fn empty() -> Self { + Self(Box::new(std::iter::empty())) + } +} + fn export_collection_file( out_path: impl AsRef, col_path: impl AsRef, media_dir: Option, legacy: bool, tr: &I18n, - progress_fn: impl FnMut(usize), + progress: &mut IncrementableProgress, ) -> Result<()> { let meta = if legacy { Meta::new_legacy() @@ -82,15 +114,13 @@ fn export_collection_file( }; let mut col_file = File::open(col_path)?; let col_size = col_file.metadata()?.len() as usize; - export_collection( - meta, - out_path, - &mut col_file, - col_size, - media_dir, - tr, - progress_fn, - ) + let media = if let Some(path) = media_dir { + MediaIter::from_folder(&path)? + } else { + MediaIter::empty() + }; + + export_collection(meta, out_path, &mut col_file, col_size, media, tr, progress) } /// Write copied collection data without any media. @@ -105,20 +135,20 @@ pub(crate) fn export_colpkg_from_data( out_path, &mut col_data, col_size, - None, + MediaIter::empty(), tr, - |_| (), + &mut IncrementableProgress::new(|_, _| true), ) } -fn export_collection( +pub(crate) fn export_collection( meta: Meta, out_path: impl AsRef, col: &mut impl Read, col_size: usize, - media_dir: Option, + media: MediaIter, tr: &I18n, - progress_fn: impl FnMut(usize), + progress: &mut IncrementableProgress, ) -> Result<()> { let out_file = File::create(&out_path)?; let mut zip = ZipWriter::new(out_file); @@ -129,7 +159,7 @@ fn export_collection( zip.write_all(&meta_bytes)?; write_collection(&meta, &mut zip, col, col_size)?; write_dummy_collection(&mut zip, tr)?; - write_media(&meta, &mut zip, media_dir, progress_fn)?; + write_media(&meta, &mut zip, media, progress)?; zip.finish()?; Ok(()) @@ -203,17 +233,12 @@ fn zstd_copy(reader: &mut impl Read, writer: &mut impl Write, size: usize) -> Re fn write_media( meta: &Meta, zip: &mut ZipWriter, - media_dir: Option, - progress_fn: impl FnMut(usize), + media: MediaIter, + progress: &mut IncrementableProgress, ) -> Result<()> { let mut media_entries = vec![]; - - if let Some(media_dir) = media_dir { - write_media_files(meta, zip, &media_dir, &mut media_entries, progress_fn)?; - } - + write_media_files(meta, zip, media, &mut media_entries, progress)?; write_media_map(meta, media_entries, zip)?; - Ok(()) } @@ -251,19 +276,23 @@ fn write_media_map( fn write_media_files( meta: &Meta, zip: &mut ZipWriter, - dir: &Path, + media: MediaIter, media_entries: &mut Vec, - mut progress_fn: impl FnMut(usize), + progress: &mut IncrementableProgress, ) -> Result<()> { let mut copier = MediaCopier::new(meta); - for (index, entry) in read_dir_files(dir)?.enumerate() { - progress_fn(index); + let mut incrementor = progress.incrementor(|u| u); + for (index, res) in media.0.enumerate() { + incrementor.increment()?; + let path = res?; zip.start_file(index.to_string(), file_options_stored())?; - let entry = entry?; - let name = normalized_unicode_file_name(&entry)?; - let mut file = File::open(entry.path())?; + let mut file = File::open(&path)?; + let file_name = path + .file_name() + .ok_or_else(|| AnkiError::invalid_input("not a file path"))?; + let name = normalized_unicode_file_name(file_name)?; let (size, sha1) = copier.copy(&mut file, zip)?; media_entries.push(MediaEntry::new(name, size, sha1)); @@ -272,23 +301,11 @@ fn write_media_files( Ok(()) } -impl MediaEntry { - fn new(name: impl Into, size: impl TryInto, sha1: impl Into>) -> Self { - MediaEntry { - name: name.into(), - size: size.try_into().unwrap_or_default(), - sha1: sha1.into(), - legacy_zip_filename: None, - } - } -} - -fn normalized_unicode_file_name(entry: &DirEntry) -> Result { - let filename = entry.file_name(); +fn normalized_unicode_file_name(filename: &OsStr) -> Result { let filename = filename.to_str().ok_or_else(|| { AnkiError::IoError(format!( "non-unicode file name: {}", - entry.file_name().to_string_lossy() + filename.to_string_lossy() )) })?; filename_if_normalized(filename) @@ -324,7 +341,7 @@ impl MediaCopier { &mut self, reader: &mut impl Read, writer: &mut impl Write, - ) -> Result<(usize, [u8; 20])> { + ) -> Result<(usize, Sha1Hash)> { let mut size = 0; let mut hasher = Sha1::new(); let mut buf = [0; 64 * 1024]; diff --git a/rslib/src/import_export/package/colpkg/import.rs b/rslib/src/import_export/package/colpkg/import.rs index 3cdffcba7..0cc38f791 100644 --- a/rslib/src/import_export/package/colpkg/import.rs +++ b/rslib/src/import_export/package/colpkg/import.rs @@ -2,64 +2,39 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html use std::{ - borrow::Cow, - collections::HashMap, - fs::{self, File}, - io::{self, Read, Write}, - path::{Component, Path, PathBuf}, + fs::File, + io::{self, Write}, + path::{Path, PathBuf}, }; -use prost::Message; use zip::{read::ZipFile, ZipArchive}; use zstd::{self, stream::copy_decode}; -use super::super::Version; use crate::{ collection::CollectionBuilder, error::ImportError, import_export::{ - package::{MediaEntries, MediaEntry, Meta}, - ImportProgress, + package::{ + media::{extract_media_entries, SafeMediaEntry}, + Meta, + }, + ImportProgress, IncrementableProgress, }, io::{atomic_rename, tempfile_in_parent_of}, - media::files::normalize_filename, + media::MediaManager, prelude::*, }; -impl Meta { - /// Extracts meta data from an archive and checks if its version is supported. - pub(super) fn from_archive(archive: &mut ZipArchive) -> Result { - let meta_bytes = archive.by_name("meta").ok().and_then(|mut meta_file| { - let mut buf = vec![]; - meta_file.read_to_end(&mut buf).ok()?; - Some(buf) - }); - let meta = if let Some(bytes) = meta_bytes { - let meta: Meta = Message::decode(&*bytes)?; - if meta.version() == Version::Unknown { - return Err(AnkiError::ImportError(ImportError::TooNew)); - } - meta - } else { - Meta { - version: if archive.by_name("collection.anki21").is_ok() { - Version::Legacy2 - } else { - Version::Legacy1 - } as i32, - } - }; - Ok(meta) - } -} - pub fn import_colpkg( colpkg_path: &str, target_col_path: &str, - target_media_folder: &str, - mut progress_fn: impl FnMut(ImportProgress) -> Result<()>, + target_media_folder: &Path, + media_db: &Path, + progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + log: &Logger, ) -> Result<()> { - progress_fn(ImportProgress::Collection)?; + let mut progress = IncrementableProgress::new(progress_fn); + progress.call(ImportProgress::File)?; let col_path = PathBuf::from(target_col_path); let mut tempfile = tempfile_in_parent_of(&col_path)?; @@ -68,12 +43,18 @@ pub fn import_colpkg( let meta = Meta::from_archive(&mut archive)?; copy_collection(&mut archive, &mut tempfile, &meta)?; - progress_fn(ImportProgress::Collection)?; + progress.call(ImportProgress::File)?; check_collection_and_mod_schema(tempfile.path())?; - progress_fn(ImportProgress::Collection)?; + progress.call(ImportProgress::File)?; - let media_folder = Path::new(target_media_folder); - restore_media(&meta, progress_fn, &mut archive, media_folder)?; + restore_media( + &meta, + &mut progress, + &mut archive, + target_media_folder, + media_db, + log, + )?; atomic_rename(tempfile, &col_path, true) } @@ -96,31 +77,25 @@ fn check_collection_and_mod_schema(col_path: &Path) -> Result<()> { fn restore_media( meta: &Meta, - mut progress_fn: impl FnMut(ImportProgress) -> Result<()>, + progress: &mut IncrementableProgress, archive: &mut ZipArchive, media_folder: &Path, + media_db: &Path, + log: &Logger, ) -> Result<()> { let media_entries = extract_media_entries(meta, archive)?; + if media_entries.is_empty() { + return Ok(()); + } + std::fs::create_dir_all(media_folder)?; + let media_manager = MediaManager::new(media_folder, media_db)?; + let mut media_comparer = MediaComparer::new(meta, progress, &media_manager, log)?; - for (entry_idx, entry) in media_entries.iter().enumerate() { - if entry_idx % 10 == 0 { - progress_fn(ImportProgress::Media(entry_idx))?; - } - - let zip_filename = entry - .legacy_zip_filename - .map(|n| n as usize) - .unwrap_or(entry_idx) - .to_string(); - - if let Ok(mut zip_file) = archive.by_name(&zip_filename) { - maybe_restore_media_file(meta, media_folder, entry, &mut zip_file)?; - } else { - return Err(AnkiError::invalid_input(&format!( - "{zip_filename} missing from archive" - ))); - } + let mut incrementor = progress.incrementor(ImportProgress::Media); + for mut entry in media_entries { + incrementor.increment()?; + maybe_restore_media_file(meta, media_folder, archive, &mut entry, &mut media_comparer)?; } Ok(()) @@ -129,13 +104,19 @@ fn restore_media( fn maybe_restore_media_file( meta: &Meta, media_folder: &Path, - entry: &MediaEntry, - zip_file: &mut ZipFile, + archive: &mut ZipArchive, + entry: &mut SafeMediaEntry, + media_comparer: &mut MediaComparer, ) -> Result<()> { - let file_path = entry.safe_normalized_file_path(meta, media_folder)?; - let already_exists = entry.is_equal_to(meta, zip_file, &file_path); + let file_path = entry.file_path(media_folder); + let mut zip_file = entry.fetch_file(archive)?; + if meta.media_list_is_hashmap() { + entry.size = zip_file.size() as u32; + } + + let already_exists = media_comparer.entry_is_equal_to(entry, &file_path)?; if !already_exists { - restore_media_file(meta, zip_file, &file_path)?; + restore_media_file(meta, &mut zip_file, &file_path)?; }; Ok(()) @@ -154,79 +135,6 @@ fn restore_media_file(meta: &Meta, zip_file: &mut ZipFile, path: &Path) -> Resul atomic_rename(tempfile, path, false) } -impl MediaEntry { - fn safe_normalized_file_path(&self, meta: &Meta, media_folder: &Path) -> Result { - check_filename_safe(&self.name)?; - let normalized = maybe_normalizing(&self.name, meta.strict_media_checks())?; - Ok(media_folder.join(normalized.as_ref())) - } - - fn is_equal_to(&self, meta: &Meta, self_zipped: &ZipFile, other_path: &Path) -> bool { - // TODO: checks hashs (https://github.com/ankitects/anki/pull/1723#discussion_r829653147) - let self_size = if meta.media_list_is_hashmap() { - self_zipped.size() - } else { - self.size as u64 - }; - fs::metadata(other_path) - .map(|metadata| metadata.len() as u64 == self_size) - .unwrap_or_default() - } -} - -/// - If strict is true, return an error if not normalized. -/// - If false, return the normalized version. -fn maybe_normalizing(name: &str, strict: bool) -> Result> { - let normalized = normalize_filename(name); - if strict && matches!(normalized, Cow::Owned(_)) { - // exporting code should have checked this - Err(AnkiError::ImportError(ImportError::Corrupt)) - } else { - Ok(normalized) - } -} - -/// Return an error if name contains any path separators. -fn check_filename_safe(name: &str) -> Result<()> { - let mut components = Path::new(name).components(); - let first_element_normal = components - .next() - .map(|component| matches!(component, Component::Normal(_))) - .unwrap_or_default(); - if !first_element_normal || components.next().is_some() { - Err(AnkiError::ImportError(ImportError::Corrupt)) - } else { - Ok(()) - } -} - -fn extract_media_entries(meta: &Meta, archive: &mut ZipArchive) -> Result> { - let mut file = archive.by_name("media")?; - let mut buf = Vec::new(); - if meta.zstd_compressed() { - copy_decode(file, &mut buf)?; - } else { - io::copy(&mut file, &mut buf)?; - } - if meta.media_list_is_hashmap() { - let map: HashMap<&str, String> = serde_json::from_slice(&buf)?; - map.into_iter() - .map(|(idx_str, name)| { - let idx: u32 = idx_str.parse()?; - Ok(MediaEntry { - name, - size: 0, - sha1: vec![], - legacy_zip_filename: Some(idx), - }) - }) - .collect() - } else { - let entries: MediaEntries = Message::decode(&*buf)?; - Ok(entries.entries) - } -} - fn copy_collection( archive: &mut ZipArchive, writer: &mut impl Write, @@ -244,29 +152,31 @@ fn copy_collection( Ok(()) } -#[cfg(test)] -mod test { - use super::*; +type GetChecksumFn<'a> = dyn FnMut(&str) -> Result> + 'a; - #[test] - fn path_traversal() { - assert!(check_filename_safe("foo").is_ok(),); +struct MediaComparer<'a>(Option>>); - assert!(check_filename_safe("..").is_err()); - assert!(check_filename_safe("foo/bar").is_err()); - assert!(check_filename_safe("/foo").is_err()); - assert!(check_filename_safe("../foo").is_err()); +impl<'a> MediaComparer<'a> { + fn new( + meta: &Meta, + progress: &mut IncrementableProgress, + media_manager: &'a MediaManager, + log: &Logger, + ) -> Result { + Ok(Self(if meta.media_list_is_hashmap() { + None + } else { + let mut db_progress_fn = progress.media_db_fn(ImportProgress::MediaCheck)?; + media_manager.register_changes(&mut db_progress_fn, log)?; + Some(Box::new(media_manager.checksum_getter())) + })) + } - if cfg!(windows) { - assert!(check_filename_safe("foo\\bar").is_err()); - assert!(check_filename_safe("c:\\foo").is_err()); - assert!(check_filename_safe("\\foo").is_err()); + fn entry_is_equal_to(&mut self, entry: &SafeMediaEntry, other_path: &Path) -> Result { + if let Some(ref mut get_checksum) = self.0 { + Ok(entry.has_checksum_equal_to(get_checksum)?) + } else { + Ok(entry.has_size_equal_to(other_path)) } } - - #[test] - fn normalization() { - assert_eq!(&maybe_normalizing("con", false).unwrap(), "con_"); - assert!(&maybe_normalizing("con", true).is_err()); - } } diff --git a/rslib/src/import_export/package/colpkg/tests.rs b/rslib/src/import_export/package/colpkg/tests.rs index 08c84012f..b07ef084c 100644 --- a/rslib/src/import_export/package/colpkg/tests.rs +++ b/rslib/src/import_export/package/colpkg/tests.rs @@ -8,8 +8,8 @@ use std::path::Path; use tempfile::tempdir; use crate::{ - collection::CollectionBuilder, import_export::package::import_colpkg, media::MediaManager, - prelude::*, + collection::CollectionBuilder, import_export::package::import_colpkg, log::terminal, + media::MediaManager, prelude::*, }; fn collection_with_media(dir: &Path, name: &str) -> Result { @@ -41,19 +41,26 @@ fn roundtrip() -> Result<()> { // export to a file let col = collection_with_media(dir, name)?; let colpkg_name = dir.join(format!("{name}.colpkg")); - col.export_colpkg(&colpkg_name, true, legacy, |_| ())?; + col.export_colpkg(&colpkg_name, true, legacy, |_, _| true)?; + // import into a new collection let anki2_name = dir .join(format!("{name}.anki2")) .to_string_lossy() .into_owned(); let import_media_dir = dir.join(format!("{name}.media")); + std::fs::create_dir_all(&import_media_dir)?; + let import_media_db = dir.join(format!("{name}.mdb")); + MediaManager::new(&import_media_dir, &import_media_db)?; import_colpkg( &colpkg_name.to_string_lossy(), &anki2_name, - import_media_dir.to_str().unwrap(), - |_| Ok(()), + &import_media_dir, + &import_media_db, + |_, _| true, + &terminal(), )?; + // confirm collection imported let col = CollectionBuilder::new(&anki2_name).build()?; assert_eq!( @@ -82,7 +89,7 @@ fn normalization_check_on_export() -> Result<()> { // manually write a file in the wrong encoding. std::fs::write(col.media_folder.join("ぱぱ.jpg"), "nfd encoding")?; assert_eq!( - col.export_colpkg(&colpkg_name, true, false, |_| ()) + col.export_colpkg(&colpkg_name, true, false, |_, _| true,) .unwrap_err(), AnkiError::MediaCheckRequired ); diff --git a/rslib/src/import_export/package/media.rs b/rslib/src/import_export/package/media.rs new file mode 100644 index 000000000..3613ffa72 --- /dev/null +++ b/rslib/src/import_export/package/media.rs @@ -0,0 +1,174 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use std::{ + borrow::Cow, + collections::HashMap, + fs::{self, File}, + io, + path::{Path, PathBuf}, +}; + +use prost::Message; +use tempfile::NamedTempFile; +use zip::{read::ZipFile, ZipArchive}; +use zstd::stream::copy_decode; + +use super::{MediaEntries, MediaEntry, Meta}; +use crate::{ + error::ImportError, + io::{atomic_rename, filename_is_safe}, + media::files::normalize_filename, + prelude::*, +}; + +/// Like [MediaEntry], but with a safe filename and set zip filename. +pub(super) struct SafeMediaEntry { + pub(super) name: String, + pub(super) size: u32, + pub(super) sha1: Sha1Hash, + pub(super) index: usize, +} + +impl MediaEntry { + pub(super) fn new( + name: impl Into, + size: impl TryInto, + sha1: impl Into>, + ) -> Self { + MediaEntry { + name: name.into(), + size: size.try_into().unwrap_or_default(), + sha1: sha1.into(), + legacy_zip_filename: None, + } + } +} + +impl SafeMediaEntry { + pub(super) fn from_entry(enumerated: (usize, MediaEntry)) -> Result { + let (index, entry) = enumerated; + if let Ok(sha1) = entry.sha1.try_into() { + if !matches!(safe_normalized_file_name(&entry.name)?, Cow::Owned(_)) { + return Ok(Self { + name: entry.name, + size: entry.size, + sha1, + index, + }); + } + } + Err(AnkiError::ImportError(ImportError::Corrupt)) + } + + pub(super) fn from_legacy(legacy_entry: (&str, String)) -> Result { + let zip_filename: usize = legacy_entry.0.parse()?; + let name = match safe_normalized_file_name(&legacy_entry.1)? { + Cow::Owned(new_name) => new_name, + Cow::Borrowed(_) => legacy_entry.1, + }; + Ok(Self { + name, + size: 0, + sha1: [0; 20], + index: zip_filename, + }) + } + + pub(super) fn file_path(&self, media_folder: &Path) -> PathBuf { + media_folder.join(&self.name) + } + + pub(super) fn fetch_file<'a>(&self, archive: &'a mut ZipArchive) -> Result> { + archive + .by_name(&self.index.to_string()) + .map_err(|_| AnkiError::invalid_input(&format!("{} missing from archive", self.index))) + } + + pub(super) fn has_checksum_equal_to( + &self, + get_checksum: &mut impl FnMut(&str) -> Result>, + ) -> Result { + get_checksum(&self.name).map(|opt| opt.map_or(false, |sha1| sha1 == self.sha1)) + } + + pub(super) fn has_size_equal_to(&self, other_path: &Path) -> bool { + fs::metadata(other_path).map_or(false, |metadata| metadata.len() == self.size as u64) + } + + pub(super) fn copy_from_archive( + &self, + archive: &mut ZipArchive, + target_folder: &Path, + ) -> Result<()> { + let mut file = self.fetch_file(archive)?; + let mut tempfile = NamedTempFile::new_in(target_folder)?; + io::copy(&mut file, &mut tempfile)?; + atomic_rename(tempfile, &self.file_path(target_folder), false) + } +} + +pub(super) fn extract_media_entries( + meta: &Meta, + archive: &mut ZipArchive, +) -> Result> { + let media_list_data = get_media_list_data(archive, meta)?; + if meta.media_list_is_hashmap() { + let map: HashMap<&str, String> = serde_json::from_slice(&media_list_data)?; + map.into_iter().map(SafeMediaEntry::from_legacy).collect() + } else { + MediaEntries::decode_safe_entries(&media_list_data) + } +} + +pub(super) fn safe_normalized_file_name(name: &str) -> Result> { + if !filename_is_safe(name) { + Err(AnkiError::ImportError(ImportError::Corrupt)) + } else { + Ok(normalize_filename(name)) + } +} + +fn get_media_list_data(archive: &mut ZipArchive, meta: &Meta) -> Result> { + let mut file = archive.by_name("media")?; + let mut buf = Vec::new(); + if meta.zstd_compressed() { + copy_decode(file, &mut buf)?; + } else { + io::copy(&mut file, &mut buf)?; + } + Ok(buf) +} + +impl MediaEntries { + fn decode_safe_entries(buf: &[u8]) -> Result> { + let entries: Self = Message::decode(buf)?; + entries + .entries + .into_iter() + .enumerate() + .map(SafeMediaEntry::from_entry) + .collect() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn normalization() { + // legacy entries get normalized on deserialisation + let entry = SafeMediaEntry::from_legacy(("1", "con".to_owned())).unwrap(); + assert_eq!(entry.name, "con_"); + + // new-style entries should have been normalized on export + let mut entries = Vec::new(); + MediaEntries { + entries: vec![MediaEntry::new("con", 0, Vec::new())], + } + .encode(&mut entries) + .unwrap(); + assert!(MediaEntries::decode_safe_entries(&entries).is_err()); + } +} diff --git a/rslib/src/import_export/package/meta.rs b/rslib/src/import_export/package/meta.rs index c2ac4e80c..62e8dbda4 100644 --- a/rslib/src/import_export/package/meta.rs +++ b/rslib/src/import_export/package/meta.rs @@ -1,7 +1,13 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::{fs::File, io::Read}; + +use prost::Message; +use zip::ZipArchive; + pub(super) use crate::backend_proto::{package_metadata::Version, PackageMetadata as Meta}; +use crate::{error::ImportError, prelude::*, storage::SchemaVersion}; impl Version { pub(super) fn collection_filename(&self) -> &'static str { @@ -12,6 +18,16 @@ impl Version { Version::Latest => "collection.anki21b", } } + + /// Latest schema version that is supported by all clients supporting + /// this package version. + pub(super) fn schema_version(&self) -> SchemaVersion { + match self { + Version::Unknown => unreachable!(), + Version::Legacy1 | Version::Legacy2 => SchemaVersion::V11, + Version::Latest => SchemaVersion::V18, + } + } } impl Meta { @@ -27,10 +43,41 @@ impl Meta { } } + /// Extracts meta data from an archive and checks if its version is supported. + pub(super) fn from_archive(archive: &mut ZipArchive) -> Result { + let meta_bytes = archive.by_name("meta").ok().and_then(|mut meta_file| { + let mut buf = vec![]; + meta_file.read_to_end(&mut buf).ok()?; + Some(buf) + }); + let meta = if let Some(bytes) = meta_bytes { + let meta: Meta = Message::decode(&*bytes)?; + if meta.version() == Version::Unknown { + return Err(AnkiError::ImportError(ImportError::TooNew)); + } + meta + } else { + Meta { + version: if archive.by_name("collection.anki21").is_ok() { + Version::Legacy2 + } else { + Version::Legacy1 + } as i32, + } + }; + Ok(meta) + } + pub(super) fn collection_filename(&self) -> &'static str { self.version().collection_filename() } + /// Latest schema version that is supported by all clients supporting + /// this package version. + pub(super) fn schema_version(&self) -> SchemaVersion { + self.version().schema_version() + } + pub(super) fn zstd_compressed(&self) -> bool { !self.is_legacy() } @@ -39,10 +86,6 @@ impl Meta { self.is_legacy() } - pub(super) fn strict_media_checks(&self) -> bool { - !self.is_legacy() - } - fn is_legacy(&self) -> bool { matches!(self.version(), Version::Legacy1 | Version::Legacy2) } diff --git a/rslib/src/import_export/package/mod.rs b/rslib/src/import_export/package/mod.rs index 66d3ca14e..f999d84ce 100644 --- a/rslib/src/import_export/package/mod.rs +++ b/rslib/src/import_export/package/mod.rs @@ -1,11 +1,15 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +mod apkg; mod colpkg; +mod media; mod meta; +pub(crate) use apkg::NoteMeta; pub(crate) use colpkg::export::export_colpkg_from_data; pub use colpkg::import::import_colpkg; pub(self) use meta::{Meta, Version}; +pub use crate::backend_proto::import_anki_package_response::{Log as NoteLog, Note as LogNote}; pub(self) use crate::backend_proto::{media_entries::MediaEntry, MediaEntries}; diff --git a/rslib/src/io.rs b/rslib/src/io.rs index 8a68ffdb3..fc86c5db6 100644 --- a/rslib/src/io.rs +++ b/rslib/src/io.rs @@ -1,7 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use std::path::Path; +use std::path::{Component, Path}; use tempfile::NamedTempFile; @@ -42,6 +42,17 @@ pub(crate) fn read_dir_files(path: impl AsRef) -> std::io::Result bool { + let mut components = Path::new(name).components(); + let first_element_normal = components + .next() + .map(|component| matches!(component, Component::Normal(_))) + .unwrap_or_default(); + + first_element_normal && components.next().is_none() +} + pub(crate) struct ReadDirFiles(std::fs::ReadDir); impl Iterator for ReadDirFiles { @@ -60,3 +71,24 @@ impl Iterator for ReadDirFiles { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn path_traversal() { + assert!(filename_is_safe("foo")); + + assert!(!filename_is_safe("..")); + assert!(!filename_is_safe("foo/bar")); + assert!(!filename_is_safe("/foo")); + assert!(!filename_is_safe("../foo")); + + if cfg!(windows) { + assert!(!filename_is_safe("foo\\bar")); + assert!(!filename_is_safe("c:\\foo")); + assert!(!filename_is_safe("\\foo")); + } + } +} diff --git a/rslib/src/lib.rs b/rslib/src/lib.rs index 9e03edc38..dc43db87e 100644 --- a/rslib/src/lib.rs +++ b/rslib/src/lib.rs @@ -40,6 +40,7 @@ mod sync; pub mod tags; pub mod template; pub mod template_filters; +pub(crate) mod tests; pub mod text; pub mod timestamp; pub mod types; diff --git a/rslib/src/media/changetracker.rs b/rslib/src/media/changetracker.rs index 6db3e94f9..dcf04da85 100644 --- a/rslib/src/media/changetracker.rs +++ b/rslib/src/media/changetracker.rs @@ -4,7 +4,6 @@ use std::{collections::HashMap, path::Path, time}; use crate::{ - error::{AnkiError, Result}, log::{debug, Logger}, media::{ database::{MediaDatabaseContext, MediaEntry}, @@ -13,11 +12,12 @@ use crate::{ NONSYNCABLE_FILENAME, }, }, + prelude::*, }; struct FilesystemEntry { fname: String, - sha1: Option<[u8; 20]>, + sha1: Option, mtime: i64, is_new: bool, } diff --git a/rslib/src/media/database.rs b/rslib/src/media/database.rs index 858035c33..927c1e6fd 100644 --- a/rslib/src/media/database.rs +++ b/rslib/src/media/database.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, path::Path}; use rusqlite::{params, Connection, OptionalExtension, Row, Statement}; -use crate::error::Result; +use crate::prelude::*; fn trace(s: &str) { println!("sql: {}", s) @@ -47,7 +47,7 @@ fn initial_db_setup(db: &mut Connection) -> Result<()> { pub struct MediaEntry { pub fname: String, /// If None, file has been deleted - pub sha1: Option<[u8; 20]>, + pub sha1: Option, // Modification time; 0 if deleted pub mtime: i64, /// True if changed since last sync @@ -222,6 +222,14 @@ delete from media where fname=?" Ok(map?) } + /// Returns all filenames and checksums, where the checksum is not null. + pub(super) fn all_checksums(&mut self) -> Result> { + self.db + .prepare("SELECT fname, csum FROM media WHERE csum IS NOT NULL")? + .query_and_then([], row_to_name_and_checksum)? + .collect() + } + pub(super) fn force_resync(&mut self) -> Result<()> { self.db .execute_batch("delete from media; update meta set lastUsn = 0, dirMod = 0") @@ -231,7 +239,7 @@ delete from media where fname=?" fn row_to_entry(row: &Row) -> rusqlite::Result { // map the string checksum into bytes - let sha1_str: Option = row.get(1)?; + let sha1_str = row.get_ref(1)?.as_str_or_null()?; let sha1_array = if let Some(s) = sha1_str { let mut arr = [0; 20]; match hex::decode_to_slice(s, arr.as_mut()) { @@ -250,6 +258,15 @@ fn row_to_entry(row: &Row) -> rusqlite::Result { }) } +fn row_to_name_and_checksum(row: &Row) -> Result<(String, Sha1Hash)> { + let file_name = row.get(0)?; + let sha1_str: String = row.get(1)?; + let mut sha1 = [0; 20]; + hex::decode_to_slice(sha1_str, &mut sha1) + .map_err(|_| AnkiError::invalid_input(format!("bad media checksum: {file_name}")))?; + Ok((file_name, sha1)) +} + #[cfg(test)] mod test { use tempfile::NamedTempFile; diff --git a/rslib/src/media/files.rs b/rslib/src/media/files.rs index b56b1baa3..077e4ce93 100644 --- a/rslib/src/media/files.rs +++ b/rslib/src/media/files.rs @@ -15,10 +15,7 @@ use sha1::Sha1; use unic_ucd_category::GeneralCategory; use unicode_normalization::{is_nfc, UnicodeNormalization}; -use crate::{ - error::{AnkiError, Result}, - log::{debug, Logger}, -}; +use crate::prelude::*; /// The maximum length we allow a filename to be. When combined /// with the rest of the path, the full path needs to be under ~240 chars @@ -164,7 +161,7 @@ pub fn add_data_to_folder_uniquely<'a, P>( folder: P, desired_name: &'a str, data: &[u8], - sha1: [u8; 20], + sha1: Sha1Hash, ) -> io::Result> where P: AsRef, @@ -194,7 +191,7 @@ where } /// Convert foo.jpg into foo-abcde12345679.jpg -fn add_hash_suffix_to_file_stem(fname: &str, hash: &[u8; 20]) -> String { +pub(crate) fn add_hash_suffix_to_file_stem(fname: &str, hash: &Sha1Hash) -> String { // when appending a hash to make unique, it will be 40 bytes plus the hyphen. let max_len = MAX_FILENAME_LENGTH - 40 - 1; @@ -244,18 +241,18 @@ fn split_and_truncate_filename(fname: &str, max_bytes: usize) -> (&str, &str) { }; // cap extension to 10 bytes so stem_len can't be negative - ext = truncate_to_char_boundary(ext, 10); + ext = truncated_to_char_boundary(ext, 10); // cap stem, allowing for the . and a trailing _ let stem_len = max_bytes - ext.len() - 2; - stem = truncate_to_char_boundary(stem, stem_len); + stem = truncated_to_char_boundary(stem, stem_len); (stem, ext) } -/// Trim a string on a valid UTF8 boundary. +/// Return a substring on a valid UTF8 boundary. /// Based on a funtion in the Rust stdlib. -fn truncate_to_char_boundary(s: &str, mut max: usize) -> &str { +fn truncated_to_char_boundary(s: &str, mut max: usize) -> &str { if max >= s.len() { s } else { @@ -267,7 +264,7 @@ fn truncate_to_char_boundary(s: &str, mut max: usize) -> &str { } /// Return the SHA1 of a file if it exists, or None. -fn existing_file_sha1(path: &Path) -> io::Result> { +fn existing_file_sha1(path: &Path) -> io::Result> { match sha1_of_file(path) { Ok(o) => Ok(Some(o)), Err(e) => { @@ -281,12 +278,17 @@ fn existing_file_sha1(path: &Path) -> io::Result> { } /// Return the SHA1 of a file, failing if it doesn't exist. -pub(crate) fn sha1_of_file(path: &Path) -> io::Result<[u8; 20]> { +pub(crate) fn sha1_of_file(path: &Path) -> io::Result { let mut file = fs::File::open(path)?; + sha1_of_reader(&mut file) +} + +/// Return the SHA1 of a stream. +pub(crate) fn sha1_of_reader(reader: &mut impl Read) -> io::Result { let mut hasher = Sha1::new(); let mut buf = [0; 64 * 1024]; loop { - match file.read(&mut buf) { + match reader.read(&mut buf) { Ok(0) => break, Ok(n) => hasher.update(&buf[0..n]), Err(e) => { @@ -302,7 +304,7 @@ pub(crate) fn sha1_of_file(path: &Path) -> io::Result<[u8; 20]> { } /// Return the SHA1 of provided data. -pub(crate) fn sha1_of_data(data: &[u8]) -> [u8; 20] { +pub(crate) fn sha1_of_data(data: &[u8]) -> Sha1Hash { let mut hasher = Sha1::new(); hasher.update(data); hasher.digest().bytes() @@ -371,7 +373,7 @@ pub(super) fn trash_folder(media_folder: &Path) -> Result { pub(super) struct AddedFile { pub fname: String, - pub sha1: [u8; 20], + pub sha1: Sha1Hash, pub mtime: i64, pub renamed_from: Option, } diff --git a/rslib/src/media/mod.rs b/rslib/src/media/mod.rs index 351befa1f..de3fb73e6 100644 --- a/rslib/src/media/mod.rs +++ b/rslib/src/media/mod.rs @@ -3,19 +3,21 @@ use std::{ borrow::Cow, + collections::HashMap, path::{Path, PathBuf}, }; use rusqlite::Connection; use slog::Logger; +use self::changetracker::ChangeTracker; use crate::{ - error::Result, media::{ database::{open_or_create, MediaDatabaseContext, MediaEntry}, files::{add_data_to_folder_uniquely, mtime_as_i64, remove_files, sha1_of_data}, sync::{MediaSyncProgress, MediaSyncer}, }, + prelude::*, }; pub mod changetracker; @@ -24,6 +26,8 @@ pub mod database; pub mod files; pub mod sync; +pub type Sha1Hash = [u8; 20]; + pub struct MediaManager { db: Connection, media_folder: PathBuf, @@ -153,4 +157,31 @@ impl MediaManager { pub fn dbctx(&self) -> MediaDatabaseContext { MediaDatabaseContext::new(&self.db) } + + pub fn all_checksums( + &self, + progress: impl FnMut(usize) -> bool, + log: &Logger, + ) -> Result> { + let mut dbctx = self.dbctx(); + ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut dbctx)?; + dbctx.all_checksums() + } + + pub fn checksum_getter(&self) -> impl FnMut(&str) -> Result> + '_ { + let mut dbctx = self.dbctx(); + move |fname: &str| { + dbctx + .get_entry(fname) + .map(|opt| opt.and_then(|entry| entry.sha1)) + } + } + + pub fn register_changes( + &self, + progress: &mut impl FnMut(usize) -> bool, + log: &Logger, + ) -> Result<()> { + ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut self.dbctx()) + } } diff --git a/rslib/src/notes/mod.rs b/rslib/src/notes/mod.rs index 5eed822f4..ed8ec29fa 100644 --- a/rslib/src/notes/mod.rs +++ b/rslib/src/notes/mod.rs @@ -55,6 +55,10 @@ impl Note { &self.fields } + pub fn into_fields(self) -> Vec { + self.fields + } + pub fn set_field(&mut self, idx: usize, text: impl Into) -> Result<()> { if idx >= self.fields.len() { return Err(AnkiError::invalid_input( @@ -320,7 +324,7 @@ fn invalid_char_for_field(c: char) -> bool { } impl Collection { - fn canonify_note_tags(&mut self, note: &mut Note, usn: Usn) -> Result<()> { + pub(crate) fn canonify_note_tags(&mut self, note: &mut Note, usn: Usn) -> Result<()> { if !note.tags.is_empty() { let tags = std::mem::take(&mut note.tags); note.tags = self.canonify_tags(tags, usn)?.0; diff --git a/rslib/src/notes/undo.rs b/rslib/src/notes/undo.rs index 709778810..89e22dc48 100644 --- a/rslib/src/notes/undo.rs +++ b/rslib/src/notes/undo.rs @@ -83,13 +83,23 @@ impl Collection { } /// Add a note, not adding any cards. - pub(super) fn add_note_only_undoable(&mut self, note: &mut Note) -> Result<(), AnkiError> { + pub(crate) fn add_note_only_undoable(&mut self, note: &mut Note) -> Result<(), AnkiError> { self.storage.add_note(note)?; self.save_undo(UndoableNoteChange::Added(Box::new(note.clone()))); Ok(()) } + /// Add a note, not adding any cards. Caller guarantees id is unique. + pub(crate) fn add_note_only_with_id_undoable(&mut self, note: &mut Note) -> Result<()> { + if self.storage.add_note_if_unique(note)? { + self.save_undo(UndoableNoteChange::Added(Box::new(note.clone()))); + Ok(()) + } else { + Err(AnkiError::invalid_input("note id existed")) + } + } + pub(crate) fn update_note_tags_undoable( &mut self, tags: &NoteTags, diff --git a/rslib/src/notetype/mod.rs b/rslib/src/notetype/mod.rs index eb00fc10d..84ebb3e24 100644 --- a/rslib/src/notetype/mod.rs +++ b/rslib/src/notetype/mod.rs @@ -648,7 +648,7 @@ impl Collection { /// - Caller must set notetype as modified if appropriate. /// - This only supports undo when an existing notetype is passed in. - fn add_or_update_notetype_with_existing_id_inner( + pub(crate) fn add_or_update_notetype_with_existing_id_inner( &mut self, notetype: &mut Notetype, original: Option, diff --git a/rslib/src/notetype/undo.rs b/rslib/src/notetype/undo.rs index 74f6d34a6..5780ed9b2 100644 --- a/rslib/src/notetype/undo.rs +++ b/rslib/src/notetype/undo.rs @@ -42,6 +42,17 @@ impl Collection { Ok(()) } + /// Caller must ensure [NotetypeId] is unique. + pub(crate) fn add_notetype_with_unique_id_undoable( + &mut self, + notetype: &Notetype, + ) -> Result<()> { + self.storage + .add_or_update_notetype_with_existing_id(notetype)?; + self.save_undo(UndoableNotetypeChange::Added(Box::new(notetype.clone()))); + Ok(()) + } + pub(super) fn update_notetype_undoable( &mut self, notetype: &Notetype, diff --git a/rslib/src/ops.rs b/rslib/src/ops.rs index 5daa2c582..c3d8f4a28 100644 --- a/rslib/src/ops.rs +++ b/rslib/src/ops.rs @@ -17,6 +17,7 @@ pub enum Op { CreateCustomStudy, EmptyFilteredDeck, FindAndReplace, + Import, RebuildFilteredDeck, RemoveDeck, RemoveNote, @@ -55,6 +56,7 @@ impl Op { Op::AnswerCard => tr.actions_answer_card(), Op::Bury => tr.studying_bury(), Op::CreateCustomStudy => tr.actions_custom_study(), + Op::Import => tr.actions_import(), Op::RemoveDeck => tr.decks_delete_deck(), Op::RemoveNote => tr.studying_delete_note(), Op::RenameDeck => tr.actions_rename_deck(), diff --git a/rslib/src/prelude.rs b/rslib/src/prelude.rs index 721aaaf52..042080415 100644 --- a/rslib/src/prelude.rs +++ b/rslib/src/prelude.rs @@ -12,6 +12,7 @@ pub use crate::{ decks::{Deck, DeckId, DeckKind, NativeDeckName}, error::{AnkiError, Result}, i18n::I18n, + media::Sha1Hash, notes::{Note, NoteId}, notetype::{Notetype, NotetypeId}, ops::{Op, OpChanges, OpOutput}, diff --git a/rslib/src/revlog/undo.rs b/rslib/src/revlog/undo.rs index 24ea6f61a..71602f0cc 100644 --- a/rslib/src/revlog/undo.rs +++ b/rslib/src/revlog/undo.rs @@ -28,9 +28,17 @@ impl Collection { /// Add the provided revlog entry, modifying the ID if it is not unique. pub(crate) fn add_revlog_entry_undoable(&mut self, mut entry: RevlogEntry) -> Result { - entry.id = self.storage.add_revlog_entry(&entry, true)?; + entry.id = self.storage.add_revlog_entry(&entry, true)?.unwrap(); let id = entry.id; self.save_undo(UndoableRevlogChange::Added(Box::new(entry))); Ok(id) } + + /// Add the provided revlog entry, if its ID is unique. + pub(crate) fn add_revlog_entry_if_unique_undoable(&mut self, entry: RevlogEntry) -> Result<()> { + if self.storage.add_revlog_entry(&entry, false)?.is_some() { + self.save_undo(UndoableRevlogChange::Added(Box::new(entry))); + } + Ok(()) + } } diff --git a/rslib/src/search/builder.rs b/rslib/src/search/builder.rs index c51c50ee9..dab57edfa 100644 --- a/rslib/src/search/builder.rs +++ b/rslib/src/search/builder.rs @@ -134,6 +134,14 @@ impl Default for SearchBuilder { } impl SearchNode { + pub fn from_deck_id(did: impl Into, with_children: bool) -> Self { + if with_children { + Self::DeckIdWithChildren(did.into()) + } else { + Self::DeckIdWithoutChildren(did.into()) + } + } + /// Construct [SearchNode] from an unescaped deck name. pub fn from_deck_name(name: &str) -> Self { Self::Deck(escape_anki_wildcards_for_search_node(name)) @@ -158,6 +166,14 @@ impl SearchNode { name, ))) } + + pub fn from_note_ids, N: Into>(ids: I) -> Self { + Self::NoteIds(ids.into_iter().map(Into::into).join(",")) + } + + pub fn from_card_ids, C: Into>(ids: I) -> Self { + Self::CardIds(ids.into_iter().map(Into::into).join(",")) + } } impl> From for Node { diff --git a/rslib/src/stats/graphs.rs b/rslib/src/stats/graphs.rs index 0a158eb66..eb5236e66 100644 --- a/rslib/src/stats/graphs.rs +++ b/rslib/src/stats/graphs.rs @@ -38,7 +38,7 @@ impl Collection { self.storage.get_all_revlog_entries(revlog_start)? } else { self.storage - .get_revlog_entries_for_searched_cards(revlog_start)? + .get_pb_revlog_entries_for_searched_cards(revlog_start)? }; self.storage.clear_searched_cards_table()?; diff --git a/rslib/src/storage/card/add_card_if_unique.sql b/rslib/src/storage/card/add_card_if_unique.sql new file mode 100644 index 000000000..80b3a6394 --- /dev/null +++ b/rslib/src/storage/card/add_card_if_unique.sql @@ -0,0 +1,41 @@ +INSERT + OR IGNORE INTO cards ( + id, + nid, + did, + ord, + mod, + usn, + type, + queue, + due, + ivl, + factor, + reps, + lapses, + left, + odue, + odid, + flags, + data + ) +VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ? + ) \ No newline at end of file diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index f62053ae8..96b314486 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -145,6 +145,34 @@ impl super::SqliteStorage { Ok(()) } + /// Add card if id is unique. True if card was added. + pub(crate) fn add_card_if_unique(&self, card: &Card) -> Result { + self.db + .prepare_cached(include_str!("add_card_if_unique.sql"))? + .execute(params![ + card.id, + card.note_id, + card.deck_id, + card.template_idx, + card.mtime, + card.usn, + card.ctype as u8, + card.queue as i8, + card.due, + card.interval, + card.ease_factor, + card.reps, + card.lapses, + card.remaining_steps, + card.original_due, + card.original_deck_id, + card.flags, + CardData::from_card(card), + ]) + .map(|n_rows| n_rows == 1) + .map_err(Into::into) + } + /// Add or update card, using the provided ID. Used for syncing & undoing. pub(crate) fn add_or_update_card(&self, card: &Card) -> Result<()> { let mut stmt = self.db.prepare_cached(include_str!("add_or_update.sql"))?; @@ -400,6 +428,20 @@ impl super::SqliteStorage { .collect() } + pub(crate) fn get_all_card_ids(&self) -> Result> { + self.db + .prepare("SELECT id FROM cards")? + .query_and_then([], |row| Ok(row.get(0)?))? + .collect() + } + + pub(crate) fn all_cards_as_nid_and_ord(&self) -> Result> { + self.db + .prepare("SELECT nid, ord FROM cards")? + .query_and_then([], |r| Ok((NoteId(r.get(0)?), r.get(1)?)))? + .collect() + } + pub(crate) fn card_ids_of_notes(&self, nids: &[NoteId]) -> Result> { let mut stmt = self .db @@ -455,6 +497,16 @@ impl super::SqliteStorage { Ok(nids) } + /// Place the ids of cards with notes in 'search_nids' into 'search_cids'. + /// Returns number of added cards. + pub(crate) fn search_cards_of_notes_into_table(&self) -> Result { + self.setup_searched_cards_table()?; + self.db + .prepare(include_str!("search_cards_of_notes_into_table.sql"))? + .execute([]) + .map_err(Into::into) + } + pub(crate) fn all_searched_cards(&self) -> Result> { self.db .prepare_cached(concat!( diff --git a/rslib/src/storage/card/search_cards_of_notes_into_table.sql b/rslib/src/storage/card/search_cards_of_notes_into_table.sql new file mode 100644 index 000000000..0387fd566 --- /dev/null +++ b/rslib/src/storage/card/search_cards_of_notes_into_table.sql @@ -0,0 +1,7 @@ +INSERT INTO search_cids +SELECT id +FROM cards +WHERE nid IN ( + SELECT nid + FROM search_nids + ) \ No newline at end of file diff --git a/rslib/src/storage/deck/mod.rs b/rslib/src/storage/deck/mod.rs index da1c5d0f7..f99467143 100644 --- a/rslib/src/storage/deck/mod.rs +++ b/rslib/src/storage/deck/mod.rs @@ -74,6 +74,15 @@ impl SqliteStorage { .transpose() } + pub(crate) fn get_deck_by_name(&self, machine_name: &str) -> Result> { + self.db + .prepare_cached(concat!(include_str!("get_deck.sql"), " WHERE name = ?"))? + .query_and_then([machine_name], row_to_deck)? + .next() + .transpose() + .map_err(Into::into) + } + pub(crate) fn get_all_decks(&self) -> Result> { self.db .prepare(include_str!("get_deck.sql"))? @@ -111,6 +120,17 @@ impl SqliteStorage { .map_err(Into::into) } + pub(crate) fn get_decks_for_search_cards(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get_deck.sql"), + " WHERE id IN (SELECT DISTINCT did FROM cards WHERE id IN", + " (SELECT cid FROM search_cids))", + ))? + .query_and_then([], row_to_deck)? + .collect() + } + // caller should ensure name unique pub(crate) fn add_deck(&self, deck: &mut Deck) -> Result<()> { assert!(deck.id.0 == 0); diff --git a/rslib/src/storage/deckconfig/add_if_unique.sql b/rslib/src/storage/deckconfig/add_if_unique.sql new file mode 100644 index 000000000..516024466 --- /dev/null +++ b/rslib/src/storage/deckconfig/add_if_unique.sql @@ -0,0 +1,3 @@ +INSERT + OR IGNORE INTO deck_config (id, name, mtime_secs, usn, config) +VALUES (?, ?, ?, ?, ?); \ No newline at end of file diff --git a/rslib/src/storage/deckconfig/mod.rs b/rslib/src/storage/deckconfig/mod.rs index aeffc0151..c7b31dc92 100644 --- a/rslib/src/storage/deckconfig/mod.rs +++ b/rslib/src/storage/deckconfig/mod.rs @@ -67,6 +67,22 @@ impl SqliteStorage { Ok(()) } + pub(crate) fn add_deck_conf_if_unique(&self, conf: &DeckConfig) -> Result { + let mut conf_bytes = vec![]; + conf.inner.encode(&mut conf_bytes)?; + self.db + .prepare_cached(include_str!("add_if_unique.sql"))? + .execute(params![ + conf.id, + conf.name, + conf.mtime_secs, + conf.usn, + conf_bytes, + ]) + .map(|added| added == 1) + .map_err(Into::into) + } + pub(crate) fn update_deck_conf(&self, conf: &DeckConfig) -> Result<()> { let mut conf_bytes = vec![]; conf.inner.encode(&mut conf_bytes)?; diff --git a/rslib/src/storage/mod.rs b/rslib/src/storage/mod.rs index f2099d5b7..5ab5aa42d 100644 --- a/rslib/src/storage/mod.rs +++ b/rslib/src/storage/mod.rs @@ -34,9 +34,10 @@ impl SchemaVersion { } /// Write a list of IDs as '(x,y,...)' into the provided string. -pub(crate) fn ids_to_string(buf: &mut String, ids: &[T]) +pub(crate) fn ids_to_string(buf: &mut String, ids: I) where - T: std::fmt::Display, + D: std::fmt::Display, + I: IntoIterator, { buf.push('('); write_comma_separated_ids(buf, ids); @@ -44,15 +45,18 @@ where } /// Write a list of Ids as 'x,y,...' into the provided string. -pub(crate) fn write_comma_separated_ids(buf: &mut String, ids: &[T]) +pub(crate) fn write_comma_separated_ids(buf: &mut String, ids: I) where - T: std::fmt::Display, + D: std::fmt::Display, + I: IntoIterator, { - if !ids.is_empty() { - for id in ids.iter().skip(1) { - write!(buf, "{},", id).unwrap(); - } - write!(buf, "{}", ids[0]).unwrap(); + let mut trailing_sep = false; + for id in ids { + write!(buf, "{},", id).unwrap(); + trailing_sep = true; + } + if trailing_sep { + buf.pop(); } } @@ -73,17 +77,17 @@ mod test { #[test] fn ids_string() { let mut s = String::new(); - ids_to_string::(&mut s, &[]); + ids_to_string(&mut s, &[0; 0]); assert_eq!(s, "()"); s.clear(); ids_to_string(&mut s, &[7]); assert_eq!(s, "(7)"); s.clear(); ids_to_string(&mut s, &[7, 6]); - assert_eq!(s, "(6,7)"); + assert_eq!(s, "(7,6)"); s.clear(); ids_to_string(&mut s, &[7, 6, 5]); - assert_eq!(s, "(6,5,7)"); + assert_eq!(s, "(7,6,5)"); s.clear(); } } diff --git a/rslib/src/storage/note/add_if_unique.sql b/rslib/src/storage/note/add_if_unique.sql new file mode 100644 index 000000000..1dd408bbe --- /dev/null +++ b/rslib/src/storage/note/add_if_unique.sql @@ -0,0 +1,27 @@ +INSERT + OR IGNORE INTO notes ( + id, + guid, + mid, + mod, + usn, + tags, + flds, + sfld, + csum, + flags, + data + ) +VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + ?, + 0, + "" + ) \ No newline at end of file diff --git a/rslib/src/storage/note/mod.rs b/rslib/src/storage/note/mod.rs index 4c7c20bc8..72d88e4e0 100644 --- a/rslib/src/storage/note/mod.rs +++ b/rslib/src/storage/note/mod.rs @@ -1,12 +1,13 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use rusqlite::{params, Row}; use crate::{ error::Result, + import_export::package::NoteMeta, notes::{Note, NoteId, NoteTags}, notetype::NotetypeId, tags::{join_tags, split_tags}, @@ -41,6 +42,13 @@ impl super::SqliteStorage { .transpose() } + pub fn get_all_note_ids(&self) -> Result> { + self.db + .prepare("SELECT id FROM notes")? + .query_and_then([], |row| Ok(row.get(0)?))? + .collect() + } + /// If fields have been modified, caller must call note.prepare_for_update() prior to calling this. pub(crate) fn update_note(&self, note: &Note) -> Result<()> { assert!(note.id.0 != 0); @@ -77,6 +85,24 @@ impl super::SqliteStorage { Ok(()) } + pub(crate) fn add_note_if_unique(&self, note: &Note) -> Result { + self.db + .prepare_cached(include_str!("add_if_unique.sql"))? + .execute(params![ + note.id, + note.guid, + note.notetype_id, + note.mtime, + note.usn, + join_tags(¬e.tags), + join_fields(note.fields()), + note.sort_field.as_ref().unwrap(), + note.checksum.unwrap(), + ]) + .map(|added| added == 1) + .map_err(Into::into) + } + /// Add or update the provided note, preserving ID. Used by the syncing code. pub(crate) fn add_or_update_note(&self, note: &Note) -> Result<()> { let mut stmt = self.db.prepare_cached(include_str!("add_or_update.sql"))?; @@ -210,6 +236,16 @@ impl super::SqliteStorage { Ok(()) } + pub(crate) fn all_searched_notes(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get.sql"), + " WHERE id IN (SELECT nid FROM search_nids)" + ))? + .query_and_then([], |r| row_to_note(r).map_err(Into::into))? + .collect() + } + pub(crate) fn get_note_tags_by_predicate(&mut self, want: F) -> Result> where F: Fn(&str) -> bool, @@ -259,6 +295,24 @@ impl super::SqliteStorage { Ok(()) } + + pub(crate) fn note_guid_map(&mut self) -> Result> { + self.db + .prepare("SELECT guid, id, mod, mid FROM notes")? + .query_and_then([], row_to_note_meta)? + .collect() + } + + #[cfg(test)] + pub(crate) fn get_all_notes(&mut self) -> Vec { + self.db + .prepare("SELECT * FROM notes") + .unwrap() + .query_and_then([], row_to_note) + .unwrap() + .collect::>() + .unwrap() + } } fn row_to_note(row: &Row) -> Result { @@ -285,3 +339,10 @@ fn row_to_note_tags(row: &Row) -> Result { tags: row.get(3)?, }) } + +fn row_to_note_meta(row: &Row) -> Result<(String, NoteMeta)> { + Ok(( + row.get(0)?, + NoteMeta::new(row.get(1)?, row.get(2)?, row.get(3)?), + )) +} diff --git a/rslib/src/storage/notetype/mod.rs b/rslib/src/storage/notetype/mod.rs index 94b551a0a..3ada340a8 100644 --- a/rslib/src/storage/notetype/mod.rs +++ b/rslib/src/storage/notetype/mod.rs @@ -99,6 +99,23 @@ impl SqliteStorage { .map_err(Into::into) } + pub(crate) fn get_notetypes_for_search_notes(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get_notetype.sql"), + " WHERE id IN (SELECT DISTINCT mid FROM notes WHERE id IN", + " (SELECT nid FROM search_nids))", + ))? + .query_and_then([], |r| { + row_to_notetype_core(r).and_then(|mut nt| { + nt.fields = self.get_notetype_fields(nt.id)?; + nt.templates = self.get_notetype_templates(nt.id)?; + Ok(nt) + }) + })? + .collect() + } + pub fn get_all_notetype_names(&self) -> Result> { self.db .prepare_cached(include_str!("get_notetype_names.sql"))? diff --git a/rslib/src/storage/revlog/mod.rs b/rslib/src/storage/revlog/mod.rs index 4284b25be..d18d321e4 100644 --- a/rslib/src/storage/revlog/mod.rs +++ b/rslib/src/storage/revlog/mod.rs @@ -61,16 +61,19 @@ impl SqliteStorage { Ok(()) } - /// Returns the used id, which may differ if `ensure_unique` is true. + /// Adds the entry, if its id is unique. If it is not, and `uniquify` is true, + /// adds it with a new id. Returns the added id. + /// (I.e., the option is safe to unwrap, if `uniquify` is true.) pub(crate) fn add_revlog_entry( &self, entry: &RevlogEntry, - ensure_unique: bool, - ) -> Result { - self.db + uniquify: bool, + ) -> Result> { + let added = self + .db .prepare_cached(include_str!("add.sql"))? .execute(params![ - ensure_unique, + uniquify, entry.id, entry.cid, entry.usn, @@ -81,7 +84,7 @@ impl SqliteStorage { entry.taken_millis, entry.review_kind as u8 ])?; - Ok(RevlogId(self.db.last_insert_rowid())) + Ok((added > 0).then(|| RevlogId(self.db.last_insert_rowid()))) } pub(crate) fn get_revlog_entry(&self, id: RevlogId) -> Result> { @@ -107,7 +110,7 @@ impl SqliteStorage { .collect() } - pub(crate) fn get_revlog_entries_for_searched_cards( + pub(crate) fn get_pb_revlog_entries_for_searched_cards( &self, after: TimestampSecs, ) -> Result> { @@ -120,6 +123,16 @@ impl SqliteStorage { .collect() } + pub(crate) fn get_revlog_entries_for_searched_cards(&self) -> Result> { + self.db + .prepare_cached(concat!( + include_str!("get.sql"), + " where cid in (select cid from search_cids)" + ))? + .query_and_then([], row_to_revlog_entry)? + .collect() + } + /// This includes entries from deleted cards. pub(crate) fn get_all_revlog_entries( &self, diff --git a/rslib/src/tests.rs b/rslib/src/tests.rs new file mode 100644 index 000000000..1880fb059 --- /dev/null +++ b/rslib/src/tests.rs @@ -0,0 +1,63 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +#![cfg(test)] + +use tempfile::{tempdir, TempDir}; + +use crate::{collection::CollectionBuilder, media::MediaManager, prelude::*}; + +pub(crate) fn open_fs_test_collection(name: &str) -> (Collection, TempDir) { + let tempdir = tempdir().unwrap(); + let dir = tempdir.path(); + let media_folder = dir.join(format!("{name}.media")); + std::fs::create_dir(&media_folder).unwrap(); + let col = CollectionBuilder::new(dir.join(format!("{name}.anki2"))) + .set_media_paths(media_folder, dir.join(format!("{name}.mdb"))) + .build() + .unwrap(); + (col, tempdir) +} + +impl Collection { + pub(crate) fn add_media(&self, media: &[(&str, &[u8])]) { + let mgr = MediaManager::new(&self.media_folder, &self.media_db).unwrap(); + let mut ctx = mgr.dbctx(); + for (name, data) in media { + mgr.add_file(&mut ctx, name, data).unwrap(); + } + } + + pub(crate) fn new_note(&mut self, notetype: &str) -> Note { + self.get_notetype_by_name(notetype) + .unwrap() + .unwrap() + .new_note() + } + + pub(crate) fn add_new_note(&mut self, notetype: &str) -> Note { + let mut note = self.new_note(notetype); + self.add_note(&mut note, DeckId(1)).unwrap(); + note + } + + pub(crate) fn get_all_notes(&mut self) -> Vec { + self.storage.get_all_notes() + } + + pub(crate) fn add_deck_with_machine_name(&mut self, name: &str, filtered: bool) -> Deck { + let mut deck = new_deck_with_machine_name(name, filtered); + self.add_deck_inner(&mut deck, Usn(1)).unwrap(); + deck + } +} + +pub(crate) fn new_deck_with_machine_name(name: &str, filtered: bool) -> Deck { + let mut deck = if filtered { + Deck::new_filtered() + } else { + Deck::new_normal() + }; + deck.name = NativeDeckName::from_native_str(name); + deck +} diff --git a/rslib/src/text.rs b/rslib/src/text.rs index c285f87f4..d8d1d0df8 100644 --- a/rslib/src/text.rs +++ b/rslib/src/text.rs @@ -31,6 +31,31 @@ impl Trimming for Cow<'_, str> { } } +pub(crate) trait CowMapping<'a, B: ?Sized + 'a + ToOwned> { + /// Returns [self] + /// - unchanged, if the given function returns [Cow::Borrowed] + /// - with the new value, if the given function returns [Cow::Owned] + fn map_cow(self, f: impl FnOnce(&B) -> Cow) -> Self; + fn get_owned(self) -> Option; +} + +impl<'a, B: ?Sized + 'a + ToOwned> CowMapping<'a, B> for Cow<'a, B> { + fn map_cow(self, f: impl FnOnce(&B) -> Cow) -> Self { + if let Cow::Owned(o) = f(&self) { + Cow::Owned(o) + } else { + self + } + } + + fn get_owned(self) -> Option { + match self { + Cow::Borrowed(_) => None, + Cow::Owned(s) => Some(s), + } + } +} + #[derive(Debug, PartialEq)] pub enum AvTag { SoundOrVideo(String), @@ -115,34 +140,48 @@ lazy_static! { | \[\[type:[^]]+\]\] ").unwrap(); + + /// Files included in CSS with a leading underscore. + static ref UNDERSCORED_CSS_IMPORTS: Regex = Regex::new( + r#"(?xi) + (?:@import\s+ # import statement with a bare + "(_[^"]*.css)" # double quoted + | # or + '(_[^']*.css)' # single quoted css filename + ) + | # or + (?:url\(\s* # a url function with a + "(_[^"]+)" # double quoted + | # or + '(_[^']+)' # single quoted + | # or + (_.+) # unquoted filename + \s*\)) + "#).unwrap(); + + /// Strings, src and data attributes with a leading underscore. + static ref UNDERSCORED_REFERENCES: Regex = Regex::new( + r#"(?x) + "(_[^"]+)" # double quoted + | # or + '(_[^']+)' # single quoted string + | # or + \b(?:src|data) # a 'src' or 'data' attribute + = # followed by + (_[^ >]+) # an unquoted value + "#).unwrap(); } pub fn html_to_text_line(html: &str) -> Cow { - let mut out: Cow = html.into(); - if let Cow::Owned(o) = PERSISTENT_HTML_SPACERS.replace_all(&out, " ") { - out = o.into(); - } - if let Cow::Owned(o) = UNPRINTABLE_TAGS.replace_all(&out, "") { - out = o.into(); - } - if let Cow::Owned(o) = strip_html_preserving_media_filenames(&out) { - out = o.into(); - } - out.trim() + PERSISTENT_HTML_SPACERS + .replace_all(html, " ") + .map_cow(|s| UNPRINTABLE_TAGS.replace_all(s, "")) + .map_cow(strip_html_preserving_media_filenames) + .trim() } pub fn strip_html(html: &str) -> Cow { - let mut out: Cow = html.into(); - - if let Cow::Owned(o) = strip_html_preserving_entities(html) { - out = o.into(); - } - - if let Cow::Owned(o) = decode_entities(out.as_ref()) { - out = o.into(); - } - - out + strip_html_preserving_entities(html).map_cow(decode_entities) } pub fn strip_html_preserving_entities(html: &str) -> Cow { @@ -161,18 +200,29 @@ pub fn decode_entities(html: &str) -> Cow { } } +pub(crate) fn newlines_to_spaces(text: &str) -> Cow { + if text.contains('\n') { + text.replace('\n', " ").into() + } else { + text.into() + } +} + pub fn strip_html_for_tts(html: &str) -> Cow { - let mut out: Cow = html.into(); + HTML_LINEBREAK_TAGS + .replace_all(html, " ") + .map_cow(strip_html) +} - if let Cow::Owned(o) = HTML_LINEBREAK_TAGS.replace_all(html, " ") { - out = o.into(); +/// Truncate a String on a valid UTF8 boundary. +pub(crate) fn truncate_to_char_boundary(s: &mut String, mut max: usize) { + if max >= s.len() { + return; } - - if let Cow::Owned(o) = strip_html(out.as_ref()) { - out = o.into(); + while !s.is_char_boundary(max) { + max -= 1; } - - out + s.truncate(max); } #[derive(Debug)] @@ -216,6 +266,61 @@ pub(crate) fn extract_media_refs(text: &str) -> Vec { out } +/// Calls `replacer` for every media reference in `text`, and optionally +/// replaces it with something else. [None] if no reference was found. +pub(crate) fn replace_media_refs( + text: &str, + mut replacer: impl FnMut(&str) -> Option, +) -> Option { + let mut rep = |caps: &Captures| { + let whole_match = caps.get(0).unwrap().as_str(); + let old_name = caps.iter().skip(1).find_map(|g| g).unwrap().as_str(); + let old_name_decoded = decode_entities(old_name); + + if let Some(mut new_name) = replacer(&old_name_decoded) { + if matches!(old_name_decoded, Cow::Owned(_)) { + new_name = htmlescape::encode_minimal(&new_name); + } + whole_match.replace(old_name, &new_name) + } else { + whole_match.to_owned() + } + }; + + HTML_MEDIA_TAGS + .replace_all(text, &mut rep) + .map_cow(|s| AV_TAGS.replace_all(s, &mut rep)) + .get_owned() +} + +pub(crate) fn extract_underscored_css_imports(text: &str) -> Vec<&str> { + UNDERSCORED_CSS_IMPORTS + .captures_iter(text) + .map(|caps| { + caps.get(1) + .or_else(|| caps.get(2)) + .or_else(|| caps.get(3)) + .or_else(|| caps.get(4)) + .or_else(|| caps.get(5)) + .unwrap() + .as_str() + }) + .collect() +} + +pub(crate) fn extract_underscored_references(text: &str) -> Vec<&str> { + UNDERSCORED_REFERENCES + .captures_iter(text) + .map(|caps| { + caps.get(1) + .or_else(|| caps.get(2)) + .or_else(|| caps.get(3)) + .unwrap() + .as_str() + }) + .collect() +} + pub fn strip_html_preserving_media_filenames(html: &str) -> Cow { let without_fnames = HTML_MEDIA_TAGS.replace_all(html, r" ${1}${2}${3} "); let without_html = strip_html(&without_fnames); @@ -463,4 +568,54 @@ mod test { assert!(!is_glob(r"\\\_")); assert!(glob_matcher(r"foo\*bar*")("foo*bar123")); } + + #[test] + fn extracting() { + assert_eq!( + extract_underscored_css_imports(concat!( + "@IMPORT '_foo.css'\n", + "@import \"_bar.css\"\n", + "@import '_baz.css'\n", + "@import 'nope.css'\n", + "url(_foo.css)\n", + "URL(\"_bar.css\")\n", + "@import url('_baz.css')\n", + "url('nope.css')\n", + )), + vec!["_foo.css", "_bar.css", "_baz.css", "_foo.css", "_bar.css", "_baz.css",] + ); + assert_eq!( + extract_underscored_references(concat!( + "", + "", + "\"_baz.js\"", + "\"nope.js\"", + "", + "", + "'_baz.js'", + )), + vec!["_foo.jpg", "_bar", "_baz.js", "_foo.jpg", "_bar", "_baz.js",] + ); + } + + #[test] + fn replacing() { + assert_eq!( + &replace_media_refs("[sound:bar.mp3]", |s| { + (s != "baz.jpg").then(|| "spam".to_string()) + }) + .unwrap(), + "[sound:spam]", + ); + } + + #[test] + fn truncate() { + let mut s = "日本語".to_string(); + truncate_to_char_boundary(&mut s, 6); + assert_eq!(&s, "日本"); + let mut s = "日本語".to_string(); + truncate_to_char_boundary(&mut s, 1); + assert_eq!(&s, ""); + } }