diff --git a/rslib/src/scheduler/fsrs/retention.rs b/rslib/src/scheduler/fsrs/retention.rs index 07602a09c..f4dc1935d 100644 --- a/rslib/src/scheduler/fsrs/retention.rs +++ b/rslib/src/scheduler/fsrs/retention.rs @@ -135,15 +135,14 @@ impl Collection { r.review_kind == RevlogReviewKind::Review && r.button_chosen > 0 && r.taken_millis > 0 + && r.taken_millis < 1200000 // 20 minutes }) .sorted_by(|a, b| a.button_chosen.cmp(&b.button_chosen)) .group_by(|r| r.button_chosen) .into_iter() .for_each(|(button_chosen, group)| { let group_vec = group.into_iter().map(|r| r.taken_millis).collect_vec(); - let average_secs = - group_vec.iter().sum::() as f64 / group_vec.len() as f64 / 1000.0; - arr[button_chosen as usize - 1] = average_secs + arr[button_chosen as usize - 1] = median_secs(&group_vec); }); if arr == default { return Err(AnkiError::FsrsInsufficientData); @@ -157,19 +156,23 @@ impl Collection { r.review_kind == RevlogReviewKind::Learning && r.button_chosen >= 1 && r.taken_millis > 0 + && r.taken_millis < 1200000 // 20 minutes }) .map(|r| r.taken_millis); - let length = revlogs_filter.clone().count() as f64; - if length > 0.0 { - revlogs_filter.sum::() as f64 / length / 1000.0 - } else { - return Err(AnkiError::FsrsInsufficientData); - } + let group_vec = revlogs_filter.collect_vec(); + median_secs(&group_vec) }; + if learn_cost == 0.0 { + return Err(AnkiError::FsrsInsufficientData); + } let forget_cost = { let review_kind_to_total_millis = revlogs .iter() + .filter(|r| { + r.button_chosen > 0 && r.taken_millis > 0 && r.taken_millis < 1200000 + // 20 minutes + }) .sorted_by(|a, b| a.cid.cmp(&b.cid).then(a.id.cmp(&b.id))) .group_by(|r| r.review_kind) /* @@ -192,12 +195,7 @@ impl Collection { } let mut arr = [0.0; 5]; for (review_kind, group) in group_sec_by_review_kind.iter().enumerate() { - let average_secs = group.iter().sum::() as f64 / group.len() as f64 / 1000.0; - arr[review_kind] = if average_secs.is_nan() { - 0.0 - } else { - average_secs - } + arr[review_kind] = median_secs(group); } arr }; @@ -221,3 +219,20 @@ impl Collection { Ok(params) } } + +fn median_secs(group: &Vec) -> f64 { + let length = group.len(); + if length > 0 { + let mut group_vec = group.clone(); + group_vec.sort_unstable(); + let median_millis = if length % 2 == 0 { + let mid = length / 2; + (group_vec[mid - 1] + group_vec[mid]) as f64 / 2.0 + } else { + group_vec[length / 2] as f64 + }; + median_millis / 1000.0 + } else { + 0.0 + } +}