mirror of
https://github.com/ankitects/anki.git
synced 2025-09-18 14:02:21 -04:00
Feat/evaluate FSRS with time series split (#3962)
This commit is contained in:
parent
37984233cc
commit
2de0c79ba5
4 changed files with 18 additions and 17 deletions
|
@ -435,9 +435,9 @@ message GetOptimalRetentionParametersResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
message EvaluateParamsRequest {
|
message EvaluateParamsRequest {
|
||||||
repeated float params = 1;
|
string search = 1;
|
||||||
string search = 2;
|
int64 ignore_revlogs_before_ms = 2;
|
||||||
int64 ignore_revlogs_before_ms = 3;
|
uint32 num_of_relearning_steps = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message EvaluateParamsResponse {
|
message EvaluateParamsResponse {
|
||||||
|
|
|
@ -115,15 +115,14 @@ impl Collection {
|
||||||
num_relearning_steps: Some(num_of_relearning_steps),
|
num_relearning_steps: Some(num_of_relearning_steps),
|
||||||
})?;
|
})?;
|
||||||
progress_thread.join().ok();
|
progress_thread.join().ok();
|
||||||
if let Ok(fsrs) = FSRS::new(Some(current_params)) {
|
if let Ok(current_fsrs) = FSRS::new(Some(current_params)) {
|
||||||
let current_log_loss = fsrs.evaluate(items.clone(), |_| true)?.log_loss;
|
let current_log_loss = current_fsrs.evaluate(items.clone(), |_| true)?.log_loss;
|
||||||
let optimized_fsrs = FSRS::new(Some(¶ms))?;
|
let optimized_fsrs = FSRS::new(Some(¶ms))?;
|
||||||
let optimized_log_loss = optimized_fsrs.evaluate(items.clone(), |_| true)?.log_loss;
|
let optimized_log_loss = optimized_fsrs.evaluate(items.clone(), |_| true)?.log_loss;
|
||||||
if current_log_loss <= optimized_log_loss {
|
if current_log_loss <= optimized_log_loss {
|
||||||
if num_of_relearning_steps <= 1 {
|
if num_of_relearning_steps <= 1 {
|
||||||
params = current_params.to_vec();
|
params = current_params.to_vec();
|
||||||
} else {
|
} else {
|
||||||
let current_fsrs = FSRS::new(Some(current_params))?;
|
|
||||||
let memory_state = MemoryState {
|
let memory_state = MemoryState {
|
||||||
stability: 1.0,
|
stability: 1.0,
|
||||||
difficulty: 1.0,
|
difficulty: 1.0,
|
||||||
|
@ -218,22 +217,24 @@ impl Collection {
|
||||||
|
|
||||||
pub fn evaluate_params(
|
pub fn evaluate_params(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: &Params,
|
|
||||||
search: &str,
|
search: &str,
|
||||||
ignore_revlogs_before: TimestampMillis,
|
ignore_revlogs_before: TimestampMillis,
|
||||||
|
num_of_relearning_steps: usize,
|
||||||
) -> Result<ModelEvaluation> {
|
) -> Result<ModelEvaluation> {
|
||||||
let timing = self.timing_today()?;
|
let timing = self.timing_today()?;
|
||||||
let mut anki_progress = self.new_progress_handler::<ComputeParamsProgress>();
|
let revlogs = self.revlog_for_srs(search)?;
|
||||||
let guard = self.search_cards_into_table(search, SortMode::NoOrder)?;
|
|
||||||
let revlogs: Vec<RevlogEntry> = guard
|
|
||||||
.col
|
|
||||||
.storage
|
|
||||||
.get_revlog_entries_for_searched_cards_in_card_order()?;
|
|
||||||
let (items, review_count) =
|
let (items, review_count) =
|
||||||
fsrs_items_for_training(revlogs, timing.next_day_at, ignore_revlogs_before);
|
fsrs_items_for_training(revlogs, timing.next_day_at, ignore_revlogs_before);
|
||||||
|
let mut anki_progress = self.new_progress_handler::<ComputeParamsProgress>();
|
||||||
anki_progress.state.reviews = review_count as u32;
|
anki_progress.state.reviews = review_count as u32;
|
||||||
let fsrs = FSRS::new(Some(params))?;
|
let fsrs = FSRS::new(None)?;
|
||||||
Ok(fsrs.evaluate(items, |ip| {
|
let input = ComputeParametersInput {
|
||||||
|
train_set: items.clone(),
|
||||||
|
progress: None,
|
||||||
|
enable_short_term: true,
|
||||||
|
num_relearning_steps: Some(num_of_relearning_steps),
|
||||||
|
};
|
||||||
|
Ok(fsrs.evaluate_with_time_series_splits(input, |ip| {
|
||||||
anki_progress
|
anki_progress
|
||||||
.update(false, |p| {
|
.update(false, |p| {
|
||||||
p.total_iterations = ip.total as u32;
|
p.total_iterations = ip.total as u32;
|
||||||
|
|
|
@ -295,9 +295,9 @@ impl crate::services::SchedulerService for Collection {
|
||||||
input: scheduler::EvaluateParamsRequest,
|
input: scheduler::EvaluateParamsRequest,
|
||||||
) -> Result<scheduler::EvaluateParamsResponse> {
|
) -> Result<scheduler::EvaluateParamsResponse> {
|
||||||
let ret = self.evaluate_params(
|
let ret = self.evaluate_params(
|
||||||
&input.params,
|
|
||||||
&input.search,
|
&input.search,
|
||||||
input.ignore_revlogs_before_ms.into(),
|
input.ignore_revlogs_before_ms.into(),
|
||||||
|
input.num_of_relearning_steps as usize,
|
||||||
)?;
|
)?;
|
||||||
Ok(scheduler::EvaluateParamsResponse {
|
Ok(scheduler::EvaluateParamsResponse {
|
||||||
log_loss: ret.log_loss,
|
log_loss: ret.log_loss,
|
||||||
|
|
|
@ -229,9 +229,9 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||||
? $config.paramSearch
|
? $config.paramSearch
|
||||||
: defaultparamSearch;
|
: defaultparamSearch;
|
||||||
const resp = await evaluateParams({
|
const resp = await evaluateParams({
|
||||||
params: fsrsParams($config),
|
|
||||||
search,
|
search,
|
||||||
ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(),
|
ignoreRevlogsBeforeMs: getIgnoreRevlogsBeforeMs(),
|
||||||
|
numOfRelearningSteps: $config.relearnSteps.length,
|
||||||
});
|
});
|
||||||
if (computeParamsProgress) {
|
if (computeParamsProgress) {
|
||||||
computeParamsProgress.current = computeParamsProgress.total;
|
computeParamsProgress.current = computeParamsProgress.total;
|
||||||
|
|
Loading…
Reference in a new issue