Integrate AnkiDroid's backend patches into the repo (#2290)

* Relax chrono specification for AnkiDroid

https://github.com/ankidroid/Anki-Android-Backend/pull/251

* Add AnkiDroid service and AnkiDroid customizations

Most of the work here was done by David in the Backend repo; integrating
it into this repo for ease of future maintenance.

Based on 5d9f262f4c
with some tweaks:

- Protobuf imports have been fixed to match the recent refactor
- FatalError has been renamed to AnkidroidPanicError
- Tweaks to the desktop code to deal with the extra arg to open_collection,
and exclude AnkiDroid service methods from our Python code.

* Refactor AnkiDroid's DB code to avoid uses of unsafe
This commit is contained in:
Damien Elmes 2023-01-03 13:11:23 +10:00 committed by GitHub
parent 922574444f
commit 0eddb25287
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 884 additions and 7 deletions

View file

@ -0,0 +1,74 @@
syntax = "proto3";
option java_multiple_files = true;
import "anki/generic.proto";
import "anki/scheduler.proto";
package anki.ankidroid;
service AnkidroidService {
rpc SchedTimingTodayLegacy(SchedTimingTodayLegacyRequest)
returns (scheduler.SchedTimingTodayResponse);
rpc LocalMinutesWestLegacy(generic.Int64) returns (generic.Int32);
rpc RunDbCommand(generic.Json) returns (generic.Json);
rpc RunDbCommandProto(generic.Json) returns (DBResponse);
rpc InsertForId(generic.Json) returns (generic.Int64);
rpc RunDbCommandForRowCount(generic.Json) returns (generic.Int64);
rpc FlushAllQueries(generic.Empty) returns (generic.Empty);
rpc FlushQuery(generic.Int32) returns (generic.Empty);
rpc GetNextResultPage(GetNextResultPageRequest) returns (DBResponse);
rpc SetPageSize(generic.Int64) returns (generic.Empty);
rpc GetColumnNamesFromQuery(generic.String) returns (generic.StringList);
rpc GetActiveSequenceNumbers(generic.Empty)
returns (GetActiveSequenceNumbersResponse);
rpc DebugProduceError(generic.String) returns (generic.Empty);
}
message DebugActiveDatabaseSequenceNumbersResponse {
repeated int32 sequence_numbers = 1;
}
message SchedTimingTodayLegacyRequest {
int64 created_secs = 1;
optional sint32 created_mins_west = 2;
int64 now_secs = 3;
sint32 now_mins_west = 4;
sint32 rollover_hour = 5;
}
// We expect in Java: Null, String, Short, Int, Long, Float, Double, Boolean,
// Blob (unused) We get: DbResult (Null, String, i64, f64, Vec<u8>), which
// matches SQLite documentation
message SqlValue {
oneof Data {
string stringValue = 1;
int64 longValue = 2;
double doubleValue = 3;
bytes blobValue = 4;
}
}
message Row {
repeated SqlValue fields = 1;
}
message DbResult {
repeated Row rows = 1;
}
message DBResponse {
DbResult result = 1;
int32 sequenceNumber = 2;
int32 rowCount = 3;
int64 startIndex = 4;
}
message GetNextResultPageRequest {
int32 sequence = 1;
int64 index = 2;
}
message GetActiveSequenceNumbersResponse {
repeated int32 numbers = 1;
}

View file

@ -30,6 +30,7 @@ enum ServiceIndex {
SERVICE_INDEX_CARDS = 14;
SERVICE_INDEX_LINKS = 15;
SERVICE_INDEX_IMPORT_EXPORT = 16;
SERVICE_INDEX_ANKIDROID = 17;
}
message BackendInit {
@ -64,6 +65,7 @@ message BackendError {
IMPORT_ERROR = 16;
DELETED = 17;
CARD_TYPE_ERROR = 18;
ANKIDROID_PANIC_ERROR = 19;
}
// error description, usually localized, suitable for displaying to the user

View file

@ -34,6 +34,9 @@ message OpenCollectionRequest {
string collection_path = 1;
string media_folder_path = 2;
string media_db_path = 3;
// temporary option for AnkiDroid
bool force_schema11 = 99;
}
message CloseCollectionRequest {

View file

@ -309,6 +309,7 @@ class Collection(DeprecatedNamesMixin):
collection_path=self.path,
media_folder_path=media_dir,
media_db_path=media_db,
force_schema11=False,
)
self.db = DBProxy(weakref.proxy(self._backend))
self.db.begin()

View file

@ -194,6 +194,8 @@ for service in anki.backend_pb2.ServiceIndex.DESCRIPTOR.values:
base = service.name.replace("SERVICE_INDEX_", "")
service_pkg = service_modules.get(base)
service_var = "_" + base.replace("_", "") + "SERVICE"
if service_var == "_ANKIDROIDSERVICE":
continue
service_obj = getattr(service_pkg, service_var)
service_index = service.number
render_service(service_obj, service_index)

View file

@ -54,7 +54,7 @@ ammonia = "3.3.0"
async-trait = "0.1.59"
blake3 = "1.3.3"
bytes = "1.3.0"
chrono = { version = "0.4.23", default-features = false, features = ["std", "clock"] }
chrono = { version = "0.4.19", default-features = false, features = ["std", "clock"] }
coarsetime = "0.1.22"
convert_case = "0.6.0"
dissimilar = "1.0.4"

View file

@ -0,0 +1,436 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
collections::HashMap,
mem::size_of,
sync::{
atomic::{AtomicI32, Ordering},
Mutex,
},
};
use itertools::{
FoldWhile,
FoldWhile::{Continue, Done},
Itertools,
};
use lazy_static::lazy_static;
use rusqlite::ToSql;
use serde_derive::Deserialize;
use crate::{
collection::Collection,
error::Result,
pb::ankidroid::{sql_value::Data, DbResponse, DbResult, Row, SqlValue},
};
/// A pointer to the SqliteStorage object stored in a collection, used to
/// uniquely index results from multiple open collections at once.
impl Collection {
fn id_for_db_cache(&self) -> CollectionId {
CollectionId((&self.storage as *const _) as i64)
}
}
#[derive(Hash, PartialEq, Eq)]
struct CollectionId(i64);
#[derive(Deserialize)]
struct DBArgs {
sql: String,
args: Vec<crate::backend::dbproxy::SqlValue>,
}
pub trait Sizable {
/** Estimates the heap size of the value, in bytes */
fn estimate_size(&self) -> usize;
}
impl Sizable for Data {
fn estimate_size(&self) -> usize {
match self {
Data::StringValue(s) => s.len(),
Data::LongValue(_) => size_of::<i64>(),
Data::DoubleValue(_) => size_of::<f64>(),
Data::BlobValue(b) => b.len(),
}
}
}
impl Sizable for SqlValue {
fn estimate_size(&self) -> usize {
// Add a byte for the optional
self.data
.as_ref()
.map(|f| f.estimate_size() + 1)
.unwrap_or(1)
}
}
impl Sizable for Row {
fn estimate_size(&self) -> usize {
self.fields.iter().map(|x| x.estimate_size()).sum()
}
}
impl Sizable for DbResult {
fn estimate_size(&self) -> usize {
// Performance: It might be best to take the first x rows and determine the data types
// If we have floats or longs, they'll be a fixed size (excluding nulls) and should speed
// up the calculation as we'll only calculate a subset of the columns.
self.rows.iter().map(|x| x.estimate_size()).sum()
}
}
pub(crate) fn select_next_slice<'a>(rows: impl Iterator<Item = &'a Row>) -> Vec<Row> {
select_slice_of_size(rows, get_max_page_size())
.into_inner()
.1
}
fn select_slice_of_size<'a>(
mut rows: impl Iterator<Item = &'a Row>,
max_size: usize,
) -> FoldWhile<(usize, Vec<Row>)> {
let init: Vec<Row> = Vec::new();
rows.fold_while((0, init), |mut acc, x| {
let new_size = acc.0 + x.estimate_size();
// If the accumulator is 0, but we're over the size: return a single result so we don't loop forever.
// Theoretically, this shouldn't happen as data should be reasonably sized
if new_size > max_size && acc.0 > 0 {
Done(acc)
} else {
// PERF: should be faster to return (size, numElements) then bulk copy/slice
acc.1.push(x.to_owned());
Continue((new_size, acc.1))
}
})
}
type SequenceNumber = i32;
lazy_static! {
static ref HASHMAP: Mutex<HashMap<CollectionId, HashMap<SequenceNumber, DbResponse>>> =
Mutex::new(HashMap::new());
}
pub(crate) fn flush_single_result(col: &Collection, sequence_number: i32) {
HASHMAP
.lock()
.unwrap()
.get_mut(&col.id_for_db_cache())
.map(|storage| storage.remove(&sequence_number));
}
pub(crate) fn flush_collection(col: &Collection) {
HASHMAP.lock().unwrap().remove(&col.id_for_db_cache());
}
pub(crate) fn active_sequences(col: &Collection) -> Vec<i32> {
HASHMAP
.lock()
.unwrap()
.get(&col.id_for_db_cache())
.map(|h| h.keys().copied().collect())
.unwrap_or_default()
}
/**
Store the data in the cache if larger than than the page size.<br/>
Returns: The data capped to the page size
*/
pub(crate) fn trim_and_cache_remaining(
col: &Collection,
values: DbResult,
sequence_number: i32,
) -> DbResponse {
let start_index = 0;
// PERF: Could speed this up by not creating the vector and just calculating the count
let first_result = select_next_slice(values.rows.iter());
let row_count = values.rows.len() as i32;
if first_result.len() < values.rows.len() {
let to_store = DbResponse {
result: Some(values),
sequence_number,
row_count,
start_index,
};
insert_cache(col, to_store);
DbResponse {
result: Some(DbResult { rows: first_result }),
sequence_number,
row_count,
start_index,
}
} else {
DbResponse {
result: Some(values),
sequence_number,
row_count,
start_index,
}
}
}
fn insert_cache(col: &Collection, result: DbResponse) {
HASHMAP
.lock()
.unwrap()
.entry(col.id_for_db_cache())
.or_default()
.insert(result.sequence_number, result);
}
pub(crate) fn get_next(
col: &Collection,
sequence_number: i32,
start_index: i64,
) -> Option<DbResponse> {
let result = get_next_result(col, &sequence_number, start_index);
if let Some(resp) = result.as_ref() {
if resp.result.is_none() || resp.result.as_ref().unwrap().rows.is_empty() {
flush_single_result(col, sequence_number)
}
}
result
}
fn get_next_result(
col: &Collection,
sequence_number: &i32,
start_index: i64,
) -> Option<DbResponse> {
let map = HASHMAP.lock().unwrap();
let result_map = map.get(&col.id_for_db_cache())?;
let current_result = result_map.get(sequence_number)?;
// TODO: This shouldn't need to exist
let tmp: Vec<Row> = Vec::new();
let next_rows = current_result
.result
.as_ref()
.map(|x| x.rows.iter())
.unwrap_or_else(|| tmp.iter());
let skipped_rows = next_rows.clone().skip(start_index as usize).collect_vec();
println!("{}", skipped_rows.len());
let filtered_rows = select_next_slice(next_rows.skip(start_index as usize));
let result = DbResult {
rows: filtered_rows,
};
let trimmed_result = DbResponse {
result: Some(result),
sequence_number: current_result.sequence_number,
row_count: current_result.row_count,
start_index,
};
Some(trimmed_result)
}
static SEQUENCE_NUMBER: AtomicI32 = AtomicI32::new(0);
pub(crate) fn next_sequence_number() -> i32 {
SEQUENCE_NUMBER.fetch_add(1, Ordering::SeqCst)
}
lazy_static! {
// same as we get from io.requery.android.database.CursorWindow.sCursorWindowSize
static ref DB_COMMAND_PAGE_SIZE: Mutex<usize> = Mutex::new(1024 * 1024 * 2);
}
pub(crate) fn set_max_page_size(size: usize) {
let mut state = DB_COMMAND_PAGE_SIZE.lock().expect("Could not lock mutex");
*state = size;
}
fn get_max_page_size() -> usize {
*DB_COMMAND_PAGE_SIZE.lock().unwrap()
}
fn get_args(in_bytes: &[u8]) -> Result<DBArgs> {
let ret: DBArgs = serde_json::from_slice(in_bytes)?;
Ok(ret)
}
pub(crate) fn insert_for_id(col: &Collection, json: &[u8]) -> Result<i64> {
let req = get_args(json)?;
let args: Vec<_> = req.args.iter().map(|a| a as &dyn ToSql).collect();
col.storage.db.execute(&req.sql, &args[..])?;
Ok(col.storage.db.last_insert_rowid())
}
pub(crate) fn execute_for_row_count(col: &Collection, req: &[u8]) -> Result<i64> {
let req = get_args(req)?;
let args: Vec<_> = req.args.iter().map(|a| a as &dyn ToSql).collect();
let count = col.storage.db.execute(&req.sql, &args[..])?;
Ok(count as i64)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
backend::ankidroid::db::{select_slice_of_size, Sizable},
collection::open_test_collection,
pb::ankidroid::{sql_value, Row, SqlValue},
};
fn gen_data() -> Vec<SqlValue> {
vec![
SqlValue {
data: Some(sql_value::Data::DoubleValue(12.0)),
},
SqlValue {
data: Some(sql_value::Data::LongValue(12)),
},
SqlValue {
data: Some(sql_value::Data::StringValue(
"Hellooooooo World".to_string(),
)),
},
SqlValue {
data: Some(sql_value::Data::BlobValue(vec![])),
},
]
}
#[test]
fn test_size_estimate() {
let row = Row { fields: gen_data() };
let result = DbResult {
rows: vec![row.clone(), row],
};
let actual_size = result.estimate_size();
let expected_size = (17 + 8 + 8) * 2; // 1 variable string, 1 long, 1 float
let expected_overhead = 4 * 2; // 4 optional columns
assert_eq!(actual_size, expected_overhead + expected_size);
}
#[test]
fn test_stream_size() {
let row = Row { fields: gen_data() };
let result = DbResult {
rows: vec![row.clone(), row.clone(), row],
};
let limit = 74 + 1; // two rows are 74
let result = select_slice_of_size(result.rows.iter(), limit).into_inner();
assert_eq!(
2,
result.1.len(),
"The final element should not be included"
);
assert_eq!(
74, result.0,
"The size should be the size of the first two objects"
);
}
#[test]
fn test_stream_size_too_small() {
let row = Row { fields: gen_data() };
let result = DbResult { rows: vec![row] };
let limit = 1;
let result = select_slice_of_size(result.rows.iter(), limit).into_inner();
assert_eq!(
1,
result.1.len(),
"If the limit is too small, a result is still returned"
);
assert_eq!(
37, result.0,
"The size should be the size of the first objects"
);
}
const SEQUENCE_NUMBER: i32 = 1;
fn get(col: &Collection, index: i64) -> Option<DbResponse> {
get_next(col, SEQUENCE_NUMBER, index)
}
fn get_first(col: &Collection, result: DbResult) -> DbResponse {
trim_and_cache_remaining(col, result, SEQUENCE_NUMBER)
}
fn seq_number_used(col: &Collection) -> bool {
HASHMAP
.lock()
.unwrap()
.get(&col.id_for_db_cache())
.unwrap()
.contains_key(&SEQUENCE_NUMBER)
}
#[test]
fn integration_test() {
let col = open_test_collection();
let row = Row { fields: gen_data() };
// return one row at a time
set_max_page_size(row.estimate_size() - 1);
let db_query_result = DbResult {
rows: vec![row.clone(), row],
};
let first_jni_response = get_first(&col, db_query_result);
assert_eq!(
row_count(&first_jni_response),
1,
"The first call should only return one row"
);
let next_index = first_jni_response.start_index + row_count(&first_jni_response);
let second_response = get(&col, next_index);
assert!(
second_response.is_some(),
"The second response should return a value"
);
let valid_second_response = second_response.unwrap();
assert_eq!(row_count(&valid_second_response), 1);
let final_index = valid_second_response.start_index + row_count(&valid_second_response);
assert!(seq_number_used(&col), "The sequence number is assigned");
let final_response = get(&col, final_index);
assert!(
final_response.is_some(),
"The third call should return something with no rows"
);
assert_eq!(
row_count(&final_response.unwrap()),
0,
"The third call should return something with no rows"
);
assert!(
!seq_number_used(&col),
"Sequence number data has been cleared"
);
}
fn row_count(resp: &DbResponse) -> i64 {
resp.result.as_ref().map(|x| x.rows.len()).unwrap_or(0) as i64
}
}

View file

@ -0,0 +1,139 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::{
error::{
DbError, DbErrorKind as DB, FilteredDeckError, InvalidInputError, NetworkError,
NetworkErrorKind as Net, NotFoundError, SearchErrorKind, SyncError, SyncErrorKind as Sync,
},
prelude::AnkiError,
};
pub(super) fn debug_produce_error(s: &str) -> AnkiError {
let info = "error_value".to_string();
match s {
"TemplateError" => AnkiError::TemplateError { info },
"DbErrorFileTooNew" => AnkiError::DbError {
source: DbError {
info,
kind: DB::FileTooNew,
},
},
"DbErrorFileTooOld" => AnkiError::DbError {
source: DbError {
info,
kind: DB::FileTooOld,
},
},
"DbErrorMissingEntity" => AnkiError::DbError {
source: DbError {
info,
kind: DB::MissingEntity,
},
},
"DbErrorCorrupt" => AnkiError::DbError {
source: DbError {
info,
kind: DB::Corrupt,
},
},
"DbErrorLocked" => AnkiError::DbError {
source: DbError {
info,
kind: DB::Locked,
},
},
"DbErrorOther" => AnkiError::DbError {
source: DbError {
info,
kind: DB::Other,
},
},
"NetworkError" => AnkiError::NetworkError {
source: NetworkError {
info,
kind: Net::Offline,
},
},
"SyncErrorConflict" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::Conflict,
},
},
"SyncErrorServerError" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::ServerError,
},
},
"SyncErrorClientTooOld" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::ClientTooOld,
},
},
"SyncErrorAuthFailed" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::AuthFailed,
},
},
"SyncErrorServerMessage" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::ServerMessage,
},
},
"SyncErrorClockIncorrect" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::ClockIncorrect,
},
},
"SyncErrorOther" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::Other,
},
},
"SyncErrorResyncRequired" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::ResyncRequired,
},
},
"SyncErrorDatabaseCheckRequired" => AnkiError::SyncError {
source: SyncError {
info,
kind: Sync::DatabaseCheckRequired,
},
},
"JSONError" => AnkiError::JsonError { info },
"ProtoError" => AnkiError::ProtoError { info },
"Interrupted" => AnkiError::Interrupted,
"CollectionNotOpen" => AnkiError::CollectionNotOpen,
"CollectionAlreadyOpen" => AnkiError::CollectionAlreadyOpen,
"NotFound" => AnkiError::NotFound {
source: NotFoundError {
type_name: "".to_string(),
identifier: "".to_string(),
backtrace: None,
},
},
"Existing" => AnkiError::Existing,
"FilteredDeckError" => AnkiError::FilteredDeckError {
source: FilteredDeckError::FilteredDeckRequired,
},
"SearchError" => AnkiError::SearchError {
source: SearchErrorKind::EmptyGroup,
},
_ => AnkiError::InvalidInput {
source: InvalidInputError {
message: info,
source: None,
backtrace: None,
},
},
}
}

View file

@ -0,0 +1,117 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
pub(crate) mod db;
pub(crate) mod error;
use self::{db::active_sequences, error::debug_produce_error};
use super::{
dbproxy::{db_command_bytes, db_command_proto},
Backend,
};
pub(super) use crate::pb::ankidroid::ankidroid_service::Service as AnkidroidService;
use crate::{
backend::ankidroid::db::{execute_for_row_count, insert_for_id},
pb::{
self as pb,
ankidroid::{DbResponse, GetActiveSequenceNumbersResponse, GetNextResultPageRequest},
generic::{self, Empty, Int32, Json},
},
prelude::*,
scheduler::timing::{self, fixed_offset_from_minutes},
};
impl AnkidroidService for Backend {
fn sched_timing_today_legacy(
&self,
input: pb::ankidroid::SchedTimingTodayLegacyRequest,
) -> Result<pb::scheduler::SchedTimingTodayResponse> {
let result = timing::sched_timing_today(
TimestampSecs::from(input.created_secs),
TimestampSecs::from(input.now_secs),
input.created_mins_west.map(fixed_offset_from_minutes),
fixed_offset_from_minutes(input.now_mins_west),
Some(input.rollover_hour as u8),
)?;
Ok(pb::scheduler::SchedTimingTodayResponse::from(result))
}
fn local_minutes_west_legacy(&self, input: pb::generic::Int64) -> Result<pb::generic::Int32> {
Ok(pb::generic::Int32 {
val: timing::local_minutes_west_for_stamp(input.val.into())?,
})
}
fn run_db_command(&self, input: Json) -> Result<Json> {
self.with_col(|col| db_command_bytes(col, &input.json))
.map(|json| Json { json })
}
fn run_db_command_proto(&self, input: Json) -> Result<DbResponse> {
self.with_col(|col| db_command_proto(col, &input.json))
}
fn run_db_command_for_row_count(&self, input: Json) -> Result<pb::generic::Int64> {
self.with_col(|col| execute_for_row_count(col, &input.json))
.map(|val| pb::generic::Int64 { val })
}
fn flush_all_queries(&self, _input: Empty) -> Result<Empty> {
self.with_col(|col| {
db::flush_collection(col);
Ok(Empty {})
})
}
fn flush_query(&self, input: Int32) -> Result<Empty> {
self.with_col(|col| {
db::flush_single_result(col, input.val);
Ok(Empty {})
})
}
fn get_next_result_page(&self, input: GetNextResultPageRequest) -> Result<DbResponse> {
self.with_col(|col| {
db::get_next(col, input.sequence, input.index).or_invalid("missing result page")
})
}
fn insert_for_id(&self, input: Json) -> Result<pb::generic::Int64> {
self.with_col(|col| insert_for_id(col, &input.json).map(Into::into))
}
fn set_page_size(&self, input: pb::generic::Int64) -> Result<Empty> {
// we don't require an open collection, but should avoid modifying this
// concurrently
let _guard = self.col.lock();
db::set_max_page_size(input.val as usize);
Ok(().into())
}
fn get_column_names_from_query(
&self,
input: generic::String,
) -> Result<pb::generic::StringList> {
self.with_col(|col| {
let stmt = col.storage.db.prepare(&input.val)?;
let names = stmt.column_names();
let names: Vec<_> = names.iter().map(ToString::to_string).collect();
Ok(names.into())
})
}
fn get_active_sequence_numbers(
&self,
_input: Empty,
) -> Result<GetActiveSequenceNumbersResponse> {
self.with_col(|col| {
Ok(GetActiveSequenceNumbersResponse {
numbers: active_sequences(col),
})
})
}
fn debug_produce_error(&self, input: generic::String) -> Result<Empty> {
Err(debug_produce_error(&input.val))
}
}

View file

@ -34,6 +34,7 @@ impl CollectionService for Backend {
let mut builder = CollectionBuilder::new(input.collection_path);
builder
.set_force_schema11(input.force_schema11)
.set_media_paths(input.media_folder_path, input.media_db_path)
.set_server(self.server)
.set_tr(self.tr.clone());

View file

@ -8,7 +8,14 @@ use rusqlite::{
};
use serde_derive::{Deserialize, Serialize};
use crate::{prelude::*, storage::SqliteStorage};
use crate::{
pb,
pb::ankidroid::{
sql_value::Data, DbResponse, DbResult as ProtoDbResult, Row, SqlValue as pb_SqlValue,
},
prelude::*,
storage::SqliteStorage,
};
#[derive(Deserialize)]
#[serde(tag = "kind", rename_all = "lowercase")]
@ -57,6 +64,42 @@ impl ToSql for SqlValue {
}
}
impl From<&SqlValue> for pb::ankidroid::SqlValue {
fn from(item: &SqlValue) -> Self {
match item {
SqlValue::Null => pb_SqlValue { data: Option::None },
SqlValue::String(s) => pb_SqlValue {
data: Some(Data::StringValue(s.to_string())),
},
SqlValue::Int(i) => pb_SqlValue {
data: Some(Data::LongValue(*i)),
},
SqlValue::Double(d) => pb_SqlValue {
data: Some(Data::DoubleValue(*d)),
},
SqlValue::Blob(b) => pb_SqlValue {
data: Some(Data::BlobValue(b.clone())),
},
}
}
}
impl From<&Vec<SqlValue>> for pb::ankidroid::Row {
fn from(item: &Vec<SqlValue>) -> Self {
Row {
fields: item.iter().map(pb::ankidroid::SqlValue::from).collect(),
}
}
}
impl From<&Vec<Vec<SqlValue>>> for pb::ankidroid::DbResult {
fn from(item: &Vec<Vec<SqlValue>>) -> Self {
ProtoDbResult {
rows: item.iter().map(Row::from).collect(),
}
}
}
impl FromSql for SqlValue {
fn column_result(value: ValueRef<'_>) -> std::result::Result<Self, FromSqlError> {
let val = match value {
@ -71,6 +114,10 @@ impl FromSql for SqlValue {
}
pub(super) fn db_command_bytes(col: &mut Collection, input: &[u8]) -> Result<Vec<u8>> {
serde_json::to_vec(&db_command_bytes_inner(col, input)?).map_err(Into::into)
}
pub(super) fn db_command_bytes_inner(col: &mut Collection, input: &[u8]) -> Result<DbResult> {
let req: DbRequest = serde_json::from_slice(input)?;
let resp = match req {
DbRequest::Query {
@ -107,7 +154,7 @@ pub(super) fn db_command_bytes(col: &mut Collection, input: &[u8]) -> Result<Vec
db_execute_many(&col.storage, &sql, &args)?
}
};
Ok(serde_json::to_vec(&resp)?)
Ok(resp)
}
fn update_state_after_modification(col: &mut Collection, sql: &str) {
@ -128,6 +175,20 @@ fn is_dql(sql: &str) -> bool {
head.starts_with("select")
}
pub(crate) fn db_command_proto(col: &mut Collection, input: &[u8]) -> Result<DbResponse> {
let result = db_command_bytes_inner(col, input)?;
let proto_resp = match result {
DbResult::None => ProtoDbResult { rows: Vec::new() },
DbResult::Rows(rows) => ProtoDbResult::from(&rows),
};
let trimmed = super::ankidroid::db::trim_and_cache_remaining(
col,
proto_resp,
super::ankidroid::db::next_sequence_number(),
);
Ok(trimmed)
}
pub(super) fn db_query_row(ctx: &SqliteStorage, sql: &str, args: &[SqlValue]) -> Result<DbResult> {
let mut stmt = ctx.db.prepare_cached(sql)?;
let columns = stmt.column_count();

View file

@ -5,6 +5,7 @@
#![allow(clippy::unnecessary_wraps)]
mod adding;
mod ankidroid;
mod card;
mod cardrendering;
mod collection;
@ -42,6 +43,7 @@ use tokio::runtime::{
};
use self::{
ankidroid::AnkidroidService,
card::CardsService,
cardrendering::CardRenderingService,
collection::CollectionService,
@ -120,6 +122,7 @@ impl Backend {
ServiceIndex::from_i32(service as i32)
.or_invalid("invalid service")
.and_then(|service| match service {
ServiceIndex::Ankidroid => AnkidroidService::run_method(self, method, input),
ServiceIndex::Scheduler => SchedulerService::run_method(self, method, input),
ServiceIndex::Decks => DecksService::run_method(self, method, input),
ServiceIndex::Notes => NotesService::run_method(self, method, input),

View file

@ -33,6 +33,8 @@ pub struct CollectionBuilder {
media_db: Option<PathBuf>,
server: Option<bool>,
tr: Option<I18n>,
// temporary option for AnkiDroid
force_schema11: Option<bool>,
}
impl CollectionBuilder {
@ -53,8 +55,8 @@ impl CollectionBuilder {
let server = self.server.unwrap_or_default();
let media_folder = self.media_folder.clone().unwrap_or_default();
let media_db = self.media_db.clone().unwrap_or_default();
let storage = SqliteStorage::open_or_create(&col_path, &tr, server)?;
let force_schema11 = self.force_schema11.unwrap_or_default();
let storage = SqliteStorage::open_or_create(&col_path, &tr, server, force_schema11)?;
let col = Collection {
storage,
col_path,
@ -88,6 +90,11 @@ impl CollectionBuilder {
self.tr = Some(tr);
self
}
pub fn set_force_schema11(&mut self, force: bool) -> &mut Self {
self.force_schema11 = Some(force);
self
}
}
#[cfg(test)]

View file

@ -9,6 +9,7 @@ macro_rules! protobuf {
};
}
protobuf!(ankidroid, "ankidroid");
protobuf!(backend, "backend");
protobuf!(card_rendering, "card_rendering");
protobuf!(cards, "cards");

View file

@ -762,7 +762,8 @@ mod test {
#[test]
fn add_card() {
let tr = I18n::template_only();
let storage = SqliteStorage::open_or_create(Path::new(":memory:"), &tr, false).unwrap();
let storage =
SqliteStorage::open_or_create(Path::new(":memory:"), &tr, false, false).unwrap();
let mut card = Card::default();
storage.add_card(&mut card).unwrap();
let id1 = card.id;

View file

@ -204,7 +204,12 @@ fn trace(s: &str) {
}
impl SqliteStorage {
pub(crate) fn open_or_create(path: &Path, tr: &I18n, server: bool) -> Result<Self> {
pub(crate) fn open_or_create(
path: &Path,
tr: &I18n,
server: bool,
force_schema11: bool,
) -> Result<Self> {
let db = open_or_create_collection_db(path)?;
let (create, ver) = schema_version(&db)?;
@ -249,6 +254,13 @@ impl SqliteStorage {
let storage = Self { db };
if force_schema11 {
if create || upgrade {
storage.commit_trx()?;
}
return storage_with_schema11(storage, ver);
}
if create || upgrade {
storage.upgrade_to_latest_schema(ver, server)?;
}
@ -369,3 +381,20 @@ impl SqliteStorage {
self.db.query_row(sql, [], |r| r.get(0)).map_err(Into::into)
}
}
fn storage_with_schema11(storage: SqliteStorage, ver: u8) -> Result<SqliteStorage> {
if ver != 11 {
if ver != SCHEMA_MAX_VERSION {
// partially upgraded; need to fully upgrade before downgrading
storage.begin_trx()?;
storage.upgrade_to_latest_schema(ver, false)?;
storage.commit_trx()?;
}
storage.downgrade_to(SchemaVersion::V11)?;
}
// Requery uses "TRUNCATE" by default if WAL is not enabled.
// We copy this behaviour here. See https://github.com/ankidroid/Anki-Android/pull/7977 for
// analysis. We may be able to enable WAL at a later time.
storage.db.pragma_update(None, "journal_mode", "TRUNCATE")?;
Ok(storage)
}