emulate named args

This commit is contained in:
Damien Elmes 2020-04-06 20:09:44 +10:00
parent 62fa265213
commit 8efc09d4ef
2 changed files with 48 additions and 11 deletions

View file

@ -3,7 +3,8 @@
from __future__ import annotations
from typing import Any, Iterable, List, Optional, Sequence, Union
import re
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import anki
@ -41,34 +42,35 @@ class DBProxy:
################
def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False
self, sql: str, *args: ValueForDB, first_row_only: bool = False, **kwargs
) -> List[Row]:
# mark modified?
s = sql.strip().lower()
for stmt in "insert", "update", "delete":
if s.startswith(stmt):
self.mod = True
sql, args2 = emulate_named_args(sql, args, kwargs)
# fetch rows
return self._backend.db_query(sql, args, first_row_only)
return self._backend.db_query(sql, args2, first_row_only)
# Query shortcuts
###################
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
return self._query(sql, *args)
def all(self, sql: str, *args: ValueForDB, **kwargs) -> List[Row]:
return self._query(sql, *args, **kwargs)
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args)]
def list(self, sql: str, *args: ValueForDB, **kwargs) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args, **kwargs)]
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True)
def first(self, sql: str, *args: ValueForDB, **kwargs) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True, **kwargs)
if rows:
return rows[0]
else:
return None
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True)
def scalar(self, sql: str, *args: ValueForDB, **kwargs) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True, **kwargs)
if rows:
return rows[0][0]
else:
@ -88,3 +90,27 @@ class DBProxy:
else:
list_args = list(args)
self._backend.db_execute_many(sql, list_args)
# convert kwargs to list format
def emulate_named_args(
sql: str, args: Tuple, kwargs: Dict[str, Any]
) -> Tuple[str, Sequence[ValueForDB]]:
# nothing to do?
if not kwargs:
return sql, args
print("named arguments in queries will go away in the future:", sql)
# map args to numbers
arg_num = {}
args2 = list(args)
for key, val in kwargs.items():
args2.append(val)
n = len(args2)
arg_num[key] = n
# update refs
def repl(m):
arg = m.group(1)
return f"?{arg_num[arg]}"
sql = re.sub(":([a-zA-Z_0-9]+)", repl, sql)
return sql, args2

View file

@ -4,6 +4,7 @@ import os
import tempfile
from anki import Collection as aopen
from anki.dbproxy import emulate_named_args
from anki.lang import without_unicode_isolation
from anki.rsbackend import TR
from anki.stdmodels import addBasicModel, models
@ -161,3 +162,13 @@ def test_translate():
)
assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=1)) == "1 review"
assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=2)) == "2 reviews"
def test_db_named_args():
sql = "select a, 2+:test5 from b where arg =:foo and x = :test5"
args = []
kwargs = dict(test5=5, foo="blah")
s, a = emulate_named_args(sql, args, kwargs)
assert s == "select a, 2+?1 from b where arg =?2 and x = ?1"
assert a == [5, "blah"]