Use median in calculating cost and remove outliers (#3181)

* Use median in calculating cost and remove outliers

* extract fn median_secs
This commit is contained in:
Jarrett Ye 2024-05-02 18:16:04 +08:00 committed by GitHub
parent 8ad65d40ff
commit c9c7a3133c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -135,15 +135,14 @@ impl Collection {
r.review_kind == RevlogReviewKind::Review
&& r.button_chosen > 0
&& r.taken_millis > 0
&& r.taken_millis < 1200000 // 20 minutes
})
.sorted_by(|a, b| a.button_chosen.cmp(&b.button_chosen))
.group_by(|r| r.button_chosen)
.into_iter()
.for_each(|(button_chosen, group)| {
let group_vec = group.into_iter().map(|r| r.taken_millis).collect_vec();
let average_secs =
group_vec.iter().sum::<u32>() as f64 / group_vec.len() as f64 / 1000.0;
arr[button_chosen as usize - 1] = average_secs
arr[button_chosen as usize - 1] = median_secs(&group_vec);
});
if arr == default {
return Err(AnkiError::FsrsInsufficientData);
@ -157,19 +156,23 @@ impl Collection {
r.review_kind == RevlogReviewKind::Learning
&& r.button_chosen >= 1
&& r.taken_millis > 0
&& r.taken_millis < 1200000 // 20 minutes
})
.map(|r| r.taken_millis);
let length = revlogs_filter.clone().count() as f64;
if length > 0.0 {
revlogs_filter.sum::<u32>() as f64 / length / 1000.0
} else {
return Err(AnkiError::FsrsInsufficientData);
}
let group_vec = revlogs_filter.collect_vec();
median_secs(&group_vec)
};
if learn_cost == 0.0 {
return Err(AnkiError::FsrsInsufficientData);
}
let forget_cost = {
let review_kind_to_total_millis = revlogs
.iter()
.filter(|r| {
r.button_chosen > 0 && r.taken_millis > 0 && r.taken_millis < 1200000
// 20 minutes
})
.sorted_by(|a, b| a.cid.cmp(&b.cid).then(a.id.cmp(&b.id)))
.group_by(|r| r.review_kind)
/*
@ -192,12 +195,7 @@ impl Collection {
}
let mut arr = [0.0; 5];
for (review_kind, group) in group_sec_by_review_kind.iter().enumerate() {
let average_secs = group.iter().sum::<u32>() as f64 / group.len() as f64 / 1000.0;
arr[review_kind] = if average_secs.is_nan() {
0.0
} else {
average_secs
}
arr[review_kind] = median_secs(group);
}
arr
};
@ -221,3 +219,20 @@ impl Collection {
Ok(params)
}
}
fn median_secs(group: &Vec<u32>) -> f64 {
let length = group.len();
if length > 0 {
let mut group_vec = group.clone();
group_vec.sort_unstable();
let median_millis = if length % 2 == 0 {
let mid = length / 2;
(group_vec[mid - 1] + group_vec[mid]) as f64 / 2.0
} else {
group_vec[length / 2] as f64
};
median_millis / 1000.0
} else {
0.0
}
}