New type-safe approach to hooks/filters

Still todo:
- Add separate module for GUI hooks
- Update the remaining runHook/runFilter() calls
- Document the changes, including defensive registration
This commit is contained in:
Damien Elmes 2020-01-13 13:57:51 +10:00
parent b42912e639
commit dd61389319
11 changed files with 243 additions and 31 deletions

View file

@ -27,7 +27,11 @@ PROTODEPS := $(wildcard ../proto/*.proto)
protoc --proto_path=../proto --python_out=anki --mypy_out=anki $(PROTODEPS) protoc --proto_path=../proto --python_out=anki --mypy_out=anki $(PROTODEPS)
@touch $@ @touch $@
BUILD_STEPS := .build/run-deps .build/dev-deps .build/py-proto anki/buildinfo.py .build/hooks: tools/genhooks.py
python tools/genhooks.py
@touch $@
BUILD_STEPS := .build/run-deps .build/dev-deps .build/py-proto .build/hooks anki/buildinfo.py
# Checking # Checking
###################### ######################

View file

@ -8,8 +8,8 @@ import time
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import anki # pylint: disable=unused-import import anki # pylint: disable=unused-import
from anki import hooks
from anki.consts import * from anki.consts import *
from anki.hooks import runHook
from anki.notes import Note from anki.notes import Note
from anki.utils import intTime, joinFields, timestampID from anki.utils import intTime, joinFields, timestampID
@ -87,7 +87,7 @@ class Card:
self.usn = self.col.usn() self.usn = self.col.usn()
# bug check # bug check
if self.queue == 2 and self.odue and not self.col.decks.isDyn(self.did): if self.queue == 2 and self.odue and not self.col.decks.isDyn(self.did):
runHook("odueInvalid") hooks.run_odue_invalid_hook()
assert self.due < 4294967296 assert self.due < 4294967296
self.col.db.execute( self.col.db.execute(
""" """
@ -119,7 +119,7 @@ insert or replace into cards values
self.usn = self.col.usn() self.usn = self.col.usn()
# bug checks # bug checks
if self.queue == 2 and self.odue and not self.col.decks.isDyn(self.did): if self.queue == 2 and self.odue and not self.col.decks.isDyn(self.did):
runHook("odueInvalid") hooks.run_odue_invalid_hook()
assert self.due < 4294967296 assert self.due < 4294967296
self.col.db.execute( self.col.db.execute(
"""update cards set """update cards set

View file

@ -16,6 +16,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import anki.find import anki.find
import anki.latex # sets up hook import anki.latex # sets up hook
import anki.template import anki.template
from anki import hooks
from anki.cards import Card from anki.cards import Card
from anki.consts import * from anki.consts import *
from anki.db import DB from anki.db import DB
@ -271,7 +272,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
def modSchema(self, check: bool) -> None: def modSchema(self, check: bool) -> None:
"Mark schema modified. Call this first so user can abort if necessary." "Mark schema modified. Call this first so user can abort if necessary."
if not self.schemaChanged(): if not self.schemaChanged():
if check and not runFilter("modSchema", True): if check and not hooks.run_mod_schema_filter(proceed=True):
raise AnkiError("abortSchemaMod") raise AnkiError("abortSchemaMod")
self.scm = intTime(1000) self.scm = intTime(1000)
self.setMod() self.setMod()

View file

@ -1,22 +1,77 @@
# Copyright: Ankitects Pty Ltd and contributors # Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
"""\
Hooks - hook management and tools for extending Anki
==============================================================================
To find available hooks, grep for runHook and runFilter in the source code.
Instrumenting allows you to modify functions that don't have hooks available.
If you call wrap() with pos='around', the original function will not be called
automatically but can be called with _old().
""" """
Tools for extending Anki.
A hook takes a function that does not return a value.
A filter takes a function that returns its first argument, optionally
modifying it.
"""
from __future__ import annotations
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
import decorator import decorator
# Hooks from anki.cards import Card
# New hook/filter handling
##############################################################################
# The code in this section is automatically generated - any edits you make
# will be lost. To add new hooks, see ../tools/genhooks.py
#
# To use an existing hook such as leech_hook, you would call the following
# in your code:
#
# from anki import hooks
# hooks.leech_hook.append(myfunc)
#
# @@AUTOGEN@@
leech_hook: List[Callable[[Card], None]] = []
mod_schema_filter: List[Callable[[bool], bool]] = []
odue_invalid_hook: List[Callable[[], None]] = []
def run_leech_hook(card: Card) -> None:
for hook in leech_hook:
try:
hook(card)
except:
# if the hook fails, remove it
leech_hook.remove(hook)
raise
# legacy support
runHook("leech", card)
def run_mod_schema_filter(proceed: bool) -> bool:
for filter in mod_schema_filter:
try:
proceed = filter(proceed)
except:
# if the hook fails, remove it
mod_schema_filter.remove(filter)
raise
return proceed
def run_odue_invalid_hook() -> None:
for hook in odue_invalid_hook:
try:
hook()
except:
# if the hook fails, remove it
odue_invalid_hook.remove(hook)
raise
# @@AUTOGEN@@
# Legacy hook handling
############################################################################## ##############################################################################
_hooks: Dict[str, List[Callable[..., Any]]] = {} _hooks: Dict[str, List[Callable[..., Any]]] = {}
@ -61,10 +116,14 @@ def remHook(hook, func) -> None:
hook.remove(func) hook.remove(func)
# Instrumenting # Monkey patching
############################################################################## ##############################################################################
# Please only use this for prototyping or for when hooks are not practical,
# as add-ons that use monkey patching are more likely to break when Anki is
# updated.
#
# If you call wrap() with pos='around', the original function will not be called
# automatically but can be called with _old().
def wrap(old, new, pos="after") -> Callable: def wrap(old, new, pos="after") -> Callable:
"Override an existing function." "Override an existing function."

View file

@ -7,8 +7,8 @@ import time
from heapq import * from heapq import *
from operator import itemgetter from operator import itemgetter
from anki import hooks
from anki.consts import * from anki.consts import *
from anki.hooks import runHook
from anki.lang import _ from anki.lang import _
# from anki.cards import Card # from anki.cards import Card
@ -1150,7 +1150,7 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?"""
card.odue = card.odid = 0 card.odue = card.odid = 0
card.queue = -1 card.queue = -1
# notify UI # notify UI
runHook("leech", card) hooks.run_leech_hook(card)
return True return True
# Tools # Tools

View file

@ -14,9 +14,9 @@ from operator import itemgetter
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import anki # pylint: disable=unused-import import anki # pylint: disable=unused-import
from anki import hooks
from anki.cards import Card from anki.cards import Card
from anki.consts import * from anki.consts import *
from anki.hooks import runHook
from anki.lang import _ from anki.lang import _
from anki.rsbackend import SchedTimingToday from anki.rsbackend import SchedTimingToday
from anki.utils import fmtTimeSpan, ids2str, intTime from anki.utils import fmtTimeSpan, ids2str, intTime
@ -1270,7 +1270,7 @@ where id = ?
if a == 0: if a == 0:
card.queue = -1 card.queue = -1
# notify UI # notify UI
runHook("leech", card) hooks.run_leech_hook(card)
return True return True
return None return None

View file

@ -3,8 +3,8 @@
import copy import copy
import time import time
from anki import hooks
from anki.consts import STARTING_FACTOR from anki.consts import STARTING_FACTOR
from anki.hooks import addHook
from anki.utils import intTime from anki.utils import intTime
from tests.shared import getEmptyCol as getEmptyColOrig from tests.shared import getEmptyCol as getEmptyColOrig
@ -373,7 +373,7 @@ def test_reviews():
def onLeech(card): def onLeech(card):
hooked.append(1) hooked.append(1)
addHook("leech", onLeech) hooks.leech_hook.append(onLeech)
d.sched.answerCard(c, 1) d.sched.answerCard(c, 1)
assert hooked assert hooked
assert c.queue == -1 assert c.queue == -1

View file

@ -3,8 +3,8 @@
import copy import copy
import time import time
from anki import hooks
from anki.consts import STARTING_FACTOR from anki.consts import STARTING_FACTOR
from anki.hooks import addHook
from anki.utils import intTime from anki.utils import intTime
from tests.shared import getEmptyCol as getEmptyColOrig from tests.shared import getEmptyCol as getEmptyColOrig
@ -395,7 +395,7 @@ def test_reviews():
def onLeech(card): def onLeech(card):
hooked.append(1) hooked.append(1)
addHook("leech", onLeech) hooks.leech_hook.append(onLeech)
d.sched.answerCard(c, 1) d.sched.answerCard(c, 1)
assert hooked assert hooked
assert c.queue == -1 assert c.queue == -1

146
pylib/tools/genhooks.py Normal file
View file

@ -0,0 +1,146 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
"""
Generate code for hook handling, and insert it into anki/hooks.py.
To add a new hook:
- update the hooks list below
- run 'make develop'
- send a pull request that includes the changes to this file and hooks.py
"""
import os
import re
from dataclasses import dataclass
from operator import attrgetter
from typing import Optional, List
@dataclass
class Hook:
# the name of the hook. _filter or _hook is appending automatically.
name: str
# string of the typed arguments passed to the callback, eg
# "kind: str, val: int"
cb_args: str = ""
# string of the return type. if set, hook is a filter.
return_type: Optional[str] = None
# if add-ons may be relying on the legacy hook name, add it here
legacy_hook: Optional[str] = None
def callable(self) -> str:
"Convert args into a Callable."
types = []
for arg in self.cb_args.split(","):
if not arg:
continue
(name, type) = arg.split(":")
types.append(type.strip())
types_str = ", ".join(types)
return f"Callable[[{types_str}], {self.return_type or 'None'}]"
def arg_names(self) -> List[str]:
names = []
for arg in self.cb_args.split(","):
if not arg:
continue
(name, type) = arg.split(":")
names.append(name.strip())
return names
def full_name(self) -> str:
return f"{self.name}_{self.kind()}"
def kind(self) -> str:
if self.return_type is not None:
return "filter"
else:
return "hook"
def list_code(self) -> str:
return f"""\
{self.full_name()}: List[{self.callable()}] = []
"""
def fire_code(self) -> str:
if self.return_type is not None:
# filter
return self.filter_fire_code()
else:
# hook
return self.hook_fire_code()
def hook_fire_code(self) -> str:
arg_names = self.arg_names()
out = f"""\
def run_{self.full_name()}({self.cb_args}) -> None:
for hook in {self.full_name()}:
try:
hook({", ".join(arg_names)})
except:
# if the hook fails, remove it
{self.full_name()}.remove(hook)
raise
"""
if self.legacy_hook:
args = ", ".join([f'"{self.legacy_hook}"'] + arg_names)
out += f"""\
# legacy support
runHook({args})
"""
return out + "\n\n"
def filter_fire_code(self) -> str:
arg_names = self.arg_names()
out = f"""\
def run_{self.full_name()}({self.cb_args}) -> {self.return_type}:
for filter in {self.full_name()}:
try:
{arg_names[0]} = filter({", ".join(arg_names)})
except:
# if the hook fails, remove it
{self.full_name()}.remove(filter)
raise
"""
if self.legacy_hook:
args = ", ".join([f'"{self.legacy_hook}"'] + arg_names)
out += f"""\
# legacy support
runFilter({args})
"""
out += f"""\
return {arg_names[0]}
"""
return out + "\n\n"
# Hook list
######################################################################
hooks = [
Hook(name="leech", cb_args="card: Card", legacy_hook="leech"),
Hook(name="odue_invalid"),
Hook(name="mod_schema", cb_args="proceed: bool", return_type="bool")
]
hooks.sort(key=attrgetter("name"))
######################################################################
tools_dir = os.path.dirname(__file__)
hooks_py = os.path.join(tools_dir, "..", "anki", "hooks.py")
code = ""
for hook in hooks:
code += hook.list_code()
code += "\n\n"
for hook in hooks:
code += hook.fire_code()
orig = open(hooks_py).read()
new = re.sub("(?s)# @@AUTOGEN@@.*?# @@AUTOGEN@@\n", f"# @@AUTOGEN@@\n\n{code}# @@AUTOGEN@@\n", orig)
open(hooks_py, "wb").write(new.encode("utf8"))
print("Updated hooks.py")

View file

@ -24,6 +24,7 @@ import aqt.sound
import aqt.stats import aqt.stats
import aqt.toolbar import aqt.toolbar
import aqt.webview import aqt.webview
from anki import hooks
from anki.collection import _Collection from anki.collection import _Collection
from anki.hooks import addHook, runFilter, runHook from anki.hooks import addHook, runFilter, runHook
from anki.lang import _, ngettext from anki.lang import _, ngettext
@ -1153,9 +1154,9 @@ Difference to correct time: %s."""
########################################################################## ##########################################################################
def setupHooks(self) -> None: def setupHooks(self) -> None:
addHook("modSchema", self.onSchemaMod) hooks.mod_schema_filter.append(self.onSchemaMod)
addHook("remNotes", self.onRemNotes) addHook("remNotes", self.onRemNotes)
addHook("odueInvalid", self.onOdueInvalid) hooks.odue_invalid_hook.append(self.onOdueInvalid)
addHook("mpvWillPlay", self.onMpvWillPlay) addHook("mpvWillPlay", self.onMpvWillPlay)
addHook("mpvIdleHook", self.onMpvIdle) addHook("mpvIdleHook", self.onMpvIdle)

View file

@ -11,8 +11,9 @@ import unicodedata as ucd
from typing import List from typing import List
import aqt import aqt
from anki import hooks
from anki.cards import Card from anki.cards import Card
from anki.hooks import addHook, runFilter, runHook from anki.hooks import runFilter, runHook
from anki.lang import _, ngettext from anki.lang import _, ngettext
from anki.utils import bodyClass, stripHTML from anki.utils import bodyClass, stripHTML
from aqt import AnkiQt from aqt import AnkiQt
@ -30,7 +31,7 @@ from aqt.utils import (
class Reviewer: class Reviewer:
"Manage reviews. Maintains a separate state." "Manage reviews. Maintains a separate state."
def __init__(self, mw: AnkiQt): def __init__(self, mw: AnkiQt) -> None:
self.mw = mw self.mw = mw
self.web = mw.web self.web = mw.web
self.card = None self.card = None
@ -41,7 +42,7 @@ class Reviewer:
self.typeCorrect = None # web init happens before this is set self.typeCorrect = None # web init happens before this is set
self.state = None self.state = None
self.bottom = aqt.toolbar.BottomBar(mw, mw.bottomWeb) self.bottom = aqt.toolbar.BottomBar(mw, mw.bottomWeb)
addHook("leech", self.onLeech) hooks.leech_hook.append(self.onLeech)
def show(self): def show(self):
self.mw.col.reset() self.mw.col.reset()