From c947690aebd84abeccaaafdc06b34b9e51ef4b00 Mon Sep 17 00:00:00 2001 From: Luc Mcgrady Date: Mon, 28 Jul 2025 10:00:16 +0100 Subject: [PATCH] Feat/Use cached workload values (#4208) * Feat/Use cached workload values * Fix: Calculation when unchanged * Modify constants * Cache clearing logic * Use function params * use https://github.com/open-spaced-repetition/fsrs-rs/pull/352 * Revert "use https://github.com/open-spaced-repetition/fsrs-rs/pull/352" This reverts commit 72efcf230c273b0eb9f1f294d21e4b9f959e3dde. * Reapply "use https://github.com/open-spaced-repetition/fsrs-rs/pull/352" This reverts commit 49eab2969f56568296289b63b0a7146118d848b3. * ./check * bump fsrs --- Cargo.lock | 4 +- Cargo.toml | 3 +- cargo/licenses.json | 2 +- proto/anki/deck_config.proto | 4 +- rslib/src/deckconfig/service.rs | 72 ++++------------ .../storage/card/get_costs_for_retention.sql | 85 ------------------- rslib/src/storage/card/mod.rs | 29 ------- ts/routes/deck-options/FsrsOptions.svelte | 49 ++++++----- 8 files changed, 50 insertions(+), 198 deletions(-) delete mode 100644 rslib/src/storage/card/get_costs_for_retention.sql diff --git a/Cargo.lock b/Cargo.lock index fce546a9e..f698a4ce4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2214,9 +2214,9 @@ dependencies = [ [[package]] name = "fsrs" -version = "4.1.1" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1f3a8c3df2c324ebab71461178fe8c1fe2d7373cf603f312b652befd026f06d" +checksum = "f590cfcbe25079bb54a39900f45e6e308935bd6067249ce00d265b280465cde2" dependencies = [ "burn", "itertools 0.14.0", diff --git a/Cargo.toml b/Cargo.toml index 3ef2df9bd..f62a71023 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,8 +33,9 @@ git = "https://github.com/ankitects/linkcheck.git" rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca" [workspace.dependencies.fsrs] -version = "4.1.1" +version = "5.0.0" # git = "https://github.com/open-spaced-repetition/fsrs-rs.git" +# branch = "Refactor/expected_workload_via_dp" # rev = "a7f7efc10f0a26b14ee348cc7402155685f2a24f" # path = "../open-spaced-repetition/fsrs-rs" diff --git a/cargo/licenses.json b/cargo/licenses.json index f2695ac76..e7b61a5fe 100644 --- a/cargo/licenses.json +++ b/cargo/licenses.json @@ -1450,7 +1450,7 @@ }, { "name": "fsrs", - "version": "4.1.1", + "version": "5.0.0", "authors": "Open Spaced Repetition", "repository": "https://github.com/open-spaced-repetition/fsrs-rs", "license": "BSD-3-Clause", diff --git a/proto/anki/deck_config.proto b/proto/anki/deck_config.proto index 55291ee5f..5ed02423e 100644 --- a/proto/anki/deck_config.proto +++ b/proto/anki/deck_config.proto @@ -40,12 +40,10 @@ message DeckConfigId { message GetRetentionWorkloadRequest { repeated float w = 1; string search = 2; - float before = 3; - float after = 4; } message GetRetentionWorkloadResponse { - float factor = 1; + map costs = 1; } message GetIgnoredBeforeCountRequest { diff --git a/rslib/src/deckconfig/service.rs b/rslib/src/deckconfig/service.rs index bc6bce8f4..8cc33fc3a 100644 --- a/rslib/src/deckconfig/service.rs +++ b/rslib/src/deckconfig/service.rs @@ -1,5 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::collections::HashMap; + use anki_proto::generic; use crate::collection::Collection; @@ -101,68 +103,26 @@ impl crate::services::DeckConfigService for Collection { &mut self, input: anki_proto::deck_config::GetRetentionWorkloadRequest, ) -> Result { - const LEARN_SPAN: usize = 100_000_000; - const TERMINATION_PROB: f32 = 0.001; - // the default values are from https://github.com/open-spaced-repetition/Anki-button-usage/blob/881009015c2a85ac911021d76d0aacb124849937/analysis.ipynb - const DEFAULT_LEARN_COST: f32 = 19.4698; - const DEFAULT_PASS_COST: f32 = 7.8454; - const DEFAULT_FAIL_COST: f32 = 23.185; - const DEFAULT_INITIAL_PASS_RATE: f32 = 0.7645; - let guard = self.search_cards_into_table(&input.search, crate::search::SortMode::NoOrder)?; - let costs = guard.col.storage.get_costs_for_retention()?; - fn smoothing(obs: f32, default: f32, count: u32) -> f32 { - let alpha = count as f32 / (50.0 + count as f32); - obs * alpha + default * (1.0 - alpha) - } + let revlogs = guard + .col + .storage + .get_revlog_entries_for_searched_cards_in_card_order()?; - let cost_success = smoothing( - costs.average_pass_time_ms / 1000.0, - DEFAULT_PASS_COST, - costs.pass_count, - ); - let cost_failure = smoothing( - costs.average_fail_time_ms / 1000.0, - DEFAULT_FAIL_COST, - costs.fail_count, - ); - let cost_learn = smoothing( - costs.average_learn_time_ms / 1000.0, - DEFAULT_LEARN_COST, - costs.learn_count, - ); - let initial_pass_rate = smoothing( - costs.initial_pass_rate, - DEFAULT_INITIAL_PASS_RATE, - costs.pass_count, - ); + let config = guard.col.get_optimal_retention_parameters(revlogs)?; - let before = fsrs::expected_workload( - &input.w, - input.before, - LEARN_SPAN, - cost_success, - cost_failure, - cost_learn, - initial_pass_rate, - TERMINATION_PROB, - )?; - let after = fsrs::expected_workload( - &input.w, - input.after, - LEARN_SPAN, - cost_success, - cost_failure, - cost_learn, - initial_pass_rate, - TERMINATION_PROB, - )?; + let costs = (70u32..=99u32) + .map(|dr| { + Ok(( + dr, + fsrs::expected_workload(&input.w, dr as f32 / 100., &config)?, + )) + }) + .collect::>>()?; - Ok(anki_proto::deck_config::GetRetentionWorkloadResponse { - factor: after / before, - }) + Ok(anki_proto::deck_config::GetRetentionWorkloadResponse { costs }) } } diff --git a/rslib/src/storage/card/get_costs_for_retention.sql b/rslib/src/storage/card/get_costs_for_retention.sql deleted file mode 100644 index ba21cc3f6..000000000 --- a/rslib/src/storage/card/get_costs_for_retention.sql +++ /dev/null @@ -1,85 +0,0 @@ -WITH searched_revlogs AS ( - SELECT *, - RANK() OVER ( - PARTITION BY cid - ORDER BY id ASC - ) AS rank_num - FROM revlog - WHERE ease > 0 - AND cid IN search_cids - ORDER BY id DESC -- Use the last 10_000 reviews - LIMIT 10000 -), average_pass AS ( - SELECT AVG(time) - FROM searched_revlogs - WHERE ease > 1 - AND type = 1 -), -lapse_count AS ( - SELECT COUNT(time) AS lapse_count - FROM searched_revlogs - WHERE ease = 1 - AND type = 1 -), -fail_sum AS ( - SELECT SUM(time) AS total_fail_time - FROM searched_revlogs - WHERE ( - ease = 1 - AND type = 1 - ) - OR type = 2 -), --- (sum(Relearning) + sum(Lapses)) / count(Lapses) -average_fail AS ( - SELECT total_fail_time * 1.0 / NULLIF(lapse_count, 0) AS avg_fail_time - FROM fail_sum, - lapse_count -), --- Can lead to cards with partial learn histories skewing the time -summed_learns AS ( - SELECT cid, - SUM(time) AS total_time - FROM searched_revlogs - WHERE searched_revlogs.type = 0 - GROUP BY cid -), -average_learn AS ( - SELECT AVG(total_time) AS avg_learn_time - FROM summed_learns -), -initial_pass_rate AS ( - SELECT AVG( - CASE - WHEN ease > 1 THEN 1.0 - ELSE 0.0 - END - ) AS initial_pass_rate - FROM searched_revlogs - WHERE rank_num = 1 -), -pass_cnt AS ( - SELECT COUNT(*) AS cnt - FROM searched_revlogs - WHERE ease > 1 - AND type = 1 -), -fail_cnt AS ( - SELECT COUNT(*) AS cnt - FROM searched_revlogs - WHERE ease = 1 - AND type = 1 -), -learn_cnt AS ( - SELECT COUNT(*) AS cnt - FROM searched_revlogs - WHERE type = 0 -) -SELECT * -FROM average_pass, - average_fail, - average_learn, - initial_pass_rate, - pass_cnt, - fail_cnt, - learn_cnt; \ No newline at end of file diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index 35a229e93..6a72dc6e7 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -42,17 +42,6 @@ use crate::timestamp::TimestampMillis; use crate::timestamp::TimestampSecs; use crate::types::Usn; -#[derive(Debug, Clone, Default)] -pub struct RetentionCosts { - pub average_pass_time_ms: f32, - pub average_fail_time_ms: f32, - pub average_learn_time_ms: f32, - pub initial_pass_rate: f32, - pub pass_count: u32, - pub fail_count: u32, - pub learn_count: u32, -} - impl FromSql for CardType { fn column_result(value: ValueRef<'_>) -> result::Result { if let ValueRef::Integer(i) = value { @@ -759,24 +748,6 @@ impl super::SqliteStorage { .get(0)?) } - pub(crate) fn get_costs_for_retention(&self) -> Result { - let mut statement = self - .db - .prepare(include_str!("get_costs_for_retention.sql"))?; - let mut query = statement.query(params![])?; - let row = query.next()?.unwrap(); - - Ok(RetentionCosts { - average_pass_time_ms: row.get(0).unwrap_or(7000.), - average_fail_time_ms: row.get(1).unwrap_or(23_000.), - average_learn_time_ms: row.get(2).unwrap_or(30_000.), - initial_pass_rate: row.get(3).unwrap_or(0.5), - pass_count: row.get(4).unwrap_or(0), - fail_count: row.get(5).unwrap_or(0), - learn_count: row.get(6).unwrap_or(0), - }) - } - #[cfg(test)] pub(crate) fn get_all_cards(&self) -> Vec { self.db diff --git a/ts/routes/deck-options/FsrsOptions.svelte b/ts/routes/deck-options/FsrsOptions.svelte index 526c3aa99..0bb6220bb 100644 --- a/ts/routes/deck-options/FsrsOptions.svelte +++ b/ts/routes/deck-options/FsrsOptions.svelte @@ -29,6 +29,7 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html import SimulatorModal from "./SimulatorModal.svelte"; import { GetRetentionWorkloadRequest, + type GetRetentionWorkloadResponse, UpdateDeckConfigsMode, } from "@generated/anki/deck_config_pb"; import type Modal from "bootstrap/js/dist/modal"; @@ -69,19 +70,9 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html $: roundedRetention = Number(effectiveDesiredRetention.toFixed(2)); $: desiredRetentionWarning = getRetentionLongShortWarning(roundedRetention); - let timeoutId: ReturnType | undefined = undefined; - const WORKLOAD_UPDATE_DELAY_MS = 100; - let desiredRetentionChangeInfo = ""; - $: { - clearTimeout(timeoutId); - if (showDesiredRetentionTooltip) { - timeoutId = setTimeout(() => { - getRetentionChangeInfo(roundedRetention, fsrsParams($config)); - }, WORKLOAD_UPDATE_DELAY_MS); - } else { - desiredRetentionChangeInfo = ""; - } + $: if (showDesiredRetentionTooltip) { + getRetentionChangeInfo(roundedRetention, fsrsParams($config)); } $: retentionWarningClass = getRetentionWarningClass(roundedRetention); @@ -137,21 +128,37 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html } } + let retentionWorloadInfo: undefined | Promise = + undefined; + let lastParams = [...fsrsParams($config)]; + async function getRetentionChangeInfo(retention: number, params: number[]) { if (+startingDesiredRetention == roundedRetention) { desiredRetentionChangeInfo = tr.deckConfigWorkloadFactorUnchanged(); return; } - const request = new GetRetentionWorkloadRequest({ - w: params, - search: defaultparamSearch, - before: +startingDesiredRetention, - after: retention, - }); - const resp = await getRetentionWorkload(request); + if ( + // If the cache is empty and a request has not yet been made to fill it + !retentionWorloadInfo || + // If the parameters have been changed + lastParams.toString() !== params.toString() + ) { + const request = new GetRetentionWorkloadRequest({ + w: params, + search: defaultparamSearch, + }); + lastParams = [...params]; + retentionWorloadInfo = getRetentionWorkload(request); + } + + const previous = +startingDesiredRetention * 100; + const after = retention * 100; + const resp = await retentionWorloadInfo; + const factor = resp.costs[after] / resp.costs[previous]; + desiredRetentionChangeInfo = tr.deckConfigWorkloadFactorChange({ - factor: resp.factor.toFixed(2), - previousDr: (+startingDesiredRetention * 100).toString(), + factor: factor.toFixed(2), + previousDr: previous.toString(), }); }