Shift weight calculation to backend so it can be run in parallel

This commit is contained in:
Damien Elmes 2023-09-28 09:10:54 +10:00
parent 1f55ad1d44
commit b8390d096e
5 changed files with 51 additions and 48 deletions

View file

@ -47,8 +47,6 @@ service SchedulerService {
rpc RepositionDefaults(generic.Empty) returns (RepositionDefaultsResponse); rpc RepositionDefaults(generic.Empty) returns (RepositionDefaultsResponse);
rpc ComputeFsrsWeights(ComputeFsrsWeightsRequest) rpc ComputeFsrsWeights(ComputeFsrsWeightsRequest)
returns (ComputeFsrsWeightsResponse); returns (ComputeFsrsWeightsResponse);
rpc ComputeFsrsWeightsFromItems(ComputeFsrsWeightsFromItemsRequest)
returns (ComputeFsrsWeightsResponse);
rpc GetOptimalRetentionParameters(GetOptimalRetentionParametersRequest) rpc GetOptimalRetentionParameters(GetOptimalRetentionParametersRequest)
returns (GetOptimalRetentionParametersResponse); returns (GetOptimalRetentionParametersResponse);
rpc ComputeOptimalRetention(ComputeOptimalRetentionRequest) rpc ComputeOptimalRetention(ComputeOptimalRetentionRequest)
@ -59,7 +57,10 @@ service SchedulerService {
// Implicitly includes any of the above methods that are not listed in the // Implicitly includes any of the above methods that are not listed in the
// backend service. // backend service.
service BackendSchedulerService {} service BackendSchedulerService {
rpc ComputeFsrsWeightsFromItems(ComputeFsrsWeightsFromItemsRequest)
returns (ComputeFsrsWeightsResponse);
}
message SchedulingState { message SchedulingState {
message New { message New {

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import sys import sys
import traceback import traceback
from typing import Any, Sequence from typing import TYPE_CHECKING, Any, Iterable, Sequence
from weakref import ref from weakref import ref
from markdown import markdown from markdown import markdown
@ -18,6 +18,9 @@ from anki.dbproxy import Row as DBRow
from anki.dbproxy import ValueForDB from anki.dbproxy import ValueForDB
from anki.utils import from_json_bytes, to_json_bytes from anki.utils import from_json_bytes, to_json_bytes
if TYPE_CHECKING:
from anki.collection import FsrsItem
from .errors import ( from .errors import (
BackendError, BackendError,
BackendIOError, BackendIOError,
@ -140,6 +143,9 @@ class RustBackend(RustBackendGenerated):
) )
return self.format_timespan(seconds=seconds, context=context) return self.format_timespan(seconds=seconds, context=context)
def compute_weights_from_items(self, items: Iterable[FsrsItem]) -> Sequence[float]:
return self.compute_fsrs_weights_from_items(items).weights
def _run_command(self, service: int, method: int, input: bytes) -> bytes: def _run_command(self, service: int, method: int, input: bytes) -> bytes:
try: try:
return self._backend.command(service, method, input) return self._backend.command(service, method, input)

View file

@ -1341,9 +1341,6 @@ class Collection(DeprecatedNamesMixin):
else: else:
return ComputedMemoryState(desired_retention=resp.desired_retention) return ComputedMemoryState(desired_retention=resp.desired_retention)
def compute_weights_from_items(self, items: Iterable[FsrsItem]) -> Sequence[float]:
return self._backend.compute_fsrs_weights_from_items(items).weights
# Timeboxing # Timeboxing
########################################################################## ##########################################################################
# fixme: there doesn't seem to be a good reason why this code is in main.py # fixme: there doesn't seem to be a good reason why this code is in main.py

View file

@ -4,7 +4,6 @@ use std::iter;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use anki_proto::scheduler::ComputeFsrsWeightsFromItemsRequest;
use anki_proto::scheduler::ComputeFsrsWeightsResponse; use anki_proto::scheduler::ComputeFsrsWeightsResponse;
use fsrs::FSRSItem; use fsrs::FSRSItem;
use fsrs::FSRSReview; use fsrs::FSRSReview;
@ -54,22 +53,6 @@ impl Collection {
}) })
} }
pub fn compute_weights_from_items(
&mut self,
req: ComputeFsrsWeightsFromItemsRequest,
) -> Result<ComputeFsrsWeightsResponse> {
let fsrs = FSRS::new(None)?;
let fsrs_items = req.items.len() as u32;
let weights = fsrs.compute_weights(
req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(),
None,
)?;
Ok(ComputeFsrsWeightsResponse {
weights,
fsrs_items,
})
}
pub(crate) fn revlog_for_srs( pub(crate) fn revlog_for_srs(
&mut self, &mut self,
search: impl TryIntoSearch, search: impl TryIntoSearch,
@ -226,23 +209,6 @@ impl RevlogEntry {
} }
} }
fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {
FSRSItem {
reviews: item
.reviews
.into_iter()
.map(fsrs_review_proto_to_fsrs)
.collect(),
}
}
fn fsrs_review_proto_to_fsrs(review: anki_proto::scheduler::FsrsReview) -> FSRSReview {
FSRSReview {
delta_t: review.delta_t,
rating: review.rating,
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -7,11 +7,16 @@ mod states;
use anki_proto::cards; use anki_proto::cards;
use anki_proto::generic; use anki_proto::generic;
use anki_proto::scheduler; use anki_proto::scheduler;
use anki_proto::scheduler::ComputeFsrsWeightsResponse;
use anki_proto::scheduler::ComputeMemoryStateResponse; use anki_proto::scheduler::ComputeMemoryStateResponse;
use anki_proto::scheduler::ComputeOptimalRetentionRequest; use anki_proto::scheduler::ComputeOptimalRetentionRequest;
use anki_proto::scheduler::ComputeOptimalRetentionResponse; use anki_proto::scheduler::ComputeOptimalRetentionResponse;
use anki_proto::scheduler::GetOptimalRetentionParametersResponse; use anki_proto::scheduler::GetOptimalRetentionParametersResponse;
use fsrs::FSRSItem;
use fsrs::FSRSReview;
use fsrs::FSRS;
use crate::backend::Backend;
use crate::prelude::*; use crate::prelude::*;
use crate::scheduler::new::NewCardDueOrder; use crate::scheduler::new::NewCardDueOrder;
use crate::scheduler::states::CardState; use crate::scheduler::states::CardState;
@ -250,13 +255,6 @@ impl crate::services::SchedulerService for Collection {
self.compute_weights(&input.search) self.compute_weights(&input.search)
} }
fn compute_fsrs_weights_from_items(
&mut self,
input: scheduler::ComputeFsrsWeightsFromItemsRequest,
) -> Result<scheduler::ComputeFsrsWeightsResponse> {
self.compute_weights_from_items(input)
}
fn compute_optimal_retention( fn compute_optimal_retention(
&mut self, &mut self,
input: ComputeOptimalRetentionRequest, input: ComputeOptimalRetentionRequest,
@ -291,3 +289,38 @@ impl crate::services::SchedulerService for Collection {
self.compute_memory_state(input.into()) self.compute_memory_state(input.into())
} }
} }
impl crate::services::BackendSchedulerService for Backend {
fn compute_fsrs_weights_from_items(
&self,
req: scheduler::ComputeFsrsWeightsFromItemsRequest,
) -> Result<scheduler::ComputeFsrsWeightsResponse> {
let fsrs = FSRS::new(None)?;
let fsrs_items = req.items.len() as u32;
let weights = fsrs.compute_weights(
req.items.into_iter().map(fsrs_item_proto_to_fsrs).collect(),
None,
)?;
Ok(ComputeFsrsWeightsResponse {
weights,
fsrs_items,
})
}
}
fn fsrs_item_proto_to_fsrs(item: anki_proto::scheduler::FsrsItem) -> FSRSItem {
FSRSItem {
reviews: item
.reviews
.into_iter()
.map(fsrs_review_proto_to_fsrs)
.collect(),
}
}
fn fsrs_review_proto_to_fsrs(review: anki_proto::scheduler::FsrsReview) -> FSRSReview {
FSRSReview {
delta_t: review.delta_t,
rating: review.rating,
}
}