mirror of
https://github.com/ankitects/anki.git
synced 2025-09-18 14:02:21 -04:00
649 lines
22 KiB
Python
649 lines
22 KiB
Python
# Copyright: Ankitects Pty Ltd and contributors
|
|
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
|
import io
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import sys
|
|
import traceback
|
|
import unicodedata
|
|
import urllib.error
|
|
import urllib.parse
|
|
import urllib.request
|
|
import zipfile
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
|
from anki.consts import *
|
|
from anki.db import DB, DBError
|
|
from anki.lang import _
|
|
from anki.latex import mungeQA
|
|
from anki.utils import checksum, isMac, isWin
|
|
|
|
|
|
class MediaManager:
|
|
|
|
soundRegexps = [r"(?i)(\[sound:(?P<fname>[^]]+)\])"]
|
|
imgRegexps = [
|
|
# src element quoted case
|
|
r"(?i)(<img[^>]* src=(?P<str>[\"'])(?P<fname>[^>]+?)(?P=str)[^>]*>)",
|
|
# unquoted case
|
|
r"(?i)(<img[^>]* src=(?!['\"])(?P<fname>[^ >]+)[^>]*?>)",
|
|
]
|
|
regexps = soundRegexps + imgRegexps
|
|
db: Optional[DB]
|
|
|
|
def __init__(self, col, server: bool) -> None:
|
|
self.col = col
|
|
if server:
|
|
self._dir = None
|
|
return
|
|
# media directory
|
|
self._dir = re.sub(r"(?i)\.(anki2)$", ".media", self.col.path)
|
|
if not os.path.exists(self._dir):
|
|
os.makedirs(self._dir)
|
|
try:
|
|
self._oldcwd = os.getcwd()
|
|
except OSError:
|
|
# cwd doesn't exist
|
|
self._oldcwd = None
|
|
try:
|
|
os.chdir(self._dir)
|
|
except OSError:
|
|
raise Exception("invalidTempFolder")
|
|
# change database
|
|
self.connect()
|
|
|
|
def connect(self) -> None:
|
|
if self.col.server:
|
|
return
|
|
path = self.dir() + ".db2"
|
|
create = not os.path.exists(path)
|
|
os.chdir(self._dir)
|
|
self.db = DB(path)
|
|
if create:
|
|
self._initDB()
|
|
self.maybeUpgrade()
|
|
|
|
def _initDB(self) -> None:
|
|
self.db.executescript(
|
|
"""
|
|
create table media (
|
|
fname text not null primary key,
|
|
csum text, -- null indicates deleted file
|
|
mtime int not null, -- zero if deleted
|
|
dirty int not null
|
|
);
|
|
|
|
create index idx_media_dirty on media (dirty);
|
|
|
|
create table meta (dirMod int, lastUsn int); insert into meta values (0, 0);
|
|
"""
|
|
)
|
|
|
|
def maybeUpgrade(self) -> None:
|
|
oldpath = self.dir() + ".db"
|
|
if os.path.exists(oldpath):
|
|
self.db.execute('attach "../collection.media.db" as old')
|
|
try:
|
|
self.db.execute(
|
|
"""
|
|
insert into media
|
|
select m.fname, csum, mod, ifnull((select 1 from log l2 where l2.fname=m.fname), 0) as dirty
|
|
from old.media m
|
|
left outer join old.log l using (fname)
|
|
union
|
|
select fname, null, 0, 1 from old.log where type=1;"""
|
|
)
|
|
self.db.execute("delete from meta")
|
|
self.db.execute(
|
|
"""
|
|
insert into meta select dirMod, usn from old.meta
|
|
"""
|
|
)
|
|
self.db.commit()
|
|
except Exception as e:
|
|
# if we couldn't import the old db for some reason, just start
|
|
# anew
|
|
self.col.log("failed to import old media db:" + traceback.format_exc())
|
|
self.db.execute("detach old")
|
|
npath = "../collection.media.db.old"
|
|
if os.path.exists(npath):
|
|
os.unlink(npath)
|
|
os.rename("../collection.media.db", npath)
|
|
|
|
def close(self) -> None:
|
|
if self.col.server:
|
|
return
|
|
self.db.close()
|
|
self.db = None
|
|
# change cwd back to old location
|
|
if self._oldcwd:
|
|
try:
|
|
os.chdir(self._oldcwd)
|
|
except:
|
|
# may have been deleted
|
|
pass
|
|
|
|
def _deleteDB(self) -> None:
|
|
path = self.db._path
|
|
self.close()
|
|
os.unlink(path)
|
|
self.connect()
|
|
|
|
def dir(self) -> Any:
|
|
return self._dir
|
|
|
|
def _isFAT32(self) -> bool:
|
|
if not isWin:
|
|
return False
|
|
# pylint: disable=import-error
|
|
import win32api, win32file # pytype: disable=import-error
|
|
|
|
try:
|
|
name = win32file.GetVolumeNameForVolumeMountPoint(self._dir[:3])
|
|
except:
|
|
# mapped & unmapped network drive; pray that it's not vfat
|
|
return False
|
|
if win32api.GetVolumeInformation(name)[4].lower().startswith("fat"):
|
|
return True
|
|
return False
|
|
|
|
# Adding media
|
|
##########################################################################
|
|
# opath must be in unicode
|
|
|
|
def addFile(self, opath: str) -> Any:
|
|
with open(opath, "rb") as f:
|
|
return self.writeData(opath, f.read())
|
|
|
|
def writeData(self, opath: str, data: bytes, typeHint: Optional[str] = None) -> Any:
|
|
# if fname is a full path, use only the basename
|
|
fname = os.path.basename(opath)
|
|
|
|
# if it's missing an extension and a type hint was provided, use that
|
|
if not os.path.splitext(fname)[1] and typeHint:
|
|
# mimetypes is returning '.jpe' even after calling .init(), so we'll do
|
|
# it manually instead
|
|
typeMap = {
|
|
"image/jpeg": ".jpg",
|
|
"image/png": ".png",
|
|
}
|
|
if typeHint in typeMap:
|
|
fname += typeMap[typeHint]
|
|
|
|
# make sure we write it in NFC form (pre-APFS Macs will autoconvert to NFD),
|
|
# and return an NFC-encoded reference
|
|
fname = unicodedata.normalize("NFC", fname)
|
|
# ensure it's a valid filename
|
|
base = self.cleanFilename(fname)
|
|
(root, ext) = os.path.splitext(base)
|
|
|
|
def repl(match):
|
|
n = int(match.group(1))
|
|
return " (%d)" % (n + 1)
|
|
|
|
# find the first available name
|
|
csum = checksum(data)
|
|
while True:
|
|
fname = root + ext
|
|
path = os.path.join(self.dir(), fname)
|
|
# if it doesn't exist, copy it directly
|
|
if not os.path.exists(path):
|
|
with open(path, "wb") as f:
|
|
f.write(data)
|
|
return fname
|
|
# if it's identical, reuse
|
|
with open(path, "rb") as f:
|
|
if checksum(f.read()) == csum:
|
|
return fname
|
|
# otherwise, increment the index in the filename
|
|
reg = r" \((\d+)\)$"
|
|
if not re.search(reg, root):
|
|
root = root + " (1)"
|
|
else:
|
|
root = re.sub(reg, repl, root)
|
|
|
|
# String manipulation
|
|
##########################################################################
|
|
|
|
def filesInStr(
|
|
self, mid: Union[int, str], string: str, includeRemote: bool = False
|
|
) -> List[str]:
|
|
l = []
|
|
model = self.col.models.get(mid)
|
|
strings: List[str] = []
|
|
if model["type"] == MODEL_CLOZE and "{{c" in string:
|
|
# if the field has clozes in it, we'll need to expand the
|
|
# possibilities so we can render latex
|
|
strings = self._expandClozes(string)
|
|
else:
|
|
strings = [string]
|
|
for string in strings:
|
|
# handle latex
|
|
string = mungeQA(string, None, None, model, None, self.col)
|
|
# extract filenames
|
|
for reg in self.regexps:
|
|
for match in re.finditer(reg, string):
|
|
fname = match.group("fname")
|
|
isLocal = not re.match("(https?|ftp)://", fname.lower())
|
|
if isLocal or includeRemote:
|
|
l.append(fname)
|
|
return l
|
|
|
|
def _expandClozes(self, string: str) -> List[str]:
|
|
ords = set(re.findall(r"{{c(\d+)::.+?}}", string))
|
|
strings = []
|
|
from anki.template.template import (
|
|
clozeReg,
|
|
CLOZE_REGEX_MATCH_GROUP_HINT,
|
|
CLOZE_REGEX_MATCH_GROUP_CONTENT,
|
|
)
|
|
|
|
def qrepl(m):
|
|
if m.group(CLOZE_REGEX_MATCH_GROUP_HINT):
|
|
return "[%s]" % m.group(CLOZE_REGEX_MATCH_GROUP_HINT)
|
|
else:
|
|
return "[...]"
|
|
|
|
def arepl(m):
|
|
return m.group(CLOZE_REGEX_MATCH_GROUP_CONTENT)
|
|
|
|
for ord in ords:
|
|
s = re.sub(clozeReg % ord, qrepl, string)
|
|
s = re.sub(clozeReg % ".+?", arepl, s)
|
|
strings.append(s)
|
|
strings.append(re.sub(clozeReg % ".+?", arepl, string))
|
|
return strings
|
|
|
|
def transformNames(self, txt: str, func: Callable) -> Any:
|
|
for reg in self.regexps:
|
|
txt = re.sub(reg, func, txt)
|
|
return txt
|
|
|
|
def strip(self, txt: str) -> str:
|
|
for reg in self.regexps:
|
|
txt = re.sub(reg, "", txt)
|
|
return txt
|
|
|
|
def escapeImages(self, string: str, unescape: bool = False) -> str:
|
|
fn: Callable
|
|
if unescape:
|
|
fn = urllib.parse.unquote
|
|
else:
|
|
fn = urllib.parse.quote
|
|
|
|
def repl(match):
|
|
tag = match.group(0)
|
|
fname = match.group("fname")
|
|
if re.match("(https?|ftp)://", fname):
|
|
return tag
|
|
return tag.replace(fname, fn(fname))
|
|
|
|
for reg in self.imgRegexps:
|
|
string = re.sub(reg, repl, string)
|
|
return string
|
|
|
|
# Rebuilding DB
|
|
##########################################################################
|
|
|
|
def check(
|
|
self, local: Optional[List[str]] = None
|
|
) -> Tuple[List[str], List[str], List[str]]:
|
|
"Return (missingFiles, unusedFiles)."
|
|
mdir = self.dir()
|
|
# gather all media references in NFC form
|
|
allRefs = set()
|
|
for nid, mid, flds in self.col.db.execute("select id, mid, flds from notes"):
|
|
noteRefs = self.filesInStr(mid, flds)
|
|
# check the refs are in NFC
|
|
for f in noteRefs:
|
|
# if they're not, we'll need to fix them first
|
|
if f != unicodedata.normalize("NFC", f):
|
|
self._normalizeNoteRefs(nid)
|
|
noteRefs = self.filesInStr(mid, flds)
|
|
break
|
|
allRefs.update(noteRefs)
|
|
# loop through media folder
|
|
unused = []
|
|
if local is None:
|
|
files = os.listdir(mdir)
|
|
else:
|
|
files = local
|
|
renamedFiles = False
|
|
dirFound = False
|
|
warnings = []
|
|
for file in files:
|
|
if not local:
|
|
if not os.path.isfile(file):
|
|
# ignore directories
|
|
dirFound = True
|
|
continue
|
|
if file.startswith("_"):
|
|
# leading _ says to ignore file
|
|
continue
|
|
|
|
if self.hasIllegal(file):
|
|
name = file.encode(sys.getfilesystemencoding(), errors="replace")
|
|
name = str(name, sys.getfilesystemencoding())
|
|
warnings.append(_("Invalid file name, please rename: %s") % name)
|
|
continue
|
|
|
|
nfcFile = unicodedata.normalize("NFC", file)
|
|
# we enforce NFC fs encoding on non-macs
|
|
if not isMac and not local:
|
|
if file != nfcFile:
|
|
# delete if we already have the NFC form, otherwise rename
|
|
if os.path.exists(nfcFile):
|
|
os.unlink(file)
|
|
renamedFiles = True
|
|
else:
|
|
os.rename(file, nfcFile)
|
|
renamedFiles = True
|
|
file = nfcFile
|
|
# compare
|
|
if nfcFile not in allRefs:
|
|
unused.append(file)
|
|
else:
|
|
allRefs.discard(nfcFile)
|
|
# if we renamed any files to nfc format, we must rerun the check
|
|
# to make sure the renamed files are not marked as unused
|
|
if renamedFiles:
|
|
return self.check(local=local)
|
|
nohave = [x for x in allRefs if not x.startswith("_")]
|
|
# make sure the media DB is valid
|
|
try:
|
|
self.findChanges()
|
|
except DBError:
|
|
self._deleteDB()
|
|
|
|
if dirFound:
|
|
warnings.append(
|
|
_(
|
|
"Anki does not support files in subfolders of the collection.media folder."
|
|
)
|
|
)
|
|
return (nohave, unused, warnings)
|
|
|
|
def _normalizeNoteRefs(self, nid) -> None:
|
|
note = self.col.getNote(nid)
|
|
for c, fld in enumerate(note.fields):
|
|
nfc = unicodedata.normalize("NFC", fld)
|
|
if nfc != fld:
|
|
note.fields[c] = nfc
|
|
note.flush()
|
|
|
|
# Copying on import
|
|
##########################################################################
|
|
|
|
def have(self, fname: str) -> bool:
|
|
return os.path.exists(os.path.join(self.dir(), fname))
|
|
|
|
# Illegal characters and paths
|
|
##########################################################################
|
|
|
|
_illegalCharReg = re.compile(r'[][><:"/?*^\\|\0\r\n]')
|
|
|
|
def stripIllegal(self, str: str) -> str:
|
|
return re.sub(self._illegalCharReg, "", str)
|
|
|
|
def hasIllegal(self, s: str) -> bool:
|
|
if re.search(self._illegalCharReg, s):
|
|
return True
|
|
try:
|
|
s.encode(sys.getfilesystemencoding())
|
|
except UnicodeEncodeError:
|
|
return True
|
|
return False
|
|
|
|
def cleanFilename(self, fname: str) -> str:
|
|
fname = self.stripIllegal(fname)
|
|
fname = self._cleanWin32Filename(fname)
|
|
fname = self._cleanLongFilename(fname)
|
|
if not fname:
|
|
fname = "renamed"
|
|
|
|
return fname
|
|
|
|
def _cleanWin32Filename(self, fname: str) -> str:
|
|
if not isWin:
|
|
return fname
|
|
|
|
# deal with things like con/prn/etc
|
|
p = pathlib.WindowsPath(fname)
|
|
if p.is_reserved():
|
|
fname = "renamed" + fname
|
|
assert not pathlib.WindowsPath(fname).is_reserved()
|
|
|
|
return fname
|
|
|
|
def _cleanLongFilename(self, fname: str) -> Any:
|
|
# a fairly safe limit that should work on typical windows
|
|
# paths and on eCryptfs partitions, even with a duplicate
|
|
# suffix appended
|
|
namemax = 136
|
|
|
|
if isWin:
|
|
pathmax = 240
|
|
else:
|
|
pathmax = 1024
|
|
|
|
# cap namemax based on absolute path
|
|
dirlen = len(os.path.dirname(os.path.abspath(fname)))
|
|
remaining = pathmax - dirlen
|
|
namemax = min(remaining, namemax)
|
|
assert namemax > 0
|
|
|
|
if len(fname) > namemax:
|
|
head, ext = os.path.splitext(fname)
|
|
headmax = namemax - len(ext)
|
|
head = head[0:headmax]
|
|
fname = head + ext
|
|
assert len(fname) <= namemax
|
|
|
|
return fname
|
|
|
|
# Tracking changes
|
|
##########################################################################
|
|
|
|
def findChanges(self) -> None:
|
|
"Scan the media folder if it's changed, and note any changes."
|
|
if self._changed():
|
|
self._logChanges()
|
|
|
|
def haveDirty(self) -> Any:
|
|
return self.db.scalar("select 1 from media where dirty=1 limit 1")
|
|
|
|
def _mtime(self, path: str) -> int:
|
|
return int(os.stat(path).st_mtime)
|
|
|
|
def _checksum(self, path: str) -> str:
|
|
with open(path, "rb") as f:
|
|
return checksum(f.read())
|
|
|
|
def _changed(self) -> int:
|
|
"Return dir mtime if it has changed since the last findChanges()"
|
|
# doesn't track edits, but user can add or remove a file to update
|
|
mod = self.db.scalar("select dirMod from meta")
|
|
mtime = self._mtime(self.dir())
|
|
if not self._isFAT32() and mod and mod == mtime:
|
|
return False
|
|
return mtime
|
|
|
|
def _logChanges(self) -> None:
|
|
(added, removed) = self._changes()
|
|
media = []
|
|
for f, mtime in added:
|
|
media.append((f, self._checksum(f), mtime, 1))
|
|
for f in removed:
|
|
media.append((f, None, 0, 1))
|
|
# update media db
|
|
self.db.executemany("insert or replace into media values (?,?,?,?)", media)
|
|
self.db.execute("update meta set dirMod = ?", self._mtime(self.dir()))
|
|
self.db.commit()
|
|
|
|
def _changes(self) -> Tuple[List[Tuple[str, int]], List[str]]:
|
|
self.cache: Dict[str, Any] = {}
|
|
for (name, csum, mod) in self.db.execute(
|
|
"select fname, csum, mtime from media where csum is not null"
|
|
):
|
|
# previous entries may not have been in NFC form
|
|
normname = unicodedata.normalize("NFC", name)
|
|
self.cache[normname] = [csum, mod, False]
|
|
added = []
|
|
removed = []
|
|
# loop through on-disk files
|
|
with os.scandir(self.dir()) as it:
|
|
for f in it:
|
|
# ignore folders and thumbs.db
|
|
if f.is_dir():
|
|
continue
|
|
if f.name.lower() == "thumbs.db":
|
|
continue
|
|
# and files with invalid chars
|
|
if self.hasIllegal(f.name):
|
|
continue
|
|
# empty files are invalid; clean them up and continue
|
|
sz = f.stat().st_size
|
|
if not sz:
|
|
os.unlink(f.name)
|
|
continue
|
|
if sz > 100 * 1024 * 1024:
|
|
self.col.log("ignoring file over 100MB", f.name)
|
|
continue
|
|
# check encoding
|
|
normname = unicodedata.normalize("NFC", f.name)
|
|
if not isMac:
|
|
if f.name != normname:
|
|
# wrong filename encoding which will cause sync errors
|
|
if os.path.exists(normname):
|
|
os.unlink(f.name)
|
|
else:
|
|
os.rename(f.name, normname)
|
|
else:
|
|
# on Macs we can access the file using any normalization
|
|
pass
|
|
|
|
# newly added?
|
|
mtime = int(f.stat().st_mtime)
|
|
if normname not in self.cache:
|
|
added.append((normname, mtime))
|
|
else:
|
|
# modified since last time?
|
|
if mtime != self.cache[normname][1]:
|
|
# and has different checksum?
|
|
if self._checksum(normname) != self.cache[normname][0]:
|
|
added.append((normname, mtime))
|
|
# mark as used
|
|
self.cache[normname][2] = True
|
|
# look for any entries in the cache that no longer exist on disk
|
|
for (k, v) in list(self.cache.items()):
|
|
if not v[2]:
|
|
removed.append(k)
|
|
return added, removed
|
|
|
|
# Syncing-related
|
|
##########################################################################
|
|
|
|
def lastUsn(self) -> Any:
|
|
return self.db.scalar("select lastUsn from meta")
|
|
|
|
def setLastUsn(self, usn) -> None:
|
|
self.db.execute("update meta set lastUsn = ?", usn)
|
|
self.db.commit()
|
|
|
|
def syncInfo(self, fname) -> Any:
|
|
ret = self.db.first("select csum, dirty from media where fname=?", fname)
|
|
return ret or (None, 0)
|
|
|
|
def markClean(self, fnames) -> None:
|
|
for fname in fnames:
|
|
self.db.execute("update media set dirty=0 where fname=?", fname)
|
|
|
|
def syncDelete(self, fname) -> None:
|
|
if os.path.exists(fname):
|
|
os.unlink(fname)
|
|
self.db.execute("delete from media where fname=?", fname)
|
|
|
|
def mediaCount(self) -> Any:
|
|
return self.db.scalar("select count() from media where csum is not null")
|
|
|
|
def dirtyCount(self) -> Any:
|
|
return self.db.scalar("select count() from media where dirty=1")
|
|
|
|
def forceResync(self) -> None:
|
|
self.db.execute("delete from media")
|
|
self.db.execute("update meta set lastUsn=0,dirMod=0")
|
|
self.db.commit()
|
|
self.db.setAutocommit(True)
|
|
self.db.execute("vacuum")
|
|
self.db.execute("analyze")
|
|
self.db.setAutocommit(False)
|
|
|
|
# Media syncing: zips
|
|
##########################################################################
|
|
|
|
def mediaChangesZip(self) -> Tuple[bytes, list]:
|
|
f = io.BytesIO()
|
|
z = zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED)
|
|
|
|
fnames = []
|
|
# meta is list of (fname, zipname), where zipname of None
|
|
# is a deleted file
|
|
meta = []
|
|
sz = 0
|
|
|
|
for c, (fname, csum) in enumerate(
|
|
self.db.execute(
|
|
"select fname, csum from media where dirty=1"
|
|
" limit %d" % SYNC_ZIP_COUNT
|
|
)
|
|
):
|
|
|
|
fnames.append(fname)
|
|
normname = unicodedata.normalize("NFC", fname)
|
|
|
|
if csum:
|
|
self.col.log("+media zip", fname)
|
|
z.write(fname, str(c))
|
|
meta.append((normname, str(c)))
|
|
sz += os.path.getsize(fname)
|
|
else:
|
|
self.col.log("-media zip", fname)
|
|
meta.append((normname, ""))
|
|
|
|
if sz >= SYNC_ZIP_SIZE:
|
|
break
|
|
|
|
z.writestr("_meta", json.dumps(meta))
|
|
z.close()
|
|
return f.getvalue(), fnames
|
|
|
|
def addFilesFromZip(self, zipData) -> int:
|
|
"Extract zip data; true if finished."
|
|
f = io.BytesIO(zipData)
|
|
z = zipfile.ZipFile(f, "r")
|
|
media = []
|
|
# get meta info first
|
|
meta = json.loads(z.read("_meta").decode("utf8"))
|
|
# then loop through all files
|
|
cnt = 0
|
|
for i in z.infolist():
|
|
if i.filename == "_meta":
|
|
# ignore previously-retrieved meta
|
|
continue
|
|
else:
|
|
data = z.read(i)
|
|
csum = checksum(data)
|
|
name = meta[i.filename]
|
|
# normalize name
|
|
name = unicodedata.normalize("NFC", name)
|
|
# save file
|
|
with open(name, "wb") as f: # type: ignore
|
|
f.write(data)
|
|
# update db
|
|
media.append((name, csum, self._mtime(name), 0))
|
|
cnt += 1
|
|
if media:
|
|
self.db.executemany("insert or replace into media values (?,?,?,?)", media)
|
|
return cnt
|