Add apkg import/export on backend (#1743)

* Add apkg export on backend

* Filter out missing media-paths at write time

* Make TagMatcher::new() infallible

* Gather export data instead of copying directly

* Revert changes to rslib/src/tags/

* Reuse filename_is_safe/check_filename_safe()

* Accept func to produce MediaIter in export_apkg()

* Only store file folder once in MediaIter

* Use temporary tables for gathering

export_apkg() now accepts a search instead of a deck id. Decks are
gathered according to the matched notes' cards.

* Use schedule_as_new() to reset cards

* ExportData → ExchangeData

* Ignore ascii case when filtering system tags

* search_notes_cards_into_table →

search_cards_of_notes_into_table

* Start on apkg importing on backend

* Fix due dates in days for apkg export

* Refactor import-export/package

- Move media and meta code into appropriate modules.
- Normalize/check for normalization when deserializing media entries.

* Add SafeMediaEntry for deserialized MediaEntries

* Prepare media based on checksums

- Ensure all existing media files are hashed.
- Hash incoming files during preparation to detect conflicts.
- Uniquify names of conflicting files with hash (not notetype id).
- Mark media files as used while importing notes.
- Finally copy used media.

* Handle encoding in `replace_media_refs()`

* Add trait to keep down cow boilerplate

* Add notetypes immediately instaed of preparing

* Move target_col into Context

* Add notes immediately instaed of preparing

* Note id, not guid of conflicting notes

* Add import_decks()

* decks_configs → deck_configs

* Add import_deck_configs()

* Add import_cards(), import_revlog()

* Use dyn instead of generic for media_fn

Otherwise, would have to pass None with type annotation in the default
case.

* Fix signature of import_apkg()

* Fix search_cards_of_notes_into_table()

* Test new functions in text.rs

* Add roundtrip test for apkg (stub)

* Keep source id of imported cards (or skip)

* Keep source ids of imported revlog (or skip)

* Try to keep source ids of imported notes

* Make adding notetype with id undoable

* Wrap apkg import in transaction

* Keep source ids of imported deck configs (or skip)

* Handle card due dates and original due/did

* Fix importing cards/revlog

Card ids are manually uniquified.

* Factor out card importing

* Refactor card and revlog importing

* Factor out card importing

Also handle missing parents .

* Factor out note importing

* Factor out media importing

* Maybe upgrade scheduler of apkg

* Fix parent deck gathering

* Unconditionally import static media

* Fix deck importing edge cases

Test those edge cases, and add some global test helpers.

* Test note importing

* Let import_apkg() take a progress func

* Expand roundtrip apkg test

* Use fat pointer to avoid propogating generics

* Fix progress_fn type

* Expose apkg export/import on backend

* Return note log when importing apkg

* Fix archived collection name on apkg import

* Add CollectionOpWithBackendProgress

* Fix wrong Interrupted Exception being checked

* Add ClosedCollectionOp

* Add note ids to log and strip HTML

* Update progress when checking incoming media too

* Conditionally enable new importing in GUI

* Fix all_checksums() for media import

Entries of deleted files are nulled, not removed.

* Make apkg exporting on backend abortable

* Return number of notes imported from apkg

* Fix exception printing for QueryOp as well

* Add QueryOpWithBackendProgress

Also support backend exporting progress.

* Expose new apkg and colpkg exporting

* Open transaction in insert_data()

Was slowing down exporting by several orders of magnitude.

* Handle zstd-compressed apkg

* Add legacy arg to ExportAnkiPackage

Currently not exposed on the frontend

* Remove unused import in proto file

* Add symlink for typechecking of import_export_pb2

* Avoid kwargs in pb message creation, so typechecking is not lost

Protobuf's behaviour is rather subtle and I had to dig through the docs
to figure it out: set a field on a submessage to automatically assign 
the submessage to the parent, or call SetInParent() to persist a default
version of the field you specified.

* Avoid re-exporting protobuf msgs we only use internally

* Stop after one test failure

mypy often fails much faster than pylint

* Avoid an extra allocation when extracting media checksums

* Update progress after prepare_media() finishes

Otherwise the bulk of the import ends up being shown as "Checked: 0"
in the progress window.

* Show progress of note imports

Note import is the slowest part, so showing progress here makes the UI
feel more responsive.

* Reset filtered decks at import time

Before this change, filtered decks exported with scheduling remained
filtered on import, and maybe_remove_from_filtered_deck() moved cards
into them as their home deck, leading to errors during review.

We may still want to provide a way to preserve filtered decks on import,
but to do that we'll need to ensure we don't rewrite the home decks of
cards, and we'll need to ensure the home decks are included as part of
the import (or give an error if they're not).

https://github.com/ankitects/anki/pull/1743/files#r839346423

* Fix a corner-case where due dates were shifted by a day

This issue existed in the old Python code as well. We need to include
the user's UTC offset in the exported file, or days_elapsed falls back
on the v1 cutoff calculation, which may be a day earlier or later than
the v2 calculation.

* Log conflicting note in remapped nt case

* take_fields() → into_fields()

* Alias `[u8; 20]` with `Sha1Hash`

* Truncate logged fields

* Rework apkg note import tests

- Use macros for more helpful errors.
- Split monolith into unit tests.
- Fix some unknown error with the previous test along the way.
(Was failing after 969484de4388d225c9f17d94534b3ba0094c3568.)

* Fix sorting of imported decks

Also adjust the test, so it fails without the patch. It was only passing
before, because the parent deck happened to come before the
inconsistently capitalised child alphabetically. But we want all parent
decks to be imported before their child decks, so their children can
adopt their capitalisation.

* target[_id]s → existing_card[_id]s

* export_collection_extracting_media() → ...

export_into_collection_file()

* target_already_exists→card_ordinal_already_exists

* Add search_cards_of_notes_into_table.sql

* Imrove type of apkg export selector/limit

* Remove redundant call to mod_schema()

* Parent tooltips to mw

* Fix a crash when truncating note text

String::truncate() is a bit of a footgun, and I've hit this before
too :-)

* Remove ExportLimit in favour of separate classes

* Remove OpWithBackendProgress and ClosedCollectionOp

Backend progress logic is now in ProgressManager. QueryOp can be used
for running on closed collection.

Also fix aborting of colpkg exports, which slipped through in #1817.

* Tidy up import log

* Avoid QDialog.exec()

* Default to excluding scheuling for deck list deck

* Use IncrementalProgress in whole import_export code

* Compare checksums when importing colpkgs

* Avoid registering changes if hashes are not needed

* ImportProgress::Collection → ImportProgress::File

* Make downgrading apkgs depend on meta version

* Generalise IncrementableProgress

And use it in entire import_export code instead.

* Fix type complexity lint

* Take count_map for IncrementableProgress::get_inner

* Replace import/export env with Shift click

* Accept all args from update() for backend progress

* Pass fields of ProgressUpdate explicitly

* Move update_interval into IncrementableProgress

* Outsource incrementing into Incrementor

* Mutate ProgressUpdate in progress_update callback

* Switch import/export legacy toggle to profile setting

Shift would have been nice, but the existing shortcuts complicate things.
If the user triggers an import with ctrl+shift+i, shift is unlikely to
have been released by the time our code runs, meaning the user accidentally
triggers the new code. We could potentially wait a while before bringing
up the dialog, but then we're forced to guess at how long it will take the
user to release the key.

One alternative would be to use alt instead of shift, but then we need to
trigger our shortcut when that key is pressed as well, and it could
potentially cause a conflict with an add-on that already uses that
combination.

* Show extension in export dialog

* Continue to provide separate options for schema 11+18 colpkg export

* Default to colpkg export when using File>Export

* Improve appearance of combo boxes when switching between apkg/colpkg

+ Deal with long deck names

* Convert newlines to spaces when showing fields from import

Ensures each imported note appears on a separate line

* Don't separate total note count from the other summary lines

This may come down to personal preference, but I feel the other counts
are equally as important, and separating them feels like it makes it
a bit easier to ignore them.

* Fix 'deck not normal' error when importing a filtered deck for the 2nd time

* Fix [Identical] being shown on first import

* Revert "Continue to provide separate options for schema 11+18 colpkg export"

This reverts commit 8f0b2c175f.

Will use a different approach

* Move legacy support into a separate exporter option; add to apkg export

* Adjust 'too new' message to also apply to .apkg import case

* Show a better message when attempting to import new apkg into old code

Previously the user could end seeing a message like:

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xb5 in position 1: invalid start byte

Unfortunately we can't retroactively fix this for older clients.

* Hide legacy support option in older exporting screen

* Reflect change from paths to fnames in type & name

* Make imported decks normal at once

Then skip special casing in update_deck(). Also skip updating
description if new one is empty.

Co-authored-by: Damien Elmes <gpg@ankiweb.net>
This commit is contained in:
RumovZ 2022-05-02 13:12:46 +02:00 committed by GitHub
parent d2ec004ac3
commit 5f9451f547
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
74 changed files with 3617 additions and 449 deletions

View file

@ -20,6 +20,9 @@ test --aspects=@rules_rust//rust:defs.bzl%rust_clippy_aspect --output_groups=+cl
# print output when test fails
test --test_output=errors
# stop after one test failure
test --notest_keep_going
# don't add empty __init__.py files
build --incompatible_default_to_explicit_init_py

View file

@ -5,7 +5,7 @@ exporting-anki-deck-package = Anki Deck Package
exporting-cards-in-plain-text = Cards in Plain Text
exporting-collection = collection
exporting-collection-exported = Collection exported.
exporting-colpkg-too-new = Please update to the latest Anki version, then import the .colpkg file again.
exporting-colpkg-too-new = Please update to the latest Anki version, then import the .colpkg/.apkg file again.
exporting-couldnt-save-file = Couldn't save file: { $val }
exporting-export = Export...
exporting-export-format = <b>Export format</b>:
@ -14,6 +14,7 @@ exporting-include-html-and-media-references = Include HTML and media references
exporting-include-media = Include media
exporting-include-scheduling-information = Include scheduling information
exporting-include-tags = Include tags
exporting-support-older-anki-versions = Support older Anki versions (slower/larger files)
exporting-notes-in-plain-text = Notes in Plain Text
exporting-selected-notes = Selected Notes
exporting-card-exported =

View file

@ -78,4 +78,10 @@ importing-processed-media-file =
*[other] Imported { $count } media files
}
importing-importing-collection = Importing collection...
importing-importing-file = Importing file...
importing-failed-to-import-media-file = Failed to import media file: { $debugInfo }
importing-processed-notes =
{ $count ->
[one] Processed { $count } note...
*[other] Processed { $count } notes...
}

View file

@ -5,6 +5,8 @@ syntax = "proto3";
package anki.import_export;
import "anki/collection.proto";
import "anki/notes.proto";
import "anki/generic.proto";
service ImportExportService {
@ -12,12 +14,16 @@ service ImportExportService {
returns (generic.Empty);
rpc ExportCollectionPackage(ExportCollectionPackageRequest)
returns (generic.Empty);
rpc ImportAnkiPackage(ImportAnkiPackageRequest)
returns (ImportAnkiPackageResponse);
rpc ExportAnkiPackage(ExportAnkiPackageRequest) returns (generic.UInt32);
}
message ImportCollectionPackageRequest {
string col_path = 1;
string backup_path = 2;
string media_folder = 3;
string media_db = 4;
}
message ExportCollectionPackageRequest {
@ -26,6 +32,37 @@ message ExportCollectionPackageRequest {
bool legacy = 3;
}
message ImportAnkiPackageRequest {
string package_path = 1;
}
message ImportAnkiPackageResponse {
message Note {
notes.NoteId id = 1;
repeated string fields = 2;
}
message Log {
repeated Note new = 1;
repeated Note updated = 2;
repeated Note duplicate = 3;
repeated Note conflicting = 4;
}
collection.OpChanges changes = 1;
Log log = 2;
}
message ExportAnkiPackageRequest {
string out_path = 1;
bool with_scheduling = 2;
bool with_media = 3;
bool legacy = 4;
oneof selector {
generic.Empty whole_collection = 5;
int64 deck_id = 6;
notes.NoteIds note_ids = 7;
}
}
message PackageMetadata {
enum Version {
VERSION_UNKNOWN = 0;

View file

@ -22,6 +22,7 @@ ignored-classes=
CustomStudyRequest,
Cram,
ScheduleCardsAsNewRequest,
ExportAnkiPackageRequest,
[REPORTS]
output-format=colorized

View file

@ -10,6 +10,7 @@ from anki import (
collection_pb2,
config_pb2,
generic_pb2,
import_export_pb2,
links_pb2,
search_pb2,
stats_pb2,
@ -32,6 +33,7 @@ OpChangesAfterUndo = collection_pb2.OpChangesAfterUndo
BrowserRow = search_pb2.BrowserRow
BrowserColumns = search_pb2.BrowserColumns
StripHtmlMode = card_rendering_pb2.StripHtmlRequest
ImportLogWithChanges = import_export_pb2.ImportAnkiPackageResponse
import copy
import os
@ -90,6 +92,19 @@ class LegacyCheckpoint:
LegacyUndoResult = Union[None, LegacyCheckpoint, LegacyReviewUndo]
@dataclass
class DeckIdLimit:
deck_id: DeckId
@dataclass
class NoteIdsLimit:
note_ids: Sequence[NoteId]
ExportLimit = Union[DeckIdLimit, NoteIdsLimit, None]
class Collection(DeprecatedNamesMixin):
sched: V1Scheduler | V2Scheduler | V3Scheduler
@ -259,14 +274,6 @@ class Collection(DeprecatedNamesMixin):
self._clear_caches()
self.db = None
def export_collection(
self, out_path: str, include_media: bool, legacy: bool
) -> None:
self.close_for_full_sync()
self._backend.export_collection_package(
out_path=out_path, include_media=include_media, legacy=legacy
)
def rollback(self) -> None:
self._clear_caches()
self.db.rollback()
@ -321,6 +328,15 @@ class Collection(DeprecatedNamesMixin):
else:
return -1
def legacy_checkpoint_pending(self) -> bool:
return (
self._have_outstanding_checkpoint()
and time.time() - self._last_checkpoint_at < 300
)
# Import/export
##########################################################################
def create_backup(
self,
*,
@ -353,12 +369,40 @@ class Collection(DeprecatedNamesMixin):
"Throws if backup creation failed."
self._backend.await_backup_completion()
def legacy_checkpoint_pending(self) -> bool:
return (
self._have_outstanding_checkpoint()
and time.time() - self._last_checkpoint_at < 300
def export_collection_package(
self, out_path: str, include_media: bool, legacy: bool
) -> None:
self.close_for_full_sync()
self._backend.export_collection_package(
out_path=out_path, include_media=include_media, legacy=legacy
)
def import_anki_package(self, path: str) -> ImportLogWithChanges:
return self._backend.import_anki_package(package_path=path)
def export_anki_package(
self,
*,
out_path: str,
limit: ExportLimit,
with_scheduling: bool,
with_media: bool,
legacy_support: bool,
) -> int:
request = import_export_pb2.ExportAnkiPackageRequest(
out_path=out_path,
with_scheduling=with_scheduling,
with_media=with_media,
legacy=legacy_support,
)
if isinstance(limit, DeckIdLimit):
request.deck_id = limit.deck_id
elif isinstance(limit, NoteIdsLimit):
request.note_ids.note_ids.extend(limit.note_ids)
else:
request.whole_collection.SetInParent()
return self._backend.export_anki_package(request)
# Object helpers
##########################################################################

View file

@ -446,7 +446,7 @@ class AnkiCollectionPackageExporter(AnkiPackageExporter):
time.sleep(0.1)
threading.Thread(target=progress).start()
self.col.export_collection(path, self.includeMedia, self.LEGACY)
self.col.export_collection_package(path, self.includeMedia, self.LEGACY)
class AnkiCollectionPackage21bExporter(AnkiCollectionPackageExporter):

View file

@ -0,0 +1 @@
../../.bazel/bin/pylib/anki/import_export_pb2.pyi

View file

@ -25,6 +25,10 @@ class V2ImportIntoV1(Exception):
pass
class MediaMapInvalid(Exception):
pass
class Anki2Importer(Importer):
needMapper = False

View file

@ -9,7 +9,7 @@ import unicodedata
import zipfile
from typing import Any, Optional
from anki.importing.anki2 import Anki2Importer
from anki.importing.anki2 import Anki2Importer, MediaMapInvalid
from anki.utils import tmpfile
@ -36,7 +36,11 @@ class AnkiPackageImporter(Anki2Importer):
# number to use during the import
self.nameToNum = {}
dir = self.col.media.dir()
for k, v in list(json.loads(z.read("media").decode("utf8")).items()):
try:
media_dict = json.loads(z.read("media").decode("utf8"))
except Exception as exc:
raise MediaMapInvalid() from exc
for k, v in list(media_dict.items()):
path = os.path.abspath(os.path.join(dir, v))
if os.path.commonprefix([path, dir]) != dir:
raise Exception("Invalid file")

View file

@ -23,7 +23,8 @@ from anki.tags import MARKED_TAG
from anki.utils import is_mac
from aqt import AnkiQt, gui_hooks
from aqt.editor import Editor
from aqt.exporting import ExportDialog
from aqt.exporting import ExportDialog as LegacyExportDialog
from aqt.import_export.exporting import ExportDialog
from aqt.operations.card import set_card_deck, set_card_flag
from aqt.operations.collection import redo, undo
from aqt.operations.note import remove_notes
@ -792,8 +793,12 @@ class Browser(QMainWindow):
@no_arg_trigger
@skip_if_selection_is_empty
def _on_export_notes(self) -> None:
cids = self.selectedNotesAsCards()
ExportDialog(self.mw, cids=list(cids))
if self.mw.pm.new_import_export():
nids = self.selected_notes()
ExportDialog(self.mw, nids=nids)
else:
cids = self.selectedNotesAsCards()
LegacyExportDialog(self.mw, cids=list(cids))
# Flags & Marking
######################################################################

View file

@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Optional, TextIO, cast
from markdown import markdown
import aqt
from anki.errors import DocumentedError, LocalizedError
from anki.errors import DocumentedError, Interrupted, LocalizedError
from aqt.qt import *
from aqt.utils import showText, showWarning, supportText, tr
@ -22,7 +22,7 @@ if TYPE_CHECKING:
def show_exception(*, parent: QWidget, exception: Exception) -> None:
"Present a caught exception to the user using a pop-up."
if isinstance(exception, InterruptedError):
if isinstance(exception, Interrupted):
# nothing to do
return
help_page = exception.help_page if isinstance(exception, DocumentedError) else None

View file

@ -40,6 +40,7 @@ class ExportDialog(QDialog):
self.col = mw.col.weakref()
self.frm = aqt.forms.exporting.Ui_ExportDialog()
self.frm.setupUi(self)
self.frm.legacy_support.setVisible(False)
self.exporter: Exporter | None = None
self.cids = cids
disable_help_button(self)

View file

@ -6,8 +6,8 @@
<rect>
<x>0</x>
<y>0</y>
<width>295</width>
<height>223</height>
<width>563</width>
<height>245</height>
</rect>
</property>
<property name="windowTitle">
@ -30,7 +30,14 @@
</widget>
</item>
<item row="0" column="1">
<widget class="QComboBox" name="format"/>
<widget class="QComboBox" name="format">
<property name="sizePolicy">
<sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_2">
@ -40,7 +47,11 @@
</widget>
</item>
<item row="1" column="1">
<widget class="QComboBox" name="deck"/>
<widget class="QComboBox" name="deck">
<property name="minimumContentsLength">
<number>50</number>
</property>
</widget>
</item>
</layout>
</item>
@ -83,6 +94,16 @@
</property>
</widget>
</item>
<item>
<widget class="QCheckBox" name="legacy_support">
<property name="text">
<string>exporting_support_older_anki_versions</string>
</property>
<property name="checked">
<bool>true</bool>
</property>
</widget>
</item>
</layout>
</item>
<item>

View file

View file

@ -0,0 +1,252 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from __future__ import annotations
import os
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Sequence, Type
import aqt.forms
import aqt.main
from anki.collection import DeckIdLimit, ExportLimit, NoteIdsLimit, Progress
from anki.decks import DeckId, DeckNameId
from anki.notes import NoteId
from aqt import gui_hooks
from aqt.errors import show_exception
from aqt.operations import QueryOp
from aqt.progress import ProgressUpdate
from aqt.qt import *
from aqt.utils import (
checkInvalidFilename,
disable_help_button,
getSaveFile,
showWarning,
tooltip,
tr,
)
class ExportDialog(QDialog):
def __init__(
self,
mw: aqt.main.AnkiQt,
did: DeckId | None = None,
nids: Sequence[NoteId] | None = None,
):
QDialog.__init__(self, mw, Qt.WindowType.Window)
self.mw = mw
self.col = mw.col.weakref()
self.frm = aqt.forms.exporting.Ui_ExportDialog()
self.frm.setupUi(self)
self.exporter: Type[Exporter] = None
self.nids = nids
disable_help_button(self)
self.setup(did)
self.open()
def setup(self, did: DeckId | None) -> None:
self.exporters: list[Type[Exporter]] = [ApkgExporter, ColpkgExporter]
self.frm.format.insertItems(
0, [f"{e.name()} (.{e.extension})" for e in self.exporters]
)
qconnect(self.frm.format.activated, self.exporter_changed)
if self.nids is None and not did:
# file>export defaults to colpkg
default_exporter_idx = 1
else:
default_exporter_idx = 0
self.frm.format.setCurrentIndex(default_exporter_idx)
self.exporter_changed(default_exporter_idx)
# deck list
if self.nids is None:
self.all_decks = self.col.decks.all_names_and_ids()
decks = [tr.exporting_all_decks()]
decks.extend(d.name for d in self.all_decks)
else:
decks = [tr.exporting_selected_notes()]
self.frm.deck.addItems(decks)
# save button
b = QPushButton(tr.exporting_export())
self.frm.buttonBox.addButton(b, QDialogButtonBox.ButtonRole.AcceptRole)
# set default option if accessed through deck button
if did:
name = self.mw.col.decks.get(did)["name"]
index = self.frm.deck.findText(name)
self.frm.deck.setCurrentIndex(index)
self.frm.includeSched.setChecked(False)
def exporter_changed(self, idx: int) -> None:
self.exporter = self.exporters[idx]
self.frm.includeSched.setVisible(self.exporter.show_include_scheduling)
self.frm.includeMedia.setVisible(self.exporter.show_include_media)
self.frm.includeTags.setVisible(self.exporter.show_include_tags)
self.frm.includeHTML.setVisible(self.exporter.show_include_html)
self.frm.legacy_support.setVisible(self.exporter.show_legacy_support)
self.frm.deck.setVisible(self.exporter.show_deck_list)
def accept(self) -> None:
if not (out_path := self.get_out_path()):
return
self.exporter.export(self.mw, self.options(out_path))
QDialog.reject(self)
def get_out_path(self) -> str | None:
filename = self.filename()
while True:
path = getSaveFile(
parent=self,
title=tr.actions_export(),
dir_description="export",
key=self.exporter.name(),
ext=self.exporter.extension,
fname=filename,
)
if not path:
return None
if checkInvalidFilename(os.path.basename(path), dirsep=False):
continue
path = os.path.normpath(path)
if os.path.commonprefix([self.mw.pm.base, path]) == self.mw.pm.base:
showWarning("Please choose a different export location.")
continue
break
return path
def options(self, out_path: str) -> Options:
limit: ExportLimit = None
if self.nids:
limit = NoteIdsLimit(self.nids)
elif current_deck_id := self.current_deck_id():
limit = DeckIdLimit(current_deck_id)
return Options(
out_path=out_path,
include_scheduling=self.frm.includeSched.isChecked(),
include_media=self.frm.includeMedia.isChecked(),
include_tags=self.frm.includeTags.isChecked(),
include_html=self.frm.includeHTML.isChecked(),
legacy_support=self.frm.legacy_support.isChecked(),
limit=limit,
)
def current_deck_id(self) -> DeckId | None:
return (deck := self.current_deck()) and DeckId(deck.id) or None
def current_deck(self) -> DeckNameId | None:
if self.exporter.show_deck_list:
if idx := self.frm.deck.currentIndex():
return self.all_decks[idx - 1]
return None
def filename(self) -> str:
if self.exporter.show_deck_list:
deck_name = self.frm.deck.currentText()
stem = re.sub('[\\\\/?<>:*|"^]', "_", deck_name)
else:
time_str = time.strftime("%Y-%m-%d@%H-%M-%S", time.localtime(time.time()))
stem = f"{tr.exporting_collection()}-{time_str}"
return f"{stem}.{self.exporter.extension}"
@dataclass
class Options:
out_path: str
include_scheduling: bool
include_media: bool
include_tags: bool
include_html: bool
legacy_support: bool
limit: ExportLimit
class Exporter(ABC):
extension: str
show_deck_list = False
show_include_scheduling = False
show_include_media = False
show_include_tags = False
show_include_html = False
show_legacy_support = False
@staticmethod
@abstractmethod
def export(mw: aqt.main.AnkiQt, options: Options) -> None:
pass
@staticmethod
@abstractmethod
def name() -> str:
pass
class ColpkgExporter(Exporter):
extension = "colpkg"
show_include_media = True
show_legacy_support = True
@staticmethod
def name() -> str:
return tr.exporting_anki_collection_package()
@staticmethod
def export(mw: aqt.main.AnkiQt, options: Options) -> None:
def on_success(_: None) -> None:
mw.reopen()
tooltip(tr.exporting_collection_exported(), parent=mw)
def on_failure(exception: Exception) -> None:
mw.reopen()
show_exception(parent=mw, exception=exception)
gui_hooks.collection_will_temporarily_close(mw.col)
QueryOp(
parent=mw,
op=lambda col: col.export_collection_package(
options.out_path,
include_media=options.include_media,
legacy=options.legacy_support,
),
success=on_success,
).with_backend_progress(export_progress_update).failure(
on_failure
).run_in_background()
class ApkgExporter(Exporter):
extension = "apkg"
show_deck_list = True
show_include_scheduling = True
show_include_media = True
show_legacy_support = True
@staticmethod
def name() -> str:
return tr.exporting_anki_deck_package()
@staticmethod
def export(mw: aqt.main.AnkiQt, options: Options) -> None:
QueryOp(
parent=mw,
op=lambda col: col.export_anki_package(
out_path=options.out_path,
limit=options.limit,
with_scheduling=options.include_scheduling,
with_media=options.include_media,
legacy_support=options.legacy_support,
),
success=lambda count: tooltip(
tr.exporting_note_exported(count=count), parent=mw
),
).with_backend_progress(export_progress_update).run_in_background()
def export_progress_update(progress: Progress, update: ProgressUpdate) -> None:
if not progress.HasField("exporting"):
return None
update.label = tr.exporting_exported_media_file(count=progress.exporting)
if update.user_wants_abort:
update.abort = True

View file

@ -0,0 +1,150 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from __future__ import annotations
from itertools import chain
import aqt.main
from anki.collection import Collection, ImportLogWithChanges, Progress
from anki.errors import Interrupted
from aqt.operations import CollectionOp, QueryOp
from aqt.progress import ProgressUpdate
from aqt.qt import *
from aqt.utils import askUser, getFile, showInfo, showText, showWarning, tooltip, tr
def import_file(mw: aqt.main.AnkiQt) -> None:
if not (path := get_file_path(mw)):
return
filename = os.path.basename(path).lower()
if filename.endswith(".anki"):
showInfo(tr.importing_anki_files_are_from_a_very())
elif filename.endswith(".anki2"):
showInfo(tr.importing_anki2_files_are_not_directly_importable())
elif is_collection_package(filename):
maybe_import_collection_package(mw, path)
elif filename.endswith(".apkg"):
import_anki_package(mw, path)
else:
raise NotImplementedError
def get_file_path(mw: aqt.main.AnkiQt) -> str | None:
if file := getFile(
mw,
tr.actions_import(),
None,
key="import",
filter=tr.importing_packaged_anki_deckcollection_apkg_colpkg_zip(),
):
return str(file)
return None
def is_collection_package(filename: str) -> bool:
return (
filename == "collection.apkg"
or (filename.startswith("backup-") and filename.endswith(".apkg"))
or filename.endswith(".colpkg")
)
def maybe_import_collection_package(mw: aqt.main.AnkiQt, path: str) -> None:
if askUser(
tr.importing_this_will_delete_your_existing_collection(),
msgfunc=QMessageBox.warning,
defaultno=True,
):
import_collection_package(mw, path)
def import_collection_package(mw: aqt.main.AnkiQt, file: str) -> None:
def on_success() -> None:
mw.loadCollection()
tooltip(tr.importing_importing_complete())
def on_failure(err: Exception) -> None:
mw.loadCollection()
if not isinstance(err, Interrupted):
showWarning(str(err))
QueryOp(
parent=mw,
op=lambda _: mw.create_backup_now(),
success=lambda _: mw.unloadCollection(
lambda: import_collection_package_op(mw, file, on_success)
.failure(on_failure)
.run_in_background()
),
).with_progress().run_in_background()
def import_collection_package_op(
mw: aqt.main.AnkiQt, path: str, success: Callable[[], None]
) -> QueryOp[None]:
def op(_: Collection) -> None:
col_path = mw.pm.collectionPath()
media_folder = os.path.join(mw.pm.profileFolder(), "collection.media")
media_db = os.path.join(mw.pm.profileFolder(), "collection.media.db2")
mw.backend.import_collection_package(
col_path=col_path,
backup_path=path,
media_folder=media_folder,
media_db=media_db,
)
return QueryOp(parent=mw, op=op, success=lambda _: success()).with_backend_progress(
import_progress_update
)
def import_anki_package(mw: aqt.main.AnkiQt, path: str) -> None:
CollectionOp(
parent=mw,
op=lambda col: col.import_anki_package(path),
).with_backend_progress(import_progress_update).success(
show_import_log
).run_in_background()
def show_import_log(log_with_changes: ImportLogWithChanges) -> None:
showText(stringify_log(log_with_changes.log), plain_text_edit=True)
def stringify_log(log: ImportLogWithChanges.Log) -> str:
total = len(log.conflicting) + len(log.updated) + len(log.new) + len(log.duplicate)
return "\n".join(
chain(
(tr.importing_notes_found_in_file(val=total),),
(
template_string(val=len(row))
for (row, template_string) in (
(log.conflicting, tr.importing_notes_that_could_not_be_imported),
(log.updated, tr.importing_notes_updated_as_file_had_newer),
(log.new, tr.importing_notes_added_from_file),
(log.duplicate, tr.importing_notes_skipped_as_theyre_already_in),
)
if row
),
("",),
*(
[f"[{action}] {', '.join(note.fields)}" for note in rows]
for (rows, action) in (
(log.conflicting, tr.importing_skipped()),
(log.updated, tr.importing_updated()),
(log.new, tr.adding_added()),
(log.duplicate, tr.importing_identical()),
)
),
)
)
def import_progress_update(progress: Progress, update: ProgressUpdate) -> None:
if not progress.HasField("importing"):
return
update.label = progress.importing
if update.user_wants_abort:
update.abort = True

View file

@ -11,11 +11,10 @@ import anki.importing as importing
import aqt.deckchooser
import aqt.forms
import aqt.modelchooser
from anki.errors import Interrupted
from anki.importing.anki2 import V2ImportIntoV1
from anki.importing.anki2 import MediaMapInvalid, V2ImportIntoV1
from anki.importing.apkg import AnkiPackageImporter
from aqt.import_export.importing import import_collection_package
from aqt.main import AnkiQt, gui_hooks
from aqt.operations import QueryOp
from aqt.qt import *
from aqt.utils import (
HelpPage,
@ -389,6 +388,10 @@ def importFile(mw: AnkiQt, file: str) -> None:
future.result()
except zipfile.BadZipfile:
showWarning(invalidZipMsg())
except MediaMapInvalid:
showWarning(
"Unable to read file. It probably requires a newer version of Anki to import."
)
except V2ImportIntoV1:
showWarning(
"""\
@ -434,78 +437,11 @@ def setupApkgImport(mw: AnkiQt, importer: AnkiPackageImporter) -> bool:
if not full:
# adding
return True
if not askUser(
if askUser(
tr.importing_this_will_delete_your_existing_collection(),
msgfunc=QMessageBox.warning,
defaultno=True,
):
return False
import_collection_package(mw, importer.file)
full_apkg_import(mw, importer.file)
return False
def full_apkg_import(mw: AnkiQt, file: str) -> None:
def on_done(success: bool) -> None:
mw.loadCollection()
if success:
tooltip(tr.importing_importing_complete())
def after_backup(created: bool) -> None:
mw.unloadCollection(lambda: replace_with_apkg(mw, file, on_done))
QueryOp(
parent=mw, op=lambda _: mw.create_backup_now(), success=after_backup
).with_progress().run_in_background()
def replace_with_apkg(
mw: AnkiQt, filename: str, callback: Callable[[bool], None]
) -> None:
"""Tries to replace the provided collection with the provided backup,
then calls the callback. True if success.
"""
if not (dialog := mw.progress.start(immediate=True)):
print("No progress dialog during import; aborting will not work")
timer = QTimer()
timer.setSingleShot(False)
timer.setInterval(100)
def on_progress() -> None:
progress = mw.backend.latest_progress()
if not progress.HasField("importing"):
return
label = progress.importing
try:
if dialog.wantCancel:
mw.backend.set_wants_abort()
except AttributeError:
# dialog may not be active
pass
mw.taskman.run_on_main(lambda: mw.progress.update(label=label))
def do_import() -> None:
col_path = mw.pm.collectionPath()
media_folder = os.path.join(mw.pm.profileFolder(), "collection.media")
mw.backend.import_collection_package(
col_path=col_path, backup_path=filename, media_folder=media_folder
)
def on_done(future: Future) -> None:
mw.progress.finish()
timer.deleteLater()
try:
future.result()
except Exception as error:
if not isinstance(error, Interrupted):
showWarning(str(error))
callback(False)
else:
callback(True)
qconnect(timer.timeout, on_progress)
timer.start()
mw.taskman.run_in_background(do_import, on_done)

View file

@ -47,6 +47,8 @@ from aqt.addons import DownloadLogEntry, check_and_prompt_for_updates, show_log_
from aqt.dbcheck import check_db
from aqt.emptycards import show_empty_cards
from aqt.flags import FlagManager
from aqt.import_export.exporting import ExportDialog
from aqt.import_export.importing import import_collection_package_op, import_file
from aqt.legacy import install_pylib_legacy
from aqt.mediacheck import check_media_db
from aqt.mediasync import MediaSyncer
@ -402,15 +404,12 @@ class AnkiQt(QMainWindow):
)
def _openBackup(self, path: str) -> None:
def on_done(success: bool) -> None:
if success:
self.onOpenProfile(callback=lambda: self.col.mod_schema(check=False))
import aqt.importing
self.restoring_backup = True
showInfo(tr.qt_misc_automatic_syncing_and_backups_have_been())
aqt.importing.replace_with_apkg(self, path, on_done)
import_collection_package_op(
self, path, success=self.onOpenProfile
).run_in_background()
def _on_downgrade(self) -> None:
self.progress.start()
@ -1183,12 +1182,18 @@ title="{}" {}>{}</button>""".format(
def onImport(self) -> None:
import aqt.importing
aqt.importing.onImport(self)
if self.pm.new_import_export():
import_file(self)
else:
aqt.importing.onImport(self)
def onExport(self, did: DeckId | None = None) -> None:
import aqt.exporting
aqt.exporting.ExportDialog(self, did=did)
if self.pm.new_import_export():
ExportDialog(self, did=did)
else:
aqt.exporting.ExportDialog(self, did=did)
# Installing add-ons from CLI / mimetype handler
##########################################################################

View file

@ -8,16 +8,19 @@ from typing import Any, Callable, Generic, Protocol, TypeVar, Union
import aqt
import aqt.gui_hooks
import aqt.main
from anki.collection import (
Collection,
ImportLogWithChanges,
OpChanges,
OpChangesAfterUndo,
OpChangesWithCount,
OpChangesWithId,
Progress,
)
from aqt.errors import show_exception
from aqt.progress import ProgressUpdate
from aqt.qt import QWidget
from aqt.utils import showWarning
class HasChangesProperty(Protocol):
@ -34,6 +37,7 @@ ResultWithChanges = TypeVar(
OpChangesWithCount,
OpChangesWithId,
OpChangesAfterUndo,
ImportLogWithChanges,
HasChangesProperty,
],
)
@ -65,6 +69,7 @@ class CollectionOp(Generic[ResultWithChanges]):
_success: Callable[[ResultWithChanges], Any] | None = None
_failure: Callable[[Exception], Any] | None = None
_progress_update: Callable[[Progress, ProgressUpdate], None] | None = None
def __init__(self, parent: QWidget, op: Callable[[Collection], ResultWithChanges]):
self._parent = parent
@ -82,6 +87,12 @@ class CollectionOp(Generic[ResultWithChanges]):
self._failure = failure
return self
def with_backend_progress(
self, progress_update: Callable[[Progress, ProgressUpdate], None] | None
) -> CollectionOp[ResultWithChanges]:
self._progress_update = progress_update
return self
def run_in_background(self, *, initiator: object | None = None) -> None:
from aqt import mw
@ -113,12 +124,29 @@ class CollectionOp(Generic[ResultWithChanges]):
if self._success:
self._success(result)
finally:
mw.update_undo_actions()
mw.autosave()
# fire change hooks
self._fire_change_hooks_after_op_performed(result, initiator)
self._finish_op(mw, result, initiator)
mw.taskman.with_progress(wrapped_op, wrapped_done)
self._run(mw, wrapped_op, wrapped_done)
def _run(
self,
mw: aqt.main.AnkiQt,
op: Callable[[], ResultWithChanges],
on_done: Callable[[Future], None],
) -> None:
if self._progress_update:
mw.taskman.with_backend_progress(
op, self._progress_update, on_done=on_done, parent=self._parent
)
else:
mw.taskman.with_progress(op, on_done, parent=self._parent)
def _finish_op(
self, mw: aqt.main.AnkiQt, result: ResultWithChanges, initiator: object | None
) -> None:
mw.update_undo_actions()
mw.autosave()
self._fire_change_hooks_after_op_performed(result, initiator)
def _fire_change_hooks_after_op_performed(
self,
@ -168,6 +196,7 @@ class QueryOp(Generic[T]):
_failure: Callable[[Exception], Any] | None = None
_progress: bool | str = False
_progress_update: Callable[[Progress, ProgressUpdate], None] | None = None
def __init__(
self,
@ -192,6 +221,12 @@ class QueryOp(Generic[T]):
self._progress = label or True
return self
def with_backend_progress(
self, progress_update: Callable[[Progress, ProgressUpdate], None] | None
) -> QueryOp[T]:
self._progress_update = progress_update
return self
def run_in_background(self) -> None:
from aqt import mw
@ -201,26 +236,11 @@ class QueryOp(Generic[T]):
def wrapped_op() -> T:
assert mw
if self._progress:
label: str | None
if isinstance(self._progress, str):
label = self._progress
else:
label = None
def start_progress() -> None:
assert mw
mw.progress.start(label=label)
mw.taskman.run_on_main(start_progress)
return self._op(mw.col)
def wrapped_done(future: Future) -> None:
assert mw
if self._progress:
mw.progress.finish()
mw._decrease_background_ops()
# did something go wrong?
if exception := future.exception():
@ -228,7 +248,7 @@ class QueryOp(Generic[T]):
if self._failure:
self._failure(exception)
else:
showWarning(str(exception), self._parent)
show_exception(parent=self._parent, exception=exception)
return
else:
# BaseException like SystemExit; rethrow it
@ -236,4 +256,24 @@ class QueryOp(Generic[T]):
self._success(future.result())
mw.taskman.run_in_background(wrapped_op, wrapped_done)
self._run(mw, wrapped_op, wrapped_done)
def _run(
self,
mw: aqt.main.AnkiQt,
op: Callable[[], T],
on_done: Callable[[Future], None],
) -> None:
label = self._progress if isinstance(self._progress, str) else None
if self._progress_update:
mw.taskman.with_backend_progress(
op,
self._progress_update,
on_done=on_done,
start_label=label,
parent=self._parent,
)
elif self._progress:
mw.taskman.with_progress(op, on_done, label=label, parent=self._parent)
else:
mw.taskman.run_in_background(op, on_done)

View file

@ -538,6 +538,12 @@ create table if not exists profiles
def dark_mode_widgets(self) -> bool:
return self.meta.get("dark_mode_widgets", False)
def new_import_export(self) -> bool:
return self.meta.get("new_import_export", False)
def set_new_import_export(self, enabled: bool) -> None:
self.meta["new_import_export"] = enabled
# Profile-specific
######################################################################

View file

@ -3,9 +3,11 @@
from __future__ import annotations
import time
from dataclasses import dataclass
import aqt.forms
from anki._legacy import print_deprecation_warning
from anki.collection import Progress
from aqt.qt import *
from aqt.utils import disable_help_button, tr
@ -23,6 +25,7 @@ class ProgressManager:
self._busy_cursor_timer: QTimer | None = None
self._win: ProgressDialog | None = None
self._levels = 0
self._backend_timer: QTimer | None = None
# Safer timers
##########################################################################
@ -166,6 +169,34 @@ class ProgressManager:
qconnect(self._show_timer.timeout, self._on_show_timer)
return self._win
def start_with_backend_updates(
self,
progress_update: Callable[[Progress, ProgressUpdate], None],
start_label: str | None = None,
parent: QWidget | None = None,
) -> None:
self._backend_timer = QTimer()
self._backend_timer.setSingleShot(False)
self._backend_timer.setInterval(100)
if not (dialog := self.start(immediate=True, label=start_label, parent=parent)):
print("Progress dialog already running; aborting will not work")
def on_progress() -> None:
assert self.mw
user_wants_abort = dialog and dialog.wantCancel or False
update = ProgressUpdate(user_wants_abort=user_wants_abort)
progress = self.mw.backend.latest_progress()
progress_update(progress, update)
if update.abort:
self.mw.backend.set_wants_abort()
if update.has_update():
self.update(label=update.label, value=update.value, max=update.max)
qconnect(self._backend_timer.timeout, on_progress)
self._backend_timer.start()
def update(
self,
label: str | None = None,
@ -204,6 +235,9 @@ class ProgressManager:
if self._show_timer:
self._show_timer.stop()
self._show_timer = None
if self._backend_timer:
self._backend_timer.deleteLater()
self._backend_timer = None
def clear(self) -> None:
"Restore the interface after an error."
@ -295,3 +329,15 @@ class ProgressDialog(QDialog):
if evt.key() == Qt.Key.Key_Escape:
evt.ignore()
self.wantCancel = True
@dataclass
class ProgressUpdate:
label: str | None = None
value: int | None = None
max: int | None = None
user_wants_abort: bool = False
abort: bool = False
def has_update(self) -> bool:
return self.label is not None or self.value is not None or self.max is not None

View file

@ -15,6 +15,8 @@ from threading import Lock
from typing import Any, Callable
import aqt
from anki.collection import Progress
from aqt.progress import ProgressUpdate
from aqt.qt import *
Closure = Callable[[], None]
@ -89,6 +91,27 @@ class TaskManager(QObject):
self.run_in_background(task, wrapped_done)
def with_backend_progress(
self,
task: Callable,
progress_update: Callable[[Progress, ProgressUpdate], None],
on_done: Callable[[Future], None] | None = None,
parent: QWidget | None = None,
start_label: str | None = None,
) -> None:
self.mw.progress.start_with_backend_updates(
progress_update,
parent=parent,
start_label=start_label,
)
def wrapped_done(fut: Future) -> None:
self.mw.progress.finish()
if on_done:
on_done(fut)
self.run_in_background(task, wrapped_done)
def _on_closures_pending(self) -> None:
"""Run any pending closures. This runs in the main thread."""
with self._closures_lock:

View file

@ -69,6 +69,12 @@ impl From<pb::NoteId> for NoteId {
}
}
impl From<NoteId> for pb::NoteId {
fn from(nid: NoteId) -> Self {
pb::NoteId { nid: nid.0 }
}
}
impl From<pb::NotetypeId> for NotetypeId {
fn from(ntid: pb::NotetypeId) -> Self {
NotetypeId(ntid.ntid)

View file

@ -1,12 +1,18 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::path::Path;
use super::{progress::Progress, Backend};
pub(super) use crate::backend_proto::importexport_service::Service as ImportExportService;
use crate::{
backend_proto::{self as pb},
import_export::{package::import_colpkg, ImportProgress},
backend_proto::{self as pb, export_anki_package_request::Selector},
import_export::{
package::{import_colpkg, NoteLog},
ImportProgress,
},
prelude::*,
search::SearchNode,
};
impl ImportExportService for Backend {
@ -38,30 +44,68 @@ impl ImportExportService for Backend {
import_colpkg(
&input.backup_path,
&input.col_path,
&input.media_folder,
Path::new(&input.media_folder),
Path::new(&input.media_db),
self.import_progress_fn(),
&self.log,
)
.map(Into::into)
}
fn import_anki_package(
&self,
input: pb::ImportAnkiPackageRequest,
) -> Result<pb::ImportAnkiPackageResponse> {
self.with_col(|col| col.import_apkg(&input.package_path, self.import_progress_fn()))
.map(Into::into)
}
fn export_anki_package(&self, input: pb::ExportAnkiPackageRequest) -> Result<pb::UInt32> {
let selector = input
.selector
.ok_or_else(|| AnkiError::invalid_input("missing oneof"))?;
self.with_col(|col| {
col.export_apkg(
&input.out_path,
SearchNode::from_selector(selector),
input.with_scheduling,
input.with_media,
input.legacy,
None,
self.export_progress_fn(),
)
})
.map(Into::into)
}
}
impl Backend {
fn import_progress_fn(&self) -> impl FnMut(ImportProgress) -> Result<()> {
let mut handler = self.new_progress_handler();
move |progress| {
let throttle = matches!(progress, ImportProgress::Media(_));
if handler.update(Progress::Import(progress), throttle) {
Ok(())
} else {
Err(AnkiError::Interrupted)
}
}
}
fn export_progress_fn(&self) -> impl FnMut(usize) {
let mut handler = self.new_progress_handler();
move |media_files| {
handler.update(Progress::Export(media_files), true);
impl SearchNode {
fn from_selector(selector: Selector) -> Self {
match selector {
Selector::WholeCollection(_) => Self::WholeCollection,
Selector::DeckId(did) => Self::from_deck_id(did, true),
Selector::NoteIds(nids) => Self::from_note_ids(nids.note_ids),
}
}
}
impl Backend {
fn import_progress_fn(&self) -> impl FnMut(ImportProgress, bool) -> bool {
let mut handler = self.new_progress_handler();
move |progress, throttle| handler.update(Progress::Import(progress), throttle)
}
fn export_progress_fn(&self) -> impl FnMut(usize, bool) -> bool {
let mut handler = self.new_progress_handler();
move |progress, throttle| handler.update(Progress::Export(progress), throttle)
}
}
impl From<OpOutput<NoteLog>> for pb::ImportAnkiPackageResponse {
fn from(output: OpOutput<NoteLog>) -> Self {
Self {
changes: Some(output.changes.into()),
log: Some(output.output),
}
}
}

View file

@ -108,8 +108,10 @@ pub(super) fn progress_to_proto(progress: Option<Progress>, tr: &I18n) -> pb::Pr
}
Progress::Import(progress) => pb::progress::Value::Importing(
match progress {
ImportProgress::Collection => tr.importing_importing_collection(),
ImportProgress::File => tr.importing_importing_file(),
ImportProgress::Media(n) => tr.importing_processed_media_file(n),
ImportProgress::MediaCheck(n) => tr.media_check_checked(n),
ImportProgress::Notes(n) => tr.importing_processed_notes(n),
}
.into(),
),

View file

@ -35,6 +35,14 @@ impl Collection {
Ok(())
}
pub(crate) fn add_card_if_unique_undoable(&mut self, card: &Card) -> Result<bool> {
let added = self.storage.add_card_if_unique(card)?;
if added {
self.save_undo(UndoableCardChange::Added(Box::new(card.clone())));
}
Ok(added)
}
pub(super) fn update_card_undoable(&mut self, card: &mut Card, original: Card) -> Result<()> {
if card.id.0 == 0 {
return Err(AnkiError::invalid_input("card id not set"));

View file

@ -44,6 +44,13 @@ impl Collection {
Ok(())
}
pub(crate) fn add_deck_config_if_unique_undoable(&mut self, config: &DeckConfig) -> Result<()> {
if self.storage.add_deck_conf_if_unique(config)? {
self.save_undo(UndoableDeckConfigChange::Added(Box::new(config.clone())));
}
Ok(())
}
pub(super) fn update_deck_config_undoable(
&mut self,
config: &DeckConfig,

View file

@ -148,7 +148,7 @@ impl Collection {
Ok(())
}
fn first_existing_parent(
pub(crate) fn first_existing_parent(
&self,
machine_name: &str,
recursion_level: usize,

View file

@ -0,0 +1,222 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use crate::{
decks::immediate_parent_name,
io::filename_is_safe,
latex::extract_latex,
prelude::*,
revlog::RevlogEntry,
text::{extract_media_refs, extract_underscored_css_imports, extract_underscored_references},
};
#[derive(Debug, Default)]
pub(super) struct ExchangeData {
pub(super) decks: Vec<Deck>,
pub(super) notes: Vec<Note>,
pub(super) cards: Vec<Card>,
pub(super) notetypes: Vec<Notetype>,
pub(super) revlog: Vec<RevlogEntry>,
pub(super) deck_configs: Vec<DeckConfig>,
pub(super) media_filenames: HashSet<String>,
pub(super) days_elapsed: u32,
pub(super) creation_utc_offset: Option<i32>,
}
impl ExchangeData {
pub(super) fn gather_data(
&mut self,
col: &mut Collection,
search: impl TryIntoSearch,
with_scheduling: bool,
) -> Result<()> {
self.days_elapsed = col.timing_today()?.days_elapsed;
self.creation_utc_offset = col.get_creation_utc_offset();
self.notes = col.gather_notes(search)?;
self.cards = col.gather_cards()?;
self.decks = col.gather_decks()?;
self.notetypes = col.gather_notetypes()?;
if with_scheduling {
self.revlog = col.gather_revlog()?;
self.deck_configs = col.gather_deck_configs(&self.decks)?;
} else {
self.remove_scheduling_information(col);
};
col.storage.clear_searched_notes_table()?;
col.storage.clear_searched_cards_table()
}
pub(super) fn gather_media_names(&mut self) {
let mut inserter = |name: String| {
if filename_is_safe(&name) {
self.media_filenames.insert(name);
}
};
let svg_getter = svg_getter(&self.notetypes);
for note in self.notes.iter() {
gather_media_names_from_note(note, &mut inserter, &svg_getter);
}
for notetype in self.notetypes.iter() {
gather_media_names_from_notetype(notetype, &mut inserter);
}
}
fn remove_scheduling_information(&mut self, col: &Collection) {
self.remove_system_tags();
self.reset_deck_config_ids();
self.reset_cards(col);
}
fn remove_system_tags(&mut self) {
const SYSTEM_TAGS: [&str; 2] = ["marked", "leech"];
for note in self.notes.iter_mut() {
note.tags = std::mem::take(&mut note.tags)
.into_iter()
.filter(|tag| !SYSTEM_TAGS.iter().any(|s| tag.eq_ignore_ascii_case(s)))
.collect();
}
}
fn reset_deck_config_ids(&mut self) {
for deck in self.decks.iter_mut() {
if let Ok(normal_mut) = deck.normal_mut() {
normal_mut.config_id = 1;
} else {
// filtered decks are reset at import time for legacy reasons
}
}
}
fn reset_cards(&mut self, col: &Collection) {
let mut position = col.get_next_card_position();
for card in self.cards.iter_mut() {
// schedule_as_new() removes cards from filtered decks, but we want to
// leave cards in their current deck, which gets converted to a regular
// deck on import
let deck_id = card.deck_id;
if card.schedule_as_new(position, true, true) {
position += 1;
}
card.flags = 0;
card.deck_id = deck_id;
}
}
}
fn gather_media_names_from_note(
note: &Note,
inserter: &mut impl FnMut(String),
svg_getter: &impl Fn(NotetypeId) -> bool,
) {
for field in note.fields() {
for media_ref in extract_media_refs(field) {
inserter(media_ref.fname_decoded.to_string());
}
for latex in extract_latex(field, svg_getter(note.notetype_id)).1 {
inserter(latex.fname);
}
}
}
fn gather_media_names_from_notetype(notetype: &Notetype, inserter: &mut impl FnMut(String)) {
for name in extract_underscored_css_imports(&notetype.config.css) {
inserter(name.to_string());
}
for template in &notetype.templates {
for template_side in [&template.config.q_format, &template.config.a_format] {
for name in extract_underscored_references(template_side) {
inserter(name.to_string());
}
}
}
}
fn svg_getter(notetypes: &[Notetype]) -> impl Fn(NotetypeId) -> bool {
let svg_map: HashMap<NotetypeId, bool> = notetypes
.iter()
.map(|nt| (nt.id, nt.config.latex_svg))
.collect();
move |nt_id| svg_map.get(&nt_id).copied().unwrap_or_default()
}
impl Collection {
fn gather_notes(&mut self, search: impl TryIntoSearch) -> Result<Vec<Note>> {
self.search_notes_into_table(search)?;
self.storage.all_searched_notes()
}
fn gather_cards(&mut self) -> Result<Vec<Card>> {
self.storage.search_cards_of_notes_into_table()?;
self.storage.all_searched_cards()
}
fn gather_decks(&mut self) -> Result<Vec<Deck>> {
let decks = self.storage.get_decks_for_search_cards()?;
let parents = self.get_parent_decks(&decks)?;
Ok(decks
.into_iter()
.filter(|deck| deck.id != DeckId(1))
.chain(parents)
.collect())
}
fn get_parent_decks(&mut self, decks: &[Deck]) -> Result<Vec<Deck>> {
let mut parent_names: HashSet<String> = decks
.iter()
.map(|deck| deck.name.as_native_str().to_owned())
.collect();
let mut parents = Vec::new();
for deck in decks {
self.add_parent_decks(deck.name.as_native_str(), &mut parent_names, &mut parents)?;
}
Ok(parents)
}
fn add_parent_decks(
&mut self,
name: &str,
parent_names: &mut HashSet<String>,
parents: &mut Vec<Deck>,
) -> Result<()> {
if let Some(parent_name) = immediate_parent_name(name) {
if parent_names.insert(parent_name.to_owned()) {
parents.push(
self.storage
.get_deck_by_name(parent_name)?
.ok_or(AnkiError::DatabaseCheckRequired)?,
);
self.add_parent_decks(parent_name, parent_names, parents)?;
}
}
Ok(())
}
fn gather_notetypes(&mut self) -> Result<Vec<Notetype>> {
self.storage.get_notetypes_for_search_notes()
}
fn gather_revlog(&mut self) -> Result<Vec<RevlogEntry>> {
self.storage.get_revlog_entries_for_searched_cards()
}
fn gather_deck_configs(&mut self, decks: &[Deck]) -> Result<Vec<DeckConfig>> {
decks
.iter()
.filter_map(|deck| deck.config_id())
.unique()
.filter(|config_id| *config_id != DeckConfigId(1))
.map(|config_id| {
self.storage
.get_deck_config(config_id)?
.ok_or(AnkiError::NotFound)
})
.collect()
}
}

View file

@ -0,0 +1,62 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::gather::ExchangeData;
use crate::{prelude::*, revlog::RevlogEntry};
impl Collection {
pub(super) fn insert_data(&mut self, data: &ExchangeData) -> Result<()> {
self.transact_no_undo(|col| {
col.insert_decks(&data.decks)?;
col.insert_notes(&data.notes)?;
col.insert_cards(&data.cards)?;
col.insert_notetypes(&data.notetypes)?;
col.insert_revlog(&data.revlog)?;
col.insert_deck_configs(&data.deck_configs)
})
}
fn insert_decks(&self, decks: &[Deck]) -> Result<()> {
for deck in decks {
self.storage.add_or_update_deck_with_existing_id(deck)?;
}
Ok(())
}
fn insert_notes(&self, notes: &[Note]) -> Result<()> {
for note in notes {
self.storage.add_or_update_note(note)?;
}
Ok(())
}
fn insert_cards(&self, cards: &[Card]) -> Result<()> {
for card in cards {
self.storage.add_or_update_card(card)?;
}
Ok(())
}
fn insert_notetypes(&self, notetypes: &[Notetype]) -> Result<()> {
for notetype in notetypes {
self.storage
.add_or_update_notetype_with_existing_id(notetype)?;
}
Ok(())
}
fn insert_revlog(&self, revlog: &[RevlogEntry]) -> Result<()> {
for entry in revlog {
self.storage.add_revlog_entry(entry, false)?;
}
Ok(())
}
fn insert_deck_configs(&self, configs: &[DeckConfig]) -> Result<()> {
for config in configs {
self.storage
.add_or_update_deck_config_with_existing_id(config)?;
}
Ok(())
}
}

View file

@ -1,10 +1,87 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
mod gather;
mod insert;
pub mod package;
use std::marker::PhantomData;
use crate::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ImportProgress {
Collection,
File,
Media(usize),
MediaCheck(usize),
Notes(usize),
}
/// Wrapper around a progress function, usually passed by the [crate::backend::Backend],
/// to make repeated calls more ergonomic.
pub(crate) struct IncrementableProgress<P>(Box<dyn FnMut(P, bool) -> bool>);
impl<P> IncrementableProgress<P> {
/// `progress_fn: (progress, throttle) -> should_continue`
pub(crate) fn new(progress_fn: impl 'static + FnMut(P, bool) -> bool) -> Self {
Self(Box::new(progress_fn))
}
/// Returns an [Incrementor] with an `increment()` function for use in loops.
pub(crate) fn incrementor<'inc, 'progress: 'inc, 'map: 'inc>(
&'progress mut self,
mut count_map: impl 'map + FnMut(usize) -> P,
) -> Incrementor<'inc, impl FnMut(usize) -> Result<()> + 'inc> {
Incrementor::new(move |u| self.update(count_map(u), true))
}
/// Manually triggers an update.
/// Returns [AnkiError::Interrupted] if the operation should be cancelled.
pub(crate) fn call(&mut self, progress: P) -> Result<()> {
self.update(progress, false)
}
fn update(&mut self, progress: P, throttle: bool) -> Result<()> {
if (self.0)(progress, throttle) {
Ok(())
} else {
Err(AnkiError::Interrupted)
}
}
/// Stopgap for returning a progress fn compliant with the media code.
pub(crate) fn media_db_fn(
&mut self,
count_map: impl 'static + Fn(usize) -> P,
) -> Result<impl FnMut(usize) -> bool + '_> {
Ok(move |count| (self.0)(count_map(count), true))
}
}
pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> {
update_fn: F,
count: usize,
update_interval: usize,
_phantom: PhantomData<&'f ()>,
}
impl<'f, F: 'f + FnMut(usize) -> Result<()>> Incrementor<'f, F> {
fn new(update_fn: F) -> Self {
Self {
update_fn,
count: 0,
update_interval: 17,
_phantom: PhantomData,
}
}
/// Increments the progress counter, periodically triggering an update.
/// Returns [AnkiError::Interrupted] if the operation should be cancelled.
pub(crate) fn increment(&mut self) -> Result<()> {
self.count += 1;
if self.count % self.update_interval != 0 {
return Ok(());
}
(self.update_fn)(self.count)
}
}

View file

@ -0,0 +1,106 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
collections::HashSet,
path::{Path, PathBuf},
};
use tempfile::NamedTempFile;
use crate::{
collection::CollectionBuilder,
import_export::{
gather::ExchangeData,
package::{
colpkg::export::{export_collection, MediaIter},
Meta,
},
IncrementableProgress,
},
io::{atomic_rename, tempfile_in_parent_of},
prelude::*,
};
impl Collection {
/// Returns number of exported notes.
#[allow(clippy::too_many_arguments)]
pub fn export_apkg(
&mut self,
out_path: impl AsRef<Path>,
search: impl TryIntoSearch,
with_scheduling: bool,
with_media: bool,
legacy: bool,
media_fn: Option<Box<dyn FnOnce(HashSet<String>) -> MediaIter>>,
progress_fn: impl 'static + FnMut(usize, bool) -> bool,
) -> Result<usize> {
let mut progress = IncrementableProgress::new(progress_fn);
let temp_apkg = tempfile_in_parent_of(out_path.as_ref())?;
let mut temp_col = NamedTempFile::new()?;
let temp_col_path = temp_col
.path()
.to_str()
.ok_or_else(|| AnkiError::IoError("tempfile with non-unicode name".into()))?;
let meta = if legacy {
Meta::new_legacy()
} else {
Meta::new()
};
let data = self.export_into_collection_file(
&meta,
temp_col_path,
search,
with_scheduling,
with_media,
)?;
let media = if let Some(media_fn) = media_fn {
media_fn(data.media_filenames)
} else {
MediaIter::from_file_list(data.media_filenames, self.media_folder.clone())
};
let col_size = temp_col.as_file().metadata()?.len() as usize;
export_collection(
meta,
temp_apkg.path(),
&mut temp_col,
col_size,
media,
&self.tr,
&mut progress,
)?;
atomic_rename(temp_apkg, out_path.as_ref(), true)?;
Ok(data.notes.len())
}
fn export_into_collection_file(
&mut self,
meta: &Meta,
path: &str,
search: impl TryIntoSearch,
with_scheduling: bool,
with_media: bool,
) -> Result<ExchangeData> {
let mut data = ExchangeData::default();
data.gather_data(self, search, with_scheduling)?;
if with_media {
data.gather_media_names();
}
let mut temp_col = Collection::new_minimal(path)?;
temp_col.insert_data(&data)?;
temp_col.set_creation_stamp(self.storage.creation_stamp()?)?;
temp_col.set_creation_utc_offset(data.creation_utc_offset)?;
temp_col.close(Some(meta.schema_version()))?;
Ok(data)
}
fn new_minimal(path: impl Into<PathBuf>) -> Result<Self> {
let col = CollectionBuilder::new(path).build()?;
col.storage.db.execute_batch("DELETE FROM notetypes")?;
Ok(col)
}
}

View file

@ -0,0 +1,179 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
collections::{HashMap, HashSet},
mem,
};
use super::Context;
use crate::{
card::{CardQueue, CardType},
config::SchedulerVersion,
prelude::*,
revlog::RevlogEntry,
};
type CardAsNidAndOrd = (NoteId, u16);
struct CardContext<'a> {
target_col: &'a mut Collection,
usn: Usn,
imported_notes: &'a HashMap<NoteId, NoteId>,
remapped_decks: &'a HashMap<DeckId, DeckId>,
/// The number of days the source collection is ahead of the target collection
collection_delta: i32,
scheduler_version: SchedulerVersion,
existing_cards: HashSet<CardAsNidAndOrd>,
existing_card_ids: HashSet<CardId>,
imported_cards: HashMap<CardId, CardId>,
}
impl<'c> CardContext<'c> {
fn new<'a: 'c>(
usn: Usn,
days_elapsed: u32,
target_col: &'a mut Collection,
imported_notes: &'a HashMap<NoteId, NoteId>,
imported_decks: &'a HashMap<DeckId, DeckId>,
) -> Result<Self> {
let existing_cards = target_col.storage.all_cards_as_nid_and_ord()?;
let collection_delta = target_col.collection_delta(days_elapsed)?;
let scheduler_version = target_col.scheduler_info()?.version;
let existing_card_ids = target_col.storage.get_all_card_ids()?;
Ok(Self {
target_col,
usn,
imported_notes,
remapped_decks: imported_decks,
existing_cards,
collection_delta,
scheduler_version,
existing_card_ids,
imported_cards: HashMap::new(),
})
}
}
impl Collection {
/// How much `days_elapsed` is ahead of this collection.
fn collection_delta(&mut self, days_elapsed: u32) -> Result<i32> {
Ok(days_elapsed as i32 - self.timing_today()?.days_elapsed as i32)
}
}
impl Context<'_> {
pub(super) fn import_cards_and_revlog(
&mut self,
imported_notes: &HashMap<NoteId, NoteId>,
imported_decks: &HashMap<DeckId, DeckId>,
) -> Result<()> {
let mut ctx = CardContext::new(
self.usn,
self.data.days_elapsed,
self.target_col,
imported_notes,
imported_decks,
)?;
ctx.import_cards(mem::take(&mut self.data.cards))?;
ctx.import_revlog(mem::take(&mut self.data.revlog))
}
}
impl CardContext<'_> {
fn import_cards(&mut self, mut cards: Vec<Card>) -> Result<()> {
for card in &mut cards {
if self.map_to_imported_note(card) && !self.card_ordinal_already_exists(card) {
self.add_card(card)?;
}
// TODO: could update existing card
}
Ok(())
}
fn import_revlog(&mut self, revlog: Vec<RevlogEntry>) -> Result<()> {
for mut entry in revlog {
if let Some(cid) = self.imported_cards.get(&entry.cid) {
entry.cid = *cid;
entry.usn = self.usn;
self.target_col.add_revlog_entry_if_unique_undoable(entry)?;
}
}
Ok(())
}
fn map_to_imported_note(&self, card: &mut Card) -> bool {
if let Some(nid) = self.imported_notes.get(&card.note_id) {
card.note_id = *nid;
true
} else {
false
}
}
fn card_ordinal_already_exists(&self, card: &Card) -> bool {
self.existing_cards
.contains(&(card.note_id, card.template_idx))
}
fn add_card(&mut self, card: &mut Card) -> Result<()> {
card.usn = self.usn;
self.remap_deck_id(card);
card.shift_collection_relative_dates(self.collection_delta);
card.maybe_remove_from_filtered_deck(self.scheduler_version);
let old_id = self.uniquify_card_id(card);
self.target_col.add_card_if_unique_undoable(card)?;
self.existing_card_ids.insert(card.id);
self.imported_cards.insert(old_id, card.id);
Ok(())
}
fn uniquify_card_id(&mut self, card: &mut Card) -> CardId {
let original = card.id;
while self.existing_card_ids.contains(&card.id) {
card.id.0 += 999;
}
original
}
fn remap_deck_id(&self, card: &mut Card) {
if let Some(did) = self.remapped_decks.get(&card.deck_id) {
card.deck_id = *did;
}
}
}
impl Card {
/// `delta` is the number days the card's source collection is ahead of the
/// target collection.
fn shift_collection_relative_dates(&mut self, delta: i32) {
if self.due_in_days_since_collection_creation() {
self.due -= delta;
}
if self.original_due_in_days_since_collection_creation() && self.original_due != 0 {
self.original_due -= delta;
}
}
fn due_in_days_since_collection_creation(&self) -> bool {
matches!(self.queue, CardQueue::Review | CardQueue::DayLearn)
|| self.ctype == CardType::Review
}
fn original_due_in_days_since_collection_creation(&self) -> bool {
self.ctype == CardType::Review
}
fn maybe_remove_from_filtered_deck(&mut self, version: SchedulerVersion) {
if self.is_filtered() {
// instead of moving between decks, the deck is converted to a regular one
self.original_deck_id = self.deck_id;
self.remove_from_filtered_deck_restoring_queue(version);
}
}
}

View file

@ -0,0 +1,212 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{collections::HashMap, mem};
use super::Context;
use crate::{decks::NormalDeck, prelude::*};
struct DeckContext<'d> {
target_col: &'d mut Collection,
usn: Usn,
renamed_parents: Vec<(String, String)>,
imported_decks: HashMap<DeckId, DeckId>,
unique_suffix: String,
}
impl<'d> DeckContext<'d> {
fn new<'a: 'd>(target_col: &'a mut Collection, usn: Usn) -> Self {
Self {
target_col,
usn,
renamed_parents: Vec::new(),
imported_decks: HashMap::new(),
unique_suffix: TimestampSecs::now().to_string(),
}
}
}
impl Context<'_> {
pub(super) fn import_decks_and_configs(&mut self) -> Result<HashMap<DeckId, DeckId>> {
let mut ctx = DeckContext::new(self.target_col, self.usn);
ctx.import_deck_configs(mem::take(&mut self.data.deck_configs))?;
ctx.import_decks(mem::take(&mut self.data.decks))?;
Ok(ctx.imported_decks)
}
}
impl DeckContext<'_> {
fn import_deck_configs(&mut self, mut configs: Vec<DeckConfig>) -> Result<()> {
for config in &mut configs {
config.usn = self.usn;
self.target_col.add_deck_config_if_unique_undoable(config)?;
}
Ok(())
}
fn import_decks(&mut self, mut decks: Vec<Deck>) -> Result<()> {
// ensure parents are seen before children
decks.sort_unstable_by_key(|deck| deck.level());
for deck in &mut decks {
self.prepare_deck(deck);
self.import_deck(deck)?;
}
Ok(())
}
fn prepare_deck(&mut self, deck: &mut Deck) {
self.maybe_reparent(deck);
if deck.is_filtered() {
deck.kind = DeckKind::Normal(NormalDeck {
config_id: 1,
..Default::default()
});
}
}
fn import_deck(&mut self, deck: &mut Deck) -> Result<()> {
if let Some(original) = self.get_deck_by_name(deck)? {
if original.is_filtered() {
self.uniquify_name(deck);
self.add_deck(deck)
} else {
self.update_deck(deck, original)
}
} else {
self.ensure_valid_first_existing_parent(deck)?;
self.add_deck(deck)
}
}
fn maybe_reparent(&self, deck: &mut Deck) {
if let Some(new_name) = self.reparented_name(deck.name.as_native_str()) {
deck.name = NativeDeckName::from_native_str(new_name);
}
}
fn reparented_name(&self, name: &str) -> Option<String> {
self.renamed_parents
.iter()
.find_map(|(old_parent, new_parent)| {
name.starts_with(old_parent)
.then(|| name.replacen(old_parent, new_parent, 1))
})
}
fn get_deck_by_name(&mut self, deck: &Deck) -> Result<Option<Deck>> {
self.target_col
.storage
.get_deck_by_name(deck.name.as_native_str())
}
fn uniquify_name(&mut self, deck: &mut Deck) {
let old_parent = format!("{}\x1f", deck.name.as_native_str());
deck.uniquify_name(&self.unique_suffix);
let new_parent = format!("{}\x1f", deck.name.as_native_str());
self.renamed_parents.push((old_parent, new_parent));
}
fn add_deck(&mut self, deck: &mut Deck) -> Result<()> {
let old_id = mem::take(&mut deck.id);
self.target_col.add_deck_inner(deck, self.usn)?;
self.imported_decks.insert(old_id, deck.id);
Ok(())
}
/// Caller must ensure decks are normal.
fn update_deck(&mut self, deck: &Deck, original: Deck) -> Result<()> {
let mut new_deck = original.clone();
new_deck.normal_mut()?.update_with_other(deck.normal()?);
self.imported_decks.insert(deck.id, new_deck.id);
self.target_col
.update_deck_inner(&mut new_deck, original, self.usn)
}
fn ensure_valid_first_existing_parent(&mut self, deck: &mut Deck) -> Result<()> {
if let Some(ancestor) = self
.target_col
.first_existing_parent(deck.name.as_native_str(), 0)?
{
if ancestor.is_filtered() {
self.add_unique_default_deck(ancestor.name.as_native_str())?;
self.maybe_reparent(deck);
}
}
Ok(())
}
fn add_unique_default_deck(&mut self, name: &str) -> Result<()> {
let mut deck = Deck::new_normal();
deck.name = NativeDeckName::from_native_str(name);
self.uniquify_name(&mut deck);
self.target_col.add_deck_inner(&mut deck, self.usn)
}
}
impl Deck {
fn uniquify_name(&mut self, suffix: &str) {
let new_name = format!("{} {}", self.name.as_native_str(), suffix);
self.name = NativeDeckName::from_native_str(new_name);
}
fn level(&self) -> usize {
self.name.components().count()
}
}
impl NormalDeck {
fn update_with_other(&mut self, other: &Self) {
if !other.description.is_empty() {
self.markdown_description = other.markdown_description;
self.description = other.description.clone();
}
if other.config_id != 1 {
self.config_id = other.config_id;
}
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use super::*;
use crate::{collection::open_test_collection, tests::new_deck_with_machine_name};
#[test]
fn parents() {
let mut col = open_test_collection();
col.add_deck_with_machine_name("filtered", true);
col.add_deck_with_machine_name("PARENT", false);
let mut ctx = DeckContext::new(&mut col, Usn(1));
ctx.unique_suffix = "".to_string();
let imports = vec![
new_deck_with_machine_name("unknown parent\x1fchild", false),
new_deck_with_machine_name("filtered\x1fchild", false),
new_deck_with_machine_name("parent\x1fchild", false),
new_deck_with_machine_name("NEW PARENT\x1fchild", false),
new_deck_with_machine_name("new parent", false),
];
ctx.import_decks(imports).unwrap();
let existing_decks: HashSet<_> = ctx
.target_col
.get_all_deck_names(true)
.unwrap()
.into_iter()
.map(|(_, name)| name)
.collect();
// missing parents get created
assert!(existing_decks.contains("unknown parent"));
// ... and uniquified if their existing counterparts are filtered
assert!(existing_decks.contains("filtered ★"));
assert!(existing_decks.contains("filtered ★::child"));
// the case of existing parents is matched
assert!(existing_decks.contains("PARENT::child"));
// the case of imported parents is matched, regardless of pass order
assert!(existing_decks.contains("new parent::child"));
}
}

View file

@ -0,0 +1,134 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{collections::HashMap, fs::File, mem};
use zip::ZipArchive;
use super::Context;
use crate::{
import_export::{
package::{
media::{extract_media_entries, SafeMediaEntry},
Meta,
},
ImportProgress, IncrementableProgress,
},
media::{
files::{add_hash_suffix_to_file_stem, sha1_of_reader},
MediaManager,
},
prelude::*,
};
/// Map of source media files, that do not already exist in the target.
#[derive(Default)]
pub(super) struct MediaUseMap {
/// original, normalized filename → (refererenced on import material,
/// entry with possibly remapped filename)
checked: HashMap<String, (bool, SafeMediaEntry)>,
/// Static files (latex, underscored). Usage is not tracked, and if the name
/// already exists in the target, it is skipped regardless of content equality.
unchecked: Vec<SafeMediaEntry>,
}
impl Context<'_> {
pub(super) fn prepare_media(&mut self) -> Result<MediaUseMap> {
let db_progress_fn = self.progress.media_db_fn(ImportProgress::MediaCheck)?;
let existing_sha1s = self.target_col.all_existing_sha1s(db_progress_fn)?;
prepare_media(
&self.meta,
&mut self.archive,
&existing_sha1s,
&mut self.progress,
)
}
pub(super) fn copy_media(&mut self, media_map: &mut MediaUseMap) -> Result<()> {
let mut incrementor = self.progress.incrementor(ImportProgress::Media);
for entry in media_map.used_entries() {
incrementor.increment()?;
entry.copy_from_archive(&mut self.archive, &self.target_col.media_folder)?;
}
Ok(())
}
}
impl Collection {
fn all_existing_sha1s(
&mut self,
progress_fn: impl FnMut(usize) -> bool,
) -> Result<HashMap<String, Sha1Hash>> {
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
mgr.all_checksums(progress_fn, &self.log)
}
}
fn prepare_media(
meta: &Meta,
archive: &mut ZipArchive<File>,
existing_sha1s: &HashMap<String, Sha1Hash>,
progress: &mut IncrementableProgress<ImportProgress>,
) -> Result<MediaUseMap> {
let mut media_map = MediaUseMap::default();
let mut incrementor = progress.incrementor(ImportProgress::MediaCheck);
for mut entry in extract_media_entries(meta, archive)? {
incrementor.increment()?;
if entry.is_static() {
if !existing_sha1s.contains_key(&entry.name) {
media_map.unchecked.push(entry);
}
} else if let Some(other_sha1) = existing_sha1s.get(&entry.name) {
entry.with_hash_from_archive(archive)?;
if entry.sha1 != *other_sha1 {
let original_name = entry.uniquify_name();
media_map.add_checked(original_name, entry);
}
} else {
media_map.add_checked(entry.name.clone(), entry);
}
}
Ok(media_map)
}
impl MediaUseMap {
pub(super) fn add_checked(&mut self, filename: impl Into<String>, entry: SafeMediaEntry) {
self.checked.insert(filename.into(), (false, entry));
}
pub(super) fn use_entry(&mut self, filename: &str) -> Option<&SafeMediaEntry> {
self.checked.get_mut(filename).map(|(used, entry)| {
*used = true;
&*entry
})
}
pub(super) fn used_entries(&self) -> impl Iterator<Item = &SafeMediaEntry> {
self.checked
.values()
.filter_map(|(used, entry)| used.then(|| entry))
.chain(self.unchecked.iter())
}
}
impl SafeMediaEntry {
fn with_hash_from_archive(&mut self, archive: &mut ZipArchive<File>) -> Result<()> {
if self.sha1 == [0; 20] {
let mut reader = self.fetch_file(archive)?;
self.sha1 = sha1_of_reader(&mut reader)?;
}
Ok(())
}
/// Requires sha1 to be set. Returns old file name.
fn uniquify_name(&mut self) -> String {
let new_name = add_hash_suffix_to_file_stem(&self.name, &self.sha1);
mem::replace(&mut self.name, new_name)
}
fn is_static(&self) -> bool {
self.name.starts_with('_') || self.name.starts_with("latex-")
}
}

View file

@ -0,0 +1,137 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
mod cards;
mod decks;
mod media;
mod notes;
use std::{fs::File, io, path::Path};
pub(crate) use notes::NoteMeta;
use rusqlite::OptionalExtension;
use tempfile::NamedTempFile;
use zip::ZipArchive;
use zstd::stream::copy_decode;
use crate::{
collection::CollectionBuilder,
import_export::{
gather::ExchangeData,
package::{Meta, NoteLog},
ImportProgress, IncrementableProgress,
},
prelude::*,
search::SearchNode,
};
struct Context<'a> {
target_col: &'a mut Collection,
archive: ZipArchive<File>,
meta: Meta,
data: ExchangeData,
usn: Usn,
progress: IncrementableProgress<ImportProgress>,
}
impl Collection {
pub fn import_apkg(
&mut self,
path: impl AsRef<Path>,
progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool,
) -> Result<OpOutput<NoteLog>> {
let file = File::open(path)?;
let archive = ZipArchive::new(file)?;
self.transact(Op::Import, |col| {
let mut ctx = Context::new(archive, col, progress_fn)?;
ctx.import()
})
}
}
impl<'a> Context<'a> {
fn new(
mut archive: ZipArchive<File>,
target_col: &'a mut Collection,
progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool,
) -> Result<Self> {
let progress = IncrementableProgress::new(progress_fn);
let meta = Meta::from_archive(&mut archive)?;
let data = ExchangeData::gather_from_archive(
&mut archive,
&meta,
SearchNode::WholeCollection,
true,
)?;
let usn = target_col.usn()?;
Ok(Self {
target_col,
archive,
meta,
data,
usn,
progress,
})
}
fn import(&mut self) -> Result<NoteLog> {
let mut media_map = self.prepare_media()?;
self.progress.call(ImportProgress::File)?;
let note_imports = self.import_notes_and_notetypes(&mut media_map)?;
let imported_decks = self.import_decks_and_configs()?;
self.import_cards_and_revlog(&note_imports.id_map, &imported_decks)?;
self.copy_media(&mut media_map)?;
Ok(note_imports.log)
}
}
impl ExchangeData {
fn gather_from_archive(
archive: &mut ZipArchive<File>,
meta: &Meta,
search: impl TryIntoSearch,
with_scheduling: bool,
) -> Result<Self> {
let tempfile = collection_to_tempfile(meta, archive)?;
let mut col = CollectionBuilder::new(tempfile.path()).build()?;
col.maybe_upgrade_scheduler()?;
let mut data = ExchangeData::default();
data.gather_data(&mut col, search, with_scheduling)?;
Ok(data)
}
}
fn collection_to_tempfile(meta: &Meta, archive: &mut ZipArchive<File>) -> Result<NamedTempFile> {
let mut zip_file = archive.by_name(meta.collection_filename())?;
let mut tempfile = NamedTempFile::new()?;
if meta.zstd_compressed() {
copy_decode(zip_file, &mut tempfile)
} else {
io::copy(&mut zip_file, &mut tempfile).map(|_| ())
}
.map_err(|err| AnkiError::file_io_error(err, tempfile.path()))?;
Ok(tempfile)
}
impl Collection {
fn maybe_upgrade_scheduler(&mut self) -> Result<()> {
if self.scheduling_included()? {
self.upgrade_to_v2_scheduler()?;
}
Ok(())
}
fn scheduling_included(&mut self) -> Result<bool> {
const SQL: &str = "SELECT 1 FROM cards WHERE queue != 0";
Ok(self
.storage
.db
.query_row(SQL, [], |_| Ok(()))
.optional()?
.is_some())
}
}

View file

@ -0,0 +1,459 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
mem,
sync::Arc,
};
use sha1::Sha1;
use super::{media::MediaUseMap, Context};
use crate::{
import_export::{
package::{media::safe_normalized_file_name, LogNote, NoteLog},
ImportProgress, IncrementableProgress,
},
prelude::*,
text::{
newlines_to_spaces, replace_media_refs, strip_html_preserving_media_filenames,
truncate_to_char_boundary, CowMapping,
},
};
struct NoteContext<'a> {
target_col: &'a mut Collection,
usn: Usn,
normalize_notes: bool,
remapped_notetypes: HashMap<NotetypeId, NotetypeId>,
target_guids: HashMap<String, NoteMeta>,
target_ids: HashSet<NoteId>,
media_map: &'a mut MediaUseMap,
imports: NoteImports,
}
#[derive(Debug, Default)]
pub(super) struct NoteImports {
pub(super) id_map: HashMap<NoteId, NoteId>,
/// All notes from the source collection as [Vec]s of their fields, and grouped
/// by import result kind.
pub(super) log: NoteLog,
}
impl NoteImports {
fn log_new(&mut self, note: Note, source_id: NoteId) {
self.id_map.insert(source_id, note.id);
self.log.new.push(note.into_log_note());
}
fn log_updated(&mut self, note: Note, source_id: NoteId) {
self.id_map.insert(source_id, note.id);
self.log.updated.push(note.into_log_note());
}
fn log_duplicate(&mut self, mut note: Note, target_id: NoteId) {
self.id_map.insert(note.id, target_id);
// id is for looking up note in *target* collection
note.id = target_id;
self.log.duplicate.push(note.into_log_note());
}
fn log_conflicting(&mut self, note: Note) {
self.log.conflicting.push(note.into_log_note());
}
}
impl Note {
fn into_log_note(self) -> LogNote {
LogNote {
id: Some(self.id.into()),
fields: self
.into_fields()
.into_iter()
.map(|field| {
let mut reduced = strip_html_preserving_media_filenames(&field)
.map_cow(newlines_to_spaces)
.get_owned()
.unwrap_or(field);
truncate_to_char_boundary(&mut reduced, 80);
reduced
})
.collect(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct NoteMeta {
id: NoteId,
mtime: TimestampSecs,
notetype_id: NotetypeId,
}
impl NoteMeta {
pub(crate) fn new(id: NoteId, mtime: TimestampSecs, notetype_id: NotetypeId) -> Self {
Self {
id,
mtime,
notetype_id,
}
}
}
impl Context<'_> {
pub(super) fn import_notes_and_notetypes(
&mut self,
media_map: &mut MediaUseMap,
) -> Result<NoteImports> {
let mut ctx = NoteContext::new(self.usn, self.target_col, media_map)?;
ctx.import_notetypes(mem::take(&mut self.data.notetypes))?;
ctx.import_notes(mem::take(&mut self.data.notes), &mut self.progress)?;
Ok(ctx.imports)
}
}
impl<'n> NoteContext<'n> {
fn new<'a: 'n>(
usn: Usn,
target_col: &'a mut Collection,
media_map: &'a mut MediaUseMap,
) -> Result<Self> {
let target_guids = target_col.storage.note_guid_map()?;
let normalize_notes = target_col.get_config_bool(BoolKey::NormalizeNoteText);
let target_ids = target_col.storage.get_all_note_ids()?;
Ok(Self {
target_col,
usn,
normalize_notes,
remapped_notetypes: HashMap::new(),
target_guids,
target_ids,
imports: NoteImports::default(),
media_map,
})
}
fn import_notetypes(&mut self, mut notetypes: Vec<Notetype>) -> Result<()> {
for notetype in &mut notetypes {
if let Some(existing) = self.target_col.storage.get_notetype(notetype.id)? {
self.merge_or_remap_notetype(notetype, existing)?;
} else {
self.add_notetype(notetype)?;
}
}
Ok(())
}
fn merge_or_remap_notetype(
&mut self,
incoming: &mut Notetype,
existing: Notetype,
) -> Result<()> {
if incoming.schema_hash() == existing.schema_hash() {
if incoming.mtime_secs > existing.mtime_secs {
self.update_notetype(incoming, existing)?;
}
} else {
self.add_notetype_with_remapped_id(incoming)?;
}
Ok(())
}
fn add_notetype(&mut self, notetype: &mut Notetype) -> Result<()> {
notetype.prepare_for_update(None, true)?;
self.target_col
.ensure_notetype_name_unique(notetype, self.usn)?;
notetype.usn = self.usn;
self.target_col
.add_notetype_with_unique_id_undoable(notetype)
}
fn update_notetype(&mut self, notetype: &mut Notetype, original: Notetype) -> Result<()> {
notetype.usn = self.usn;
self.target_col
.add_or_update_notetype_with_existing_id_inner(notetype, Some(original), self.usn, true)
}
fn add_notetype_with_remapped_id(&mut self, notetype: &mut Notetype) -> Result<()> {
let old_id = std::mem::take(&mut notetype.id);
notetype.usn = self.usn;
self.target_col
.add_notetype_inner(notetype, self.usn, true)?;
self.remapped_notetypes.insert(old_id, notetype.id);
Ok(())
}
fn import_notes(
&mut self,
notes: Vec<Note>,
progress: &mut IncrementableProgress<ImportProgress>,
) -> Result<()> {
let mut incrementor = progress.incrementor(ImportProgress::Notes);
for mut note in notes {
incrementor.increment()?;
if let Some(notetype_id) = self.remapped_notetypes.get(&note.notetype_id) {
if self.target_guids.contains_key(&note.guid) {
self.imports.log_conflicting(note);
} else {
note.notetype_id = *notetype_id;
self.add_note(note)?;
}
} else if let Some(&meta) = self.target_guids.get(&note.guid) {
self.maybe_update_note(note, meta)?;
} else {
self.add_note(note)?;
}
}
Ok(())
}
fn add_note(&mut self, mut note: Note) -> Result<()> {
self.munge_media(&mut note)?;
self.target_col.canonify_note_tags(&mut note, self.usn)?;
let notetype = self.get_expected_notetype(note.notetype_id)?;
note.prepare_for_update(&notetype, self.normalize_notes)?;
note.usn = self.usn;
let old_id = self.uniquify_note_id(&mut note);
self.target_col.add_note_only_with_id_undoable(&mut note)?;
self.target_ids.insert(note.id);
self.imports.log_new(note, old_id);
Ok(())
}
fn uniquify_note_id(&mut self, note: &mut Note) -> NoteId {
let original = note.id;
while self.target_ids.contains(&note.id) {
note.id.0 += 999;
}
original
}
fn get_expected_notetype(&mut self, ntid: NotetypeId) -> Result<Arc<Notetype>> {
self.target_col
.get_notetype(ntid)?
.ok_or(AnkiError::NotFound)
}
fn get_expected_note(&mut self, nid: NoteId) -> Result<Note> {
self.target_col
.storage
.get_note(nid)?
.ok_or(AnkiError::NotFound)
}
fn maybe_update_note(&mut self, note: Note, meta: NoteMeta) -> Result<()> {
if meta.mtime < note.mtime {
if meta.notetype_id == note.notetype_id {
self.update_note(note, meta.id)?;
} else {
self.imports.log_conflicting(note);
}
} else {
self.imports.log_duplicate(note, meta.id);
}
Ok(())
}
fn update_note(&mut self, mut note: Note, target_id: NoteId) -> Result<()> {
let source_id = note.id;
note.id = target_id;
self.munge_media(&mut note)?;
let original = self.get_expected_note(note.id)?;
let notetype = self.get_expected_notetype(note.notetype_id)?;
self.target_col.update_note_inner_without_cards(
&mut note,
&original,
&notetype,
self.usn,
true,
self.normalize_notes,
true,
)?;
self.imports.log_updated(note, source_id);
Ok(())
}
fn munge_media(&mut self, note: &mut Note) -> Result<()> {
for field in note.fields_mut() {
if let Some(new_field) = self.replace_media_refs(field) {
*field = new_field;
};
}
Ok(())
}
fn replace_media_refs(&mut self, field: &mut String) -> Option<String> {
replace_media_refs(field, |name| {
if let Ok(normalized) = safe_normalized_file_name(name) {
if let Some(entry) = self.media_map.use_entry(&normalized) {
if entry.name != name {
// name is not normalized, and/or remapped
return Some(entry.name.clone());
}
} else if let Cow::Owned(s) = normalized {
// no entry; might be a reference to an existing file, so ensure normalization
return Some(s);
}
}
None
})
}
}
impl Notetype {
fn schema_hash(&self) -> Sha1Hash {
let mut hasher = Sha1::new();
for field in &self.fields {
hasher.update(field.name.as_bytes());
}
for template in &self.templates {
hasher.update(template.name.as_bytes());
}
hasher.digest().bytes()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{collection::open_test_collection, import_export::package::media::SafeMediaEntry};
/// Import [Note] into [Collection], optionally taking a [MediaUseMap],
/// or a [Notetype] remapping.
macro_rules! import_note {
($col:expr, $note:expr, $old_notetype:expr => $new_notetype:expr) => {{
let mut media_map = MediaUseMap::default();
let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut media_map).unwrap();
ctx.remapped_notetypes.insert($old_notetype, $new_notetype);
let mut progress = IncrementableProgress::new(|_, _| true);
ctx.import_notes(vec![$note], &mut progress).unwrap();
ctx.imports.log
}};
($col:expr, $note:expr, $media_map:expr) => {{
let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut $media_map).unwrap();
let mut progress = IncrementableProgress::new(|_, _| true);
ctx.import_notes(vec![$note], &mut progress).unwrap();
ctx.imports.log
}};
($col:expr, $note:expr) => {{
let mut media_map = MediaUseMap::default();
import_note!($col, $note, media_map)
}};
}
/// Assert that exactly one [Note] is logged, and that with the given state and fields.
macro_rules! assert_note_logged {
($log:expr, $state:ident, $fields:expr) => {
assert_eq!($log.$state.pop().unwrap().fields, $fields);
assert!($log.new.is_empty());
assert!($log.updated.is_empty());
assert!($log.duplicate.is_empty());
assert!($log.conflicting.is_empty());
};
}
impl Collection {
fn note_id_for_guid(&self, guid: &str) -> NoteId {
self.storage
.db
.query_row("SELECT id FROM notes WHERE guid = ?", [guid], |r| r.get(0))
.unwrap()
}
}
#[test]
fn should_add_note_with_new_id_if_guid_is_unique_and_id_is_not() {
let mut col = open_test_collection();
let mut note = col.add_new_note("basic");
note.guid = "other".to_string();
let original_id = note.id;
let mut log = import_note!(col, note);
assert_ne!(col.note_id_for_guid("other"), original_id);
assert_note_logged!(log, new, &["", ""]);
}
#[test]
fn should_skip_note_if_guid_already_exists_with_newer_mtime() {
let mut col = open_test_collection();
let mut note = col.add_new_note("basic");
note.mtime.0 -= 1;
note.fields_mut()[0] = "outdated".to_string();
let mut log = import_note!(col, note);
assert_eq!(col.get_all_notes()[0].fields()[0], "");
assert_note_logged!(log, duplicate, &["outdated", ""]);
}
#[test]
fn should_update_note_if_guid_already_exists_with_different_id() {
let mut col = open_test_collection();
let mut note = col.add_new_note("basic");
note.id.0 = 42;
note.mtime.0 += 1;
note.fields_mut()[0] = "updated".to_string();
let mut log = import_note!(col, note);
assert_eq!(col.get_all_notes()[0].fields()[0], "updated");
assert_note_logged!(log, updated, &["updated", ""]);
}
#[test]
fn should_ignore_note_if_guid_already_exists_with_different_notetype() {
let mut col = open_test_collection();
let mut note = col.add_new_note("basic");
note.notetype_id.0 = 42;
note.mtime.0 += 1;
note.fields_mut()[0] = "updated".to_string();
let mut log = import_note!(col, note);
assert_eq!(col.get_all_notes()[0].fields()[0], "");
assert_note_logged!(log, conflicting, &["updated", ""]);
}
#[test]
fn should_add_note_with_remapped_notetype_if_in_notetype_map() {
let mut col = open_test_collection();
let basic_ntid = col.get_notetype_by_name("basic").unwrap().unwrap().id;
let mut note = col.new_note("basic");
note.notetype_id.0 = 123;
let mut log = import_note!(col, note, NotetypeId(123) => basic_ntid);
assert_eq!(col.get_all_notes()[0].notetype_id, basic_ntid);
assert_note_logged!(log, new, &["", ""]);
}
#[test]
fn should_ignore_note_if_guid_already_exists_and_notetype_is_remapped() {
let mut col = open_test_collection();
let basic_ntid = col.get_notetype_by_name("basic").unwrap().unwrap().id;
let mut note = col.add_new_note("basic");
note.notetype_id.0 = 123;
note.mtime.0 += 1;
note.fields_mut()[0] = "updated".to_string();
let mut log = import_note!(col, note, NotetypeId(123) => basic_ntid);
assert_eq!(col.get_all_notes()[0].fields()[0], "");
assert_note_logged!(log, conflicting, &["updated", ""]);
}
#[test]
fn should_add_note_with_remapped_media_reference_in_field_if_in_media_map() {
let mut col = open_test_collection();
let mut note = col.new_note("basic");
note.fields_mut()[0] = "<img src='foo.jpg'>".to_string();
let mut media_map = MediaUseMap::default();
let entry = SafeMediaEntry::from_legacy(("0", "bar.jpg".to_string())).unwrap();
media_map.add_checked("foo.jpg", entry);
let mut log = import_note!(col, note, media_map);
assert_eq!(col.get_all_notes()[0].fields()[0], "<img src='bar.jpg'>");
assert_note_logged!(log, new, &[" bar.jpg ", ""]);
}
}

View file

@ -0,0 +1,8 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
mod export;
mod import;
mod tests;
pub(crate) use import::NoteMeta;

View file

@ -0,0 +1,150 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
#![cfg(test)]
use std::{collections::HashSet, fs::File, io::Write};
use crate::{
media::files::sha1_of_data, prelude::*, search::SearchNode, tests::open_fs_test_collection,
};
const SAMPLE_JPG: &str = "sample.jpg";
const SAMPLE_MP3: &str = "sample.mp3";
const SAMPLE_JS: &str = "_sample.js";
const JPG_DATA: &[u8] = b"1";
const MP3_DATA: &[u8] = b"2";
const JS_DATA: &[u8] = b"3";
const OTHER_MP3_DATA: &[u8] = b"4";
#[test]
fn roundtrip() {
let (mut src_col, src_tempdir) = open_fs_test_collection("src");
let (mut target_col, _target_tempdir) = open_fs_test_collection("target");
let apkg_path = src_tempdir.path().join("test.apkg");
let (main_deck, sibling_deck) = src_col.add_sample_decks();
let notetype = src_col.add_sample_notetype();
let note = src_col.add_sample_note(&main_deck, &sibling_deck, &notetype);
src_col.add_sample_media();
target_col.add_conflicting_media();
src_col
.export_apkg(
&apkg_path,
SearchNode::from_deck_name("parent::sample"),
true,
true,
true,
None,
|_, _| true,
)
.unwrap();
target_col.import_apkg(&apkg_path, |_, _| true).unwrap();
target_col.assert_decks();
target_col.assert_notetype(&notetype);
target_col.assert_note_and_media(&note);
target_col.undo().unwrap();
target_col.assert_empty();
}
impl Collection {
fn add_sample_decks(&mut self) -> (Deck, Deck) {
let sample = self.add_named_deck("parent\x1fsample");
self.add_named_deck("parent\x1fsample\x1fchild");
let siblings = self.add_named_deck("siblings");
(sample, siblings)
}
fn add_named_deck(&mut self, name: &str) -> Deck {
let mut deck = Deck::new_normal();
deck.name = NativeDeckName::from_native_str(name);
self.add_deck(&mut deck).unwrap();
deck
}
fn add_sample_notetype(&mut self) -> Notetype {
let mut nt = Notetype {
name: "sample".into(),
..Default::default()
};
nt.add_field("sample");
nt.add_template("sample1", "{{sample}}", "<script src=_sample.js></script>");
nt.add_template("sample2", "{{sample}}2", "");
self.add_notetype(&mut nt, true).unwrap();
nt
}
fn add_sample_note(
&mut self,
main_deck: &Deck,
sibling_decks: &Deck,
notetype: &Notetype,
) -> Note {
let mut sample = notetype.new_note();
sample.fields_mut()[0] = format!("<img src='{SAMPLE_JPG}'> [sound:{SAMPLE_MP3}]");
sample.tags = vec!["sample".into()];
self.add_note(&mut sample, main_deck.id).unwrap();
let card = self
.storage
.get_card_by_ordinal(sample.id, 1)
.unwrap()
.unwrap();
self.set_deck(&[card.id], sibling_decks.id).unwrap();
sample
}
fn add_sample_media(&self) {
self.add_media(&[
(SAMPLE_JPG, JPG_DATA),
(SAMPLE_MP3, MP3_DATA),
(SAMPLE_JS, JS_DATA),
]);
}
fn add_conflicting_media(&mut self) {
let mut file = File::create(self.media_folder.join(SAMPLE_MP3)).unwrap();
file.write_all(OTHER_MP3_DATA).unwrap();
}
fn assert_decks(&mut self) {
let existing_decks: HashSet<_> = self
.get_all_deck_names(true)
.unwrap()
.into_iter()
.map(|(_, name)| name)
.collect();
for deck in ["parent", "parent::sample", "siblings"] {
assert!(existing_decks.contains(deck));
}
assert!(!existing_decks.contains("parent::sample::child"));
}
fn assert_notetype(&mut self, notetype: &Notetype) {
assert!(self.get_notetype(notetype.id).unwrap().is_some());
}
fn assert_note_and_media(&mut self, note: &Note) {
let sha1 = sha1_of_data(MP3_DATA);
let new_mp3_name = format!("sample-{}.mp3", hex::encode(&sha1));
for file in [SAMPLE_JPG, SAMPLE_JS, &new_mp3_name] {
assert!(self.media_folder.join(file).exists())
}
let imported_note = self.storage.get_note(note.id).unwrap().unwrap();
assert!(imported_note.fields()[0].contains(&new_mp3_name));
}
fn assert_empty(&self) {
assert!(self.get_all_deck_names(true).unwrap().is_empty());
assert!(self.storage.get_all_note_ids().unwrap().is_empty());
assert!(self.storage.get_all_card_ids().unwrap().is_empty());
assert!(self.storage.all_tags().unwrap().is_empty());
}
}

View file

@ -4,7 +4,8 @@
use std::{
borrow::Cow,
collections::HashMap,
fs::{DirEntry, File},
ffi::OsStr,
fs::File,
io::{self, Read, Write},
path::{Path, PathBuf},
};
@ -21,6 +22,7 @@ use zstd::{
use super::super::{MediaEntries, MediaEntry, Meta, Version};
use crate::{
collection::CollectionBuilder,
import_export::IncrementableProgress,
io::{atomic_rename, read_dir_files, tempfile_in_parent_of},
media::files::filename_if_normalized,
prelude::*,
@ -38,8 +40,9 @@ impl Collection {
out_path: impl AsRef<Path>,
include_media: bool,
legacy: bool,
progress_fn: impl FnMut(usize),
progress_fn: impl 'static + FnMut(usize, bool) -> bool,
) -> Result<()> {
let mut progress = IncrementableProgress::new(progress_fn);
let colpkg_name = out_path.as_ref();
let temp_colpkg = tempfile_in_parent_of(colpkg_name)?;
let src_path = self.col_path.clone();
@ -61,19 +64,48 @@ impl Collection {
src_media_folder,
legacy,
&tr,
progress_fn,
&mut progress,
)?;
atomic_rename(temp_colpkg, colpkg_name, true)
}
}
pub struct MediaIter(Box<dyn Iterator<Item = io::Result<PathBuf>>>);
impl MediaIter {
/// Iterator over all files in the given path, without traversing subfolders.
pub fn from_folder(path: &Path) -> Result<Self> {
Ok(Self(Box::new(
read_dir_files(path)?.map(|res| res.map(|entry| entry.path())),
)))
}
/// Iterator over all given files in the given folder.
/// Missing files are silently ignored.
pub fn from_file_list(
list: impl IntoIterator<Item = String> + 'static,
folder: PathBuf,
) -> Self {
Self(Box::new(
list.into_iter()
.map(move |file| folder.join(file))
.filter(|path| path.exists())
.map(Ok),
))
}
pub fn empty() -> Self {
Self(Box::new(std::iter::empty()))
}
}
fn export_collection_file(
out_path: impl AsRef<Path>,
col_path: impl AsRef<Path>,
media_dir: Option<PathBuf>,
legacy: bool,
tr: &I18n,
progress_fn: impl FnMut(usize),
progress: &mut IncrementableProgress<usize>,
) -> Result<()> {
let meta = if legacy {
Meta::new_legacy()
@ -82,15 +114,13 @@ fn export_collection_file(
};
let mut col_file = File::open(col_path)?;
let col_size = col_file.metadata()?.len() as usize;
export_collection(
meta,
out_path,
&mut col_file,
col_size,
media_dir,
tr,
progress_fn,
)
let media = if let Some(path) = media_dir {
MediaIter::from_folder(&path)?
} else {
MediaIter::empty()
};
export_collection(meta, out_path, &mut col_file, col_size, media, tr, progress)
}
/// Write copied collection data without any media.
@ -105,20 +135,20 @@ pub(crate) fn export_colpkg_from_data(
out_path,
&mut col_data,
col_size,
None,
MediaIter::empty(),
tr,
|_| (),
&mut IncrementableProgress::new(|_, _| true),
)
}
fn export_collection(
pub(crate) fn export_collection(
meta: Meta,
out_path: impl AsRef<Path>,
col: &mut impl Read,
col_size: usize,
media_dir: Option<PathBuf>,
media: MediaIter,
tr: &I18n,
progress_fn: impl FnMut(usize),
progress: &mut IncrementableProgress<usize>,
) -> Result<()> {
let out_file = File::create(&out_path)?;
let mut zip = ZipWriter::new(out_file);
@ -129,7 +159,7 @@ fn export_collection(
zip.write_all(&meta_bytes)?;
write_collection(&meta, &mut zip, col, col_size)?;
write_dummy_collection(&mut zip, tr)?;
write_media(&meta, &mut zip, media_dir, progress_fn)?;
write_media(&meta, &mut zip, media, progress)?;
zip.finish()?;
Ok(())
@ -203,17 +233,12 @@ fn zstd_copy(reader: &mut impl Read, writer: &mut impl Write, size: usize) -> Re
fn write_media(
meta: &Meta,
zip: &mut ZipWriter<File>,
media_dir: Option<PathBuf>,
progress_fn: impl FnMut(usize),
media: MediaIter,
progress: &mut IncrementableProgress<usize>,
) -> Result<()> {
let mut media_entries = vec![];
if let Some(media_dir) = media_dir {
write_media_files(meta, zip, &media_dir, &mut media_entries, progress_fn)?;
}
write_media_files(meta, zip, media, &mut media_entries, progress)?;
write_media_map(meta, media_entries, zip)?;
Ok(())
}
@ -251,19 +276,23 @@ fn write_media_map(
fn write_media_files(
meta: &Meta,
zip: &mut ZipWriter<File>,
dir: &Path,
media: MediaIter,
media_entries: &mut Vec<MediaEntry>,
mut progress_fn: impl FnMut(usize),
progress: &mut IncrementableProgress<usize>,
) -> Result<()> {
let mut copier = MediaCopier::new(meta);
for (index, entry) in read_dir_files(dir)?.enumerate() {
progress_fn(index);
let mut incrementor = progress.incrementor(|u| u);
for (index, res) in media.0.enumerate() {
incrementor.increment()?;
let path = res?;
zip.start_file(index.to_string(), file_options_stored())?;
let entry = entry?;
let name = normalized_unicode_file_name(&entry)?;
let mut file = File::open(entry.path())?;
let mut file = File::open(&path)?;
let file_name = path
.file_name()
.ok_or_else(|| AnkiError::invalid_input("not a file path"))?;
let name = normalized_unicode_file_name(file_name)?;
let (size, sha1) = copier.copy(&mut file, zip)?;
media_entries.push(MediaEntry::new(name, size, sha1));
@ -272,23 +301,11 @@ fn write_media_files(
Ok(())
}
impl MediaEntry {
fn new(name: impl Into<String>, size: impl TryInto<u32>, sha1: impl Into<Vec<u8>>) -> Self {
MediaEntry {
name: name.into(),
size: size.try_into().unwrap_or_default(),
sha1: sha1.into(),
legacy_zip_filename: None,
}
}
}
fn normalized_unicode_file_name(entry: &DirEntry) -> Result<String> {
let filename = entry.file_name();
fn normalized_unicode_file_name(filename: &OsStr) -> Result<String> {
let filename = filename.to_str().ok_or_else(|| {
AnkiError::IoError(format!(
"non-unicode file name: {}",
entry.file_name().to_string_lossy()
filename.to_string_lossy()
))
})?;
filename_if_normalized(filename)
@ -324,7 +341,7 @@ impl MediaCopier {
&mut self,
reader: &mut impl Read,
writer: &mut impl Write,
) -> Result<(usize, [u8; 20])> {
) -> Result<(usize, Sha1Hash)> {
let mut size = 0;
let mut hasher = Sha1::new();
let mut buf = [0; 64 * 1024];

View file

@ -2,64 +2,39 @@
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
borrow::Cow,
collections::HashMap,
fs::{self, File},
io::{self, Read, Write},
path::{Component, Path, PathBuf},
fs::File,
io::{self, Write},
path::{Path, PathBuf},
};
use prost::Message;
use zip::{read::ZipFile, ZipArchive};
use zstd::{self, stream::copy_decode};
use super::super::Version;
use crate::{
collection::CollectionBuilder,
error::ImportError,
import_export::{
package::{MediaEntries, MediaEntry, Meta},
ImportProgress,
package::{
media::{extract_media_entries, SafeMediaEntry},
Meta,
},
ImportProgress, IncrementableProgress,
},
io::{atomic_rename, tempfile_in_parent_of},
media::files::normalize_filename,
media::MediaManager,
prelude::*,
};
impl Meta {
/// Extracts meta data from an archive and checks if its version is supported.
pub(super) fn from_archive(archive: &mut ZipArchive<File>) -> Result<Self> {
let meta_bytes = archive.by_name("meta").ok().and_then(|mut meta_file| {
let mut buf = vec![];
meta_file.read_to_end(&mut buf).ok()?;
Some(buf)
});
let meta = if let Some(bytes) = meta_bytes {
let meta: Meta = Message::decode(&*bytes)?;
if meta.version() == Version::Unknown {
return Err(AnkiError::ImportError(ImportError::TooNew));
}
meta
} else {
Meta {
version: if archive.by_name("collection.anki21").is_ok() {
Version::Legacy2
} else {
Version::Legacy1
} as i32,
}
};
Ok(meta)
}
}
pub fn import_colpkg(
colpkg_path: &str,
target_col_path: &str,
target_media_folder: &str,
mut progress_fn: impl FnMut(ImportProgress) -> Result<()>,
target_media_folder: &Path,
media_db: &Path,
progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool,
log: &Logger,
) -> Result<()> {
progress_fn(ImportProgress::Collection)?;
let mut progress = IncrementableProgress::new(progress_fn);
progress.call(ImportProgress::File)?;
let col_path = PathBuf::from(target_col_path);
let mut tempfile = tempfile_in_parent_of(&col_path)?;
@ -68,12 +43,18 @@ pub fn import_colpkg(
let meta = Meta::from_archive(&mut archive)?;
copy_collection(&mut archive, &mut tempfile, &meta)?;
progress_fn(ImportProgress::Collection)?;
progress.call(ImportProgress::File)?;
check_collection_and_mod_schema(tempfile.path())?;
progress_fn(ImportProgress::Collection)?;
progress.call(ImportProgress::File)?;
let media_folder = Path::new(target_media_folder);
restore_media(&meta, progress_fn, &mut archive, media_folder)?;
restore_media(
&meta,
&mut progress,
&mut archive,
target_media_folder,
media_db,
log,
)?;
atomic_rename(tempfile, &col_path, true)
}
@ -96,31 +77,25 @@ fn check_collection_and_mod_schema(col_path: &Path) -> Result<()> {
fn restore_media(
meta: &Meta,
mut progress_fn: impl FnMut(ImportProgress) -> Result<()>,
progress: &mut IncrementableProgress<ImportProgress>,
archive: &mut ZipArchive<File>,
media_folder: &Path,
media_db: &Path,
log: &Logger,
) -> Result<()> {
let media_entries = extract_media_entries(meta, archive)?;
if media_entries.is_empty() {
return Ok(());
}
std::fs::create_dir_all(media_folder)?;
let media_manager = MediaManager::new(media_folder, media_db)?;
let mut media_comparer = MediaComparer::new(meta, progress, &media_manager, log)?;
for (entry_idx, entry) in media_entries.iter().enumerate() {
if entry_idx % 10 == 0 {
progress_fn(ImportProgress::Media(entry_idx))?;
}
let zip_filename = entry
.legacy_zip_filename
.map(|n| n as usize)
.unwrap_or(entry_idx)
.to_string();
if let Ok(mut zip_file) = archive.by_name(&zip_filename) {
maybe_restore_media_file(meta, media_folder, entry, &mut zip_file)?;
} else {
return Err(AnkiError::invalid_input(&format!(
"{zip_filename} missing from archive"
)));
}
let mut incrementor = progress.incrementor(ImportProgress::Media);
for mut entry in media_entries {
incrementor.increment()?;
maybe_restore_media_file(meta, media_folder, archive, &mut entry, &mut media_comparer)?;
}
Ok(())
@ -129,13 +104,19 @@ fn restore_media(
fn maybe_restore_media_file(
meta: &Meta,
media_folder: &Path,
entry: &MediaEntry,
zip_file: &mut ZipFile,
archive: &mut ZipArchive<File>,
entry: &mut SafeMediaEntry,
media_comparer: &mut MediaComparer,
) -> Result<()> {
let file_path = entry.safe_normalized_file_path(meta, media_folder)?;
let already_exists = entry.is_equal_to(meta, zip_file, &file_path);
let file_path = entry.file_path(media_folder);
let mut zip_file = entry.fetch_file(archive)?;
if meta.media_list_is_hashmap() {
entry.size = zip_file.size() as u32;
}
let already_exists = media_comparer.entry_is_equal_to(entry, &file_path)?;
if !already_exists {
restore_media_file(meta, zip_file, &file_path)?;
restore_media_file(meta, &mut zip_file, &file_path)?;
};
Ok(())
@ -154,79 +135,6 @@ fn restore_media_file(meta: &Meta, zip_file: &mut ZipFile, path: &Path) -> Resul
atomic_rename(tempfile, path, false)
}
impl MediaEntry {
fn safe_normalized_file_path(&self, meta: &Meta, media_folder: &Path) -> Result<PathBuf> {
check_filename_safe(&self.name)?;
let normalized = maybe_normalizing(&self.name, meta.strict_media_checks())?;
Ok(media_folder.join(normalized.as_ref()))
}
fn is_equal_to(&self, meta: &Meta, self_zipped: &ZipFile, other_path: &Path) -> bool {
// TODO: checks hashs (https://github.com/ankitects/anki/pull/1723#discussion_r829653147)
let self_size = if meta.media_list_is_hashmap() {
self_zipped.size()
} else {
self.size as u64
};
fs::metadata(other_path)
.map(|metadata| metadata.len() as u64 == self_size)
.unwrap_or_default()
}
}
/// - If strict is true, return an error if not normalized.
/// - If false, return the normalized version.
fn maybe_normalizing(name: &str, strict: bool) -> Result<Cow<str>> {
let normalized = normalize_filename(name);
if strict && matches!(normalized, Cow::Owned(_)) {
// exporting code should have checked this
Err(AnkiError::ImportError(ImportError::Corrupt))
} else {
Ok(normalized)
}
}
/// Return an error if name contains any path separators.
fn check_filename_safe(name: &str) -> Result<()> {
let mut components = Path::new(name).components();
let first_element_normal = components
.next()
.map(|component| matches!(component, Component::Normal(_)))
.unwrap_or_default();
if !first_element_normal || components.next().is_some() {
Err(AnkiError::ImportError(ImportError::Corrupt))
} else {
Ok(())
}
}
fn extract_media_entries(meta: &Meta, archive: &mut ZipArchive<File>) -> Result<Vec<MediaEntry>> {
let mut file = archive.by_name("media")?;
let mut buf = Vec::new();
if meta.zstd_compressed() {
copy_decode(file, &mut buf)?;
} else {
io::copy(&mut file, &mut buf)?;
}
if meta.media_list_is_hashmap() {
let map: HashMap<&str, String> = serde_json::from_slice(&buf)?;
map.into_iter()
.map(|(idx_str, name)| {
let idx: u32 = idx_str.parse()?;
Ok(MediaEntry {
name,
size: 0,
sha1: vec![],
legacy_zip_filename: Some(idx),
})
})
.collect()
} else {
let entries: MediaEntries = Message::decode(&*buf)?;
Ok(entries.entries)
}
}
fn copy_collection(
archive: &mut ZipArchive<File>,
writer: &mut impl Write,
@ -244,29 +152,31 @@ fn copy_collection(
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
type GetChecksumFn<'a> = dyn FnMut(&str) -> Result<Option<Sha1Hash>> + 'a;
#[test]
fn path_traversal() {
assert!(check_filename_safe("foo").is_ok(),);
struct MediaComparer<'a>(Option<Box<GetChecksumFn<'a>>>);
assert!(check_filename_safe("..").is_err());
assert!(check_filename_safe("foo/bar").is_err());
assert!(check_filename_safe("/foo").is_err());
assert!(check_filename_safe("../foo").is_err());
impl<'a> MediaComparer<'a> {
fn new(
meta: &Meta,
progress: &mut IncrementableProgress<ImportProgress>,
media_manager: &'a MediaManager,
log: &Logger,
) -> Result<Self> {
Ok(Self(if meta.media_list_is_hashmap() {
None
} else {
let mut db_progress_fn = progress.media_db_fn(ImportProgress::MediaCheck)?;
media_manager.register_changes(&mut db_progress_fn, log)?;
Some(Box::new(media_manager.checksum_getter()))
}))
}
if cfg!(windows) {
assert!(check_filename_safe("foo\\bar").is_err());
assert!(check_filename_safe("c:\\foo").is_err());
assert!(check_filename_safe("\\foo").is_err());
fn entry_is_equal_to(&mut self, entry: &SafeMediaEntry, other_path: &Path) -> Result<bool> {
if let Some(ref mut get_checksum) = self.0 {
Ok(entry.has_checksum_equal_to(get_checksum)?)
} else {
Ok(entry.has_size_equal_to(other_path))
}
}
#[test]
fn normalization() {
assert_eq!(&maybe_normalizing("con", false).unwrap(), "con_");
assert!(&maybe_normalizing("con", true).is_err());
}
}

View file

@ -8,8 +8,8 @@ use std::path::Path;
use tempfile::tempdir;
use crate::{
collection::CollectionBuilder, import_export::package::import_colpkg, media::MediaManager,
prelude::*,
collection::CollectionBuilder, import_export::package::import_colpkg, log::terminal,
media::MediaManager, prelude::*,
};
fn collection_with_media(dir: &Path, name: &str) -> Result<Collection> {
@ -41,19 +41,26 @@ fn roundtrip() -> Result<()> {
// export to a file
let col = collection_with_media(dir, name)?;
let colpkg_name = dir.join(format!("{name}.colpkg"));
col.export_colpkg(&colpkg_name, true, legacy, |_| ())?;
col.export_colpkg(&colpkg_name, true, legacy, |_, _| true)?;
// import into a new collection
let anki2_name = dir
.join(format!("{name}.anki2"))
.to_string_lossy()
.into_owned();
let import_media_dir = dir.join(format!("{name}.media"));
std::fs::create_dir_all(&import_media_dir)?;
let import_media_db = dir.join(format!("{name}.mdb"));
MediaManager::new(&import_media_dir, &import_media_db)?;
import_colpkg(
&colpkg_name.to_string_lossy(),
&anki2_name,
import_media_dir.to_str().unwrap(),
|_| Ok(()),
&import_media_dir,
&import_media_db,
|_, _| true,
&terminal(),
)?;
// confirm collection imported
let col = CollectionBuilder::new(&anki2_name).build()?;
assert_eq!(
@ -82,7 +89,7 @@ fn normalization_check_on_export() -> Result<()> {
// manually write a file in the wrong encoding.
std::fs::write(col.media_folder.join("ぱぱ.jpg"), "nfd encoding")?;
assert_eq!(
col.export_colpkg(&colpkg_name, true, false, |_| ())
col.export_colpkg(&colpkg_name, true, false, |_, _| true,)
.unwrap_err(),
AnkiError::MediaCheckRequired
);

View file

@ -0,0 +1,174 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
borrow::Cow,
collections::HashMap,
fs::{self, File},
io,
path::{Path, PathBuf},
};
use prost::Message;
use tempfile::NamedTempFile;
use zip::{read::ZipFile, ZipArchive};
use zstd::stream::copy_decode;
use super::{MediaEntries, MediaEntry, Meta};
use crate::{
error::ImportError,
io::{atomic_rename, filename_is_safe},
media::files::normalize_filename,
prelude::*,
};
/// Like [MediaEntry], but with a safe filename and set zip filename.
pub(super) struct SafeMediaEntry {
pub(super) name: String,
pub(super) size: u32,
pub(super) sha1: Sha1Hash,
pub(super) index: usize,
}
impl MediaEntry {
pub(super) fn new(
name: impl Into<String>,
size: impl TryInto<u32>,
sha1: impl Into<Vec<u8>>,
) -> Self {
MediaEntry {
name: name.into(),
size: size.try_into().unwrap_or_default(),
sha1: sha1.into(),
legacy_zip_filename: None,
}
}
}
impl SafeMediaEntry {
pub(super) fn from_entry(enumerated: (usize, MediaEntry)) -> Result<Self> {
let (index, entry) = enumerated;
if let Ok(sha1) = entry.sha1.try_into() {
if !matches!(safe_normalized_file_name(&entry.name)?, Cow::Owned(_)) {
return Ok(Self {
name: entry.name,
size: entry.size,
sha1,
index,
});
}
}
Err(AnkiError::ImportError(ImportError::Corrupt))
}
pub(super) fn from_legacy(legacy_entry: (&str, String)) -> Result<Self> {
let zip_filename: usize = legacy_entry.0.parse()?;
let name = match safe_normalized_file_name(&legacy_entry.1)? {
Cow::Owned(new_name) => new_name,
Cow::Borrowed(_) => legacy_entry.1,
};
Ok(Self {
name,
size: 0,
sha1: [0; 20],
index: zip_filename,
})
}
pub(super) fn file_path(&self, media_folder: &Path) -> PathBuf {
media_folder.join(&self.name)
}
pub(super) fn fetch_file<'a>(&self, archive: &'a mut ZipArchive<File>) -> Result<ZipFile<'a>> {
archive
.by_name(&self.index.to_string())
.map_err(|_| AnkiError::invalid_input(&format!("{} missing from archive", self.index)))
}
pub(super) fn has_checksum_equal_to(
&self,
get_checksum: &mut impl FnMut(&str) -> Result<Option<Sha1Hash>>,
) -> Result<bool> {
get_checksum(&self.name).map(|opt| opt.map_or(false, |sha1| sha1 == self.sha1))
}
pub(super) fn has_size_equal_to(&self, other_path: &Path) -> bool {
fs::metadata(other_path).map_or(false, |metadata| metadata.len() == self.size as u64)
}
pub(super) fn copy_from_archive(
&self,
archive: &mut ZipArchive<File>,
target_folder: &Path,
) -> Result<()> {
let mut file = self.fetch_file(archive)?;
let mut tempfile = NamedTempFile::new_in(target_folder)?;
io::copy(&mut file, &mut tempfile)?;
atomic_rename(tempfile, &self.file_path(target_folder), false)
}
}
pub(super) fn extract_media_entries(
meta: &Meta,
archive: &mut ZipArchive<File>,
) -> Result<Vec<SafeMediaEntry>> {
let media_list_data = get_media_list_data(archive, meta)?;
if meta.media_list_is_hashmap() {
let map: HashMap<&str, String> = serde_json::from_slice(&media_list_data)?;
map.into_iter().map(SafeMediaEntry::from_legacy).collect()
} else {
MediaEntries::decode_safe_entries(&media_list_data)
}
}
pub(super) fn safe_normalized_file_name(name: &str) -> Result<Cow<str>> {
if !filename_is_safe(name) {
Err(AnkiError::ImportError(ImportError::Corrupt))
} else {
Ok(normalize_filename(name))
}
}
fn get_media_list_data(archive: &mut ZipArchive<File>, meta: &Meta) -> Result<Vec<u8>> {
let mut file = archive.by_name("media")?;
let mut buf = Vec::new();
if meta.zstd_compressed() {
copy_decode(file, &mut buf)?;
} else {
io::copy(&mut file, &mut buf)?;
}
Ok(buf)
}
impl MediaEntries {
fn decode_safe_entries(buf: &[u8]) -> Result<Vec<SafeMediaEntry>> {
let entries: Self = Message::decode(buf)?;
entries
.entries
.into_iter()
.enumerate()
.map(SafeMediaEntry::from_entry)
.collect()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn normalization() {
// legacy entries get normalized on deserialisation
let entry = SafeMediaEntry::from_legacy(("1", "con".to_owned())).unwrap();
assert_eq!(entry.name, "con_");
// new-style entries should have been normalized on export
let mut entries = Vec::new();
MediaEntries {
entries: vec![MediaEntry::new("con", 0, Vec::new())],
}
.encode(&mut entries)
.unwrap();
assert!(MediaEntries::decode_safe_entries(&entries).is_err());
}
}

View file

@ -1,7 +1,13 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{fs::File, io::Read};
use prost::Message;
use zip::ZipArchive;
pub(super) use crate::backend_proto::{package_metadata::Version, PackageMetadata as Meta};
use crate::{error::ImportError, prelude::*, storage::SchemaVersion};
impl Version {
pub(super) fn collection_filename(&self) -> &'static str {
@ -12,6 +18,16 @@ impl Version {
Version::Latest => "collection.anki21b",
}
}
/// Latest schema version that is supported by all clients supporting
/// this package version.
pub(super) fn schema_version(&self) -> SchemaVersion {
match self {
Version::Unknown => unreachable!(),
Version::Legacy1 | Version::Legacy2 => SchemaVersion::V11,
Version::Latest => SchemaVersion::V18,
}
}
}
impl Meta {
@ -27,10 +43,41 @@ impl Meta {
}
}
/// Extracts meta data from an archive and checks if its version is supported.
pub(super) fn from_archive(archive: &mut ZipArchive<File>) -> Result<Self> {
let meta_bytes = archive.by_name("meta").ok().and_then(|mut meta_file| {
let mut buf = vec![];
meta_file.read_to_end(&mut buf).ok()?;
Some(buf)
});
let meta = if let Some(bytes) = meta_bytes {
let meta: Meta = Message::decode(&*bytes)?;
if meta.version() == Version::Unknown {
return Err(AnkiError::ImportError(ImportError::TooNew));
}
meta
} else {
Meta {
version: if archive.by_name("collection.anki21").is_ok() {
Version::Legacy2
} else {
Version::Legacy1
} as i32,
}
};
Ok(meta)
}
pub(super) fn collection_filename(&self) -> &'static str {
self.version().collection_filename()
}
/// Latest schema version that is supported by all clients supporting
/// this package version.
pub(super) fn schema_version(&self) -> SchemaVersion {
self.version().schema_version()
}
pub(super) fn zstd_compressed(&self) -> bool {
!self.is_legacy()
}
@ -39,10 +86,6 @@ impl Meta {
self.is_legacy()
}
pub(super) fn strict_media_checks(&self) -> bool {
!self.is_legacy()
}
fn is_legacy(&self) -> bool {
matches!(self.version(), Version::Legacy1 | Version::Legacy2)
}

View file

@ -1,11 +1,15 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
mod apkg;
mod colpkg;
mod media;
mod meta;
pub(crate) use apkg::NoteMeta;
pub(crate) use colpkg::export::export_colpkg_from_data;
pub use colpkg::import::import_colpkg;
pub(self) use meta::{Meta, Version};
pub use crate::backend_proto::import_anki_package_response::{Log as NoteLog, Note as LogNote};
pub(self) use crate::backend_proto::{media_entries::MediaEntry, MediaEntries};

View file

@ -1,7 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::path::Path;
use std::path::{Component, Path};
use tempfile::NamedTempFile;
@ -42,6 +42,17 @@ pub(crate) fn read_dir_files(path: impl AsRef<Path>) -> std::io::Result<ReadDirF
std::fs::read_dir(path).map(ReadDirFiles)
}
/// True if name does not contain any path separators.
pub(crate) fn filename_is_safe(name: &str) -> bool {
let mut components = Path::new(name).components();
let first_element_normal = components
.next()
.map(|component| matches!(component, Component::Normal(_)))
.unwrap_or_default();
first_element_normal && components.next().is_none()
}
pub(crate) struct ReadDirFiles(std::fs::ReadDir);
impl Iterator for ReadDirFiles {
@ -60,3 +71,24 @@ impl Iterator for ReadDirFiles {
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn path_traversal() {
assert!(filename_is_safe("foo"));
assert!(!filename_is_safe(".."));
assert!(!filename_is_safe("foo/bar"));
assert!(!filename_is_safe("/foo"));
assert!(!filename_is_safe("../foo"));
if cfg!(windows) {
assert!(!filename_is_safe("foo\\bar"));
assert!(!filename_is_safe("c:\\foo"));
assert!(!filename_is_safe("\\foo"));
}
}
}

View file

@ -40,6 +40,7 @@ mod sync;
pub mod tags;
pub mod template;
pub mod template_filters;
pub(crate) mod tests;
pub mod text;
pub mod timestamp;
pub mod types;

View file

@ -4,7 +4,6 @@
use std::{collections::HashMap, path::Path, time};
use crate::{
error::{AnkiError, Result},
log::{debug, Logger},
media::{
database::{MediaDatabaseContext, MediaEntry},
@ -13,11 +12,12 @@ use crate::{
NONSYNCABLE_FILENAME,
},
},
prelude::*,
};
struct FilesystemEntry {
fname: String,
sha1: Option<[u8; 20]>,
sha1: Option<Sha1Hash>,
mtime: i64,
is_new: bool,
}

View file

@ -5,7 +5,7 @@ use std::{collections::HashMap, path::Path};
use rusqlite::{params, Connection, OptionalExtension, Row, Statement};
use crate::error::Result;
use crate::prelude::*;
fn trace(s: &str) {
println!("sql: {}", s)
@ -47,7 +47,7 @@ fn initial_db_setup(db: &mut Connection) -> Result<()> {
pub struct MediaEntry {
pub fname: String,
/// If None, file has been deleted
pub sha1: Option<[u8; 20]>,
pub sha1: Option<Sha1Hash>,
// Modification time; 0 if deleted
pub mtime: i64,
/// True if changed since last sync
@ -222,6 +222,14 @@ delete from media where fname=?"
Ok(map?)
}
/// Returns all filenames and checksums, where the checksum is not null.
pub(super) fn all_checksums(&mut self) -> Result<HashMap<String, Sha1Hash>> {
self.db
.prepare("SELECT fname, csum FROM media WHERE csum IS NOT NULL")?
.query_and_then([], row_to_name_and_checksum)?
.collect()
}
pub(super) fn force_resync(&mut self) -> Result<()> {
self.db
.execute_batch("delete from media; update meta set lastUsn = 0, dirMod = 0")
@ -231,7 +239,7 @@ delete from media where fname=?"
fn row_to_entry(row: &Row) -> rusqlite::Result<MediaEntry> {
// map the string checksum into bytes
let sha1_str: Option<String> = row.get(1)?;
let sha1_str = row.get_ref(1)?.as_str_or_null()?;
let sha1_array = if let Some(s) = sha1_str {
let mut arr = [0; 20];
match hex::decode_to_slice(s, arr.as_mut()) {
@ -250,6 +258,15 @@ fn row_to_entry(row: &Row) -> rusqlite::Result<MediaEntry> {
})
}
fn row_to_name_and_checksum(row: &Row) -> Result<(String, Sha1Hash)> {
let file_name = row.get(0)?;
let sha1_str: String = row.get(1)?;
let mut sha1 = [0; 20];
hex::decode_to_slice(sha1_str, &mut sha1)
.map_err(|_| AnkiError::invalid_input(format!("bad media checksum: {file_name}")))?;
Ok((file_name, sha1))
}
#[cfg(test)]
mod test {
use tempfile::NamedTempFile;

View file

@ -15,10 +15,7 @@ use sha1::Sha1;
use unic_ucd_category::GeneralCategory;
use unicode_normalization::{is_nfc, UnicodeNormalization};
use crate::{
error::{AnkiError, Result},
log::{debug, Logger},
};
use crate::prelude::*;
/// The maximum length we allow a filename to be. When combined
/// with the rest of the path, the full path needs to be under ~240 chars
@ -164,7 +161,7 @@ pub fn add_data_to_folder_uniquely<'a, P>(
folder: P,
desired_name: &'a str,
data: &[u8],
sha1: [u8; 20],
sha1: Sha1Hash,
) -> io::Result<Cow<'a, str>>
where
P: AsRef<Path>,
@ -194,7 +191,7 @@ where
}
/// Convert foo.jpg into foo-abcde12345679.jpg
fn add_hash_suffix_to_file_stem(fname: &str, hash: &[u8; 20]) -> String {
pub(crate) fn add_hash_suffix_to_file_stem(fname: &str, hash: &Sha1Hash) -> String {
// when appending a hash to make unique, it will be 40 bytes plus the hyphen.
let max_len = MAX_FILENAME_LENGTH - 40 - 1;
@ -244,18 +241,18 @@ fn split_and_truncate_filename(fname: &str, max_bytes: usize) -> (&str, &str) {
};
// cap extension to 10 bytes so stem_len can't be negative
ext = truncate_to_char_boundary(ext, 10);
ext = truncated_to_char_boundary(ext, 10);
// cap stem, allowing for the . and a trailing _
let stem_len = max_bytes - ext.len() - 2;
stem = truncate_to_char_boundary(stem, stem_len);
stem = truncated_to_char_boundary(stem, stem_len);
(stem, ext)
}
/// Trim a string on a valid UTF8 boundary.
/// Return a substring on a valid UTF8 boundary.
/// Based on a funtion in the Rust stdlib.
fn truncate_to_char_boundary(s: &str, mut max: usize) -> &str {
fn truncated_to_char_boundary(s: &str, mut max: usize) -> &str {
if max >= s.len() {
s
} else {
@ -267,7 +264,7 @@ fn truncate_to_char_boundary(s: &str, mut max: usize) -> &str {
}
/// Return the SHA1 of a file if it exists, or None.
fn existing_file_sha1(path: &Path) -> io::Result<Option<[u8; 20]>> {
fn existing_file_sha1(path: &Path) -> io::Result<Option<Sha1Hash>> {
match sha1_of_file(path) {
Ok(o) => Ok(Some(o)),
Err(e) => {
@ -281,12 +278,17 @@ fn existing_file_sha1(path: &Path) -> io::Result<Option<[u8; 20]>> {
}
/// Return the SHA1 of a file, failing if it doesn't exist.
pub(crate) fn sha1_of_file(path: &Path) -> io::Result<[u8; 20]> {
pub(crate) fn sha1_of_file(path: &Path) -> io::Result<Sha1Hash> {
let mut file = fs::File::open(path)?;
sha1_of_reader(&mut file)
}
/// Return the SHA1 of a stream.
pub(crate) fn sha1_of_reader(reader: &mut impl Read) -> io::Result<Sha1Hash> {
let mut hasher = Sha1::new();
let mut buf = [0; 64 * 1024];
loop {
match file.read(&mut buf) {
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => hasher.update(&buf[0..n]),
Err(e) => {
@ -302,7 +304,7 @@ pub(crate) fn sha1_of_file(path: &Path) -> io::Result<[u8; 20]> {
}
/// Return the SHA1 of provided data.
pub(crate) fn sha1_of_data(data: &[u8]) -> [u8; 20] {
pub(crate) fn sha1_of_data(data: &[u8]) -> Sha1Hash {
let mut hasher = Sha1::new();
hasher.update(data);
hasher.digest().bytes()
@ -371,7 +373,7 @@ pub(super) fn trash_folder(media_folder: &Path) -> Result<PathBuf> {
pub(super) struct AddedFile {
pub fname: String,
pub sha1: [u8; 20],
pub sha1: Sha1Hash,
pub mtime: i64,
pub renamed_from: Option<String>,
}

View file

@ -3,19 +3,21 @@
use std::{
borrow::Cow,
collections::HashMap,
path::{Path, PathBuf},
};
use rusqlite::Connection;
use slog::Logger;
use self::changetracker::ChangeTracker;
use crate::{
error::Result,
media::{
database::{open_or_create, MediaDatabaseContext, MediaEntry},
files::{add_data_to_folder_uniquely, mtime_as_i64, remove_files, sha1_of_data},
sync::{MediaSyncProgress, MediaSyncer},
},
prelude::*,
};
pub mod changetracker;
@ -24,6 +26,8 @@ pub mod database;
pub mod files;
pub mod sync;
pub type Sha1Hash = [u8; 20];
pub struct MediaManager {
db: Connection,
media_folder: PathBuf,
@ -153,4 +157,31 @@ impl MediaManager {
pub fn dbctx(&self) -> MediaDatabaseContext {
MediaDatabaseContext::new(&self.db)
}
pub fn all_checksums(
&self,
progress: impl FnMut(usize) -> bool,
log: &Logger,
) -> Result<HashMap<String, Sha1Hash>> {
let mut dbctx = self.dbctx();
ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut dbctx)?;
dbctx.all_checksums()
}
pub fn checksum_getter(&self) -> impl FnMut(&str) -> Result<Option<Sha1Hash>> + '_ {
let mut dbctx = self.dbctx();
move |fname: &str| {
dbctx
.get_entry(fname)
.map(|opt| opt.and_then(|entry| entry.sha1))
}
}
pub fn register_changes(
&self,
progress: &mut impl FnMut(usize) -> bool,
log: &Logger,
) -> Result<()> {
ChangeTracker::new(&self.media_folder, progress, log).register_changes(&mut self.dbctx())
}
}

View file

@ -55,6 +55,10 @@ impl Note {
&self.fields
}
pub fn into_fields(self) -> Vec<String> {
self.fields
}
pub fn set_field(&mut self, idx: usize, text: impl Into<String>) -> Result<()> {
if idx >= self.fields.len() {
return Err(AnkiError::invalid_input(
@ -320,7 +324,7 @@ fn invalid_char_for_field(c: char) -> bool {
}
impl Collection {
fn canonify_note_tags(&mut self, note: &mut Note, usn: Usn) -> Result<()> {
pub(crate) fn canonify_note_tags(&mut self, note: &mut Note, usn: Usn) -> Result<()> {
if !note.tags.is_empty() {
let tags = std::mem::take(&mut note.tags);
note.tags = self.canonify_tags(tags, usn)?.0;

View file

@ -83,13 +83,23 @@ impl Collection {
}
/// Add a note, not adding any cards.
pub(super) fn add_note_only_undoable(&mut self, note: &mut Note) -> Result<(), AnkiError> {
pub(crate) fn add_note_only_undoable(&mut self, note: &mut Note) -> Result<(), AnkiError> {
self.storage.add_note(note)?;
self.save_undo(UndoableNoteChange::Added(Box::new(note.clone())));
Ok(())
}
/// Add a note, not adding any cards. Caller guarantees id is unique.
pub(crate) fn add_note_only_with_id_undoable(&mut self, note: &mut Note) -> Result<()> {
if self.storage.add_note_if_unique(note)? {
self.save_undo(UndoableNoteChange::Added(Box::new(note.clone())));
Ok(())
} else {
Err(AnkiError::invalid_input("note id existed"))
}
}
pub(crate) fn update_note_tags_undoable(
&mut self,
tags: &NoteTags,

View file

@ -648,7 +648,7 @@ impl Collection {
/// - Caller must set notetype as modified if appropriate.
/// - This only supports undo when an existing notetype is passed in.
fn add_or_update_notetype_with_existing_id_inner(
pub(crate) fn add_or_update_notetype_with_existing_id_inner(
&mut self,
notetype: &mut Notetype,
original: Option<Notetype>,

View file

@ -42,6 +42,17 @@ impl Collection {
Ok(())
}
/// Caller must ensure [NotetypeId] is unique.
pub(crate) fn add_notetype_with_unique_id_undoable(
&mut self,
notetype: &Notetype,
) -> Result<()> {
self.storage
.add_or_update_notetype_with_existing_id(notetype)?;
self.save_undo(UndoableNotetypeChange::Added(Box::new(notetype.clone())));
Ok(())
}
pub(super) fn update_notetype_undoable(
&mut self,
notetype: &Notetype,

View file

@ -17,6 +17,7 @@ pub enum Op {
CreateCustomStudy,
EmptyFilteredDeck,
FindAndReplace,
Import,
RebuildFilteredDeck,
RemoveDeck,
RemoveNote,
@ -55,6 +56,7 @@ impl Op {
Op::AnswerCard => tr.actions_answer_card(),
Op::Bury => tr.studying_bury(),
Op::CreateCustomStudy => tr.actions_custom_study(),
Op::Import => tr.actions_import(),
Op::RemoveDeck => tr.decks_delete_deck(),
Op::RemoveNote => tr.studying_delete_note(),
Op::RenameDeck => tr.actions_rename_deck(),

View file

@ -12,6 +12,7 @@ pub use crate::{
decks::{Deck, DeckId, DeckKind, NativeDeckName},
error::{AnkiError, Result},
i18n::I18n,
media::Sha1Hash,
notes::{Note, NoteId},
notetype::{Notetype, NotetypeId},
ops::{Op, OpChanges, OpOutput},

View file

@ -28,9 +28,17 @@ impl Collection {
/// Add the provided revlog entry, modifying the ID if it is not unique.
pub(crate) fn add_revlog_entry_undoable(&mut self, mut entry: RevlogEntry) -> Result<RevlogId> {
entry.id = self.storage.add_revlog_entry(&entry, true)?;
entry.id = self.storage.add_revlog_entry(&entry, true)?.unwrap();
let id = entry.id;
self.save_undo(UndoableRevlogChange::Added(Box::new(entry)));
Ok(id)
}
/// Add the provided revlog entry, if its ID is unique.
pub(crate) fn add_revlog_entry_if_unique_undoable(&mut self, entry: RevlogEntry) -> Result<()> {
if self.storage.add_revlog_entry(&entry, false)?.is_some() {
self.save_undo(UndoableRevlogChange::Added(Box::new(entry)));
}
Ok(())
}
}

View file

@ -134,6 +134,14 @@ impl Default for SearchBuilder {
}
impl SearchNode {
pub fn from_deck_id(did: impl Into<DeckId>, with_children: bool) -> Self {
if with_children {
Self::DeckIdWithChildren(did.into())
} else {
Self::DeckIdWithoutChildren(did.into())
}
}
/// Construct [SearchNode] from an unescaped deck name.
pub fn from_deck_name(name: &str) -> Self {
Self::Deck(escape_anki_wildcards_for_search_node(name))
@ -158,6 +166,14 @@ impl SearchNode {
name,
)))
}
pub fn from_note_ids<I: IntoIterator<Item = N>, N: Into<NoteId>>(ids: I) -> Self {
Self::NoteIds(ids.into_iter().map(Into::into).join(","))
}
pub fn from_card_ids<I: IntoIterator<Item = C>, C: Into<CardId>>(ids: I) -> Self {
Self::CardIds(ids.into_iter().map(Into::into).join(","))
}
}
impl<T: Into<SearchNode>> From<T> for Node {

View file

@ -38,7 +38,7 @@ impl Collection {
self.storage.get_all_revlog_entries(revlog_start)?
} else {
self.storage
.get_revlog_entries_for_searched_cards(revlog_start)?
.get_pb_revlog_entries_for_searched_cards(revlog_start)?
};
self.storage.clear_searched_cards_table()?;

View file

@ -0,0 +1,41 @@
INSERT
OR IGNORE INTO cards (
id,
nid,
did,
ord,
mod,
usn,
type,
queue,
due,
ivl,
factor,
reps,
lapses,
left,
odue,
odid,
flags,
data
)
VALUES (
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?
)

View file

@ -145,6 +145,34 @@ impl super::SqliteStorage {
Ok(())
}
/// Add card if id is unique. True if card was added.
pub(crate) fn add_card_if_unique(&self, card: &Card) -> Result<bool> {
self.db
.prepare_cached(include_str!("add_card_if_unique.sql"))?
.execute(params![
card.id,
card.note_id,
card.deck_id,
card.template_idx,
card.mtime,
card.usn,
card.ctype as u8,
card.queue as i8,
card.due,
card.interval,
card.ease_factor,
card.reps,
card.lapses,
card.remaining_steps,
card.original_due,
card.original_deck_id,
card.flags,
CardData::from_card(card),
])
.map(|n_rows| n_rows == 1)
.map_err(Into::into)
}
/// Add or update card, using the provided ID. Used for syncing & undoing.
pub(crate) fn add_or_update_card(&self, card: &Card) -> Result<()> {
let mut stmt = self.db.prepare_cached(include_str!("add_or_update.sql"))?;
@ -400,6 +428,20 @@ impl super::SqliteStorage {
.collect()
}
pub(crate) fn get_all_card_ids(&self) -> Result<HashSet<CardId>> {
self.db
.prepare("SELECT id FROM cards")?
.query_and_then([], |row| Ok(row.get(0)?))?
.collect()
}
pub(crate) fn all_cards_as_nid_and_ord(&self) -> Result<HashSet<(NoteId, u16)>> {
self.db
.prepare("SELECT nid, ord FROM cards")?
.query_and_then([], |r| Ok((NoteId(r.get(0)?), r.get(1)?)))?
.collect()
}
pub(crate) fn card_ids_of_notes(&self, nids: &[NoteId]) -> Result<Vec<CardId>> {
let mut stmt = self
.db
@ -455,6 +497,16 @@ impl super::SqliteStorage {
Ok(nids)
}
/// Place the ids of cards with notes in 'search_nids' into 'search_cids'.
/// Returns number of added cards.
pub(crate) fn search_cards_of_notes_into_table(&self) -> Result<usize> {
self.setup_searched_cards_table()?;
self.db
.prepare(include_str!("search_cards_of_notes_into_table.sql"))?
.execute([])
.map_err(Into::into)
}
pub(crate) fn all_searched_cards(&self) -> Result<Vec<Card>> {
self.db
.prepare_cached(concat!(

View file

@ -0,0 +1,7 @@
INSERT INTO search_cids
SELECT id
FROM cards
WHERE nid IN (
SELECT nid
FROM search_nids
)

View file

@ -74,6 +74,15 @@ impl SqliteStorage {
.transpose()
}
pub(crate) fn get_deck_by_name(&self, machine_name: &str) -> Result<Option<Deck>> {
self.db
.prepare_cached(concat!(include_str!("get_deck.sql"), " WHERE name = ?"))?
.query_and_then([machine_name], row_to_deck)?
.next()
.transpose()
.map_err(Into::into)
}
pub(crate) fn get_all_decks(&self) -> Result<Vec<Deck>> {
self.db
.prepare(include_str!("get_deck.sql"))?
@ -111,6 +120,17 @@ impl SqliteStorage {
.map_err(Into::into)
}
pub(crate) fn get_decks_for_search_cards(&self) -> Result<Vec<Deck>> {
self.db
.prepare_cached(concat!(
include_str!("get_deck.sql"),
" WHERE id IN (SELECT DISTINCT did FROM cards WHERE id IN",
" (SELECT cid FROM search_cids))",
))?
.query_and_then([], row_to_deck)?
.collect()
}
// caller should ensure name unique
pub(crate) fn add_deck(&self, deck: &mut Deck) -> Result<()> {
assert!(deck.id.0 == 0);

View file

@ -0,0 +1,3 @@
INSERT
OR IGNORE INTO deck_config (id, name, mtime_secs, usn, config)
VALUES (?, ?, ?, ?, ?);

View file

@ -67,6 +67,22 @@ impl SqliteStorage {
Ok(())
}
pub(crate) fn add_deck_conf_if_unique(&self, conf: &DeckConfig) -> Result<bool> {
let mut conf_bytes = vec![];
conf.inner.encode(&mut conf_bytes)?;
self.db
.prepare_cached(include_str!("add_if_unique.sql"))?
.execute(params![
conf.id,
conf.name,
conf.mtime_secs,
conf.usn,
conf_bytes,
])
.map(|added| added == 1)
.map_err(Into::into)
}
pub(crate) fn update_deck_conf(&self, conf: &DeckConfig) -> Result<()> {
let mut conf_bytes = vec![];
conf.inner.encode(&mut conf_bytes)?;

View file

@ -34,9 +34,10 @@ impl SchemaVersion {
}
/// Write a list of IDs as '(x,y,...)' into the provided string.
pub(crate) fn ids_to_string<T>(buf: &mut String, ids: &[T])
pub(crate) fn ids_to_string<D, I>(buf: &mut String, ids: I)
where
T: std::fmt::Display,
D: std::fmt::Display,
I: IntoIterator<Item = D>,
{
buf.push('(');
write_comma_separated_ids(buf, ids);
@ -44,15 +45,18 @@ where
}
/// Write a list of Ids as 'x,y,...' into the provided string.
pub(crate) fn write_comma_separated_ids<T>(buf: &mut String, ids: &[T])
pub(crate) fn write_comma_separated_ids<D, I>(buf: &mut String, ids: I)
where
T: std::fmt::Display,
D: std::fmt::Display,
I: IntoIterator<Item = D>,
{
if !ids.is_empty() {
for id in ids.iter().skip(1) {
write!(buf, "{},", id).unwrap();
}
write!(buf, "{}", ids[0]).unwrap();
let mut trailing_sep = false;
for id in ids {
write!(buf, "{},", id).unwrap();
trailing_sep = true;
}
if trailing_sep {
buf.pop();
}
}
@ -73,17 +77,17 @@ mod test {
#[test]
fn ids_string() {
let mut s = String::new();
ids_to_string::<u8>(&mut s, &[]);
ids_to_string(&mut s, &[0; 0]);
assert_eq!(s, "()");
s.clear();
ids_to_string(&mut s, &[7]);
assert_eq!(s, "(7)");
s.clear();
ids_to_string(&mut s, &[7, 6]);
assert_eq!(s, "(6,7)");
assert_eq!(s, "(7,6)");
s.clear();
ids_to_string(&mut s, &[7, 6, 5]);
assert_eq!(s, "(6,5,7)");
assert_eq!(s, "(7,6,5)");
s.clear();
}
}

View file

@ -0,0 +1,27 @@
INSERT
OR IGNORE INTO notes (
id,
guid,
mid,
mod,
usn,
tags,
flds,
sfld,
csum,
flags,
data
)
VALUES (
?,
?,
?,
?,
?,
?,
?,
?,
?,
0,
""
)

View file

@ -1,12 +1,13 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use rusqlite::{params, Row};
use crate::{
error::Result,
import_export::package::NoteMeta,
notes::{Note, NoteId, NoteTags},
notetype::NotetypeId,
tags::{join_tags, split_tags},
@ -41,6 +42,13 @@ impl super::SqliteStorage {
.transpose()
}
pub fn get_all_note_ids(&self) -> Result<HashSet<NoteId>> {
self.db
.prepare("SELECT id FROM notes")?
.query_and_then([], |row| Ok(row.get(0)?))?
.collect()
}
/// If fields have been modified, caller must call note.prepare_for_update() prior to calling this.
pub(crate) fn update_note(&self, note: &Note) -> Result<()> {
assert!(note.id.0 != 0);
@ -77,6 +85,24 @@ impl super::SqliteStorage {
Ok(())
}
pub(crate) fn add_note_if_unique(&self, note: &Note) -> Result<bool> {
self.db
.prepare_cached(include_str!("add_if_unique.sql"))?
.execute(params![
note.id,
note.guid,
note.notetype_id,
note.mtime,
note.usn,
join_tags(&note.tags),
join_fields(note.fields()),
note.sort_field.as_ref().unwrap(),
note.checksum.unwrap(),
])
.map(|added| added == 1)
.map_err(Into::into)
}
/// Add or update the provided note, preserving ID. Used by the syncing code.
pub(crate) fn add_or_update_note(&self, note: &Note) -> Result<()> {
let mut stmt = self.db.prepare_cached(include_str!("add_or_update.sql"))?;
@ -210,6 +236,16 @@ impl super::SqliteStorage {
Ok(())
}
pub(crate) fn all_searched_notes(&self) -> Result<Vec<Note>> {
self.db
.prepare_cached(concat!(
include_str!("get.sql"),
" WHERE id IN (SELECT nid FROM search_nids)"
))?
.query_and_then([], |r| row_to_note(r).map_err(Into::into))?
.collect()
}
pub(crate) fn get_note_tags_by_predicate<F>(&mut self, want: F) -> Result<Vec<NoteTags>>
where
F: Fn(&str) -> bool,
@ -259,6 +295,24 @@ impl super::SqliteStorage {
Ok(())
}
pub(crate) fn note_guid_map(&mut self) -> Result<HashMap<String, NoteMeta>> {
self.db
.prepare("SELECT guid, id, mod, mid FROM notes")?
.query_and_then([], row_to_note_meta)?
.collect()
}
#[cfg(test)]
pub(crate) fn get_all_notes(&mut self) -> Vec<Note> {
self.db
.prepare("SELECT * FROM notes")
.unwrap()
.query_and_then([], row_to_note)
.unwrap()
.collect::<Result<_>>()
.unwrap()
}
}
fn row_to_note(row: &Row) -> Result<Note> {
@ -285,3 +339,10 @@ fn row_to_note_tags(row: &Row) -> Result<NoteTags> {
tags: row.get(3)?,
})
}
fn row_to_note_meta(row: &Row) -> Result<(String, NoteMeta)> {
Ok((
row.get(0)?,
NoteMeta::new(row.get(1)?, row.get(2)?, row.get(3)?),
))
}

View file

@ -99,6 +99,23 @@ impl SqliteStorage {
.map_err(Into::into)
}
pub(crate) fn get_notetypes_for_search_notes(&self) -> Result<Vec<Notetype>> {
self.db
.prepare_cached(concat!(
include_str!("get_notetype.sql"),
" WHERE id IN (SELECT DISTINCT mid FROM notes WHERE id IN",
" (SELECT nid FROM search_nids))",
))?
.query_and_then([], |r| {
row_to_notetype_core(r).and_then(|mut nt| {
nt.fields = self.get_notetype_fields(nt.id)?;
nt.templates = self.get_notetype_templates(nt.id)?;
Ok(nt)
})
})?
.collect()
}
pub fn get_all_notetype_names(&self) -> Result<Vec<(NotetypeId, String)>> {
self.db
.prepare_cached(include_str!("get_notetype_names.sql"))?

View file

@ -61,16 +61,19 @@ impl SqliteStorage {
Ok(())
}
/// Returns the used id, which may differ if `ensure_unique` is true.
/// Adds the entry, if its id is unique. If it is not, and `uniquify` is true,
/// adds it with a new id. Returns the added id.
/// (I.e., the option is safe to unwrap, if `uniquify` is true.)
pub(crate) fn add_revlog_entry(
&self,
entry: &RevlogEntry,
ensure_unique: bool,
) -> Result<RevlogId> {
self.db
uniquify: bool,
) -> Result<Option<RevlogId>> {
let added = self
.db
.prepare_cached(include_str!("add.sql"))?
.execute(params![
ensure_unique,
uniquify,
entry.id,
entry.cid,
entry.usn,
@ -81,7 +84,7 @@ impl SqliteStorage {
entry.taken_millis,
entry.review_kind as u8
])?;
Ok(RevlogId(self.db.last_insert_rowid()))
Ok((added > 0).then(|| RevlogId(self.db.last_insert_rowid())))
}
pub(crate) fn get_revlog_entry(&self, id: RevlogId) -> Result<Option<RevlogEntry>> {
@ -107,7 +110,7 @@ impl SqliteStorage {
.collect()
}
pub(crate) fn get_revlog_entries_for_searched_cards(
pub(crate) fn get_pb_revlog_entries_for_searched_cards(
&self,
after: TimestampSecs,
) -> Result<Vec<pb::RevlogEntry>> {
@ -120,6 +123,16 @@ impl SqliteStorage {
.collect()
}
pub(crate) fn get_revlog_entries_for_searched_cards(&self) -> Result<Vec<RevlogEntry>> {
self.db
.prepare_cached(concat!(
include_str!("get.sql"),
" where cid in (select cid from search_cids)"
))?
.query_and_then([], row_to_revlog_entry)?
.collect()
}
/// This includes entries from deleted cards.
pub(crate) fn get_all_revlog_entries(
&self,

63
rslib/src/tests.rs Normal file
View file

@ -0,0 +1,63 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
#![cfg(test)]
use tempfile::{tempdir, TempDir};
use crate::{collection::CollectionBuilder, media::MediaManager, prelude::*};
pub(crate) fn open_fs_test_collection(name: &str) -> (Collection, TempDir) {
let tempdir = tempdir().unwrap();
let dir = tempdir.path();
let media_folder = dir.join(format!("{name}.media"));
std::fs::create_dir(&media_folder).unwrap();
let col = CollectionBuilder::new(dir.join(format!("{name}.anki2")))
.set_media_paths(media_folder, dir.join(format!("{name}.mdb")))
.build()
.unwrap();
(col, tempdir)
}
impl Collection {
pub(crate) fn add_media(&self, media: &[(&str, &[u8])]) {
let mgr = MediaManager::new(&self.media_folder, &self.media_db).unwrap();
let mut ctx = mgr.dbctx();
for (name, data) in media {
mgr.add_file(&mut ctx, name, data).unwrap();
}
}
pub(crate) fn new_note(&mut self, notetype: &str) -> Note {
self.get_notetype_by_name(notetype)
.unwrap()
.unwrap()
.new_note()
}
pub(crate) fn add_new_note(&mut self, notetype: &str) -> Note {
let mut note = self.new_note(notetype);
self.add_note(&mut note, DeckId(1)).unwrap();
note
}
pub(crate) fn get_all_notes(&mut self) -> Vec<Note> {
self.storage.get_all_notes()
}
pub(crate) fn add_deck_with_machine_name(&mut self, name: &str, filtered: bool) -> Deck {
let mut deck = new_deck_with_machine_name(name, filtered);
self.add_deck_inner(&mut deck, Usn(1)).unwrap();
deck
}
}
pub(crate) fn new_deck_with_machine_name(name: &str, filtered: bool) -> Deck {
let mut deck = if filtered {
Deck::new_filtered()
} else {
Deck::new_normal()
};
deck.name = NativeDeckName::from_native_str(name);
deck
}

View file

@ -31,6 +31,31 @@ impl Trimming for Cow<'_, str> {
}
}
pub(crate) trait CowMapping<'a, B: ?Sized + 'a + ToOwned> {
/// Returns [self]
/// - unchanged, if the given function returns [Cow::Borrowed]
/// - with the new value, if the given function returns [Cow::Owned]
fn map_cow(self, f: impl FnOnce(&B) -> Cow<B>) -> Self;
fn get_owned(self) -> Option<B::Owned>;
}
impl<'a, B: ?Sized + 'a + ToOwned> CowMapping<'a, B> for Cow<'a, B> {
fn map_cow(self, f: impl FnOnce(&B) -> Cow<B>) -> Self {
if let Cow::Owned(o) = f(&self) {
Cow::Owned(o)
} else {
self
}
}
fn get_owned(self) -> Option<B::Owned> {
match self {
Cow::Borrowed(_) => None,
Cow::Owned(s) => Some(s),
}
}
}
#[derive(Debug, PartialEq)]
pub enum AvTag {
SoundOrVideo(String),
@ -115,34 +140,48 @@ lazy_static! {
|
\[\[type:[^]]+\]\]
").unwrap();
/// Files included in CSS with a leading underscore.
static ref UNDERSCORED_CSS_IMPORTS: Regex = Regex::new(
r#"(?xi)
(?:@import\s+ # import statement with a bare
"(_[^"]*.css)" # double quoted
| # or
'(_[^']*.css)' # single quoted css filename
)
| # or
(?:url\(\s* # a url function with a
"(_[^"]+)" # double quoted
| # or
'(_[^']+)' # single quoted
| # or
(_.+) # unquoted filename
\s*\))
"#).unwrap();
/// Strings, src and data attributes with a leading underscore.
static ref UNDERSCORED_REFERENCES: Regex = Regex::new(
r#"(?x)
"(_[^"]+)" # double quoted
| # or
'(_[^']+)' # single quoted string
| # or
\b(?:src|data) # a 'src' or 'data' attribute
= # followed by
(_[^ >]+) # an unquoted value
"#).unwrap();
}
pub fn html_to_text_line(html: &str) -> Cow<str> {
let mut out: Cow<str> = html.into();
if let Cow::Owned(o) = PERSISTENT_HTML_SPACERS.replace_all(&out, " ") {
out = o.into();
}
if let Cow::Owned(o) = UNPRINTABLE_TAGS.replace_all(&out, "") {
out = o.into();
}
if let Cow::Owned(o) = strip_html_preserving_media_filenames(&out) {
out = o.into();
}
out.trim()
PERSISTENT_HTML_SPACERS
.replace_all(html, " ")
.map_cow(|s| UNPRINTABLE_TAGS.replace_all(s, ""))
.map_cow(strip_html_preserving_media_filenames)
.trim()
}
pub fn strip_html(html: &str) -> Cow<str> {
let mut out: Cow<str> = html.into();
if let Cow::Owned(o) = strip_html_preserving_entities(html) {
out = o.into();
}
if let Cow::Owned(o) = decode_entities(out.as_ref()) {
out = o.into();
}
out
strip_html_preserving_entities(html).map_cow(decode_entities)
}
pub fn strip_html_preserving_entities(html: &str) -> Cow<str> {
@ -161,18 +200,29 @@ pub fn decode_entities(html: &str) -> Cow<str> {
}
}
pub(crate) fn newlines_to_spaces(text: &str) -> Cow<str> {
if text.contains('\n') {
text.replace('\n', " ").into()
} else {
text.into()
}
}
pub fn strip_html_for_tts(html: &str) -> Cow<str> {
let mut out: Cow<str> = html.into();
HTML_LINEBREAK_TAGS
.replace_all(html, " ")
.map_cow(strip_html)
}
if let Cow::Owned(o) = HTML_LINEBREAK_TAGS.replace_all(html, " ") {
out = o.into();
/// Truncate a String on a valid UTF8 boundary.
pub(crate) fn truncate_to_char_boundary(s: &mut String, mut max: usize) {
if max >= s.len() {
return;
}
if let Cow::Owned(o) = strip_html(out.as_ref()) {
out = o.into();
while !s.is_char_boundary(max) {
max -= 1;
}
out
s.truncate(max);
}
#[derive(Debug)]
@ -216,6 +266,61 @@ pub(crate) fn extract_media_refs(text: &str) -> Vec<MediaRef> {
out
}
/// Calls `replacer` for every media reference in `text`, and optionally
/// replaces it with something else. [None] if no reference was found.
pub(crate) fn replace_media_refs(
text: &str,
mut replacer: impl FnMut(&str) -> Option<String>,
) -> Option<String> {
let mut rep = |caps: &Captures| {
let whole_match = caps.get(0).unwrap().as_str();
let old_name = caps.iter().skip(1).find_map(|g| g).unwrap().as_str();
let old_name_decoded = decode_entities(old_name);
if let Some(mut new_name) = replacer(&old_name_decoded) {
if matches!(old_name_decoded, Cow::Owned(_)) {
new_name = htmlescape::encode_minimal(&new_name);
}
whole_match.replace(old_name, &new_name)
} else {
whole_match.to_owned()
}
};
HTML_MEDIA_TAGS
.replace_all(text, &mut rep)
.map_cow(|s| AV_TAGS.replace_all(s, &mut rep))
.get_owned()
}
pub(crate) fn extract_underscored_css_imports(text: &str) -> Vec<&str> {
UNDERSCORED_CSS_IMPORTS
.captures_iter(text)
.map(|caps| {
caps.get(1)
.or_else(|| caps.get(2))
.or_else(|| caps.get(3))
.or_else(|| caps.get(4))
.or_else(|| caps.get(5))
.unwrap()
.as_str()
})
.collect()
}
pub(crate) fn extract_underscored_references(text: &str) -> Vec<&str> {
UNDERSCORED_REFERENCES
.captures_iter(text)
.map(|caps| {
caps.get(1)
.or_else(|| caps.get(2))
.or_else(|| caps.get(3))
.unwrap()
.as_str()
})
.collect()
}
pub fn strip_html_preserving_media_filenames(html: &str) -> Cow<str> {
let without_fnames = HTML_MEDIA_TAGS.replace_all(html, r" ${1}${2}${3} ");
let without_html = strip_html(&without_fnames);
@ -463,4 +568,54 @@ mod test {
assert!(!is_glob(r"\\\_"));
assert!(glob_matcher(r"foo\*bar*")("foo*bar123"));
}
#[test]
fn extracting() {
assert_eq!(
extract_underscored_css_imports(concat!(
"@IMPORT '_foo.css'\n",
"@import \"_bar.css\"\n",
"@import '_baz.css'\n",
"@import 'nope.css'\n",
"url(_foo.css)\n",
"URL(\"_bar.css\")\n",
"@import url('_baz.css')\n",
"url('nope.css')\n",
)),
vec!["_foo.css", "_bar.css", "_baz.css", "_foo.css", "_bar.css", "_baz.css",]
);
assert_eq!(
extract_underscored_references(concat!(
"<img src=\"_foo.jpg\">",
"<object data=\"_bar\">",
"\"_baz.js\"",
"\"nope.js\"",
"<img src=_foo.jpg>",
"<object data=_bar>",
"'_baz.js'",
)),
vec!["_foo.jpg", "_bar", "_baz.js", "_foo.jpg", "_bar", "_baz.js",]
);
}
#[test]
fn replacing() {
assert_eq!(
&replace_media_refs("<img src=foo.jpg>[sound:bar.mp3]<img src=baz.jpg>", |s| {
(s != "baz.jpg").then(|| "spam".to_string())
})
.unwrap(),
"<img src=spam>[sound:spam]<img src=baz.jpg>",
);
}
#[test]
fn truncate() {
let mut s = "日本語".to_string();
truncate_to_char_boundary(&mut s, 6);
assert_eq!(&s, "日本");
let mut s = "日本語".to_string();
truncate_to_char_boundary(&mut s, 1);
assert_eq!(&s, "");
}
}