diff --git a/proto/backend.proto b/proto/backend.proto index b5a35d0d1..67197e9bc 100644 --- a/proto/backend.proto +++ b/proto/backend.proto @@ -13,6 +13,7 @@ message BackendInit { repeated string preferred_langs = 4; string locale_folder_path = 5; string log_path = 6; + bool server = 7; } message I18nBackendInit { @@ -44,6 +45,7 @@ message BackendInput { CongratsLearnMsgIn congrats_learn_msg = 33; Empty empty_trash = 34; Empty restore_trash = 35; + DBQueryIn db_query = 36; } } @@ -72,6 +74,7 @@ message BackendOutput { Empty trash_media_files = 29; Empty empty_trash = 34; Empty restore_trash = 35; + DBQueryOut db_query = 36; BackendError error = 2047; } @@ -324,3 +327,26 @@ message CongratsLearnMsgIn { float next_due = 1; uint32 remaining = 2; } + +message DBQueryIn { + string sql = 1; + repeated SqlValue args = 2; +} + +message DBQueryOut { + repeated SqlRow rows = 1; +} + +message SqlValue { + oneof value { + Empty null = 1; + string string = 2; + int64 int = 3; + double double = 4; + bytes blob = 5; + } +} + +message SqlRow { + repeated SqlValue values = 1; +} diff --git a/pylib/anki/rsbackend.py b/pylib/anki/rsbackend.py index 9bef66dfd..18b405031 100644 --- a/pylib/anki/rsbackend.py +++ b/pylib/anki/rsbackend.py @@ -5,13 +5,25 @@ import enum import os from dataclasses import dataclass -from typing import Callable, Dict, List, NewType, NoReturn, Optional, Tuple, Union +from typing import ( + Callable, + Dict, + Iterable, + List, + NewType, + NoReturn, + Optional, + Tuple, + Union, + Any) import ankirspy # pytype: disable=import-error import anki.backend_pb2 as pb import anki.buildinfo from anki import hooks +from anki.dbproxy import Row as DBRow +from anki.dbproxy import ValueForDB from anki.fluent_pb2 import FluentString as TR from anki.models import AllTemplateReqs from anki.sound import AVTag, SoundOrVideoTag, TTSTag @@ -186,7 +198,12 @@ def _on_progress(progress_bytes: bytes) -> bool: class RustBackend: def __init__( - self, col_path: str, media_folder_path: str, media_db_path: str, log_path: str + self, + col_path: str, + media_folder_path: str, + media_db_path: str, + log_path: str, + server: bool, ) -> None: ftl_folder = os.path.join(anki.lang.locale_folder, "fluent") init_msg = pb.BackendInit( @@ -196,6 +213,7 @@ class RustBackend: locale_folder_path=ftl_folder, preferred_langs=[anki.lang.currentLang], log_path=log_path, + server=server, ) self._backend = ankirspy.open_backend(init_msg.SerializeToString()) self._backend.set_progress_callback(_on_progress) @@ -366,6 +384,42 @@ class RustBackend: def restore_trash(self): self._run_command(pb.BackendInput(restore_trash=pb.Empty())) + def db_query(self, sql: str, args: Iterable[ValueForDB]) -> Iterable[DBRow]: + def arg_to_proto(arg: ValueForDB) -> pb.SqlValue: + if isinstance(arg, int): + return pb.SqlValue(int=arg) + elif isinstance(arg, float): + return pb.SqlValue(double=arg) + elif isinstance(arg, str): + return pb.SqlValue(string=arg) + elif arg is None: + return pb.SqlValue(null=pb.Empty()) + else: + raise Exception("unexpected DB type") + + output = self._run_command( + pb.BackendInput( + db_query=pb.DBQueryIn(sql=sql, args=map(arg_to_proto, args)) + ) + ).db_query + + def sqlvalue_to_native(arg: pb.SqlValue) -> Any: + v = arg.WhichOneof("value") + if v == "int": + return arg.int + elif v == "double": + return arg.double + elif v == "string": + return arg.string + elif v == "null": + return None + else: + assert_impossible_literal(v) + + def sqlrow_to_tuple(arg: pb.SqlRow) -> Tuple: + return tuple(map(sqlvalue_to_native, arg.values)) + + return map(sqlrow_to_tuple, output.rows) def translate_string_in( key: TR, **kwargs: Union[str, int, float] diff --git a/pylib/anki/storage.py b/pylib/anki/storage.py index 01a71ad3e..c81dac715 100644 --- a/pylib/anki/storage.py +++ b/pylib/anki/storage.py @@ -35,7 +35,6 @@ def Collection( log_path = "" if not server: log_path = path.replace(".anki2", "2.log") - backend = RustBackend(path, media_dir, media_db, log_path) path = os.path.abspath(path) create = not os.path.exists(path) if create: @@ -43,7 +42,10 @@ def Collection( for c in ("/", ":", "\\"): assert c not in base # connect - db = DBProxy(path) + backend = RustBackend( + path, media_dir, media_db, log_path, server=server is not None + ) + db = DBProxy(backend, path) db.setAutocommit(True) if create: ver = _createDB(db) diff --git a/rslib/src/backend/dbproxy.rs b/rslib/src/backend/dbproxy.rs new file mode 100644 index 000000000..3d8cca005 --- /dev/null +++ b/rslib/src/backend/dbproxy.rs @@ -0,0 +1,84 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use crate::backend_proto as pb; +use crate::err::Result; +use crate::storage::SqliteStorage; +use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; +use serde_derive::{Deserialize, Serialize}; +// +// #[derive(Deserialize)] +// struct DBRequest { +// sql: String, +// args: Vec, +// } +// +// #[derive(Serialize)] +// struct DBResult { +// rows: Vec>, +// } +// +// #[derive(Serialize, Deserialize, Debug)] +// #[serde(untagged)] +// enum SqlValue { +// Null, +// String(String), +// Int(i64), +// Float(f64), +// Blob(Vec), +// } +// protobuf implementation + +impl ToSql for pb::SqlValue { + fn to_sql(&self) -> std::result::Result, rusqlite::Error> { + use pb::sql_value::Value as SqlValue; + let val = match self + .value + .as_ref() + .unwrap_or_else(|| &SqlValue::Null(pb::Empty {})) + { + SqlValue::Null(_) => ValueRef::Null, + SqlValue::String(v) => ValueRef::Text(v.as_bytes()), + SqlValue::Int(v) => ValueRef::Integer(*v), + SqlValue::Double(v) => ValueRef::Real(*v), + SqlValue::Blob(v) => ValueRef::Blob(&v), + }; + Ok(ToSqlOutput::Borrowed(val)) + } +} + +impl FromSql for pb::SqlValue { + fn column_result(value: ValueRef<'_>) -> std::result::Result { + use pb::sql_value::Value as SqlValue; + let val = match value { + ValueRef::Null => SqlValue::Null(pb::Empty {}), + ValueRef::Integer(i) => SqlValue::Int(i), + ValueRef::Real(v) => SqlValue::Double(v), + ValueRef::Text(v) => SqlValue::String(String::from_utf8_lossy(v).to_string()), + ValueRef::Blob(v) => SqlValue::Blob(v.to_vec()), + }; + Ok(pb::SqlValue { value: Some(val) }) + } +} + +pub(super) fn db_query_proto(db: &SqliteStorage, input: pb::DbQueryIn) -> Result { + let mut stmt = db.db.prepare_cached(&input.sql)?; + + let columns = stmt.column_count(); + + let mut rows = stmt.query(&input.args)?; + + let mut output_rows = vec![]; + + while let Some(row) = rows.next()? { + let mut orow = Vec::with_capacity(columns); + for i in 0..columns { + let v: pb::SqlValue = row.get(i)?; + orow.push(v); + } + + output_rows.push(pb::SqlRow { values: orow }); + } + + Ok(pb::DbQueryOut { rows: output_rows }) +} diff --git a/rslib/src/backend.rs b/rslib/src/backend/mod.rs similarity index 97% rename from rslib/src/backend.rs rename to rslib/src/backend/mod.rs index 970768fb5..6e88ed2a1 100644 --- a/rslib/src/backend.rs +++ b/rslib/src/backend/mod.rs @@ -1,6 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use crate::backend::dbproxy::db_query_proto; use crate::backend_proto::backend_input::Value; use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn}; use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind}; @@ -12,6 +13,7 @@ use crate::media::sync::MediaSyncProgress; use crate::media::MediaManager; use crate::sched::cutoff::{local_minutes_west_for_stamp, sched_timing_today}; use crate::sched::timespan::{answer_button_time, learning_congrats, studied_today, time_span}; +use crate::storage::SqliteStorage; use crate::template::{ render_card, without_legacy_template_directives, FieldMap, FieldRequirements, ParsedTemplate, RenderedNode, @@ -24,9 +26,12 @@ use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use tokio::runtime::Runtime; +mod dbproxy; + pub type ProtoProgressCallback = Box) -> bool + Send>; pub struct Backend { + col: SqliteStorage, #[allow(dead_code)] col_path: PathBuf, media_folder: PathBuf, @@ -119,7 +124,11 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result { log::terminal(), ); + let col = SqliteStorage::open_or_create(Path::new(&input.collection_path), input.server) + .map_err(|e| format!("Unable to open collection: {:?}", e))?; + match Backend::new( + col, &input.collection_path, &input.media_folder_path, &input.media_db_path, @@ -133,6 +142,7 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result { impl Backend { pub fn new( + col: SqliteStorage, col_path: &str, media_folder: &str, media_db: &str, @@ -140,6 +150,7 @@ impl Backend { log: Logger, ) -> Result { Ok(Backend { + col, col_path: col_path.into(), media_folder: media_folder.into(), media_db: media_db.into(), @@ -241,6 +252,7 @@ impl Backend { self.restore_trash()?; OValue::RestoreTrash(Empty {}) } + Value::DbQuery(input) => OValue::DbQuery(self.db_query(input)?), }) } @@ -481,6 +493,10 @@ impl Backend { checker.restore_trash() } + + fn db_query(&self, input: pb::DbQueryIn) -> Result { + db_query_proto(&self.col, input) + } } fn translate_arg_to_fluent_val(arg: &pb::TranslateArgValue) -> FluentValue { diff --git a/rslib/src/err.rs b/rslib/src/err.rs index 1d2bd2faf..66e4bc66c 100644 --- a/rslib/src/err.rs +++ b/rslib/src/err.rs @@ -20,7 +20,7 @@ pub enum AnkiError { IOError { info: String }, #[fail(display = "DB error: {}", info)] - DBError { info: String }, + DBError { info: String, kind: DBErrorKind }, #[fail(display = "Network error: {:?} {}", kind, info)] NetworkError { @@ -112,6 +112,7 @@ impl From for AnkiError { fn from(err: rusqlite::Error) -> Self { AnkiError::DBError { info: format!("{:?}", err), + kind: DBErrorKind::Other, } } } @@ -120,6 +121,7 @@ impl From for AnkiError { fn from(err: rusqlite::types::FromSqlError) -> Self { AnkiError::DBError { info: format!("{:?}", err), + kind: DBErrorKind::Other, } } } @@ -215,3 +217,11 @@ impl From for AnkiError { AnkiError::sync_misc(err.to_string()) } } + +#[derive(Debug, PartialEq)] +pub enum DBErrorKind { + FileTooNew, + FileTooOld, + MissingEntity, + Other, +} diff --git a/rslib/src/lib.rs b/rslib/src/lib.rs index cce95c5b0..82ce394fb 100644 --- a/rslib/src/lib.rs +++ b/rslib/src/lib.rs @@ -17,6 +17,7 @@ pub mod latex; pub mod log; pub mod media; pub mod sched; +pub mod storage; pub mod template; pub mod template_filters; pub mod text; diff --git a/rslib/src/media/check.rs b/rslib/src/media/check.rs index 2049ed0e2..a6f1fa76a 100644 --- a/rslib/src/media/check.rs +++ b/rslib/src/media/check.rs @@ -1,7 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use crate::err::{AnkiError, Result}; +use crate::err::{AnkiError, DBErrorKind, Result}; use crate::i18n::{tr_args, tr_strs, FString, I18n}; use crate::latex::extract_latex_expanding_clozes; use crate::log::{debug, Logger}; @@ -403,6 +403,7 @@ where .get(¬e.mid) .ok_or_else(|| AnkiError::DBError { info: "missing note type".to_string(), + kind: DBErrorKind::MissingEntity, })?; if fix_and_extract_media_refs(note, &mut referenced_files, renamed)? { // note was modified, needs saving diff --git a/rslib/src/media/col.rs b/rslib/src/media/col.rs index a563bb93e..460d64b5f 100644 --- a/rslib/src/media/col.rs +++ b/rslib/src/media/col.rs @@ -2,7 +2,7 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html /// Basic note reading/updating functionality for the media DB check. -use crate::err::{AnkiError, Result}; +use crate::err::{AnkiError, DBErrorKind, Result}; use crate::text::strip_html_preserving_image_filenames; use crate::time::{i64_unix_millis, i64_unix_secs}; use crate::types::{ObjID, Timestamp, Usn}; @@ -85,6 +85,7 @@ pub(super) fn get_note_types(db: &Connection) -> Result .next() .ok_or_else(|| AnkiError::DBError { info: "col table empty".to_string(), + kind: DBErrorKind::MissingEntity, })??; Ok(note_types) } @@ -136,6 +137,7 @@ pub(super) fn set_note(db: &Connection, note: &mut Note, note_type: &NoteType) - .get(note_type.sort_field_idx as usize) .ok_or_else(|| AnkiError::DBError { info: "sort field out of range".to_string(), + kind: DBErrorKind::MissingEntity, })?, ); diff --git a/rslib/src/storage/mod.rs b/rslib/src/storage/mod.rs new file mode 100644 index 000000000..2b474f145 --- /dev/null +++ b/rslib/src/storage/mod.rs @@ -0,0 +1,3 @@ +mod sqlite; + +pub(crate) use sqlite::SqliteStorage; diff --git a/rslib/src/storage/schema11.sql b/rslib/src/storage/schema11.sql new file mode 100644 index 000000000..e5d1b7f4d --- /dev/null +++ b/rslib/src/storage/schema11.sql @@ -0,0 +1,88 @@ +create table col +( + id integer primary key, + crt integer not null, + mod integer not null, + scm integer not null, + ver integer not null, + dty integer not null, + usn integer not null, + ls integer not null, + conf text not null, + models text not null, + decks text not null, + dconf text not null, + tags text not null +); + +create table notes +( + id integer primary key, + guid text not null, + mid integer not null, + mod integer not null, + usn integer not null, + tags text not null, + flds text not null, + sfld integer not null, + csum integer not null, + flags integer not null, + data text not null +); + +create table cards +( + id integer primary key, + nid integer not null, + did integer not null, + ord integer not null, + mod integer not null, + usn integer not null, + type integer not null, + queue integer not null, + due integer not null, + ivl integer not null, + factor integer not null, + reps integer not null, + lapses integer not null, + left integer not null, + odue integer not null, + odid integer not null, + flags integer not null, + data text not null +); + +create table revlog +( + id integer primary key, + cid integer not null, + usn integer not null, + ease integer not null, + ivl integer not null, + lastIvl integer not null, + factor integer not null, + time integer not null, + type integer not null +); + +create table graves +( + usn integer not null, + oid integer not null, + type integer not null +); + +-- syncing +create index ix_notes_usn on notes (usn); +create index ix_cards_usn on cards (usn); +create index ix_revlog_usn on revlog (usn); +-- card spacing, etc +create index ix_cards_nid on cards (nid); +-- scheduling and deck limiting +create index ix_cards_sched on cards (did, queue, due); +-- revlog by card +create index ix_revlog_cid on revlog (cid); +-- field uniqueness +create index ix_notes_csum on notes (csum); + +insert into col values (1,0,0,0,0,0,0,0,'{}','{}','{}','{}','{}'); diff --git a/rslib/src/storage/sqlite.rs b/rslib/src/storage/sqlite.rs new file mode 100644 index 000000000..853a22759 --- /dev/null +++ b/rslib/src/storage/sqlite.rs @@ -0,0 +1,128 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +use crate::err::Result; +use crate::err::{AnkiError, DBErrorKind}; +use crate::time::i64_unix_timestamp; +use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef}; +use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS}; +use serde::de::DeserializeOwned; +use serde_derive::{Deserialize, Serialize}; +use serde_json::Value; +use std::borrow::Cow; +use std::convert::TryFrom; +use std::fmt; +use std::path::{Path, PathBuf}; + +const SCHEMA_MIN_VERSION: u8 = 11; +const SCHEMA_MAX_VERSION: u8 = 11; + +macro_rules! cached_sql { + ( $label:expr, $db:expr, $sql:expr ) => {{ + if $label.is_none() { + $label = Some($db.prepare_cached($sql)?); + } + $label.as_mut().unwrap() + }}; +} + +#[derive(Debug)] +pub struct SqliteStorage { + // currently crate-visible for dbproxy + pub(crate) db: Connection, + path: PathBuf, + server: bool, +} + +fn open_or_create_collection_db(path: &Path) -> Result { + let mut db = Connection::open(path)?; + + if std::env::var("TRACESQL").is_ok() { + db.trace(Some(trace)); + } + + db.pragma_update(None, "locking_mode", &"exclusive")?; + db.pragma_update(None, "page_size", &4096)?; + db.pragma_update(None, "cache_size", &(-40 * 1024))?; + db.pragma_update(None, "legacy_file_format", &false)?; + db.pragma_update(None, "journal", &"wal")?; + db.set_prepared_statement_cache_capacity(50); + + Ok(db) +} + +/// Fetch schema version from database. +/// Return (must_create, version) +fn schema_version(db: &Connection) -> Result<(bool, u8)> { + if !db + .prepare("select null from sqlite_master where type = 'table' and name = 'col'")? + .exists(NO_PARAMS)? + { + return Ok((true, SCHEMA_MAX_VERSION)); + } + + Ok(( + false, + db.query_row("select ver from col", NO_PARAMS, |r| Ok(r.get(0)?))?, + )) +} + +fn trace(s: &str) { + println!("sql: {}", s) +} + +impl SqliteStorage { + pub(crate) fn open_or_create(path: &Path, server: bool) -> Result { + let db = open_or_create_collection_db(path)?; + + let (create, ver) = schema_version(&db)?; + if create { + unimplemented!(); // todo + db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?; + db.execute_batch(include_str!("schema11.sql"))?; + db.execute( + "update col set crt=?, ver=?", + params![i64_unix_timestamp(), ver], + )?; + db.prepare_cached("commit")?.execute(NO_PARAMS)?; + } else { + if ver > SCHEMA_MAX_VERSION { + return Err(AnkiError::DBError { + info: "".to_string(), + kind: DBErrorKind::FileTooNew, + }); + } + if ver < SCHEMA_MIN_VERSION { + return Err(AnkiError::DBError { + info: "".to_string(), + kind: DBErrorKind::FileTooOld, + }); + } + }; + + let storage = Self { + db, + path: path.to_owned(), + server, + }; + + Ok(storage) + } + + pub(crate) fn begin(&self) -> Result<()> { + self.db + .prepare_cached("begin exclusive")? + .execute(NO_PARAMS)?; + Ok(()) + } + + pub(crate) fn commit(&self) -> Result<()> { + self.db.prepare_cached("commit")?.execute(NO_PARAMS)?; + Ok(()) + } + + pub(crate) fn rollback(&self) -> Result<()> { + self.db.execute("rollback", NO_PARAMS)?; + Ok(()) + } +}