From 2de0c79ba599d69b2e142c4187f41888af9f4a60 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 3 Jun 2025 16:26:33 +0800 Subject: [PATCH] Feat/evaluate FSRS with time series split (#3962) --- proto/anki/scheduler.proto | 6 +++--- rslib/src/scheduler/fsrs/params.rs | 25 ++++++++++++----------- rslib/src/scheduler/service/mod.rs | 2 +- ts/routes/deck-options/FsrsOptions.svelte | 2 +- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/proto/anki/scheduler.proto b/proto/anki/scheduler.proto index 364bf50ad..45facd7ee 100644 --- a/proto/anki/scheduler.proto +++ b/proto/anki/scheduler.proto @@ -435,9 +435,9 @@ message GetOptimalRetentionParametersResponse { } message EvaluateParamsRequest { - repeated float params = 1; - string search = 2; - int64 ignore_revlogs_before_ms = 3; + string search = 1; + int64 ignore_revlogs_before_ms = 2; + uint32 num_of_relearning_steps = 3; } message EvaluateParamsResponse { diff --git a/rslib/src/scheduler/fsrs/params.rs b/rslib/src/scheduler/fsrs/params.rs index 6da02776a..840d16217 100644 --- a/rslib/src/scheduler/fsrs/params.rs +++ b/rslib/src/scheduler/fsrs/params.rs @@ -115,15 +115,14 @@ impl Collection { num_relearning_steps: Some(num_of_relearning_steps), })?; progress_thread.join().ok(); - if let Ok(fsrs) = FSRS::new(Some(current_params)) { - let current_log_loss = fsrs.evaluate(items.clone(), |_| true)?.log_loss; + if let Ok(current_fsrs) = FSRS::new(Some(current_params)) { + let current_log_loss = current_fsrs.evaluate(items.clone(), |_| true)?.log_loss; let optimized_fsrs = FSRS::new(Some(¶ms))?; let optimized_log_loss = optimized_fsrs.evaluate(items.clone(), |_| true)?.log_loss; if current_log_loss <= optimized_log_loss { if num_of_relearning_steps <= 1 { params = current_params.to_vec(); } else { - let current_fsrs = FSRS::new(Some(current_params))?; let memory_state = MemoryState { stability: 1.0, difficulty: 1.0, @@ -218,22 +217,24 @@ impl Collection { pub fn evaluate_params( &mut self, - params: &Params, search: &str, ignore_revlogs_before: TimestampMillis, + num_of_relearning_steps: usize, ) -> Result { let timing = self.timing_today()?; - let mut anki_progress = self.new_progress_handler::(); - let guard = self.search_cards_into_table(search, SortMode::NoOrder)?; - let revlogs: Vec = guard - .col - .storage - .get_revlog_entries_for_searched_cards_in_card_order()?; + let revlogs = self.revlog_for_srs(search)?; let (items, review_count) = fsrs_items_for_training(revlogs, timing.next_day_at, ignore_revlogs_before); + let mut anki_progress = self.new_progress_handler::(); anki_progress.state.reviews = review_count as u32; - let fsrs = FSRS::new(Some(params))?; - Ok(fsrs.evaluate(items, |ip| { + let fsrs = FSRS::new(None)?; + let input = ComputeParametersInput { + train_set: items.clone(), + progress: None, + enable_short_term: true, + num_relearning_steps: Some(num_of_relearning_steps), + }; + Ok(fsrs.evaluate_with_time_series_splits(input, |ip| { anki_progress .update(false, |p| { p.total_iterations = ip.total as u32; diff --git a/rslib/src/scheduler/service/mod.rs b/rslib/src/scheduler/service/mod.rs index e7d9a04eb..dc3de1dc7 100644 --- a/rslib/src/scheduler/service/mod.rs +++ b/rslib/src/scheduler/service/mod.rs @@ -295,9 +295,9 @@ impl crate::services::SchedulerService for Collection { input: scheduler::EvaluateParamsRequest, ) -> Result { let ret = self.evaluate_params( - &input.params, &input.search, input.ignore_revlogs_before_ms.into(), + input.num_of_relearning_steps as usize, )?; Ok(scheduler::EvaluateParamsResponse { log_loss: ret.log_loss, diff --git a/ts/routes/deck-options/FsrsOptions.svelte b/ts/routes/deck-options/FsrsOptions.svelte index 7b2318e00..97890a411 100644 --- a/ts/routes/deck-options/FsrsOptions.svelte +++ b/ts/routes/deck-options/FsrsOptions.svelte @@ -229,9 +229,9 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html ? $config.paramSearch : defaultparamSearch; const resp = await evaluateParams({ - params: fsrsParams($config), search, ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(), + numOfRelearningSteps: $config.relearnSteps.length, }); if (computeParamsProgress) { computeParamsProgress.current = computeParamsProgress.total;