Anki/qt/aqt/import_export/importing.py
2022-05-09 11:48:40 +02:00

177 lines
5.8 KiB
Python

# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from __future__ import annotations
from itertools import chain
import aqt.main
from anki.collection import Collection, ImportLogWithChanges, Progress
from anki.errors import Interrupted
from anki.foreign_data import mnemosyne
from aqt.operations import CollectionOp, QueryOp
from aqt.progress import ProgressUpdate
from aqt.qt import *
from aqt.utils import askUser, getFile, showInfo, showText, showWarning, tooltip, tr
def import_file(mw: aqt.main.AnkiQt, path: str) -> None:
filename = os.path.basename(path).lower()
if filename.endswith(".anki"):
showInfo(tr.importing_anki_files_are_from_a_very())
elif filename.endswith(".anki2"):
showInfo(tr.importing_anki2_files_are_not_directly_importable())
elif is_collection_package(filename):
maybe_import_collection_package(mw, path)
elif filename.endswith(".apkg") or filename.endswith(".zip"):
import_anki_package(mw, path)
elif filename.endswith(".db"):
import_mnemosyne(mw, path)
else:
import aqt.import_export.import_dialog
aqt.import_export.import_dialog.ImportDialog(mw, path)
def prompt_for_file_then_import(mw: aqt.main.AnkiQt) -> None:
if path := get_file_path(mw):
import_file(mw, path)
def get_file_path(mw: aqt.main.AnkiQt) -> str | None:
if file := getFile(
mw, tr.actions_import(), None, key="import", filter=file_filter()
):
return str(file)
return None
def file_filter() -> str:
return ";;".join(
(
tr.importing_packaged_anki_deckcollection_apkg_colpkg_zip(),
tr.importing_text_separated_by_tabs_or_semicolons(),
tr.importing_mnemosyne_20_deck_db(),
)
)
def is_collection_package(filename: str) -> bool:
return (
filename == "collection.apkg"
or (filename.startswith("backup-") and filename.endswith(".apkg"))
or filename.endswith(".colpkg")
)
def maybe_import_collection_package(mw: aqt.main.AnkiQt, path: str) -> None:
if askUser(
tr.importing_this_will_delete_your_existing_collection(),
msgfunc=QMessageBox.warning,
defaultno=True,
):
import_collection_package(mw, path)
def import_collection_package(mw: aqt.main.AnkiQt, file: str) -> None:
def on_success() -> None:
mw.loadCollection()
tooltip(tr.importing_importing_complete())
def on_failure(err: Exception) -> None:
mw.loadCollection()
if not isinstance(err, Interrupted):
showWarning(str(err))
QueryOp(
parent=mw,
op=lambda _: mw.create_backup_now(),
success=lambda _: mw.unloadCollection(
lambda: import_collection_package_op(mw, file, on_success)
.failure(on_failure)
.run_in_background()
),
).with_progress().run_in_background()
def import_collection_package_op(
mw: aqt.main.AnkiQt, path: str, success: Callable[[], None]
) -> QueryOp[None]:
def op(_: Collection) -> None:
col_path = mw.pm.collectionPath()
media_folder = os.path.join(mw.pm.profileFolder(), "collection.media")
media_db = os.path.join(mw.pm.profileFolder(), "collection.media.db2")
mw.backend.import_collection_package(
col_path=col_path,
backup_path=path,
media_folder=media_folder,
media_db=media_db,
)
return QueryOp(parent=mw, op=op, success=lambda _: success()).with_backend_progress(
import_progress_update
)
def import_anki_package(mw: aqt.main.AnkiQt, path: str) -> None:
CollectionOp(
parent=mw,
op=lambda col: col.import_anki_package(path),
).with_backend_progress(import_progress_update).success(
show_import_log
).run_in_background()
def import_mnemosyne(mw: aqt.main.AnkiQt, path: str) -> None:
QueryOp(
parent=mw,
op=lambda _: mnemosyne.serialize(path),
success=lambda json: import_json(mw, json),
).with_progress().run_in_background()
def import_json(mw: aqt.main.AnkiQt, json: str) -> None:
CollectionOp(parent=mw, op=lambda col: col.import_json(json)).with_backend_progress(
import_progress_update
).success(show_import_log).run_in_background()
def show_import_log(log_with_changes: ImportLogWithChanges) -> None:
showText(stringify_log(log_with_changes.log), plain_text_edit=True)
def stringify_log(log: ImportLogWithChanges.Log) -> str:
total = len(log.conflicting) + len(log.updated) + len(log.new) + len(log.duplicate)
return "\n".join(
chain(
(tr.importing_notes_found_in_file(val=total),),
(
template_string(val=len(row))
for (row, template_string) in (
(log.conflicting, tr.importing_notes_that_could_not_be_imported),
(log.updated, tr.importing_notes_updated_as_file_had_newer),
(log.new, tr.importing_notes_added_from_file),
(log.duplicate, tr.importing_notes_skipped_as_theyre_already_in),
)
if row
),
("",),
*(
[f"[{action}] {', '.join(note.fields)}" for note in rows]
for (rows, action) in (
(log.conflicting, tr.importing_skipped()),
(log.updated, tr.importing_updated()),
(log.new, tr.adding_added()),
(log.duplicate, tr.importing_identical()),
)
),
)
)
def import_progress_update(progress: Progress, update: ProgressUpdate) -> None:
if not progress.HasField("importing"):
return
update.label = progress.importing
if update.user_wants_abort:
update.abort = True