initial work on DB handling in Rust

committing the Protobuf implementation for posterity, but will replace
it with json, as Protobuf measures about 6x slower for some workloads
like 'select * from notes'
This commit is contained in:
Damien Elmes 2020-03-03 15:04:03 +10:00
parent 77cf7dd4b7
commit 04ca8ec038
12 changed files with 422 additions and 7 deletions

View file

@ -13,6 +13,7 @@ message BackendInit {
repeated string preferred_langs = 4; repeated string preferred_langs = 4;
string locale_folder_path = 5; string locale_folder_path = 5;
string log_path = 6; string log_path = 6;
bool server = 7;
} }
message I18nBackendInit { message I18nBackendInit {
@ -44,6 +45,7 @@ message BackendInput {
CongratsLearnMsgIn congrats_learn_msg = 33; CongratsLearnMsgIn congrats_learn_msg = 33;
Empty empty_trash = 34; Empty empty_trash = 34;
Empty restore_trash = 35; Empty restore_trash = 35;
DBQueryIn db_query = 36;
} }
} }
@ -72,6 +74,7 @@ message BackendOutput {
Empty trash_media_files = 29; Empty trash_media_files = 29;
Empty empty_trash = 34; Empty empty_trash = 34;
Empty restore_trash = 35; Empty restore_trash = 35;
DBQueryOut db_query = 36;
BackendError error = 2047; BackendError error = 2047;
} }
@ -324,3 +327,26 @@ message CongratsLearnMsgIn {
float next_due = 1; float next_due = 1;
uint32 remaining = 2; uint32 remaining = 2;
} }
message DBQueryIn {
string sql = 1;
repeated SqlValue args = 2;
}
message DBQueryOut {
repeated SqlRow rows = 1;
}
message SqlValue {
oneof value {
Empty null = 1;
string string = 2;
int64 int = 3;
double double = 4;
bytes blob = 5;
}
}
message SqlRow {
repeated SqlValue values = 1;
}

View file

@ -5,13 +5,25 @@
import enum import enum
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, NewType, NoReturn, Optional, Tuple, Union from typing import (
Callable,
Dict,
Iterable,
List,
NewType,
NoReturn,
Optional,
Tuple,
Union,
Any)
import ankirspy # pytype: disable=import-error import ankirspy # pytype: disable=import-error
import anki.backend_pb2 as pb import anki.backend_pb2 as pb
import anki.buildinfo import anki.buildinfo
from anki import hooks from anki import hooks
from anki.dbproxy import Row as DBRow
from anki.dbproxy import ValueForDB
from anki.fluent_pb2 import FluentString as TR from anki.fluent_pb2 import FluentString as TR
from anki.models import AllTemplateReqs from anki.models import AllTemplateReqs
from anki.sound import AVTag, SoundOrVideoTag, TTSTag from anki.sound import AVTag, SoundOrVideoTag, TTSTag
@ -186,7 +198,12 @@ def _on_progress(progress_bytes: bytes) -> bool:
class RustBackend: class RustBackend:
def __init__( def __init__(
self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str self,
col_path: str,
media_folder_path: str,
media_db_path: str,
log_path: str,
server: bool,
) -> None: ) -> None:
ftl_folder = os.path.join(anki.lang.locale_folder, "fluent") ftl_folder = os.path.join(anki.lang.locale_folder, "fluent")
init_msg = pb.BackendInit( init_msg = pb.BackendInit(
@ -196,6 +213,7 @@ class RustBackend:
locale_folder_path=ftl_folder, locale_folder_path=ftl_folder,
preferred_langs=[anki.lang.currentLang], preferred_langs=[anki.lang.currentLang],
log_path=log_path, log_path=log_path,
server=server,
) )
self._backend = ankirspy.open_backend(init_msg.SerializeToString()) self._backend = ankirspy.open_backend(init_msg.SerializeToString())
self._backend.set_progress_callback(_on_progress) self._backend.set_progress_callback(_on_progress)
@ -366,6 +384,42 @@ class RustBackend:
def restore_trash(self): def restore_trash(self):
self._run_command(pb.BackendInput(restore_trash=pb.Empty())) self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
def db_query(self, sql: str, args: Iterable[ValueForDB]) -> Iterable[DBRow]:
def arg_to_proto(arg: ValueForDB) -> pb.SqlValue:
if isinstance(arg, int):
return pb.SqlValue(int=arg)
elif isinstance(arg, float):
return pb.SqlValue(double=arg)
elif isinstance(arg, str):
return pb.SqlValue(string=arg)
elif arg is None:
return pb.SqlValue(null=pb.Empty())
else:
raise Exception("unexpected DB type")
output = self._run_command(
pb.BackendInput(
db_query=pb.DBQueryIn(sql=sql, args=map(arg_to_proto, args))
)
).db_query
def sqlvalue_to_native(arg: pb.SqlValue) -> Any:
v = arg.WhichOneof("value")
if v == "int":
return arg.int
elif v == "double":
return arg.double
elif v == "string":
return arg.string
elif v == "null":
return None
else:
assert_impossible_literal(v)
def sqlrow_to_tuple(arg: pb.SqlRow) -> Tuple:
return tuple(map(sqlvalue_to_native, arg.values))
return map(sqlrow_to_tuple, output.rows)
def translate_string_in( def translate_string_in(
key: TR, **kwargs: Union[str, int, float] key: TR, **kwargs: Union[str, int, float]

View file

@ -35,7 +35,6 @@ def Collection(
log_path = "" log_path = ""
if not server: if not server:
log_path = path.replace(".anki2", "2.log") log_path = path.replace(".anki2", "2.log")
backend = RustBackend(path, media_dir, media_db, log_path)
path = os.path.abspath(path) path = os.path.abspath(path)
create = not os.path.exists(path) create = not os.path.exists(path)
if create: if create:
@ -43,7 +42,10 @@ def Collection(
for c in ("/", ":", "\\"): for c in ("/", ":", "\\"):
assert c not in base assert c not in base
# connect # connect
db = DBProxy(path) backend = RustBackend(
path, media_dir, media_db, log_path, server=server is not None
)
db = DBProxy(backend, path)
db.setAutocommit(True) db.setAutocommit(True)
if create: if create:
ver = _createDB(db) ver = _createDB(db)

View file

@ -0,0 +1,84 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::backend_proto as pb;
use crate::err::Result;
use crate::storage::SqliteStorage;
use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
use serde_derive::{Deserialize, Serialize};
//
// #[derive(Deserialize)]
// struct DBRequest {
// sql: String,
// args: Vec<SqlValue>,
// }
//
// #[derive(Serialize)]
// struct DBResult {
// rows: Vec<Vec<SqlValue>>,
// }
//
// #[derive(Serialize, Deserialize, Debug)]
// #[serde(untagged)]
// enum SqlValue {
// Null,
// String(String),
// Int(i64),
// Float(f64),
// Blob(Vec<u8>),
// }
// protobuf implementation
impl ToSql for pb::SqlValue {
fn to_sql(&self) -> std::result::Result<ToSqlOutput<'_>, rusqlite::Error> {
use pb::sql_value::Value as SqlValue;
let val = match self
.value
.as_ref()
.unwrap_or_else(|| &SqlValue::Null(pb::Empty {}))
{
SqlValue::Null(_) => ValueRef::Null,
SqlValue::String(v) => ValueRef::Text(v.as_bytes()),
SqlValue::Int(v) => ValueRef::Integer(*v),
SqlValue::Double(v) => ValueRef::Real(*v),
SqlValue::Blob(v) => ValueRef::Blob(&v),
};
Ok(ToSqlOutput::Borrowed(val))
}
}
impl FromSql for pb::SqlValue {
fn column_result(value: ValueRef<'_>) -> std::result::Result<Self, FromSqlError> {
use pb::sql_value::Value as SqlValue;
let val = match value {
ValueRef::Null => SqlValue::Null(pb::Empty {}),
ValueRef::Integer(i) => SqlValue::Int(i),
ValueRef::Real(v) => SqlValue::Double(v),
ValueRef::Text(v) => SqlValue::String(String::from_utf8_lossy(v).to_string()),
ValueRef::Blob(v) => SqlValue::Blob(v.to_vec()),
};
Ok(pb::SqlValue { value: Some(val) })
}
}
pub(super) fn db_query_proto(db: &SqliteStorage, input: pb::DbQueryIn) -> Result<pb::DbQueryOut> {
let mut stmt = db.db.prepare_cached(&input.sql)?;
let columns = stmt.column_count();
let mut rows = stmt.query(&input.args)?;
let mut output_rows = vec![];
while let Some(row) = rows.next()? {
let mut orow = Vec::with_capacity(columns);
for i in 0..columns {
let v: pb::SqlValue = row.get(i)?;
orow.push(v);
}
output_rows.push(pb::SqlRow { values: orow });
}
Ok(pb::DbQueryOut { rows: output_rows })
}

View file

@ -1,6 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::backend::dbproxy::db_query_proto;
use crate::backend_proto::backend_input::Value; use crate::backend_proto::backend_input::Value;
use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn}; use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn};
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind}; use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
@ -12,6 +13,7 @@ use crate::media::sync::MediaSyncProgress;
use crate::media::MediaManager; use crate::media::MediaManager;
use crate::sched::cutoff::{local_minutes_west_for_stamp, sched_timing_today}; use crate::sched::cutoff::{local_minutes_west_for_stamp, sched_timing_today};
use crate::sched::timespan::{answer_button_time, learning_congrats, studied_today, time_span}; use crate::sched::timespan::{answer_button_time, learning_congrats, studied_today, time_span};
use crate::storage::SqliteStorage;
use crate::template::{ use crate::template::{
render_card, without_legacy_template_directives, FieldMap, FieldRequirements, ParsedTemplate, render_card, without_legacy_template_directives, FieldMap, FieldRequirements, ParsedTemplate,
RenderedNode, RenderedNode,
@ -24,9 +26,12 @@ use std::collections::{HashMap, HashSet};
use std::path::PathBuf; use std::path::PathBuf;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
mod dbproxy;
pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>; pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>;
pub struct Backend { pub struct Backend {
col: SqliteStorage,
#[allow(dead_code)] #[allow(dead_code)]
col_path: PathBuf, col_path: PathBuf,
media_folder: PathBuf, media_folder: PathBuf,
@ -119,7 +124,11 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
log::terminal(), log::terminal(),
); );
let col = SqliteStorage::open_or_create(Path::new(&input.collection_path), input.server)
.map_err(|e| format!("Unable to open collection: {:?}", e))?;
match Backend::new( match Backend::new(
col,
&input.collection_path, &input.collection_path,
&input.media_folder_path, &input.media_folder_path,
&input.media_db_path, &input.media_db_path,
@ -133,6 +142,7 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
impl Backend { impl Backend {
pub fn new( pub fn new(
col: SqliteStorage,
col_path: &str, col_path: &str,
media_folder: &str, media_folder: &str,
media_db: &str, media_db: &str,
@ -140,6 +150,7 @@ impl Backend {
log: Logger, log: Logger,
) -> Result<Backend> { ) -> Result<Backend> {
Ok(Backend { Ok(Backend {
col,
col_path: col_path.into(), col_path: col_path.into(),
media_folder: media_folder.into(), media_folder: media_folder.into(),
media_db: media_db.into(), media_db: media_db.into(),
@ -241,6 +252,7 @@ impl Backend {
self.restore_trash()?; self.restore_trash()?;
OValue::RestoreTrash(Empty {}) OValue::RestoreTrash(Empty {})
} }
Value::DbQuery(input) => OValue::DbQuery(self.db_query(input)?),
}) })
} }
@ -481,6 +493,10 @@ impl Backend {
checker.restore_trash() checker.restore_trash()
} }
fn db_query(&self, input: pb::DbQueryIn) -> Result<pb::DbQueryOut> {
db_query_proto(&self.col, input)
}
} }
fn translate_arg_to_fluent_val(arg: &pb::TranslateArgValue) -> FluentValue { fn translate_arg_to_fluent_val(arg: &pb::TranslateArgValue) -> FluentValue {

View file

@ -20,7 +20,7 @@ pub enum AnkiError {
IOError { info: String }, IOError { info: String },
#[fail(display = "DB error: {}", info)] #[fail(display = "DB error: {}", info)]
DBError { info: String }, DBError { info: String, kind: DBErrorKind },
#[fail(display = "Network error: {:?} {}", kind, info)] #[fail(display = "Network error: {:?} {}", kind, info)]
NetworkError { NetworkError {
@ -112,6 +112,7 @@ impl From<rusqlite::Error> for AnkiError {
fn from(err: rusqlite::Error) -> Self { fn from(err: rusqlite::Error) -> Self {
AnkiError::DBError { AnkiError::DBError {
info: format!("{:?}", err), info: format!("{:?}", err),
kind: DBErrorKind::Other,
} }
} }
} }
@ -120,6 +121,7 @@ impl From<rusqlite::types::FromSqlError> for AnkiError {
fn from(err: rusqlite::types::FromSqlError) -> Self { fn from(err: rusqlite::types::FromSqlError) -> Self {
AnkiError::DBError { AnkiError::DBError {
info: format!("{:?}", err), info: format!("{:?}", err),
kind: DBErrorKind::Other,
} }
} }
} }
@ -215,3 +217,11 @@ impl From<serde_json::Error> for AnkiError {
AnkiError::sync_misc(err.to_string()) AnkiError::sync_misc(err.to_string())
} }
} }
#[derive(Debug, PartialEq)]
pub enum DBErrorKind {
FileTooNew,
FileTooOld,
MissingEntity,
Other,
}

View file

@ -17,6 +17,7 @@ pub mod latex;
pub mod log; pub mod log;
pub mod media; pub mod media;
pub mod sched; pub mod sched;
pub mod storage;
pub mod template; pub mod template;
pub mod template_filters; pub mod template_filters;
pub mod text; pub mod text;

View file

@ -1,7 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::{AnkiError, Result}; use crate::err::{AnkiError, DBErrorKind, Result};
use crate::i18n::{tr_args, tr_strs, FString, I18n}; use crate::i18n::{tr_args, tr_strs, FString, I18n};
use crate::latex::extract_latex_expanding_clozes; use crate::latex::extract_latex_expanding_clozes;
use crate::log::{debug, Logger}; use crate::log::{debug, Logger};
@ -403,6 +403,7 @@ where
.get(&note.mid) .get(&note.mid)
.ok_or_else(|| AnkiError::DBError { .ok_or_else(|| AnkiError::DBError {
info: "missing note type".to_string(), info: "missing note type".to_string(),
kind: DBErrorKind::MissingEntity,
})?; })?;
if fix_and_extract_media_refs(note, &mut referenced_files, renamed)? { if fix_and_extract_media_refs(note, &mut referenced_files, renamed)? {
// note was modified, needs saving // note was modified, needs saving

View file

@ -2,7 +2,7 @@
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
/// Basic note reading/updating functionality for the media DB check. /// Basic note reading/updating functionality for the media DB check.
use crate::err::{AnkiError, Result}; use crate::err::{AnkiError, DBErrorKind, Result};
use crate::text::strip_html_preserving_image_filenames; use crate::text::strip_html_preserving_image_filenames;
use crate::time::{i64_unix_millis, i64_unix_secs}; use crate::time::{i64_unix_millis, i64_unix_secs};
use crate::types::{ObjID, Timestamp, Usn}; use crate::types::{ObjID, Timestamp, Usn};
@ -85,6 +85,7 @@ pub(super) fn get_note_types(db: &Connection) -> Result<HashMap<ObjID, NoteType>
.next() .next()
.ok_or_else(|| AnkiError::DBError { .ok_or_else(|| AnkiError::DBError {
info: "col table empty".to_string(), info: "col table empty".to_string(),
kind: DBErrorKind::MissingEntity,
})??; })??;
Ok(note_types) Ok(note_types)
} }
@ -136,6 +137,7 @@ pub(super) fn set_note(db: &Connection, note: &mut Note, note_type: &NoteType) -
.get(note_type.sort_field_idx as usize) .get(note_type.sort_field_idx as usize)
.ok_or_else(|| AnkiError::DBError { .ok_or_else(|| AnkiError::DBError {
info: "sort field out of range".to_string(), info: "sort field out of range".to_string(),
kind: DBErrorKind::MissingEntity,
})?, })?,
); );

3
rslib/src/storage/mod.rs Normal file
View file

@ -0,0 +1,3 @@
mod sqlite;
pub(crate) use sqlite::SqliteStorage;

View file

@ -0,0 +1,88 @@
create table col
(
id integer primary key,
crt integer not null,
mod integer not null,
scm integer not null,
ver integer not null,
dty integer not null,
usn integer not null,
ls integer not null,
conf text not null,
models text not null,
decks text not null,
dconf text not null,
tags text not null
);
create table notes
(
id integer primary key,
guid text not null,
mid integer not null,
mod integer not null,
usn integer not null,
tags text not null,
flds text not null,
sfld integer not null,
csum integer not null,
flags integer not null,
data text not null
);
create table cards
(
id integer primary key,
nid integer not null,
did integer not null,
ord integer not null,
mod integer not null,
usn integer not null,
type integer not null,
queue integer not null,
due integer not null,
ivl integer not null,
factor integer not null,
reps integer not null,
lapses integer not null,
left integer not null,
odue integer not null,
odid integer not null,
flags integer not null,
data text not null
);
create table revlog
(
id integer primary key,
cid integer not null,
usn integer not null,
ease integer not null,
ivl integer not null,
lastIvl integer not null,
factor integer not null,
time integer not null,
type integer not null
);
create table graves
(
usn integer not null,
oid integer not null,
type integer not null
);
-- syncing
create index ix_notes_usn on notes (usn);
create index ix_cards_usn on cards (usn);
create index ix_revlog_usn on revlog (usn);
-- card spacing, etc
create index ix_cards_nid on cards (nid);
-- scheduling and deck limiting
create index ix_cards_sched on cards (did, queue, due);
-- revlog by card
create index ix_revlog_cid on revlog (cid);
-- field uniqueness
create index ix_notes_csum on notes (csum);
insert into col values (1,0,0,0,0,0,0,0,'{}','{}','{}','{}','{}');

128
rslib/src/storage/sqlite.rs Normal file
View file

@ -0,0 +1,128 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::err::Result;
use crate::err::{AnkiError, DBErrorKind};
use crate::time::i64_unix_timestamp;
use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef};
use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS};
use serde::de::DeserializeOwned;
use serde_derive::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::fmt;
use std::path::{Path, PathBuf};
const SCHEMA_MIN_VERSION: u8 = 11;
const SCHEMA_MAX_VERSION: u8 = 11;
macro_rules! cached_sql {
( $label:expr, $db:expr, $sql:expr ) => {{
if $label.is_none() {
$label = Some($db.prepare_cached($sql)?);
}
$label.as_mut().unwrap()
}};
}
#[derive(Debug)]
pub struct SqliteStorage {
// currently crate-visible for dbproxy
pub(crate) db: Connection,
path: PathBuf,
server: bool,
}
fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
let mut db = Connection::open(path)?;
if std::env::var("TRACESQL").is_ok() {
db.trace(Some(trace));
}
db.pragma_update(None, "locking_mode", &"exclusive")?;
db.pragma_update(None, "page_size", &4096)?;
db.pragma_update(None, "cache_size", &(-40 * 1024))?;
db.pragma_update(None, "legacy_file_format", &false)?;
db.pragma_update(None, "journal", &"wal")?;
db.set_prepared_statement_cache_capacity(50);
Ok(db)
}
/// Fetch schema version from database.
/// Return (must_create, version)
fn schema_version(db: &Connection) -> Result<(bool, u8)> {
if !db
.prepare("select null from sqlite_master where type = 'table' and name = 'col'")?
.exists(NO_PARAMS)?
{
return Ok((true, SCHEMA_MAX_VERSION));
}
Ok((
false,
db.query_row("select ver from col", NO_PARAMS, |r| Ok(r.get(0)?))?,
))
}
fn trace(s: &str) {
println!("sql: {}", s)
}
impl SqliteStorage {
pub(crate) fn open_or_create(path: &Path, server: bool) -> Result<Self> {
let db = open_or_create_collection_db(path)?;
let (create, ver) = schema_version(&db)?;
if create {
unimplemented!(); // todo
db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?;
db.execute_batch(include_str!("schema11.sql"))?;
db.execute(
"update col set crt=?, ver=?",
params![i64_unix_timestamp(), ver],
)?;
db.prepare_cached("commit")?.execute(NO_PARAMS)?;
} else {
if ver > SCHEMA_MAX_VERSION {
return Err(AnkiError::DBError {
info: "".to_string(),
kind: DBErrorKind::FileTooNew,
});
}
if ver < SCHEMA_MIN_VERSION {
return Err(AnkiError::DBError {
info: "".to_string(),
kind: DBErrorKind::FileTooOld,
});
}
};
let storage = Self {
db,
path: path.to_owned(),
server,
};
Ok(storage)
}
pub(crate) fn begin(&self) -> Result<()> {
self.db
.prepare_cached("begin exclusive")?
.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn commit(&self) -> Result<()> {
self.db.prepare_cached("commit")?.execute(NO_PARAMS)?;
Ok(())
}
pub(crate) fn rollback(&self) -> Result<()> {
self.db.execute("rollback", NO_PARAMS)?;
Ok(())
}
}