tweak db type hints

This commit is contained in:
Damien Elmes 2020-03-03 12:05:33 +10:00
parent b5c6134d80
commit 77cf7dd4b7
3 changed files with 23 additions and 13 deletions

View file

@ -5,7 +5,15 @@
# fixme: progress # fixme: progress
from sqlite3 import dbapi2 as sqlite from sqlite3 import dbapi2 as sqlite
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional, Sequence, Union
# DBValue is actually Union[str, int, float, None], but if defined
# that way, every call site needs to do a type check prior to using
# the return values.
ValueFromDB = Any
Row = Sequence[ValueFromDB]
ValueForDB = Union[str, int, float, None]
class DBProxy: class DBProxy:
@ -38,7 +46,9 @@ class DBProxy:
# Querying # Querying
################ ################
def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]: def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False
) -> List[Row]:
# mark modified? # mark modified?
s = sql.strip().lower() s = sql.strip().lower()
for stmt in "insert", "update", "delete": for stmt in "insert", "update", "delete":
@ -59,20 +69,20 @@ class DBProxy:
# Query shortcuts # Query shortcuts
################### ###################
def all(self, sql: str, *args) -> List: def all(self, sql: str, *args: ValueForDB) -> List[Row]:
return self._query(sql, *args) return self._query(sql, *args)
def list(self, sql: str, *args) -> List: def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args)] return [x[0] for x in self._query(sql, *args)]
def first(self, sql: str, *args) -> Optional[List]: def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True) rows = self._query(sql, *args, first_row_only=True)
if rows: if rows:
return rows[0] return rows[0]
else: else:
return None return None
def scalar(self, sql: str, *args) -> Optional[Any]: def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True) rows = self._query(sql, *args, first_row_only=True)
if rows: if rows:
return rows[0][0] return rows[0][0]
@ -86,7 +96,7 @@ class DBProxy:
# Updates # Updates
################ ################
def executemany(self, sql: str, args: Iterable) -> None: def executemany(self, sql: str, args: Iterable[Iterable[ValueForDB]]) -> None:
self.mod = True self.mod = True
self._db.executemany(sql, args) self._db.executemany(sql, args)

View file

@ -138,8 +138,8 @@ class Scheduler:
def dueForecast(self, days: int = 7) -> List[Any]: def dueForecast(self, days: int = 7) -> List[Any]:
"Return counts over next DAYS. Includes today." "Return counts over next DAYS. Includes today."
daysd = dict( daysd: Dict[int, int] = dict(
self.col.db.all( self.col.db.all( # type: ignore
f""" f"""
select due, count() from cards select due, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV} where did in %s and queue = {QUEUE_TYPE_REV}
@ -542,7 +542,7 @@ select count() from cards where did in %s and queue = {QUEUE_TYPE_PREVIEW}
if self._lrnQueue: if self._lrnQueue:
return True return True
cutoff = intTime() + self.col.conf["collapseTime"] cutoff = intTime() + self.col.conf["collapseTime"]
self._lrnQueue = self.col.db.all( self._lrnQueue = self.col.db.all( # type: ignore
f""" f"""
select due, id from cards where select due, id from cards where
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ? did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?

View file

@ -8,7 +8,7 @@ import io
import json import json
import os import os
import random import random
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import anki import anki
from anki.consts import * from anki.consts import *
@ -31,7 +31,7 @@ class UnexpectedSchemaChange(Exception):
class Syncer: class Syncer:
chunkRows: Optional[List[List]] chunkRows: Optional[List[Sequence]]
def __init__(self, col: anki.storage._Collection, server=None) -> None: def __init__(self, col: anki.storage._Collection, server=None) -> None:
self.col = col.weakref() self.col = col.weakref()
@ -248,7 +248,7 @@ class Syncer:
self.tablesLeft = ["revlog", "cards", "notes"] self.tablesLeft = ["revlog", "cards", "notes"]
self.chunkRows = None self.chunkRows = None
def getChunkRows(self, table) -> List[List]: def getChunkRows(self, table) -> List[Sequence]:
lim = self.usnLim() lim = self.usnLim()
x = self.col.db.all x = self.col.db.all
d = (self.maxUsn, lim) d = (self.maxUsn, lim)