diff --git a/Cargo.lock b/Cargo.lock index a18f94d36..5c8d338c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1792,9 +1792,9 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.4" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5210280b424b4b47187c7af3935fcd13aca73afec897675b860b0c6e133bab" +checksum = "7c7e6a1986cc2b7a64445d84e2c453ecd8d95dcf90b797205c54573697e10b17" dependencies = [ "burn", "itertools 0.12.1", diff --git a/Cargo.toml b/Cargo.toml index 535f397ef..34277b4e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,10 +35,10 @@ git = "https://github.com/ankitects/linkcheck.git" rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca" [workspace.dependencies.fsrs] -version = "0.4.4" +version = "0.5.0" # git = "https://github.com/open-spaced-repetition/fsrs-rs.git" # rev = "58ca25ed2bc4bb1dc376208bbcaed7f5a501b941" -# path = "../../../fsrs-rs" +# path = "../open-spaced-repetition/fsrs-rs" [workspace.dependencies] # local diff --git a/cargo/licenses.json b/cargo/licenses.json index 6e0921918..3efaf806f 100644 --- a/cargo/licenses.json +++ b/cargo/licenses.json @@ -1198,7 +1198,7 @@ }, { "name": "fsrs", - "version": "0.4.4", + "version": "0.5.0", "authors": "Open Spaced Repetition", "repository": "https://github.com/open-spaced-repetition/fsrs-rs", "license": "BSD-3-Clause", diff --git a/proto/anki/scheduler.proto b/proto/anki/scheduler.proto index 90a0a4f3b..0611d41bb 100644 --- a/proto/anki/scheduler.proto +++ b/proto/anki/scheduler.proto @@ -63,6 +63,8 @@ service SchedulerService { service BackendSchedulerService { rpc ComputeFsrsWeightsFromItems(ComputeFsrsWeightsFromItemsRequest) returns (ComputeFsrsWeightsResponse); + // Generates parameters used for FSRS's scheduler benchmarks. + rpc FsrsBenchmark(FsrsBenchmarkRequest) returns (FsrsBenchmarkResponse); } message SchedulingState { @@ -351,6 +353,15 @@ message ComputeFsrsWeightsFromItemsRequest { repeated FsrsItem items = 1; } +message FsrsBenchmarkRequest { + repeated FsrsItem train_set = 1; + repeated FsrsItem test_set = 2; +} + +message FsrsBenchmarkResponse { + repeated float weights = 1; +} + message FsrsItem { repeated FsrsReview reviews = 1; } diff --git a/pylib/anki/_backend.py b/pylib/anki/_backend.py index 2dfeb3a90..7a2ab571a 100644 --- a/pylib/anki/_backend.py +++ b/pylib/anki/_backend.py @@ -151,6 +151,11 @@ class RustBackend(RustBackendGenerated): def compute_weights_from_items(self, items: Iterable[FsrsItem]) -> Sequence[float]: return self.compute_fsrs_weights_from_items(items).weights + def benchmark( + self, train_set: Iterable[FsrsItem], test_set: Iterable[FsrsItem] + ) -> Sequence[float]: + return self.fsrs_benchmark(train_set=train_set, test_set=test_set) + def _run_command(self, service: int, method: int, input: bytes) -> bytes: start = time.time() try: diff --git a/rslib/src/scheduler/service/mod.rs b/rslib/src/scheduler/service/mod.rs index 7a2a14dad..e40f49d32 100644 --- a/rslib/src/scheduler/service/mod.rs +++ b/rslib/src/scheduler/service/mod.rs @@ -11,6 +11,7 @@ use anki_proto::scheduler::ComputeFsrsWeightsResponse; use anki_proto::scheduler::ComputeMemoryStateResponse; use anki_proto::scheduler::ComputeOptimalRetentionRequest; use anki_proto::scheduler::ComputeOptimalRetentionResponse; +use anki_proto::scheduler::FsrsBenchmarkResponse; use anki_proto::scheduler::FuzzDeltaRequest; use anki_proto::scheduler::FuzzDeltaResponse; use anki_proto::scheduler::GetOptimalRetentionParametersResponse; @@ -325,6 +326,25 @@ impl crate::services::BackendSchedulerService for Backend { fsrs_items, }) } + + fn fsrs_benchmark( + &self, + req: scheduler::FsrsBenchmarkRequest, + ) -> Result { + let fsrs = FSRS::new(None)?; + let train_set = req + .train_set + .into_iter() + .map(fsrs_item_proto_to_fsrs) + .collect(); + let test_set = req + .test_set + .into_iter() + .map(fsrs_item_proto_to_fsrs) + .collect(); + let weights = fsrs.benchmark(train_set, test_set); + Ok(FsrsBenchmarkResponse { weights }) + } } fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {