mirror of
https://github.com/ankitects/anki.git
synced 2025-09-21 15:32:23 -04:00
tweak db type hints
This commit is contained in:
parent
b5c6134d80
commit
77cf7dd4b7
3 changed files with 23 additions and 13 deletions
|
@ -5,7 +5,15 @@
|
||||||
# fixme: progress
|
# fixme: progress
|
||||||
|
|
||||||
from sqlite3 import dbapi2 as sqlite
|
from sqlite3 import dbapi2 as sqlite
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
# DBValue is actually Union[str, int, float, None], but if defined
|
||||||
|
# that way, every call site needs to do a type check prior to using
|
||||||
|
# the return values.
|
||||||
|
ValueFromDB = Any
|
||||||
|
Row = Sequence[ValueFromDB]
|
||||||
|
|
||||||
|
ValueForDB = Union[str, int, float, None]
|
||||||
|
|
||||||
|
|
||||||
class DBProxy:
|
class DBProxy:
|
||||||
|
@ -38,7 +46,9 @@ class DBProxy:
|
||||||
# Querying
|
# Querying
|
||||||
################
|
################
|
||||||
|
|
||||||
def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]:
|
def _query(
|
||||||
|
self, sql: str, *args: ValueForDB, first_row_only: bool = False
|
||||||
|
) -> List[Row]:
|
||||||
# mark modified?
|
# mark modified?
|
||||||
s = sql.strip().lower()
|
s = sql.strip().lower()
|
||||||
for stmt in "insert", "update", "delete":
|
for stmt in "insert", "update", "delete":
|
||||||
|
@ -59,20 +69,20 @@ class DBProxy:
|
||||||
# Query shortcuts
|
# Query shortcuts
|
||||||
###################
|
###################
|
||||||
|
|
||||||
def all(self, sql: str, *args) -> List:
|
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
|
||||||
return self._query(sql, *args)
|
return self._query(sql, *args)
|
||||||
|
|
||||||
def list(self, sql: str, *args) -> List:
|
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
|
||||||
return [x[0] for x in self._query(sql, *args)]
|
return [x[0] for x in self._query(sql, *args)]
|
||||||
|
|
||||||
def first(self, sql: str, *args) -> Optional[List]:
|
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
|
||||||
rows = self._query(sql, *args, first_row_only=True)
|
rows = self._query(sql, *args, first_row_only=True)
|
||||||
if rows:
|
if rows:
|
||||||
return rows[0]
|
return rows[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def scalar(self, sql: str, *args) -> Optional[Any]:
|
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
|
||||||
rows = self._query(sql, *args, first_row_only=True)
|
rows = self._query(sql, *args, first_row_only=True)
|
||||||
if rows:
|
if rows:
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
|
@ -86,7 +96,7 @@ class DBProxy:
|
||||||
# Updates
|
# Updates
|
||||||
################
|
################
|
||||||
|
|
||||||
def executemany(self, sql: str, args: Iterable) -> None:
|
def executemany(self, sql: str, args: Iterable[Iterable[ValueForDB]]) -> None:
|
||||||
self.mod = True
|
self.mod = True
|
||||||
self._db.executemany(sql, args)
|
self._db.executemany(sql, args)
|
||||||
|
|
||||||
|
|
|
@ -138,8 +138,8 @@ class Scheduler:
|
||||||
|
|
||||||
def dueForecast(self, days: int = 7) -> List[Any]:
|
def dueForecast(self, days: int = 7) -> List[Any]:
|
||||||
"Return counts over next DAYS. Includes today."
|
"Return counts over next DAYS. Includes today."
|
||||||
daysd = dict(
|
daysd: Dict[int, int] = dict(
|
||||||
self.col.db.all(
|
self.col.db.all( # type: ignore
|
||||||
f"""
|
f"""
|
||||||
select due, count() from cards
|
select due, count() from cards
|
||||||
where did in %s and queue = {QUEUE_TYPE_REV}
|
where did in %s and queue = {QUEUE_TYPE_REV}
|
||||||
|
@ -542,7 +542,7 @@ select count() from cards where did in %s and queue = {QUEUE_TYPE_PREVIEW}
|
||||||
if self._lrnQueue:
|
if self._lrnQueue:
|
||||||
return True
|
return True
|
||||||
cutoff = intTime() + self.col.conf["collapseTime"]
|
cutoff = intTime() + self.col.conf["collapseTime"]
|
||||||
self._lrnQueue = self.col.db.all(
|
self._lrnQueue = self.col.db.all( # type: ignore
|
||||||
f"""
|
f"""
|
||||||
select due, id from cards where
|
select due, id from cards where
|
||||||
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?
|
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?
|
||||||
|
|
|
@ -8,7 +8,7 @@ import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import anki
|
import anki
|
||||||
from anki.consts import *
|
from anki.consts import *
|
||||||
|
@ -31,7 +31,7 @@ class UnexpectedSchemaChange(Exception):
|
||||||
|
|
||||||
|
|
||||||
class Syncer:
|
class Syncer:
|
||||||
chunkRows: Optional[List[List]]
|
chunkRows: Optional[List[Sequence]]
|
||||||
|
|
||||||
def __init__(self, col: anki.storage._Collection, server=None) -> None:
|
def __init__(self, col: anki.storage._Collection, server=None) -> None:
|
||||||
self.col = col.weakref()
|
self.col = col.weakref()
|
||||||
|
@ -248,7 +248,7 @@ class Syncer:
|
||||||
self.tablesLeft = ["revlog", "cards", "notes"]
|
self.tablesLeft = ["revlog", "cards", "notes"]
|
||||||
self.chunkRows = None
|
self.chunkRows = None
|
||||||
|
|
||||||
def getChunkRows(self, table) -> List[List]:
|
def getChunkRows(self, table) -> List[Sequence]:
|
||||||
lim = self.usnLim()
|
lim = self.usnLim()
|
||||||
x = self.col.db.all
|
x = self.col.db.all
|
||||||
d = (self.maxUsn, lim)
|
d = (self.maxUsn, lim)
|
||||||
|
|
Loading…
Reference in a new issue