Fix/re-optimize FSRS if short-term param is weird (#3742)

* Fix/re-optimize FSRS if short-term param is weird

* Reset progress when another run is required (dae)

* only count the same-day steps

* Fix flicker when optimizing again (dae)
This commit is contained in:
Jarrett Ye 2025-01-26 07:42:17 +08:00 committed by GitHub
parent 5883e4eae8
commit 43e860783b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 80 additions and 24 deletions

View file

@ -346,6 +346,7 @@ message ComputeFsrsParamsRequest {
string search = 1; string search = 1;
repeated float current_params = 2; repeated float current_params = 2;
int64 ignore_revlogs_before_ms = 3; int64 ignore_revlogs_before_ms = 3;
uint32 num_of_relearning_steps = 4;
} }
message ComputeFsrsParamsResponse { message ComputeFsrsParamsResponse {

View file

@ -356,12 +356,14 @@ impl Collection {
config.inner.param_search.clone() config.inner.param_search.clone()
}; };
let ignore_revlogs_before_ms = ignore_revlogs_before_ms_from_config(config)?; let ignore_revlogs_before_ms = ignore_revlogs_before_ms_from_config(config)?;
let num_of_relearning_steps = config.inner.relearn_steps.len();
match self.compute_params( match self.compute_params(
&search, &search,
ignore_revlogs_before_ms, ignore_revlogs_before_ms,
idx as u32 + 1, idx as u32 + 1,
config_len, config_len,
config.fsrs_params(), config.fsrs_params(),
num_of_relearning_steps,
) { ) {
Ok(params) => { Ok(params) => {
println!("{}: {:?}", config.name, params.params); println!("{}: {:?}", config.name, params.params);

View file

@ -122,6 +122,13 @@ pub struct ProgressState {
pub last_progress: Option<Progress>, pub last_progress: Option<Progress>,
} }
impl ProgressState {
pub fn reset(&mut self) {
self.want_abort = false;
self.last_progress = None;
}
}
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub enum Progress { pub enum Progress {
MediaSync(MediaSyncProgress), MediaSync(MediaSyncProgress),
@ -320,6 +327,10 @@ impl Collection {
) -> ThrottlingProgressHandler<P> { ) -> ThrottlingProgressHandler<P> {
ThrottlingProgressHandler::new(self.state.progress.clone()) ThrottlingProgressHandler::new(self.state.progress.clone())
} }
pub(crate) fn clear_progress(&mut self) {
self.state.progress.lock().unwrap().reset();
}
} }
pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> { pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> {

View file

@ -16,6 +16,7 @@ use chrono::NaiveTime;
use fsrs::CombinedProgressState; use fsrs::CombinedProgressState;
use fsrs::FSRSItem; use fsrs::FSRSItem;
use fsrs::FSRSReview; use fsrs::FSRSReview;
use fsrs::MemoryState;
use fsrs::ModelEvaluation; use fsrs::ModelEvaluation;
use fsrs::FSRS; use fsrs::FSRS;
use itertools::Itertools; use itertools::Itertools;
@ -60,8 +61,9 @@ impl Collection {
current_preset: u32, current_preset: u32,
total_presets: u32, total_presets: u32,
current_params: &Params, current_params: &Params,
num_of_relearning_steps: usize,
) -> Result<ComputeFsrsParamsResponse> { ) -> Result<ComputeFsrsParamsResponse> {
let mut anki_progress = self.new_progress_handler::<ComputeParamsProgress>(); self.clear_progress();
let timing = self.timing_today()?; let timing = self.timing_today()?;
let revlogs = self.revlog_for_srs(search)?; let revlogs = self.revlog_for_srs(search)?;
let (items, review_count) = let (items, review_count) =
@ -74,31 +76,38 @@ impl Collection {
fsrs_items, fsrs_items,
}); });
} }
anki_progress.update(false, |p| {
p.current_preset = current_preset;
p.total_presets = total_presets;
})?;
// adapt the progress handler to our built-in progress handling // adapt the progress handler to our built-in progress handling
let progress = CombinedProgressState::new_shared();
let progress2 = progress.clone(); let create_progress_thread = || -> Result<_> {
let progress_thread = thread::spawn(move || { let mut anki_progress = self.new_progress_handler::<ComputeParamsProgress>();
let mut finished = false; anki_progress.update(false, |p| {
while !finished { p.current_preset = current_preset;
thread::sleep(Duration::from_millis(100)); p.total_presets = total_presets;
let mut guard = progress.lock().unwrap(); })?;
if let Err(_err) = anki_progress.update(false, |s| { let progress = CombinedProgressState::new_shared();
s.total_iterations = guard.total() as u32; let progress2 = progress.clone();
s.current_iteration = guard.current() as u32; let progress_thread = thread::spawn(move || {
s.reviews = review_count as u32; let mut finished = false;
finished = guard.finished(); while !finished {
}) { thread::sleep(Duration::from_millis(100));
guard.want_abort = true; let mut guard = progress.lock().unwrap();
return; if let Err(_err) = anki_progress.update(false, |s| {
s.total_iterations = guard.total() as u32;
s.current_iteration = guard.current() as u32;
s.reviews = review_count as u32;
finished = guard.finished();
}) {
guard.want_abort = true;
return;
}
} }
} });
}); Ok((progress2, progress_thread))
let mut params = };
FSRS::new(None)?.compute_parameters(items.clone(), Some(progress2), true)?;
let (progress, progress_thread) = create_progress_thread()?;
let fsrs = FSRS::new(None)?;
let mut params = fsrs.compute_parameters(items.clone(), Some(progress.clone()), true)?;
progress_thread.join().ok(); progress_thread.join().ok();
if let Ok(fsrs) = FSRS::new(Some(current_params)) { if let Ok(fsrs) = FSRS::new(Some(current_params)) {
let current_rmse = fsrs.evaluate(items.clone(), |_| true)?.rmse_bins; let current_rmse = fsrs.evaluate(items.clone(), |_| true)?.rmse_bins;
@ -107,6 +116,27 @@ impl Collection {
if current_rmse <= optimized_rmse { if current_rmse <= optimized_rmse {
params = current_params.to_vec(); params = current_params.to_vec();
} }
if num_of_relearning_steps > 1 {
let memory_state = MemoryState {
stability: 1.0,
difficulty: 1.0,
};
let s_fail = optimized_fsrs
.next_states(Some(memory_state), 0.9, 2)?
.again;
let mut s_short_term = s_fail.memory;
for _ in 0..num_of_relearning_steps {
s_short_term = optimized_fsrs
.next_states(Some(s_short_term), 0.9, 0)?
.good
.memory;
}
if s_short_term.stability > memory_state.stability {
let (progress, progress_thread) = create_progress_thread()?;
params = fsrs.compute_parameters(items.clone(), Some(progress), false)?;
progress_thread.join().ok();
}
}
} }
Ok(ComputeFsrsParamsResponse { params, fsrs_items }) Ok(ComputeFsrsParamsResponse { params, fsrs_items })

View file

@ -264,6 +264,7 @@ impl crate::services::SchedulerService for Collection {
1, 1,
1, 1,
&input.current_params, &input.current_params,
input.num_of_relearning_steps as usize,
) )
} }

View file

@ -151,12 +151,23 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
await runWithBackendProgress( await runWithBackendProgress(
async () => { async () => {
const params = fsrsParams($config); const params = fsrsParams($config);
const RelearningSteps = $config.relearnSteps;
let numOfRelearningStepsInDay = 0;
let accumulatedTime = 0;
for (let i = 0; i < RelearningSteps.length; i++) {
accumulatedTime += RelearningSteps[i];
if (accumulatedTime >= 1440) {
break;
}
numOfRelearningStepsInDay++;
}
const resp = await computeFsrsParams({ const resp = await computeFsrsParams({
search: $config.paramSearch search: $config.paramSearch
? $config.paramSearch ? $config.paramSearch
: defaultparamSearch, : defaultparamSearch,
ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(), ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(),
currentParams: params, currentParams: params,
numOfRelearningSteps: numOfRelearningStepsInDay,
}); });
const already_optimal = const already_optimal =