Feat/evaluate FSRS with time series split (#3962)

This commit is contained in:
Jarrett Ye 2025-06-03 16:26:33 +08:00 committed by GitHub
parent 37984233cc
commit 2de0c79ba5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 17 deletions

View file

@ -435,9 +435,9 @@ message GetOptimalRetentionParametersResponse {
} }
message EvaluateParamsRequest { message EvaluateParamsRequest {
repeated float params = 1; string search = 1;
string search = 2; int64 ignore_revlogs_before_ms = 2;
int64 ignore_revlogs_before_ms = 3; uint32 num_of_relearning_steps = 3;
} }
message EvaluateParamsResponse { message EvaluateParamsResponse {

View file

@ -115,15 +115,14 @@ impl Collection {
num_relearning_steps: Some(num_of_relearning_steps), num_relearning_steps: Some(num_of_relearning_steps),
})?; })?;
progress_thread.join().ok(); progress_thread.join().ok();
if let Ok(fsrs) = FSRS::new(Some(current_params)) { if let Ok(current_fsrs) = FSRS::new(Some(current_params)) {
let current_log_loss = fsrs.evaluate(items.clone(), |_| true)?.log_loss; let current_log_loss = current_fsrs.evaluate(items.clone(), |_| true)?.log_loss;
let optimized_fsrs = FSRS::new(Some(&params))?; let optimized_fsrs = FSRS::new(Some(&params))?;
let optimized_log_loss = optimized_fsrs.evaluate(items.clone(), |_| true)?.log_loss; let optimized_log_loss = optimized_fsrs.evaluate(items.clone(), |_| true)?.log_loss;
if current_log_loss <= optimized_log_loss { if current_log_loss <= optimized_log_loss {
if num_of_relearning_steps <= 1 { if num_of_relearning_steps <= 1 {
params = current_params.to_vec(); params = current_params.to_vec();
} else { } else {
let current_fsrs = FSRS::new(Some(current_params))?;
let memory_state = MemoryState { let memory_state = MemoryState {
stability: 1.0, stability: 1.0,
difficulty: 1.0, difficulty: 1.0,
@ -218,22 +217,24 @@ impl Collection {
pub fn evaluate_params( pub fn evaluate_params(
&mut self, &mut self,
params: &Params,
search: &str, search: &str,
ignore_revlogs_before: TimestampMillis, ignore_revlogs_before: TimestampMillis,
num_of_relearning_steps: usize,
) -> Result<ModelEvaluation> { ) -> Result<ModelEvaluation> {
let timing = self.timing_today()?; let timing = self.timing_today()?;
let mut anki_progress = self.new_progress_handler::<ComputeParamsProgress>(); let revlogs = self.revlog_for_srs(search)?;
let guard = self.search_cards_into_table(search, SortMode::NoOrder)?;
let revlogs: Vec<RevlogEntry> = guard
.col
.storage
.get_revlog_entries_for_searched_cards_in_card_order()?;
let (items, review_count) = let (items, review_count) =
fsrs_items_for_training(revlogs, timing.next_day_at, ignore_revlogs_before); fsrs_items_for_training(revlogs, timing.next_day_at, ignore_revlogs_before);
let mut anki_progress = self.new_progress_handler::<ComputeParamsProgress>();
anki_progress.state.reviews = review_count as u32; anki_progress.state.reviews = review_count as u32;
let fsrs = FSRS::new(Some(params))?; let fsrs = FSRS::new(None)?;
Ok(fsrs.evaluate(items, |ip| { let input = ComputeParametersInput {
train_set: items.clone(),
progress: None,
enable_short_term: true,
num_relearning_steps: Some(num_of_relearning_steps),
};
Ok(fsrs.evaluate_with_time_series_splits(input, |ip| {
anki_progress anki_progress
.update(false, |p| { .update(false, |p| {
p.total_iterations = ip.total as u32; p.total_iterations = ip.total as u32;

View file

@ -295,9 +295,9 @@ impl crate::services::SchedulerService for Collection {
input: scheduler::EvaluateParamsRequest, input: scheduler::EvaluateParamsRequest,
) -> Result<scheduler::EvaluateParamsResponse> { ) -> Result<scheduler::EvaluateParamsResponse> {
let ret = self.evaluate_params( let ret = self.evaluate_params(
&input.params,
&input.search, &input.search,
input.ignore_revlogs_before_ms.into(), input.ignore_revlogs_before_ms.into(),
input.num_of_relearning_steps as usize,
)?; )?;
Ok(scheduler::EvaluateParamsResponse { Ok(scheduler::EvaluateParamsResponse {
log_loss: ret.log_loss, log_loss: ret.log_loss,

View file

@ -229,9 +229,9 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
? $config.paramSearch ? $config.paramSearch
: defaultparamSearch; : defaultparamSearch;
const resp = await evaluateParams({ const resp = await evaluateParams({
params: fsrsParams($config),
search, search,
ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(), ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(),
numOfRelearningSteps: $config.relearnSteps.length,
}); });
if (computeParamsProgress) { if (computeParamsProgress) {
computeParamsProgress.current = computeParamsProgress.total; computeParamsProgress.current = computeParamsProgress.total;