From 9c955a2bbba30d5ae4f7acb70255a99e9c2c6cbf Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 4 Jun 2025 16:46:42 +0800 Subject: [PATCH] implement smoothing & a separate struct for return value --- rslib/src/deckconfig/service.rs | 46 +++++++++++++++---- .../storage/card/get_costs_for_retention.sql | 23 +++++++++- rslib/src/storage/card/mod.rs | 28 ++++++++--- 3 files changed, 81 insertions(+), 16 deletions(-) diff --git a/rslib/src/deckconfig/service.rs b/rslib/src/deckconfig/service.rs index f0d419931..c9a355672 100644 --- a/rslib/src/deckconfig/service.rs +++ b/rslib/src/deckconfig/service.rs @@ -103,19 +103,49 @@ impl crate::services::DeckConfigService for Collection { ) -> Result { const LEARN_SPAN: usize = 100_000_000; const TERMINATION_PROB: f32 = 0.01; + // the default values are from https://github.com/open-spaced-repetition/Anki-button-usage/blob/881009015c2a85ac911021d76d0aacb124849937/analysis.ipynb + const DEFAULT_LEARN_COST: f32 = 19.4698; + const DEFAULT_PASS_COST: f32 = 7.8454; + const DEFAULT_FAIL_COST: f32 = 23.185; + const DEFAULT_INITIAL_PASS_RATE: f32 = 0.7645; let guard = self.search_cards_into_table(&input.search, crate::search::SortMode::NoOrder)?; - let (pass_cost, fail_cost, learn_cost, initial_pass_rate) = - guard.col.storage.get_costs_for_retention()?; + let costs = guard.col.storage.get_costs_for_retention()?; + + fn smoothing(obs: f32, default: f32, count: u32) -> f32 { + let alpha = count as f32 / (50.0 + count as f32); + obs * alpha + default * (1.0 - alpha) + } + + let cost_success = smoothing( + costs.average_pass_time_ms / 1000.0, + DEFAULT_PASS_COST, + costs.pass_count, + ); + let cost_failure = smoothing( + costs.average_fail_time_ms / 1000.0, + DEFAULT_FAIL_COST, + costs.fail_count, + ); + let cost_learn = smoothing( + costs.average_learn_time_ms / 1000.0, + DEFAULT_LEARN_COST, + costs.learn_count, + ); + let initial_pass_rate = smoothing( + costs.initial_pass_rate, + DEFAULT_INITIAL_PASS_RATE, + costs.pass_count, + ); let before = fsrs::expected_workload( &input.w, input.before, LEARN_SPAN, - pass_cost, - fail_cost, - learn_cost, + cost_success, + cost_failure, + cost_learn, initial_pass_rate, TERMINATION_PROB, )?; @@ -123,9 +153,9 @@ impl crate::services::DeckConfigService for Collection { &input.w, input.after, LEARN_SPAN, - pass_cost, - fail_cost, - learn_cost, + cost_success, + cost_failure, + cost_learn, initial_pass_rate, TERMINATION_PROB, )?; diff --git a/rslib/src/storage/card/get_costs_for_retention.sql b/rslib/src/storage/card/get_costs_for_retention.sql index aacdbcb00..ba21cc3f6 100644 --- a/rslib/src/storage/card/get_costs_for_retention.sql +++ b/rslib/src/storage/card/get_costs_for_retention.sql @@ -13,6 +13,7 @@ WITH searched_revlogs AS ( SELECT AVG(time) FROM searched_revlogs WHERE ease > 1 + AND type = 1 ), lapse_count AS ( SELECT COUNT(time) AS lapse_count @@ -56,9 +57,29 @@ initial_pass_rate AS ( ) AS initial_pass_rate FROM searched_revlogs WHERE rank_num = 1 +), +pass_cnt AS ( + SELECT COUNT(*) AS cnt + FROM searched_revlogs + WHERE ease > 1 + AND type = 1 +), +fail_cnt AS ( + SELECT COUNT(*) AS cnt + FROM searched_revlogs + WHERE ease = 1 + AND type = 1 +), +learn_cnt AS ( + SELECT COUNT(*) AS cnt + FROM searched_revlogs + WHERE type = 0 ) SELECT * FROM average_pass, average_fail, average_learn, - initial_pass_rate; \ No newline at end of file + initial_pass_rate, + pass_cnt, + fail_cnt, + learn_cnt; \ No newline at end of file diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index a86e61460..38cf5ef0f 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -42,6 +42,17 @@ use crate::timestamp::TimestampMillis; use crate::timestamp::TimestampSecs; use crate::types::Usn; +#[derive(Debug, Clone, Default)] +pub struct RetentionCosts { + pub average_pass_time_ms: f32, + pub average_fail_time_ms: f32, + pub average_learn_time_ms: f32, + pub initial_pass_rate: f32, + pub pass_count: u32, + pub fail_count: u32, + pub learn_count: u32, +} + impl FromSql for CardType { fn column_result(value: ValueRef<'_>) -> result::Result { if let ValueRef::Integer(i) = value { @@ -747,19 +758,22 @@ impl super::SqliteStorage { .get(0)?) } - pub(crate) fn get_costs_for_retention(&self) -> Result<(f32, f32, f32, f32)> { + pub(crate) fn get_costs_for_retention(&self) -> Result { let mut statement = self .db .prepare(include_str!("get_costs_for_retention.sql"))?; let mut query = statement.query(params![])?; let row = query.next()?.unwrap(); - Ok(( - row.get(0).unwrap_or(7000.), - row.get(1).unwrap_or(23_000.), - row.get(2).unwrap_or(30_000.), - row.get(3).unwrap_or(0.5), - )) + Ok(RetentionCosts { + average_pass_time_ms: row.get(0).unwrap_or(7000.), + average_fail_time_ms: row.get(1).unwrap_or(23_000.), + average_learn_time_ms: row.get(2).unwrap_or(30_000.), + initial_pass_rate: row.get(3).unwrap_or(0.5), + pass_count: row.get(4).unwrap_or(0), + fail_count: row.get(5).unwrap_or(0), + learn_count: row.get(6).unwrap_or(0), + }) } #[cfg(test)]