Feat/export dataset for research (#3511)

* Feat/export dataset for research

* add comment

Co-authored-by: Damien Elmes <dae@users.noreply.github.com>

* target_path is required

* format

* improve efficiency to look up parent_id

* move `use` down
This commit is contained in:
Jarrett Ye 2024-10-18 16:57:06 +08:00 committed by GitHub
parent c1a2b03871
commit b09326cddd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 105 additions and 15 deletions

View file

@ -67,6 +67,8 @@ service BackendSchedulerService {
returns (ComputeFsrsWeightsResponse); returns (ComputeFsrsWeightsResponse);
// Generates parameters used for FSRS's scheduler benchmarks. // Generates parameters used for FSRS's scheduler benchmarks.
rpc FsrsBenchmark(FsrsBenchmarkRequest) returns (FsrsBenchmarkResponse); rpc FsrsBenchmark(FsrsBenchmarkRequest) returns (FsrsBenchmarkResponse);
// Used for exporting revlogs for algorithm research.
rpc ExportDataset(ExportDatasetRequest) returns (generic.Empty);
} }
message SchedulingState { message SchedulingState {
@ -363,6 +365,11 @@ message FsrsBenchmarkResponse {
repeated float weights = 1; repeated float weights = 1;
} }
message ExportDatasetRequest {
uint32 min_entries = 1;
string target_path = 2;
}
message FsrsItem { message FsrsItem {
repeated FsrsReview reviews = 1; repeated FsrsReview reviews = 1;
} }

View file

@ -217,7 +217,21 @@ message RevlogEntry {
ReviewKind review_kind = 9; ReviewKind review_kind = 9;
} }
message RevlogEntries { message CardEntry {
repeated RevlogEntry entries = 1; int64 id = 1;
int64 next_day_at = 2; int64 note_id = 2;
int64 deck_id = 3;
}
message DeckEntry {
int64 id = 1;
int64 parent_id = 2;
int64 preset_id = 3;
}
message Dataset {
repeated RevlogEntry revlogs = 1;
repeated CardEntry cards = 2;
repeated DeckEntry decks = 3;
int64 next_day_at = 4;
} }

View file

@ -420,6 +420,11 @@ class Collection(DeprecatedNamesMixin):
def import_json_string(self, json: str) -> ImportLogWithChanges: def import_json_string(self, json: str) -> ImportLogWithChanges:
return self._backend.import_json_string(json) return self._backend.import_json_string(json)
def export_dataset_for_research(
self, target_path: str, min_entries: int = 0
) -> None:
self._backend.export_dataset(min_entries=min_entries, target_path=target_path)
# Image Occlusion # Image Occlusion
########################################################################## ##########################################################################

View file

@ -1,5 +1,6 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::collections::HashMap;
use std::iter; use std::iter;
use std::path::Path; use std::path::Path;
use std::thread; use std::thread;
@ -8,7 +9,8 @@ use std::time::Duration;
use anki_io::write_file; use anki_io::write_file;
use anki_proto::scheduler::ComputeFsrsWeightsResponse; use anki_proto::scheduler::ComputeFsrsWeightsResponse;
use anki_proto::stats::revlog_entry; use anki_proto::stats::revlog_entry;
use anki_proto::stats::RevlogEntries; use anki_proto::stats::Dataset;
use anki_proto::stats::DeckEntry;
use chrono::NaiveDate; use chrono::NaiveDate;
use chrono::NaiveTime; use chrono::NaiveTime;
use fsrs::CombinedProgressState; use fsrs::CombinedProgressState;
@ -19,6 +21,7 @@ use fsrs::FSRS;
use itertools::Itertools; use itertools::Itertools;
use prost::Message; use prost::Message;
use crate::decks::immediate_parent_name;
use crate::prelude::*; use crate::prelude::*;
use crate::revlog::RevlogEntry; use crate::revlog::RevlogEntry;
use crate::revlog::RevlogReviewKind; use crate::revlog::RevlogReviewKind;
@ -127,22 +130,51 @@ impl Collection {
} }
/// Used for exporting revlogs for algorithm research. /// Used for exporting revlogs for algorithm research.
pub fn export_revlog_entries_to_protobuf( pub fn export_dataset(&mut self, min_entries: usize, target_path: &Path) -> Result<()> {
&mut self, let revlog_entries = self.storage.get_all_revlog_entries_in_card_order()?;
min_entries: usize, if revlog_entries.len() < min_entries {
target_path: &Path,
) -> Result<()> {
let entries = self.storage.get_all_revlog_entries_in_card_order()?;
if entries.len() < min_entries {
return Err(AnkiError::FsrsInsufficientData); return Err(AnkiError::FsrsInsufficientData);
} }
let entries = entries.into_iter().map(revlog_entry_to_proto).collect_vec(); let revlogs = revlog_entries
.into_iter()
.map(revlog_entry_to_proto)
.collect_vec();
let cards = self.storage.get_all_card_entries()?;
let decks_map = self.storage.get_decks_map()?;
let deck_name_to_id: HashMap<String, DeckId> = decks_map
.into_iter()
.map(|(id, deck)| (deck.name.to_string(), id))
.collect();
let decks = self
.storage
.get_all_decks()?
.into_iter()
.filter_map(|deck| {
if let Some(preset_id) = deck.config_id().map(|id| id.0) {
let parent_id = immediate_parent_name(&deck.name.to_string())
.and_then(|parent_name| deck_name_to_id.get(parent_name))
.map(|id| id.0)
.unwrap_or(0);
Some(DeckEntry {
id: deck.id.0,
parent_id,
preset_id,
})
} else {
None
}
})
.collect_vec();
let next_day_at = self.timing_today()?.next_day_at.0; let next_day_at = self.timing_today()?.next_day_at.0;
let entries = RevlogEntries { let dataset = Dataset {
entries, revlogs,
cards,
decks,
next_day_at, next_day_at,
}; };
let data = entries.encode_to_vec(); let data = dataset.encode_to_vec();
write_file(target_path, data)?; write_file(target_path, data)?;
Ok(()) Ok(())
} }

View file

@ -368,6 +368,15 @@ impl crate::services::BackendSchedulerService for Backend {
let weights = fsrs.benchmark(train_set); let weights = fsrs.benchmark(train_set);
Ok(FsrsBenchmarkResponse { weights }) Ok(FsrsBenchmarkResponse { weights })
} }
fn export_dataset(&self, req: scheduler::ExportDatasetRequest) -> Result<()> {
self.with_col(|col| {
col.export_dataset(
req.min_entries.try_into().unwrap(),
req.target_path.as_ref(),
)
})
}
} }
fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem { fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {

View file

@ -0,0 +1,7 @@
SELECT id,
nid,
CASE
WHEN odid = 0 THEN did
ELSE odid
END AS did
FROM cards;

View file

@ -9,6 +9,7 @@ use std::convert::TryFrom;
use std::fmt; use std::fmt;
use std::result; use std::result;
use anki_proto::stats::CardEntry;
use rusqlite::named_params; use rusqlite::named_params;
use rusqlite::params; use rusqlite::params;
use rusqlite::types::FromSql; use rusqlite::types::FromSql;
@ -87,6 +88,14 @@ fn row_to_card(row: &Row) -> result::Result<Card, rusqlite::Error> {
}) })
} }
fn row_to_card_entry(row: &Row) -> Result<CardEntry> {
Ok(CardEntry {
id: row.get(0)?,
note_id: row.get(1)?,
deck_id: row.get(2)?,
})
}
fn row_to_new_card(row: &Row) -> result::Result<NewCard, rusqlite::Error> { fn row_to_new_card(row: &Row) -> result::Result<NewCard, rusqlite::Error> {
Ok(NewCard { Ok(NewCard {
id: row.get(0)?, id: row.get(0)?,
@ -108,6 +117,13 @@ impl super::SqliteStorage {
.map_err(Into::into) .map_err(Into::into)
} }
pub(crate) fn get_all_card_entries(&self) -> Result<Vec<CardEntry>> {
self.db
.prepare_cached(include_str!("get_card_entry.sql"))?
.query_and_then([], row_to_card_entry)?
.collect()
}
pub(crate) fn update_card(&self, card: &Card) -> Result<()> { pub(crate) fn update_card(&self, card: &Card) -> Result<()> {
let mut stmt = self.db.prepare_cached(include_str!("update_card.sql"))?; let mut stmt = self.db.prepare_cached(include_str!("update_card.sql"))?;
stmt.execute(params![ stmt.execute(params![