Expose the ability to train weights from items (#2687)

This commit is contained in:
Damien Elmes 2023-09-28 08:28:24 +10:00 committed by GitHub
parent 05499297e0
commit 1f55ad1d44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 0 deletions

View file

@ -47,6 +47,8 @@ service SchedulerService {
rpc RepositionDefaults(generic.Empty) returns (RepositionDefaultsResponse);
rpc ComputeFsrsWeights(ComputeFsrsWeightsRequest)
returns (ComputeFsrsWeightsResponse);
rpc ComputeFsrsWeightsFromItems(ComputeFsrsWeightsFromItemsRequest)
returns (ComputeFsrsWeightsResponse);
rpc GetOptimalRetentionParameters(GetOptimalRetentionParametersRequest)
returns (GetOptimalRetentionParametersResponse);
rpc ComputeOptimalRetention(ComputeOptimalRetentionRequest)
@ -339,6 +341,19 @@ message ComputeFsrsWeightsResponse {
uint32 fsrs_items = 2;
}
message ComputeFsrsWeightsFromItemsRequest {
repeated FsrsItem items = 1;
}
message FsrsItem {
repeated FsrsReview reviews = 1;
}
message FsrsReview {
uint32 rating = 1;
uint32 delta_t = 2;
}
message ComputeOptimalRetentionRequest {
repeated float weights = 1;
uint32 deck_size = 2;

View file

@ -15,6 +15,7 @@ from anki import (
import_export_pb2,
links_pb2,
notes_pb2,
scheduler_pb2,
search_pb2,
stats_pb2,
sync_pb2,
@ -52,6 +53,8 @@ GetImageOcclusionNoteResponse = image_occlusion_pb2.GetImageOcclusionNoteRespons
AddonInfo = ankiweb_pb2.AddonInfo
CheckForUpdateResponse = ankiweb_pb2.CheckForUpdateResponse
MediaSyncStatus = sync_pb2.MediaSyncStatusResponse
FsrsItem = scheduler_pb2.FsrsItem
FsrsReview = scheduler_pb2.FsrsReview
import copy
import os
@ -1338,6 +1341,9 @@ class Collection(DeprecatedNamesMixin):
else:
return ComputedMemoryState(desired_retention=resp.desired_retention)
def compute_weights_from_items(self, items: Iterable[FsrsItem]) -> Sequence[float]:
return self._backend.compute_fsrs_weights_from_items(items).weights
# Timeboxing
##########################################################################
# fixme: there doesn't seem to be a good reason why this code is in main.py

View file

@ -4,6 +4,7 @@ use std::iter;
use std::thread;
use std::time::Duration;
use anki_proto::scheduler::ComputeFsrsWeightsFromItemsRequest;
use anki_proto::scheduler::ComputeFsrsWeightsResponse;
use fsrs::FSRSItem;
use fsrs::FSRSReview;
@ -53,6 +54,22 @@ impl Collection {
})
}
pub fn compute_weights_from_items(
&mut self,
req: ComputeFsrsWeightsFromItemsRequest,
) -> Result<ComputeFsrsWeightsResponse> {
let fsrs = FSRS::new(None)?;
let fsrs_items = req.items.len() as u32;
let weights = fsrs.compute_weights(
req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(),
None,
)?;
Ok(ComputeFsrsWeightsResponse {
weights,
fsrs_items,
})
}
pub(crate) fn revlog_for_srs(
&mut self,
search: impl TryIntoSearch,
@ -209,6 +226,23 @@ impl RevlogEntry {
}
}
fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {
FSRSItem {
reviews: item
.reviews
.into_iter()
.map(fsrs_review_proto_to_fsrs)
.collect(),
}
}
fn fsrs_review_proto_to_fsrs(review: anki_proto::scheduler::FsrsReview) -> FSRSReview {
FSRSReview {
delta_t: review.delta_t,
rating: review.rating,
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -250,6 +250,13 @@ impl crate::services::SchedulerService for Collection {
self.compute_weights(&input.search)
}
fn compute_fsrs_weights_from_items(
&mut self,
input: scheduler::ComputeFsrsWeightsFromItemsRequest,
) -> Result<scheduler::ComputeFsrsWeightsResponse> {
self.compute_weights_from_items(input)
}
fn compute_optimal_retention(
&mut self,
input: ComputeOptimalRetentionRequest,