mirror of
https://github.com/ankitects/anki.git
synced 2025-09-18 22:12:21 -04:00

- tweak sync code so that a failure in loading the collection won't leave the app with an unopen collection - don't show corrupt collection message when the error is not a db error - catch the temp folder issue when loading the collection. i suspect this was the issue that was causing some people to end up with an open anki instance with no collection loaded
511 lines
18 KiB
Python
511 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright: Damien Elmes <anki@ichi2.net>
|
|
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
|
|
|
import os
|
|
import re
|
|
import urllib
|
|
import unicodedata
|
|
import sys
|
|
import zipfile
|
|
from cStringIO import StringIO
|
|
|
|
import send2trash
|
|
from anki.utils import checksum, isWin, isMac, json
|
|
from anki.db import DB
|
|
from anki.consts import *
|
|
from anki.latex import mungeQA
|
|
|
|
|
|
class MediaManager(object):
|
|
|
|
soundRegexps = ["(?i)(\[sound:(?P<fname>[^]]+)\])"]
|
|
imgRegexps = [
|
|
# src element quoted case
|
|
"(?i)(<img[^>]+src=(?P<str>[\"'])(?P<fname>[^>]+?)(?P=str)[^>]*>)",
|
|
# unquoted case
|
|
"(?i)(<img[^>]+src=(?!['\"])(?P<fname>[^ >]+)[^>]*?>)",
|
|
]
|
|
regexps = soundRegexps + imgRegexps
|
|
|
|
def __init__(self, col, server):
|
|
self.col = col
|
|
if server:
|
|
self._dir = None
|
|
return
|
|
# media directory
|
|
self._dir = re.sub("(?i)\.(anki2)$", ".media", self.col.path)
|
|
# convert dir to unicode if it's not already
|
|
if isinstance(self._dir, str):
|
|
self._dir = unicode(self._dir, sys.getfilesystemencoding())
|
|
if not os.path.exists(self._dir):
|
|
os.makedirs(self._dir)
|
|
try:
|
|
self._oldcwd = os.getcwd()
|
|
except OSError:
|
|
# cwd doesn't exist
|
|
self._oldcwd = None
|
|
try:
|
|
os.chdir(self._dir)
|
|
except OSError:
|
|
raise Exception("invalidTempFolder")
|
|
# change database
|
|
self.connect()
|
|
|
|
def connect(self):
|
|
if self.col.server:
|
|
return
|
|
path = self.dir()+".db"
|
|
create = not os.path.exists(path)
|
|
os.chdir(self._dir)
|
|
self.db = DB(path)
|
|
if create:
|
|
self._initDB()
|
|
|
|
def close(self):
|
|
if self.col.server:
|
|
return
|
|
self.db.close()
|
|
self.db = None
|
|
# change cwd back to old location
|
|
if self._oldcwd:
|
|
try:
|
|
os.chdir(self._oldcwd)
|
|
except:
|
|
# may have been deleted
|
|
pass
|
|
|
|
def dir(self):
|
|
return self._dir
|
|
|
|
def _isFAT32(self):
|
|
if not isWin:
|
|
return
|
|
import win32api, win32file
|
|
try:
|
|
name = win32file.GetVolumeNameForVolumeMountPoint(self._dir[:3])
|
|
except:
|
|
# mapped & unmapped network drive; pray that it's not vfat
|
|
return
|
|
if win32api.GetVolumeInformation(name)[4].lower().startswith("fat"):
|
|
return True
|
|
|
|
# Adding media
|
|
##########################################################################
|
|
# opath must be in unicode
|
|
|
|
def addFile(self, opath):
|
|
return self.writeData(opath, open(opath, "rb").read())
|
|
|
|
def writeData(self, opath, data):
|
|
# if fname is a full path, use only the basename
|
|
fname = os.path.basename(opath)
|
|
# make sure we write it in NFC form (on mac will autoconvert to NFD),
|
|
# and return an NFC-encoded reference
|
|
fname = unicodedata.normalize("NFC", fname)
|
|
# remove any dangerous characters
|
|
base = self.stripIllegal(fname)
|
|
(root, ext) = os.path.splitext(base)
|
|
def repl(match):
|
|
n = int(match.group(1))
|
|
return " (%d)" % (n+1)
|
|
# find the first available name
|
|
csum = checksum(data)
|
|
while True:
|
|
fname = root + ext
|
|
path = os.path.join(self.dir(), fname)
|
|
# if it doesn't exist, copy it directly
|
|
if not os.path.exists(path):
|
|
open(path, "wb").write(data)
|
|
return fname
|
|
# if it's identical, reuse
|
|
if checksum(open(path, "rb").read()) == csum:
|
|
return fname
|
|
# otherwise, increment the index in the filename
|
|
reg = " \((\d+)\)$"
|
|
if not re.search(reg, root):
|
|
root = root + " (1)"
|
|
else:
|
|
root = re.sub(reg, repl, root)
|
|
|
|
# String manipulation
|
|
##########################################################################
|
|
|
|
def filesInStr(self, mid, string, includeRemote=False):
|
|
l = []
|
|
model = self.col.models.get(mid)
|
|
strings = []
|
|
if model['type'] == MODEL_CLOZE and "{{c" in string:
|
|
# if the field has clozes in it, we'll need to expand the
|
|
# possibilities so we can render latex
|
|
strings = self._expandClozes(string)
|
|
else:
|
|
strings = [string]
|
|
for string in strings:
|
|
# handle latex
|
|
string = mungeQA(string, None, None, model, None, self.col)
|
|
# extract filenames
|
|
for reg in self.regexps:
|
|
for match in re.finditer(reg, string):
|
|
fname = match.group("fname")
|
|
isLocal = not re.match("(https?|ftp)://", fname.lower())
|
|
if isLocal or includeRemote:
|
|
l.append(fname)
|
|
return l
|
|
|
|
def _expandClozes(self, string):
|
|
ords = set(re.findall("{{c(\d+)::.+?}}", string))
|
|
strings = []
|
|
from anki.template.template import clozeReg
|
|
def qrepl(m):
|
|
if m.group(3):
|
|
return "[%s]" % m.group(3)
|
|
else:
|
|
return "[...]"
|
|
def arepl(m):
|
|
return m.group(1)
|
|
for ord in ords:
|
|
s = re.sub(clozeReg%ord, qrepl, string)
|
|
s = re.sub(clozeReg%".+?", "\\1", s)
|
|
strings.append(s)
|
|
strings.append(re.sub(clozeReg%".+?", arepl, string))
|
|
return strings
|
|
|
|
def transformNames(self, txt, func):
|
|
for reg in self.regexps:
|
|
txt = re.sub(reg, func, txt)
|
|
return txt
|
|
|
|
def strip(self, txt):
|
|
for reg in self.regexps:
|
|
txt = re.sub(reg, "", txt)
|
|
return txt
|
|
|
|
def escapeImages(self, string):
|
|
def repl(match):
|
|
tag = match.group(0)
|
|
fname = match.group("fname")
|
|
if re.match("(https?|ftp)://", fname):
|
|
return tag
|
|
return tag.replace(
|
|
fname, urllib.quote(fname.encode("utf-8")))
|
|
for reg in self.imgRegexps:
|
|
string = re.sub(reg, repl, string)
|
|
return string
|
|
|
|
# Rebuilding DB
|
|
##########################################################################
|
|
|
|
def check(self, local=None):
|
|
"Return (missingFiles, unusedFiles)."
|
|
mdir = self.dir()
|
|
# gather all media references in NFC form
|
|
allRefs = set()
|
|
for nid, mid, flds in self.col.db.execute("select id, mid, flds from notes"):
|
|
noteRefs = self.filesInStr(mid, flds)
|
|
# check the refs are in NFC
|
|
for f in noteRefs:
|
|
# if they're not, we'll need to fix them first
|
|
if f != unicodedata.normalize("NFC", f):
|
|
self._normalizeNoteRefs(nid)
|
|
noteRefs = self.filesInStr(mid, flds)
|
|
break
|
|
allRefs.update(noteRefs)
|
|
# loop through media folder
|
|
unused = []
|
|
if local is None:
|
|
files = os.listdir(mdir)
|
|
else:
|
|
files = local
|
|
for file in files:
|
|
if not local:
|
|
if not os.path.isfile(file):
|
|
# ignore directories
|
|
continue
|
|
if file.startswith("_"):
|
|
# leading _ says to ignore file
|
|
continue
|
|
nfcFile = unicodedata.normalize("NFC", file)
|
|
# we enforce NFC fs encoding on non-macs; on macs we'll have gotten
|
|
# NFD so we use the above variable for comparing references
|
|
if not isMac and local:
|
|
if file != nfcFile:
|
|
# delete if we already have the NFC form, otherwise rename
|
|
if os.path.exists(nfcFile):
|
|
os.unlink(file)
|
|
else:
|
|
os.rename(file, nfcFile)
|
|
file = nfcFile
|
|
# compare
|
|
if nfcFile not in allRefs:
|
|
unused.append(file)
|
|
else:
|
|
allRefs.discard(nfcFile)
|
|
nohave = [x for x in allRefs if not x.startswith("_")]
|
|
return (nohave, unused)
|
|
|
|
def _normalizeNoteRefs(self, nid):
|
|
note = self.col.getNote(nid)
|
|
for c, fld in enumerate(note.fields):
|
|
nfc = unicodedata.normalize("NFC", fld)
|
|
if nfc != fld:
|
|
note.fields[c] = nfc
|
|
note.flush()
|
|
|
|
# Copying on import
|
|
##########################################################################
|
|
|
|
def have(self, fname):
|
|
return os.path.exists(os.path.join(self.dir(), fname))
|
|
|
|
# Media syncing - changes and removal
|
|
##########################################################################
|
|
|
|
def hasChanged(self):
|
|
return self.db.scalar("select 1 from log limit 1")
|
|
|
|
def removed(self):
|
|
return self.db.list("select * from log where type = ?", MEDIA_REM)
|
|
|
|
def syncRemove(self, fnames):
|
|
# remove provided deletions
|
|
for f in fnames:
|
|
if os.path.exists(f):
|
|
send2trash.send2trash(f)
|
|
self.db.execute("delete from log where fname = ?", f)
|
|
self.db.execute("delete from media where fname = ?", f)
|
|
# and all locally-logged deletions, as server has acked them
|
|
self.db.execute("delete from log where type = ?", MEDIA_REM)
|
|
self.db.commit()
|
|
|
|
# Media syncing - unbundling zip files from server
|
|
##########################################################################
|
|
|
|
def syncAdd(self, zipData):
|
|
"Extract zip data; true if finished."
|
|
f = StringIO(zipData)
|
|
z = zipfile.ZipFile(f, "r")
|
|
finished = False
|
|
meta = None
|
|
media = []
|
|
# get meta info first
|
|
meta = json.loads(z.read("_meta"))
|
|
nextUsn = int(z.read("_usn"))
|
|
# then loop through all files
|
|
for i in z.infolist():
|
|
if i.filename == "_meta" or i.filename == "_usn":
|
|
# ignore previously-retrieved meta
|
|
continue
|
|
elif i.filename == "_finished":
|
|
# last zip in set
|
|
finished = True
|
|
else:
|
|
data = z.read(i)
|
|
csum = checksum(data)
|
|
name = meta[i.filename]
|
|
if not isinstance(name, unicode):
|
|
name = unicode(name, "utf8")
|
|
# normalize name for platform
|
|
if isMac:
|
|
name = unicodedata.normalize("NFD", name)
|
|
else:
|
|
name = unicodedata.normalize("NFC", name)
|
|
# save file
|
|
open(name, "wb").write(data)
|
|
# update db
|
|
media.append((name, csum, self._mtime(name)))
|
|
# remove entries from local log
|
|
self.db.execute("delete from log where fname = ?", name)
|
|
# update media db and note new starting usn
|
|
if media:
|
|
self.db.executemany(
|
|
"insert or replace into media values (?,?,?)", media)
|
|
self.setUsn(nextUsn) # commits
|
|
# if we have finished adding, we need to record the new folder mtime
|
|
# so that we don't trigger a needless scan
|
|
if finished:
|
|
self.syncMod()
|
|
return finished
|
|
|
|
# Illegal characters
|
|
##########################################################################
|
|
|
|
_illegalCharReg = re.compile(r'[][><:"/?*^\\|\0]')
|
|
|
|
def stripIllegal(self, str):
|
|
return re.sub(self._illegalCharReg, "", str)
|
|
|
|
def hasIllegal(self, str):
|
|
return not not re.search(self._illegalCharReg, str)
|
|
|
|
# Media syncing - bundling zip files to send to server
|
|
##########################################################################
|
|
# Because there's no standard filename encoding for zips, and because not
|
|
# all zip clients support retrieving mtime, we store the files as ascii
|
|
# and place a json file in the zip with the necessary information.
|
|
|
|
def zipAdded(self):
|
|
"Add files to a zip until over SYNC_ZIP_SIZE/COUNT. Return zip data."
|
|
f = StringIO()
|
|
z = zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED)
|
|
sz = 0
|
|
cnt = 0
|
|
files = {}
|
|
cur = self.db.execute(
|
|
"select fname from log where type = ?", MEDIA_ADD)
|
|
fnames = []
|
|
while 1:
|
|
fname = cur.fetchone()
|
|
if not fname:
|
|
# add a flag so the server knows it can clean up
|
|
z.writestr("_finished", "")
|
|
break
|
|
fname = fname[0]
|
|
# we add it as a one-element array simply to make
|
|
# the later forgetAdded() call easier
|
|
fnames.append([fname])
|
|
z.write(fname, str(cnt))
|
|
files[str(cnt)] = unicodedata.normalize("NFC", fname)
|
|
sz += os.path.getsize(fname)
|
|
if sz > SYNC_ZIP_SIZE or cnt > SYNC_ZIP_COUNT:
|
|
break
|
|
cnt += 1
|
|
z.writestr("_meta", json.dumps(files))
|
|
z.close()
|
|
return f.getvalue(), fnames
|
|
|
|
def forgetAdded(self, fnames):
|
|
if not fnames:
|
|
return
|
|
self.db.executemany("delete from log where fname = ?", fnames)
|
|
self.db.commit()
|
|
|
|
# Tracking changes (private)
|
|
##########################################################################
|
|
|
|
def _initDB(self):
|
|
self.db.executescript("""
|
|
create table media (fname text primary key, csum text, mod int);
|
|
create table meta (dirMod int, usn int); insert into meta values (0, 0);
|
|
create table log (fname text primary key, type int);
|
|
""")
|
|
|
|
def _mtime(self, path):
|
|
return int(os.stat(path).st_mtime)
|
|
|
|
def _checksum(self, path):
|
|
return checksum(open(path, "rb").read())
|
|
|
|
def usn(self):
|
|
return self.db.scalar("select usn from meta")
|
|
|
|
def setUsn(self, usn):
|
|
self.db.execute("update meta set usn = ?", usn)
|
|
self.db.commit()
|
|
|
|
def syncMod(self):
|
|
self.db.execute("update meta set dirMod = ?", self._mtime(self.dir()))
|
|
self.db.commit()
|
|
|
|
def _changed(self):
|
|
"Return dir mtime if it has changed since the last findChanges()"
|
|
# doesn't track edits, but user can add or remove a file to update
|
|
mod = self.db.scalar("select dirMod from meta")
|
|
mtime = self._mtime(self.dir())
|
|
if not self._isFAT32() and mod and mod == mtime:
|
|
return False
|
|
return mtime
|
|
|
|
def findChanges(self):
|
|
"Scan the media folder if it's changed, and note any changes."
|
|
if self._changed():
|
|
self._logChanges()
|
|
|
|
def _logChanges(self):
|
|
(added, removed) = self._changes()
|
|
log = []
|
|
media = []
|
|
mediaRem = []
|
|
for f in added:
|
|
mt = self._mtime(f)
|
|
media.append((f, self._checksum(f), mt))
|
|
log.append((f, MEDIA_ADD))
|
|
for f in removed:
|
|
mediaRem.append((f,))
|
|
log.append((f, MEDIA_REM))
|
|
# update media db
|
|
self.db.executemany("insert or replace into media values (?,?,?)",
|
|
media)
|
|
if mediaRem:
|
|
self.db.executemany("delete from media where fname = ?",
|
|
mediaRem)
|
|
self.db.execute("update meta set dirMod = ?", self._mtime(self.dir()))
|
|
# and logs
|
|
self.db.executemany("insert or replace into log values (?,?)", log)
|
|
self.db.commit()
|
|
|
|
def _changes(self):
|
|
self.cache = {}
|
|
for (name, csum, mod) in self.db.execute(
|
|
"select * from media"):
|
|
self.cache[name] = [csum, mod, False]
|
|
added = []
|
|
removed = []
|
|
# loop through on-disk files
|
|
for f in os.listdir(self.dir()):
|
|
# ignore folders and thumbs.db
|
|
if os.path.isdir(f):
|
|
continue
|
|
if f.lower() == "thumbs.db":
|
|
continue
|
|
# and files with invalid chars
|
|
if self.hasIllegal(f):
|
|
continue
|
|
# empty files are invalid; clean them up and continue
|
|
if not os.path.getsize(f):
|
|
os.unlink(f)
|
|
continue
|
|
# newly added?
|
|
if f not in self.cache:
|
|
added.append(f)
|
|
else:
|
|
# modified since last time?
|
|
if self._mtime(f) != self.cache[f][1]:
|
|
# and has different checksum?
|
|
if self._checksum(f) != self.cache[f][0]:
|
|
added.append(f)
|
|
# mark as used
|
|
self.cache[f][2] = True
|
|
# look for any entries in the cache that no longer exist on disk
|
|
for (k, v) in self.cache.items():
|
|
if not v[2]:
|
|
removed.append(k)
|
|
return added, removed
|
|
|
|
def sanityCheck(self):
|
|
assert not self.db.scalar("select count() from log")
|
|
cnt = self.db.scalar("select count() from media")
|
|
return cnt
|
|
|
|
def forceResync(self):
|
|
self.db.execute("delete from media")
|
|
self.db.execute("delete from log")
|
|
self.db.execute("update meta set usn = 0, dirMod = 0")
|
|
self.db.commit()
|
|
|
|
def removeExisting(self, files):
|
|
"Remove files from list of files to sync, and return missing files."
|
|
need = []
|
|
remove = []
|
|
for f in files:
|
|
if self.db.scalar("select 1 from log where fname=?", f):
|
|
remove.append((f,))
|
|
else:
|
|
need.append(f)
|
|
self.db.executemany("delete from log where fname=?", remove)
|
|
self.db.commit()
|
|
# if we need all the server files, it's faster to pass None than
|
|
# the full list
|
|
if need and len(files) == len(need):
|
|
return None
|
|
return need
|