Feat/export dataset for research (#3511)

* Feat/export dataset for research

* add comment

Co-authored-by: Damien Elmes <dae@users.noreply.github.com>

* target_path is required

* format

* improve efficiency to look up parent_id

* move `use` down
This commit is contained in:
Jarrett Ye 2024-10-18 16:57:06 +08:00 committed by GitHub
parent c1a2b03871
commit b09326cddd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 105 additions and 15 deletions

View file

@ -67,6 +67,8 @@ service BackendSchedulerService {
returns (ComputeFsrsWeightsResponse);
// Generates parameters used for FSRS's scheduler benchmarks.
rpc FsrsBenchmark(FsrsBenchmarkRequest) returns (FsrsBenchmarkResponse);
// Used for exporting revlogs for algorithm research.
rpc ExportDataset(ExportDatasetRequest) returns (generic.Empty);
}
message SchedulingState {
@ -363,6 +365,11 @@ message FsrsBenchmarkResponse {
repeated float weights = 1;
}
message ExportDatasetRequest {
uint32 min_entries = 1;
string target_path = 2;
}
message FsrsItem {
repeated FsrsReview reviews = 1;
}

View file

@ -217,7 +217,21 @@ message RevlogEntry {
ReviewKind review_kind = 9;
}
message RevlogEntries {
repeated RevlogEntry entries = 1;
int64 next_day_at = 2;
message CardEntry {
int64 id = 1;
int64 note_id = 2;
int64 deck_id = 3;
}
message DeckEntry {
int64 id = 1;
int64 parent_id = 2;
int64 preset_id = 3;
}
message Dataset {
repeated RevlogEntry revlogs = 1;
repeated CardEntry cards = 2;
repeated DeckEntry decks = 3;
int64 next_day_at = 4;
}

View file

@ -420,6 +420,11 @@ class Collection(DeprecatedNamesMixin):
def import_json_string(self, json: str) -> ImportLogWithChanges:
return self._backend.import_json_string(json)
def export_dataset_for_research(
self, target_path: str, min_entries: int = 0
) -> None:
self._backend.export_dataset(min_entries=min_entries, target_path=target_path)
# Image Occlusion
##########################################################################

View file

@ -1,5 +1,6 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::collections::HashMap;
use std::iter;
use std::path::Path;
use std::thread;
@ -8,7 +9,8 @@ use std::time::Duration;
use anki_io::write_file;
use anki_proto::scheduler::ComputeFsrsWeightsResponse;
use anki_proto::stats::revlog_entry;
use anki_proto::stats::RevlogEntries;
use anki_proto::stats::Dataset;
use anki_proto::stats::DeckEntry;
use chrono::NaiveDate;
use chrono::NaiveTime;
use fsrs::CombinedProgressState;
@ -19,6 +21,7 @@ use fsrs::FSRS;
use itertools::Itertools;
use prost::Message;
use crate::decks::immediate_parent_name;
use crate::prelude::*;
use crate::revlog::RevlogEntry;
use crate::revlog::RevlogReviewKind;
@ -127,22 +130,51 @@ impl Collection {
}
/// Used for exporting revlogs for algorithm research.
pub fn export_revlog_entries_to_protobuf(
&mut self,
min_entries: usize,
target_path: &Path,
) -> Result<()> {
let entries = self.storage.get_all_revlog_entries_in_card_order()?;
if entries.len() < min_entries {
pub fn export_dataset(&mut self, min_entries: usize, target_path: &Path) -> Result<()> {
let revlog_entries = self.storage.get_all_revlog_entries_in_card_order()?;
if revlog_entries.len() < min_entries {
return Err(AnkiError::FsrsInsufficientData);
}
let entries = entries.into_iter().map(revlog_entry_to_proto).collect_vec();
let revlogs = revlog_entries
.into_iter()
.map(revlog_entry_to_proto)
.collect_vec();
let cards = self.storage.get_all_card_entries()?;
let decks_map = self.storage.get_decks_map()?;
let deck_name_to_id: HashMap<String, DeckId> = decks_map
.into_iter()
.map(|(id, deck)| (deck.name.to_string(), id))
.collect();
let decks = self
.storage
.get_all_decks()?
.into_iter()
.filter_map(|deck| {
if let Some(preset_id) = deck.config_id().map(|id| id.0) {
let parent_id = immediate_parent_name(&deck.name.to_string())
.and_then(|parent_name| deck_name_to_id.get(parent_name))
.map(|id| id.0)
.unwrap_or(0);
Some(DeckEntry {
id: deck.id.0,
parent_id,
preset_id,
})
} else {
None
}
})
.collect_vec();
let next_day_at = self.timing_today()?.next_day_at.0;
let entries = RevlogEntries {
entries,
let dataset = Dataset {
revlogs,
cards,
decks,
next_day_at,
};
let data = entries.encode_to_vec();
let data = dataset.encode_to_vec();
write_file(target_path, data)?;
Ok(())
}

View file

@ -368,6 +368,15 @@ impl crate::services::BackendSchedulerService for Backend {
let weights = fsrs.benchmark(train_set);
Ok(FsrsBenchmarkResponse { weights })
}
fn export_dataset(&self, req: scheduler::ExportDatasetRequest) -> Result<()> {
self.with_col(|col| {
col.export_dataset(
req.min_entries.try_into().unwrap(),
req.target_path.as_ref(),
)
})
}
}
fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {

View file

@ -0,0 +1,7 @@
SELECT id,
nid,
CASE
WHEN odid = 0 THEN did
ELSE odid
END AS did
FROM cards;

View file

@ -9,6 +9,7 @@ use std::convert::TryFrom;
use std::fmt;
use std::result;
use anki_proto::stats::CardEntry;
use rusqlite::named_params;
use rusqlite::params;
use rusqlite::types::FromSql;
@ -87,6 +88,14 @@ fn row_to_card(row: &Row) -> result::Result<Card, rusqlite::Error> {
})
}
fn row_to_card_entry(row: &Row) -> Result<CardEntry> {
Ok(CardEntry {
id: row.get(0)?,
note_id: row.get(1)?,
deck_id: row.get(2)?,
})
}
fn row_to_new_card(row: &Row) -> result::Result<NewCard, rusqlite::Error> {
Ok(NewCard {
id: row.get(0)?,
@ -108,6 +117,13 @@ impl super::SqliteStorage {
.map_err(Into::into)
}
pub(crate) fn get_all_card_entries(&self) -> Result<Vec<CardEntry>> {
self.db
.prepare_cached(include_str!("get_card_entry.sql"))?
.query_and_then([], row_to_card_entry)?
.collect()
}
pub(crate) fn update_card(&self, card: &Card) -> Result<()> {
let mut stmt = self.db.prepare_cached(include_str!("update_card.sql"))?;
stmt.execute(params![