diff --git a/proto/anki/scheduler.proto b/proto/anki/scheduler.proto index 1bacd0d1f..e0e653082 100644 --- a/proto/anki/scheduler.proto +++ b/proto/anki/scheduler.proto @@ -67,6 +67,8 @@ service BackendSchedulerService { returns (ComputeFsrsWeightsResponse); // Generates parameters used for FSRS's scheduler benchmarks. rpc FsrsBenchmark(FsrsBenchmarkRequest) returns (FsrsBenchmarkResponse); + // Used for exporting revlogs for algorithm research. + rpc ExportDataset(ExportDatasetRequest) returns (generic.Empty); } message SchedulingState { @@ -363,6 +365,11 @@ message FsrsBenchmarkResponse { repeated float weights = 1; } +message ExportDatasetRequest { + uint32 min_entries = 1; + string target_path = 2; +} + message FsrsItem { repeated FsrsReview reviews = 1; } diff --git a/proto/anki/stats.proto b/proto/anki/stats.proto index 4ef811df9..14d0eef6f 100644 --- a/proto/anki/stats.proto +++ b/proto/anki/stats.proto @@ -217,7 +217,21 @@ message RevlogEntry { ReviewKind review_kind = 9; } -message RevlogEntries { - repeated RevlogEntry entries = 1; - int64 next_day_at = 2; +message CardEntry { + int64 id = 1; + int64 note_id = 2; + int64 deck_id = 3; +} + +message DeckEntry { + int64 id = 1; + int64 parent_id = 2; + int64 preset_id = 3; +} + +message Dataset { + repeated RevlogEntry revlogs = 1; + repeated CardEntry cards = 2; + repeated DeckEntry decks = 3; + int64 next_day_at = 4; } diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 27c01e0e1..559434cda 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -420,6 +420,11 @@ class Collection(DeprecatedNamesMixin): def import_json_string(self, json: str) -> ImportLogWithChanges: return self._backend.import_json_string(json) + def export_dataset_for_research( + self, target_path: str, min_entries: int = 0 + ) -> None: + self._backend.export_dataset(min_entries=min_entries, target_path=target_path) + # Image Occlusion ########################################################################## diff --git a/rslib/src/scheduler/fsrs/weights.rs b/rslib/src/scheduler/fsrs/weights.rs index 51bd7107d..3350487e4 100644 --- a/rslib/src/scheduler/fsrs/weights.rs +++ b/rslib/src/scheduler/fsrs/weights.rs @@ -1,5 +1,6 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::collections::HashMap; use std::iter; use std::path::Path; use std::thread; @@ -8,7 +9,8 @@ use std::time::Duration; use anki_io::write_file; use anki_proto::scheduler::ComputeFsrsWeightsResponse; use anki_proto::stats::revlog_entry; -use anki_proto::stats::RevlogEntries; +use anki_proto::stats::Dataset; +use anki_proto::stats::DeckEntry; use chrono::NaiveDate; use chrono::NaiveTime; use fsrs::CombinedProgressState; @@ -19,6 +21,7 @@ use fsrs::FSRS; use itertools::Itertools; use prost::Message; +use crate::decks::immediate_parent_name; use crate::prelude::*; use crate::revlog::RevlogEntry; use crate::revlog::RevlogReviewKind; @@ -127,22 +130,51 @@ impl Collection { } /// Used for exporting revlogs for algorithm research. - pub fn export_revlog_entries_to_protobuf( - &mut self, - min_entries: usize, - target_path: &Path, - ) -> Result<()> { - let entries = self.storage.get_all_revlog_entries_in_card_order()?; - if entries.len() < min_entries { + pub fn export_dataset(&mut self, min_entries: usize, target_path: &Path) -> Result<()> { + let revlog_entries = self.storage.get_all_revlog_entries_in_card_order()?; + if revlog_entries.len() < min_entries { return Err(AnkiError::FsrsInsufficientData); } - let entries = entries.into_iter().map(revlog_entry_to_proto).collect_vec(); + let revlogs = revlog_entries + .into_iter() + .map(revlog_entry_to_proto) + .collect_vec(); + let cards = self.storage.get_all_card_entries()?; + + let decks_map = self.storage.get_decks_map()?; + let deck_name_to_id: HashMap = decks_map + .into_iter() + .map(|(id, deck)| (deck.name.to_string(), id)) + .collect(); + + let decks = self + .storage + .get_all_decks()? + .into_iter() + .filter_map(|deck| { + if let Some(preset_id) = deck.config_id().map(|id| id.0) { + let parent_id = immediate_parent_name(&deck.name.to_string()) + .and_then(|parent_name| deck_name_to_id.get(parent_name)) + .map(|id| id.0) + .unwrap_or(0); + Some(DeckEntry { + id: deck.id.0, + parent_id, + preset_id, + }) + } else { + None + } + }) + .collect_vec(); let next_day_at = self.timing_today()?.next_day_at.0; - let entries = RevlogEntries { - entries, + let dataset = Dataset { + revlogs, + cards, + decks, next_day_at, }; - let data = entries.encode_to_vec(); + let data = dataset.encode_to_vec(); write_file(target_path, data)?; Ok(()) } diff --git a/rslib/src/scheduler/service/mod.rs b/rslib/src/scheduler/service/mod.rs index 9cce0c44e..aa71d4829 100644 --- a/rslib/src/scheduler/service/mod.rs +++ b/rslib/src/scheduler/service/mod.rs @@ -368,6 +368,15 @@ impl crate::services::BackendSchedulerService for Backend { let weights = fsrs.benchmark(train_set); Ok(FsrsBenchmarkResponse { weights }) } + + fn export_dataset(&self, req: scheduler::ExportDatasetRequest) -> Result<()> { + self.with_col(|col| { + col.export_dataset( + req.min_entries.try_into().unwrap(), + req.target_path.as_ref(), + ) + }) + } } fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem { diff --git a/rslib/src/storage/card/get_card_entry.sql b/rslib/src/storage/card/get_card_entry.sql new file mode 100644 index 000000000..d7d76ab87 --- /dev/null +++ b/rslib/src/storage/card/get_card_entry.sql @@ -0,0 +1,7 @@ +SELECT id, + nid, + CASE + WHEN odid = 0 THEN did + ELSE odid + END AS did +FROM cards; \ No newline at end of file diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index f290a7f71..51263a8b1 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -9,6 +9,7 @@ use std::convert::TryFrom; use std::fmt; use std::result; +use anki_proto::stats::CardEntry; use rusqlite::named_params; use rusqlite::params; use rusqlite::types::FromSql; @@ -87,6 +88,14 @@ fn row_to_card(row: &Row) -> result::Result { }) } +fn row_to_card_entry(row: &Row) -> Result { + Ok(CardEntry { + id: row.get(0)?, + note_id: row.get(1)?, + deck_id: row.get(2)?, + }) +} + fn row_to_new_card(row: &Row) -> result::Result { Ok(NewCard { id: row.get(0)?, @@ -108,6 +117,13 @@ impl super::SqliteStorage { .map_err(Into::into) } + pub(crate) fn get_all_card_entries(&self) -> Result> { + self.db + .prepare_cached(include_str!("get_card_entry.sql"))? + .query_and_then([], row_to_card_entry)? + .collect() + } + pub(crate) fn update_card(&self, card: &Card) -> Result<()> { let mut stmt = self.db.prepare_cached(include_str!("update_card.sql"))?; stmt.execute(params![