diff --git a/Cargo.lock b/Cargo.lock index f73263c28..562f4242f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2295,7 +2295,7 @@ dependencies = [ [[package]] name = "fsrs" version = "4.0.0" -source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git?rev=33ec3ee4d5d73e704633469cf5bf1a42e620a524#33ec3ee4d5d73e704633469cf5bf1a42e620a524" +source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git?rev=a7f7efc10f0a26b14ee348cc7402155685f2a24f#a7f7efc10f0a26b14ee348cc7402155685f2a24f" dependencies = [ "burn", "itertools 0.14.0", @@ -5010,9 +5010,9 @@ dependencies = [ [[package]] name = "priority-queue" -version = "2.3.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef08705fa1589a1a59aa924ad77d14722cb0cd97b67dd5004ed5f4a4873fce8d" +checksum = "5676d703dda103cbb035b653a9f11448c0a7216c7926bd35fcb5865475d0c970" dependencies = [ "autocfg", "equivalent", @@ -6290,18 +6290,18 @@ checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "snafu" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" +checksum = "320b01e011bf8d5d7a4a4a4be966d9160968935849c83b918827f6a435e7f627" dependencies = [ "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" +checksum = "1961e2ef424c1424204d3a5d6975f934f56b6d50ff5732382d84ebf460e147f7" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 7d1645fc4..fbca1fe56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca" [workspace.dependencies.fsrs] # version = "3.0.0" git = "https://github.com/open-spaced-repetition/fsrs-rs.git" -rev = "33ec3ee4d5d73e704633469cf5bf1a42e620a524" +rev = "a7f7efc10f0a26b14ee348cc7402155685f2a24f" # path = "../open-spaced-repetition/fsrs-rs" [workspace.dependencies] @@ -125,7 +125,7 @@ serde_tuple = "0.5.0" sha1 = "0.10.6" sha2 = { version = "0.10.8" } simple-file-manifest = "0.11.0" -snafu = { version = "0.8.5", features = ["rust_1_61"] } +snafu = { version = "0.8.6", features = ["rust_1_61"] } strum = { version = "0.26.3", features = ["derive"] } syn = { version = "2.0.82", features = ["parsing", "printing"] } tar = "0.4.42" diff --git a/cargo/licenses.json b/cargo/licenses.json index 4854cb085..87a69df25 100644 --- a/cargo/licenses.json +++ b/cargo/licenses.json @@ -3151,7 +3151,7 @@ }, { "name": "priority-queue", - "version": "2.3.1", + "version": "2.5.0", "authors": "Gianmarco Garrisi ", "repository": "https://github.com/garro95/priority-queue", "license": "LGPL-3.0-or-later OR MPL-2.0", @@ -4015,7 +4015,7 @@ }, { "name": "snafu", - "version": "0.8.5", + "version": "0.8.6", "authors": "Jake Goulding ", "repository": "https://github.com/shepmaster/snafu", "license": "Apache-2.0 OR MIT", @@ -4024,7 +4024,7 @@ }, { "name": "snafu-derive", - "version": "0.8.5", + "version": "0.8.6", "authors": "Jake Goulding ", "repository": "https://github.com/shepmaster/snafu", "license": "Apache-2.0 OR MIT", diff --git a/rslib/src/deckconfig/service.rs b/rslib/src/deckconfig/service.rs index 516132763..37e1f407d 100644 --- a/rslib/src/deckconfig/service.rs +++ b/rslib/src/deckconfig/service.rs @@ -101,30 +101,64 @@ impl crate::services::DeckConfigService for Collection { &mut self, input: anki_proto::deck_config::GetRetentionWorkloadRequest, ) -> Result { - const LEARN_SPAN: usize = 1000; + const LEARN_SPAN: usize = 100_000_000; + const TERMINATION_PROB: f32 = 0.001; + // the default values are from https://github.com/open-spaced-repetition/Anki-button-usage/blob/881009015c2a85ac911021d76d0aacb124849937/analysis.ipynb + const DEFAULT_LEARN_COST: f32 = 19.4698; + const DEFAULT_PASS_COST: f32 = 7.8454; + const DEFAULT_FAIL_COST: f32 = 23.185; + const DEFAULT_INITIAL_PASS_RATE: f32 = 0.7645; let guard = self.search_cards_into_table(&input.search, crate::search::SortMode::NoOrder)?; - let (pass_cost, fail_cost, learn_cost) = guard.col.storage.get_costs_for_retention()?; + let costs = guard.col.storage.get_costs_for_retention()?; + + fn smoothing(obs: f32, default: f32, count: u32) -> f32 { + let alpha = count as f32 / (50.0 + count as f32); + obs * alpha + default * (1.0 - alpha) + } + + let cost_success = smoothing( + costs.average_pass_time_ms / 1000.0, + DEFAULT_PASS_COST, + costs.pass_count, + ); + let cost_failure = smoothing( + costs.average_fail_time_ms / 1000.0, + DEFAULT_FAIL_COST, + costs.fail_count, + ); + let cost_learn = smoothing( + costs.average_learn_time_ms / 1000.0, + DEFAULT_LEARN_COST, + costs.learn_count, + ); + let initial_pass_rate = smoothing( + costs.initial_pass_rate, + DEFAULT_INITIAL_PASS_RATE, + costs.pass_count, + ); let before = fsrs::expected_workload( &input.w, input.before, LEARN_SPAN, - pass_cost, - fail_cost, - 0., - input.before, - )? + learn_cost; + cost_success, + cost_failure, + cost_learn, + initial_pass_rate, + TERMINATION_PROB, + )?; let after = fsrs::expected_workload( &input.w, input.after, LEARN_SPAN, - pass_cost, - fail_cost, - 0., - input.after, - )? + learn_cost; + cost_success, + cost_failure, + cost_learn, + initial_pass_rate, + TERMINATION_PROB, + )?; Ok(anki_proto::deck_config::GetRetentionWorkloadResponse { factor: after / before, diff --git a/rslib/src/storage/card/get_costs_for_retention.sql b/rslib/src/storage/card/get_costs_for_retention.sql index 811ca3050..ba21cc3f6 100644 --- a/rslib/src/storage/card/get_costs_for_retention.sql +++ b/rslib/src/storage/card/get_costs_for_retention.sql @@ -1,5 +1,9 @@ WITH searched_revlogs AS ( - SELECT * + SELECT *, + RANK() OVER ( + PARTITION BY cid + ORDER BY id ASC + ) AS rank_num FROM revlog WHERE ease > 0 AND cid IN search_cids @@ -9,6 +13,7 @@ WITH searched_revlogs AS ( SELECT AVG(time) FROM searched_revlogs WHERE ease > 1 + AND type = 1 ), lapse_count AS ( SELECT COUNT(time) AS lapse_count @@ -42,8 +47,39 @@ summed_learns AS ( average_learn AS ( SELECT AVG(total_time) AS avg_learn_time FROM summed_learns +), +initial_pass_rate AS ( + SELECT AVG( + CASE + WHEN ease > 1 THEN 1.0 + ELSE 0.0 + END + ) AS initial_pass_rate + FROM searched_revlogs + WHERE rank_num = 1 +), +pass_cnt AS ( + SELECT COUNT(*) AS cnt + FROM searched_revlogs + WHERE ease > 1 + AND type = 1 +), +fail_cnt AS ( + SELECT COUNT(*) AS cnt + FROM searched_revlogs + WHERE ease = 1 + AND type = 1 +), +learn_cnt AS ( + SELECT COUNT(*) AS cnt + FROM searched_revlogs + WHERE type = 0 ) SELECT * FROM average_pass, average_fail, - average_learn; \ No newline at end of file + average_learn, + initial_pass_rate, + pass_cnt, + fail_cnt, + learn_cnt; \ No newline at end of file diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index bef3251de..38cf5ef0f 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -42,6 +42,17 @@ use crate::timestamp::TimestampMillis; use crate::timestamp::TimestampSecs; use crate::types::Usn; +#[derive(Debug, Clone, Default)] +pub struct RetentionCosts { + pub average_pass_time_ms: f32, + pub average_fail_time_ms: f32, + pub average_learn_time_ms: f32, + pub initial_pass_rate: f32, + pub pass_count: u32, + pub fail_count: u32, + pub learn_count: u32, +} + impl FromSql for CardType { fn column_result(value: ValueRef<'_>) -> result::Result { if let ValueRef::Integer(i) = value { @@ -747,18 +758,22 @@ impl super::SqliteStorage { .get(0)?) } - pub(crate) fn get_costs_for_retention(&self) -> Result<(f32, f32, f32)> { + pub(crate) fn get_costs_for_retention(&self) -> Result { let mut statement = self .db .prepare(include_str!("get_costs_for_retention.sql"))?; let mut query = statement.query(params![])?; let row = query.next()?.unwrap(); - Ok(( - row.get(0).unwrap_or(7000.), - row.get(1).unwrap_or(23_000.), - row.get(2).unwrap_or(30_000.), - )) + Ok(RetentionCosts { + average_pass_time_ms: row.get(0).unwrap_or(7000.), + average_fail_time_ms: row.get(1).unwrap_or(23_000.), + average_learn_time_ms: row.get(2).unwrap_or(30_000.), + initial_pass_rate: row.get(3).unwrap_or(0.5), + pass_count: row.get(4).unwrap_or(0), + fail_count: row.get(5).unwrap_or(0), + learn_count: row.get(6).unwrap_or(0), + }) } #[cfg(test)]