Feat/simplified relearning steps logic with updated FSRS training API (#3867)

* Feat/simplified relearning steps logic with updated FSRS training API

* Update params.rs

* use ComputeParametersInput

* update fsrs-rs dependency

* update cargo/format/rust-toolchain
This commit is contained in:
Jarrett Ye 2025-03-20 15:04:38 +08:00 committed by GitHub
parent 5d7f6b25c0
commit d52889f45c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 45 additions and 31 deletions

2
Cargo.lock generated
View file

@ -2099,7 +2099,7 @@ dependencies = [
[[package]] [[package]]
name = "fsrs" name = "fsrs"
version = "3.0.0" version = "3.0.0"
source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git?rev=22f8e453c120f5bc5996f86558a559c6b7abfc49#22f8e453c120f5bc5996f86558a559c6b7abfc49" source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git?rev=08d90d1363b0c4722422bf0ef71ed8fd7d053f8a#08d90d1363b0c4722422bf0ef71ed8fd7d053f8a"
dependencies = [ dependencies = [
"burn", "burn",
"itertools 0.12.1", "itertools 0.12.1",

View file

@ -37,7 +37,7 @@ rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca"
[workspace.dependencies.fsrs] [workspace.dependencies.fsrs]
# version = "=2.0.3" # version = "=2.0.3"
git = "https://github.com/open-spaced-repetition/fsrs-rs.git" git = "https://github.com/open-spaced-repetition/fsrs-rs.git"
rev = "22f8e453c120f5bc5996f86558a559c6b7abfc49" rev = "08d90d1363b0c4722422bf0ef71ed8fd7d053f8a"
# path = "../open-spaced-repetition/fsrs-rs" # path = "../open-spaced-repetition/fsrs-rs"
[workspace.dependencies] [workspace.dependencies]

View file

@ -1,4 +1,4 @@
[toolchain] [toolchain]
channel = "nightly-2023-09-02" channel = "nightly-2025-03-20"
profile = "minimal" profile = "minimal"
components = ["rustfmt"] components = ["rustfmt"]

View file

@ -14,6 +14,7 @@ use anki_proto::stats::DeckEntry;
use chrono::NaiveDate; use chrono::NaiveDate;
use chrono::NaiveTime; use chrono::NaiveTime;
use fsrs::CombinedProgressState; use fsrs::CombinedProgressState;
use fsrs::ComputeParametersInput;
use fsrs::FSRSItem; use fsrs::FSRSItem;
use fsrs::FSRSReview; use fsrs::FSRSReview;
use fsrs::MemoryState; use fsrs::MemoryState;
@ -107,34 +108,40 @@ impl Collection {
let (progress, progress_thread) = create_progress_thread()?; let (progress, progress_thread) = create_progress_thread()?;
let fsrs = FSRS::new(None)?; let fsrs = FSRS::new(None)?;
let mut params = fsrs.compute_parameters(items.clone(), Some(progress.clone()), true)?; let mut params = fsrs.compute_parameters(ComputeParametersInput {
train_set: items.clone(),
progress: Some(progress.clone()),
enable_short_term: true,
num_relearning_steps: Some(num_of_relearning_steps),
})?;
progress_thread.join().ok(); progress_thread.join().ok();
if let Ok(fsrs) = FSRS::new(Some(current_params)) { if let Ok(fsrs) = FSRS::new(Some(current_params)) {
let current_rmse = fsrs.evaluate(items.clone(), |_| true)?.rmse_bins; let current_rmse = fsrs.evaluate(items.clone(), |_| true)?.rmse_bins;
let optimized_fsrs = FSRS::new(Some(&params))?; let optimized_fsrs = FSRS::new(Some(&params))?;
let optimized_rmse = optimized_fsrs.evaluate(items.clone(), |_| true)?.rmse_bins; let optimized_rmse = optimized_fsrs.evaluate(items.clone(), |_| true)?.rmse_bins;
if current_rmse <= optimized_rmse { if current_rmse <= optimized_rmse {
params = current_params.to_vec(); if num_of_relearning_steps <= 1 {
} params = current_params.to_vec();
if num_of_relearning_steps > 1 { } else {
let memory_state = MemoryState { let current_fsrs = FSRS::new(Some(current_params))?;
stability: 1.0, let memory_state = MemoryState {
difficulty: 1.0, stability: 1.0,
}; difficulty: 1.0,
let s_fail = optimized_fsrs };
.next_states(Some(memory_state), 0.9, 2)?
.again; let s_fail = current_fsrs.next_states(Some(memory_state), 0.9, 2)?.again;
let mut s_short_term = s_fail.memory; let mut s_short_term = s_fail.memory;
for _ in 0..num_of_relearning_steps {
s_short_term = optimized_fsrs for _ in 0..num_of_relearning_steps {
.next_states(Some(s_short_term), 0.9, 0)? s_short_term = current_fsrs
.good .next_states(Some(s_short_term), 0.9, 0)?
.memory; .good
} .memory;
if s_short_term.stability > memory_state.stability { }
let (progress, progress_thread) = create_progress_thread()?;
params = fsrs.compute_parameters(items.clone(), Some(progress), false)?; if s_short_term.stability < memory_state.stability {
progress_thread.join().ok(); params = current_params.to_vec();
}
} }
} }
} }

View file

@ -17,6 +17,7 @@ use anki_proto::scheduler::FuzzDeltaResponse;
use anki_proto::scheduler::GetOptimalRetentionParametersResponse; use anki_proto::scheduler::GetOptimalRetentionParametersResponse;
use anki_proto::scheduler::SimulateFsrsReviewRequest; use anki_proto::scheduler::SimulateFsrsReviewRequest;
use anki_proto::scheduler::SimulateFsrsReviewResponse; use anki_proto::scheduler::SimulateFsrsReviewResponse;
use fsrs::ComputeParametersInput;
use fsrs::FSRSItem; use fsrs::FSRSItem;
use fsrs::FSRSReview; use fsrs::FSRSReview;
use fsrs::FSRS; use fsrs::FSRS;
@ -352,11 +353,12 @@ impl crate::services::BackendSchedulerService for Backend {
) -> Result<scheduler::ComputeFsrsParamsResponse> { ) -> Result<scheduler::ComputeFsrsParamsResponse> {
let fsrs = FSRS::new(None)?; let fsrs = FSRS::new(None)?;
let fsrs_items = req.items.len() as u32; let fsrs_items = req.items.len() as u32;
let params = fsrs.compute_parameters( let params = fsrs.compute_parameters(ComputeParametersInput {
req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(), train_set: req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(),
None, progress: None,
true, enable_short_term: true,
)?; num_relearning_steps: None,
})?;
Ok(ComputeFsrsParamsResponse { params, fsrs_items }) Ok(ComputeFsrsParamsResponse { params, fsrs_items })
} }
@ -370,7 +372,12 @@ impl crate::services::BackendSchedulerService for Backend {
.into_iter() .into_iter()
.map(fsrs_item_proto_to_fsrs) .map(fsrs_item_proto_to_fsrs)
.collect(); .collect();
let params = fsrs.benchmark(train_set, true); let params = fsrs.benchmark(ComputeParametersInput {
train_set,
progress: None,
enable_short_term: true,
num_relearning_steps: None,
});
Ok(FsrsBenchmarkResponse { params }) Ok(FsrsBenchmarkResponse { params })
} }