From dda13248728c61604d446b590b79eb6cda56434f Mon Sep 17 00:00:00 2001 From: Daniel Pechersky Date: Mon, 15 Sep 2025 19:02:47 +0700 Subject: [PATCH] Use fsrs batched function --- Cargo.lock | 5 +- Cargo.toml | 59 ++++++++++++++++++++---- rslib/src/scheduler/fsrs/memory_state.rs | 28 +++++------ 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fe88eb3ab..962c70b53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2233,9 +2233,8 @@ dependencies = [ [[package]] name = "fsrs" -version = "5.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04954cc67c3c11ee342a2ee1f5222bf76d73f7772df08d37dc9a6cdd73c467eb" +version = "5.2.0" +source = "git+https://github.com/open-spaced-repetition/fsrs-rs.git#1e271981367454468391f1c686af03a0aa7aab3c" dependencies = [ "burn", "itertools 0.14.0", diff --git a/Cargo.toml b/Cargo.toml index 2e9489cb8..186335001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,8 +33,8 @@ git = "https://github.com/ankitects/linkcheck.git" rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca" [workspace.dependencies.fsrs] -version = "5.1.0" -# git = "https://github.com/open-spaced-repetition/fsrs-rs.git" +# version = "5.1.0" +git = "https://github.com/open-spaced-repetition/fsrs-rs.git" # path = "../open-spaced-repetition/fsrs-rs" [workspace.dependencies] @@ -63,7 +63,10 @@ bitflags = "2.9.1" blake3 = "1.8.2" bytes = "1.10.1" camino = "1.1.10" -chrono = { version = "0.4.41", default-features = false, features = ["std", "clock"] } +chrono = { version = "0.4.41", default-features = false, features = [ + "std", + "clock", +] } clap = { version = "4.5.40", features = ["derive"] } coarsetime = "0.1.36" convert_case = "0.8.0" @@ -107,12 +110,26 @@ prost-build = "0.13" prost-reflect = "0.14.7" prost-types = "0.13" pulldown-cmark = "0.13.0" -pyo3 = { version = "0.25.1", features = ["extension-module", "abi3", "abi3-py39"] } +pyo3 = { version = "0.25.1", features = [ + "extension-module", + "abi3", + "abi3-py39", +] } rand = "0.9.1" rayon = "1.10.0" regex = "1.11.1" -reqwest = { version = "0.12.20", default-features = false, features = ["json", "socks", "stream", "multipart"] } -rusqlite = { version = "0.36.0", features = ["trace", "functions", "collation", "bundled"] } +reqwest = { version = "0.12.20", default-features = false, features = [ + "json", + "socks", + "stream", + "multipart", +] } +rusqlite = { version = "0.36.0", features = [ + "trace", + "functions", + "collation", + "bundled", +] } rustls-pemfile = "2.2.0" scopeguard = "1.2.0" serde = { version = "1.0.219", features = ["derive"] } @@ -128,10 +145,18 @@ syn = { version = "2.0.103", features = ["parsing", "printing"] } tar = "0.4.44" tempfile = "3.20.0" termcolor = "1.4.1" -tokio = { version = "1.45", features = ["fs", "rt-multi-thread", "macros", "signal"] } +tokio = { version = "1.45", features = [ + "fs", + "rt-multi-thread", + "macros", + "signal", +] } tokio-util = { version = "0.7.15", features = ["io"] } tower-http = { version = "0.6.6", features = ["trace"] } -tracing = { version = "0.1.41", features = ["max_level_trace", "release_max_level_debug"] } +tracing = { version = "0.1.41", features = [ + "max_level_trace", + "release_max_level_debug", +] } tracing-appender = "0.2.3" tracing-subscriber = { version = "0.3.20", features = ["fmt", "env-filter"] } unic-langid = { version = "0.9.6", features = ["macros"] } @@ -141,10 +166,24 @@ walkdir = "2.5.0" which = "8.0.0" widestring = "1.1.0" winapi = { version = "0.3", features = ["wincon", "winreg"] } -windows = { version = "0.61.3", features = ["Media_SpeechSynthesis", "Media_Core", "Foundation_Collections", "Storage_Streams", "Win32_System_Console", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_Foundation", "Win32_UI_Shell", "Wdk_System_SystemServices"] } +windows = { version = "0.61.3", features = [ + "Media_SpeechSynthesis", + "Media_Core", + "Foundation_Collections", + "Storage_Streams", + "Win32_System_Console", + "Win32_System_Registry", + "Win32_System_SystemInformation", + "Win32_Foundation", + "Win32_UI_Shell", + "Wdk_System_SystemServices", +] } wiremock = "0.6.3" xz2 = "0.1.7" -zip = { version = "4.1.0", default-features = false, features = ["deflate", "time"] } +zip = { version = "4.1.0", default-features = false, features = [ + "deflate", + "time", +] } zstd = { version = "0.13.3", features = ["zstdmt"] } # Apply mild optimizations to our dependencies in dev mode, which among other things diff --git a/rslib/src/scheduler/fsrs/memory_state.rs b/rslib/src/scheduler/fsrs/memory_state.rs index 0016488c7..6fc2913c1 100644 --- a/rslib/src/scheduler/fsrs/memory_state.rs +++ b/rslib/src/scheduler/fsrs/memory_state.rs @@ -10,8 +10,6 @@ use fsrs::FSRS; use fsrs::FSRS5_DEFAULT_DECAY; use fsrs::FSRS6_DEFAULT_DECAY; use itertools::Itertools; -use rayon::iter::IntoParallelRefMutIterator as _; -use rayon::iter::ParallelIterator as _; use super::params::ignore_revlogs_before_ms_from_config; use super::rescheduler::Rescheduler; @@ -113,7 +111,9 @@ impl Collection { }; let preset_desired_retention = req.preset_desired_retention; - let mut to_update_memory_state = Vec::new(); + let mut to_update = Vec::new(); + let mut fsrs_items = Vec::new(); + let mut starting_states = Vec::new(); for (idx, (card_id, item)) in items.into_iter().enumerate() { progress.update(true, |state| state.current_cards = idx as u32 + 1)?; let mut card = self.storage.get_card(card_id)?.or_not_found(card_id)?; @@ -130,7 +130,9 @@ impl Collection { card.desired_retention = Some(desired_retention); card.decay = decay; if let Some(item) = item { - to_update_memory_state.push((card, original, item)); + to_update.push((card, original)); + fsrs_items.push(item.item); + starting_states.push(item.starting_state); } else { // clear memory states if item is None card.memory_state = None; @@ -138,14 +140,11 @@ impl Collection { } } - to_update_memory_state.par_iter_mut().try_for_each_with( - fsrs.clone(), - |fsrs, (card, _, item)| { - card.set_memory_state(fsrs, Some(item.clone()), historical_retention.unwrap()) - }, - )?; + let memory_states = fsrs.memory_state_batch(fsrs_items, starting_states)?; + + for ((mut card, original), memory_state) in to_update.into_iter().zip(memory_states) { + card.memory_state = Some(memory_state.into()); - for (mut card, original, _) in to_update_memory_state { 'reschedule_card: { // if rescheduling let Some(reviews) = &last_revlog_info else { @@ -159,11 +158,6 @@ impl Collection { let Some(last_review) = &last_info.last_reviewed_at else { break 'reschedule_card; }; - - // and the card's not new - let Some(state) = &card.memory_state else { - break 'reschedule_card; - }; // or in (re)learning if card.ctype != CardType::Review { break 'reschedule_card; @@ -177,7 +171,7 @@ impl Collection { let days_elapsed = timing.next_day_at.elapsed_days_since(*last_review) as i32; let original_interval = card.interval; let interval = fsrs.next_interval( - Some(state.stability), + Some(memory_state.stability), card.desired_retention .expect("We set desired retention above"), 0,