From d52889f45c3f7999a45fd2dde367f79f3bc3bad4 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 20 Mar 2025 15:04:38 +0800 Subject: [PATCH] Feat/simplified relearning steps logic with updated FSRS training API (#3867) * Feat/simplified relearning steps logic with updated FSRS training API * Update params.rs * use ComputeParametersInput * update fsrs-rs dependency * update cargo/format/rust-toolchain --- Cargo.lock | 2 +- Cargo.toml | 2 +- cargo/format/rust-toolchain.toml | 2 +- rslib/src/scheduler/fsrs/params.rs | 51 +++++++++++++++++------------- rslib/src/scheduler/service/mod.rs | 19 +++++++---- 5 files changed, 45 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6d6651800..32f15ab2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2099,7 +2099,7 @@ dependencies = [ [[package]] name = "fsrs" version = "3.0.0" -source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git?rev=22f8e453c120f5bc5996f86558a559c6b7abfc49#22f8e453c120f5bc5996f86558a559c6b7abfc49" +source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git?rev=08d90d1363b0c4722422bf0ef71ed8fd7d053f8a#08d90d1363b0c4722422bf0ef71ed8fd7d053f8a" dependencies = [ "burn", "itertools 0.12.1", diff --git a/Cargo.toml b/Cargo.toml index e3299a740..c16294236 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca" [workspace.dependencies.fsrs] # version = "=2.0.3" git = "https://github.com/open-spaced-repetition/fsrs-rs.git" -rev = "22f8e453c120f5bc5996f86558a559c6b7abfc49" +rev = "08d90d1363b0c4722422bf0ef71ed8fd7d053f8a" # path = "../open-spaced-repetition/fsrs-rs" [workspace.dependencies] diff --git a/cargo/format/rust-toolchain.toml b/cargo/format/rust-toolchain.toml index 66a834c36..42af1fe66 100644 --- a/cargo/format/rust-toolchain.toml +++ b/cargo/format/rust-toolchain.toml @@ -1,4 +1,4 @@ [toolchain] -channel = "nightly-2023-09-02" +channel = "nightly-2025-03-20" profile = "minimal" components = ["rustfmt"] diff --git a/rslib/src/scheduler/fsrs/params.rs b/rslib/src/scheduler/fsrs/params.rs index f304c6e63..2bc3338eb 100644 --- a/rslib/src/scheduler/fsrs/params.rs +++ b/rslib/src/scheduler/fsrs/params.rs @@ -14,6 +14,7 @@ use anki_proto::stats::DeckEntry; use chrono::NaiveDate; use chrono::NaiveTime; use fsrs::CombinedProgressState; +use fsrs::ComputeParametersInput; use fsrs::FSRSItem; use fsrs::FSRSReview; use fsrs::MemoryState; @@ -107,34 +108,40 @@ impl Collection { let (progress, progress_thread) = create_progress_thread()?; let fsrs = FSRS::new(None)?; - let mut params = fsrs.compute_parameters(items.clone(), Some(progress.clone()), true)?; + let mut params = fsrs.compute_parameters(ComputeParametersInput { + train_set: items.clone(), + progress: Some(progress.clone()), + enable_short_term: true, + num_relearning_steps: Some(num_of_relearning_steps), + })?; progress_thread.join().ok(); if let Ok(fsrs) = FSRS::new(Some(current_params)) { let current_rmse = fsrs.evaluate(items.clone(), |_| true)?.rmse_bins; let optimized_fsrs = FSRS::new(Some(¶ms))?; let optimized_rmse = optimized_fsrs.evaluate(items.clone(), |_| true)?.rmse_bins; if current_rmse <= optimized_rmse { - params = current_params.to_vec(); - } - if num_of_relearning_steps > 1 { - let memory_state = MemoryState { - stability: 1.0, - difficulty: 1.0, - }; - let s_fail = optimized_fsrs - .next_states(Some(memory_state), 0.9, 2)? - .again; - let mut s_short_term = s_fail.memory; - for _ in 0..num_of_relearning_steps { - s_short_term = optimized_fsrs - .next_states(Some(s_short_term), 0.9, 0)? - .good - .memory; - } - if s_short_term.stability > memory_state.stability { - let (progress, progress_thread) = create_progress_thread()?; - params = fsrs.compute_parameters(items.clone(), Some(progress), false)?; - progress_thread.join().ok(); + if num_of_relearning_steps <= 1 { + params = current_params.to_vec(); + } else { + let current_fsrs = FSRS::new(Some(current_params))?; + let memory_state = MemoryState { + stability: 1.0, + difficulty: 1.0, + }; + + let s_fail = current_fsrs.next_states(Some(memory_state), 0.9, 2)?.again; + let mut s_short_term = s_fail.memory; + + for _ in 0..num_of_relearning_steps { + s_short_term = current_fsrs + .next_states(Some(s_short_term), 0.9, 0)? + .good + .memory; + } + + if s_short_term.stability < memory_state.stability { + params = current_params.to_vec(); + } } } } diff --git a/rslib/src/scheduler/service/mod.rs b/rslib/src/scheduler/service/mod.rs index e77dce6b3..d398ae65b 100644 --- a/rslib/src/scheduler/service/mod.rs +++ b/rslib/src/scheduler/service/mod.rs @@ -17,6 +17,7 @@ use anki_proto::scheduler::FuzzDeltaResponse; use anki_proto::scheduler::GetOptimalRetentionParametersResponse; use anki_proto::scheduler::SimulateFsrsReviewRequest; use anki_proto::scheduler::SimulateFsrsReviewResponse; +use fsrs::ComputeParametersInput; use fsrs::FSRSItem; use fsrs::FSRSReview; use fsrs::FSRS; @@ -352,11 +353,12 @@ impl crate::services::BackendSchedulerService for Backend { ) -> Result { let fsrs = FSRS::new(None)?; let fsrs_items = req.items.len() as u32; - let params = fsrs.compute_parameters( - req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(), - None, - true, - )?; + let params = fsrs.compute_parameters(ComputeParametersInput { + train_set: req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(), + progress: None, + enable_short_term: true, + num_relearning_steps: None, + })?; Ok(ComputeFsrsParamsResponse { params, fsrs_items }) } @@ -370,7 +372,12 @@ impl crate::services::BackendSchedulerService for Backend { .into_iter() .map(fsrs_item_proto_to_fsrs) .collect(); - let params = fsrs.benchmark(train_set, true); + let params = fsrs.benchmark(ComputeParametersInput { + train_set, + progress: None, + enable_short_term: true, + num_relearning_steps: None, + }); Ok(FsrsBenchmarkResponse { params }) }