diff --git a/pylib/anki/dbproxy.py b/pylib/anki/dbproxy.py index 6ce136160..534f0e35f 100644 --- a/pylib/anki/dbproxy.py +++ b/pylib/anki/dbproxy.py @@ -3,7 +3,8 @@ from __future__ import annotations -from typing import Any, Iterable, List, Optional, Sequence, Union +import re +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import anki @@ -41,34 +42,35 @@ class DBProxy: ################ def _query( - self, sql: str, *args: ValueForDB, first_row_only: bool = False + self, sql: str, *args: ValueForDB, first_row_only: bool = False, **kwargs ) -> List[Row]: # mark modified? s = sql.strip().lower() for stmt in "insert", "update", "delete": if s.startswith(stmt): self.mod = True + sql, args2 = emulate_named_args(sql, args, kwargs) # fetch rows - return self._backend.db_query(sql, args, first_row_only) + return self._backend.db_query(sql, args2, first_row_only) # Query shortcuts ################### - def all(self, sql: str, *args: ValueForDB) -> List[Row]: - return self._query(sql, *args) + def all(self, sql: str, *args: ValueForDB, **kwargs) -> List[Row]: + return self._query(sql, *args, **kwargs) - def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]: - return [x[0] for x in self._query(sql, *args)] + def list(self, sql: str, *args: ValueForDB, **kwargs) -> List[ValueFromDB]: + return [x[0] for x in self._query(sql, *args, **kwargs)] - def first(self, sql: str, *args: ValueForDB) -> Optional[Row]: - rows = self._query(sql, *args, first_row_only=True) + def first(self, sql: str, *args: ValueForDB, **kwargs) -> Optional[Row]: + rows = self._query(sql, *args, first_row_only=True, **kwargs) if rows: return rows[0] else: return None - def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB: - rows = self._query(sql, *args, first_row_only=True) + def scalar(self, sql: str, *args: ValueForDB, **kwargs) -> ValueFromDB: + rows = self._query(sql, *args, first_row_only=True, **kwargs) if rows: return rows[0][0] else: @@ -88,3 +90,27 @@ class DBProxy: else: list_args = list(args) self._backend.db_execute_many(sql, list_args) + + +# convert kwargs to list format +def emulate_named_args( + sql: str, args: Tuple, kwargs: Dict[str, Any] +) -> Tuple[str, Sequence[ValueForDB]]: + # nothing to do? + if not kwargs: + return sql, args + print("named arguments in queries will go away in the future:", sql) + # map args to numbers + arg_num = {} + args2 = list(args) + for key, val in kwargs.items(): + args2.append(val) + n = len(args2) + arg_num[key] = n + # update refs + def repl(m): + arg = m.group(1) + return f"?{arg_num[arg]}" + + sql = re.sub(":([a-zA-Z_0-9]+)", repl, sql) + return sql, args2 diff --git a/pylib/tests/test_collection.py b/pylib/tests/test_collection.py index 8c0985426..ab867225a 100644 --- a/pylib/tests/test_collection.py +++ b/pylib/tests/test_collection.py @@ -4,6 +4,7 @@ import os import tempfile from anki import Collection as aopen +from anki.dbproxy import emulate_named_args from anki.lang import without_unicode_isolation from anki.rsbackend import TR from anki.stdmodels import addBasicModel, models @@ -161,3 +162,13 @@ def test_translate(): ) assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=1)) == "1 review" assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=2)) == "2 reviews" + + +def test_db_named_args(): + sql = "select a, 2+:test5 from b where arg =:foo and x = :test5" + args = [] + kwargs = dict(test5=5, foo="blah") + + s, a = emulate_named_args(sql, args, kwargs) + assert s == "select a, 2+?1 from b where arg =?2 and x = ?1" + assert a == [5, "blah"]