Expose the ability to train weights from items (#2687)

This commit is contained in:
Damien Elmes 2023-09-28 08:28:24 +10:00 committed by GitHub
parent 05499297e0
commit 1f55ad1d44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 0 deletions

View file

@ -47,6 +47,8 @@ service SchedulerService {
rpc RepositionDefaults(generic.Empty) returns (RepositionDefaultsResponse); rpc RepositionDefaults(generic.Empty) returns (RepositionDefaultsResponse);
rpc ComputeFsrsWeights(ComputeFsrsWeightsRequest) rpc ComputeFsrsWeights(ComputeFsrsWeightsRequest)
returns (ComputeFsrsWeightsResponse); returns (ComputeFsrsWeightsResponse);
rpc ComputeFsrsWeightsFromItems(ComputeFsrsWeightsFromItemsRequest)
returns (ComputeFsrsWeightsResponse);
rpc GetOptimalRetentionParameters(GetOptimalRetentionParametersRequest) rpc GetOptimalRetentionParameters(GetOptimalRetentionParametersRequest)
returns (GetOptimalRetentionParametersResponse); returns (GetOptimalRetentionParametersResponse);
rpc ComputeOptimalRetention(ComputeOptimalRetentionRequest) rpc ComputeOptimalRetention(ComputeOptimalRetentionRequest)
@ -339,6 +341,19 @@ message ComputeFsrsWeightsResponse {
uint32 fsrs_items = 2; uint32 fsrs_items = 2;
} }
message ComputeFsrsWeightsFromItemsRequest {
repeated FsrsItem items = 1;
}
message FsrsItem {
repeated FsrsReview reviews = 1;
}
message FsrsReview {
uint32 rating = 1;
uint32 delta_t = 2;
}
message ComputeOptimalRetentionRequest { message ComputeOptimalRetentionRequest {
repeated float weights = 1; repeated float weights = 1;
uint32 deck_size = 2; uint32 deck_size = 2;

View file

@ -15,6 +15,7 @@ from anki import (
import_export_pb2, import_export_pb2,
links_pb2, links_pb2,
notes_pb2, notes_pb2,
scheduler_pb2,
search_pb2, search_pb2,
stats_pb2, stats_pb2,
sync_pb2, sync_pb2,
@ -52,6 +53,8 @@ GetImageOcclusionNoteResponse = image_occlusion_pb2.GetImageOcclusionNoteRespons
AddonInfo = ankiweb_pb2.AddonInfo AddonInfo = ankiweb_pb2.AddonInfo
CheckForUpdateResponse = ankiweb_pb2.CheckForUpdateResponse CheckForUpdateResponse = ankiweb_pb2.CheckForUpdateResponse
MediaSyncStatus = sync_pb2.MediaSyncStatusResponse MediaSyncStatus = sync_pb2.MediaSyncStatusResponse
FsrsItem = scheduler_pb2.FsrsItem
FsrsReview = scheduler_pb2.FsrsReview
import copy import copy
import os import os
@ -1338,6 +1341,9 @@ class Collection(DeprecatedNamesMixin):
else: else:
return ComputedMemoryState(desired_retention=resp.desired_retention) return ComputedMemoryState(desired_retention=resp.desired_retention)
def compute_weights_from_items(self, items: Iterable[FsrsItem]) -> Sequence[float]:
return self._backend.compute_fsrs_weights_from_items(items).weights
# Timeboxing # Timeboxing
########################################################################## ##########################################################################
# fixme: there doesn't seem to be a good reason why this code is in main.py # fixme: there doesn't seem to be a good reason why this code is in main.py

View file

@ -4,6 +4,7 @@ use std::iter;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use anki_proto::scheduler::ComputeFsrsWeightsFromItemsRequest;
use anki_proto::scheduler::ComputeFsrsWeightsResponse; use anki_proto::scheduler::ComputeFsrsWeightsResponse;
use fsrs::FSRSItem; use fsrs::FSRSItem;
use fsrs::FSRSReview; use fsrs::FSRSReview;
@ -53,6 +54,22 @@ impl Collection {
}) })
} }
pub fn compute_weights_from_items(
&mut self,
req: ComputeFsrsWeightsFromItemsRequest,
) -> Result<ComputeFsrsWeightsResponse> {
let fsrs = FSRS::new(None)?;
let fsrs_items = req.items.len() as u32;
let weights = fsrs.compute_weights(
req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(),
None,
)?;
Ok(ComputeFsrsWeightsResponse {
weights,
fsrs_items,
})
}
pub(crate) fn revlog_for_srs( pub(crate) fn revlog_for_srs(
&mut self, &mut self,
search: impl TryIntoSearch, search: impl TryIntoSearch,
@ -209,6 +226,23 @@ impl RevlogEntry {
} }
} }
fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {
FSRSItem {
reviews: item
.reviews
.into_iter()
.map(fsrs_review_proto_to_fsrs)
.collect(),
}
}
fn fsrs_review_proto_to_fsrs(review: anki_proto::scheduler::FsrsReview) -> FSRSReview {
FSRSReview {
delta_t: review.delta_t,
rating: review.rating,
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -250,6 +250,13 @@ impl crate::services::SchedulerService for Collection {
self.compute_weights(&input.search) self.compute_weights(&input.search)
} }
fn compute_fsrs_weights_from_items(
&mut self,
input: scheduler::ComputeFsrsWeightsFromItemsRequest,
) -> Result<scheduler::ComputeFsrsWeightsResponse> {
self.compute_weights_from_items(input)
}
fn compute_optimal_retention( fn compute_optimal_retention(
&mut self, &mut self,
input: ComputeOptimalRetentionRequest, input: ComputeOptimalRetentionRequest,