mirror of
https://github.com/ankitects/anki.git
synced 2025-09-23 08:22:24 -04:00
emulate named args
This commit is contained in:
parent
62fa265213
commit
8efc09d4ef
2 changed files with 48 additions and 11 deletions
|
@ -3,7 +3,8 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Iterable, List, Optional, Sequence, Union
|
import re
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import anki
|
import anki
|
||||||
|
|
||||||
|
@ -41,34 +42,35 @@ class DBProxy:
|
||||||
################
|
################
|
||||||
|
|
||||||
def _query(
|
def _query(
|
||||||
self, sql: str, *args: ValueForDB, first_row_only: bool = False
|
self, sql: str, *args: ValueForDB, first_row_only: bool = False, **kwargs
|
||||||
) -> List[Row]:
|
) -> List[Row]:
|
||||||
# mark modified?
|
# mark modified?
|
||||||
s = sql.strip().lower()
|
s = sql.strip().lower()
|
||||||
for stmt in "insert", "update", "delete":
|
for stmt in "insert", "update", "delete":
|
||||||
if s.startswith(stmt):
|
if s.startswith(stmt):
|
||||||
self.mod = True
|
self.mod = True
|
||||||
|
sql, args2 = emulate_named_args(sql, args, kwargs)
|
||||||
# fetch rows
|
# fetch rows
|
||||||
return self._backend.db_query(sql, args, first_row_only)
|
return self._backend.db_query(sql, args2, first_row_only)
|
||||||
|
|
||||||
# Query shortcuts
|
# Query shortcuts
|
||||||
###################
|
###################
|
||||||
|
|
||||||
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
|
def all(self, sql: str, *args: ValueForDB, **kwargs) -> List[Row]:
|
||||||
return self._query(sql, *args)
|
return self._query(sql, *args, **kwargs)
|
||||||
|
|
||||||
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
|
def list(self, sql: str, *args: ValueForDB, **kwargs) -> List[ValueFromDB]:
|
||||||
return [x[0] for x in self._query(sql, *args)]
|
return [x[0] for x in self._query(sql, *args, **kwargs)]
|
||||||
|
|
||||||
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
|
def first(self, sql: str, *args: ValueForDB, **kwargs) -> Optional[Row]:
|
||||||
rows = self._query(sql, *args, first_row_only=True)
|
rows = self._query(sql, *args, first_row_only=True, **kwargs)
|
||||||
if rows:
|
if rows:
|
||||||
return rows[0]
|
return rows[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
|
def scalar(self, sql: str, *args: ValueForDB, **kwargs) -> ValueFromDB:
|
||||||
rows = self._query(sql, *args, first_row_only=True)
|
rows = self._query(sql, *args, first_row_only=True, **kwargs)
|
||||||
if rows:
|
if rows:
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
else:
|
else:
|
||||||
|
@ -88,3 +90,27 @@ class DBProxy:
|
||||||
else:
|
else:
|
||||||
list_args = list(args)
|
list_args = list(args)
|
||||||
self._backend.db_execute_many(sql, list_args)
|
self._backend.db_execute_many(sql, list_args)
|
||||||
|
|
||||||
|
|
||||||
|
# convert kwargs to list format
|
||||||
|
def emulate_named_args(
|
||||||
|
sql: str, args: Tuple, kwargs: Dict[str, Any]
|
||||||
|
) -> Tuple[str, Sequence[ValueForDB]]:
|
||||||
|
# nothing to do?
|
||||||
|
if not kwargs:
|
||||||
|
return sql, args
|
||||||
|
print("named arguments in queries will go away in the future:", sql)
|
||||||
|
# map args to numbers
|
||||||
|
arg_num = {}
|
||||||
|
args2 = list(args)
|
||||||
|
for key, val in kwargs.items():
|
||||||
|
args2.append(val)
|
||||||
|
n = len(args2)
|
||||||
|
arg_num[key] = n
|
||||||
|
# update refs
|
||||||
|
def repl(m):
|
||||||
|
arg = m.group(1)
|
||||||
|
return f"?{arg_num[arg]}"
|
||||||
|
|
||||||
|
sql = re.sub(":([a-zA-Z_0-9]+)", repl, sql)
|
||||||
|
return sql, args2
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from anki import Collection as aopen
|
from anki import Collection as aopen
|
||||||
|
from anki.dbproxy import emulate_named_args
|
||||||
from anki.lang import without_unicode_isolation
|
from anki.lang import without_unicode_isolation
|
||||||
from anki.rsbackend import TR
|
from anki.rsbackend import TR
|
||||||
from anki.stdmodels import addBasicModel, models
|
from anki.stdmodels import addBasicModel, models
|
||||||
|
@ -161,3 +162,13 @@ def test_translate():
|
||||||
)
|
)
|
||||||
assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=1)) == "1 review"
|
assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=1)) == "1 review"
|
||||||
assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=2)) == "2 reviews"
|
assert no_uni(d.tr(TR.STATISTICS_REVIEWS, reviews=2)) == "2 reviews"
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_named_args():
|
||||||
|
sql = "select a, 2+:test5 from b where arg =:foo and x = :test5"
|
||||||
|
args = []
|
||||||
|
kwargs = dict(test5=5, foo="blah")
|
||||||
|
|
||||||
|
s, a = emulate_named_args(sql, args, kwargs)
|
||||||
|
assert s == "select a, 2+?1 from b where arg =?2 and x = ?1"
|
||||||
|
assert a == [5, "blah"]
|
||||||
|
|
Loading…
Reference in a new issue