implement smoothing & a separate struct for return value

This commit is contained in:
Jarrett Ye 2025-06-04 16:46:42 +08:00
parent e3afbd93e2
commit 9c955a2bbb
No known key found for this signature in database
GPG key ID: EBFC55E0C1A352BB
3 changed files with 81 additions and 16 deletions

View file

@ -103,19 +103,49 @@ impl crate::services::DeckConfigService for Collection {
) -> Result<anki_proto::deck_config::GetRetentionWorkloadResponse> {
const LEARN_SPAN: usize = 100_000_000;
const TERMINATION_PROB: f32 = 0.01;
// the default values are from https://github.com/open-spaced-repetition/Anki-button-usage/blob/881009015c2a85ac911021d76d0aacb124849937/analysis.ipynb
const DEFAULT_LEARN_COST: f32 = 19.4698;
const DEFAULT_PASS_COST: f32 = 7.8454;
const DEFAULT_FAIL_COST: f32 = 23.185;
const DEFAULT_INITIAL_PASS_RATE: f32 = 0.7645;
let guard =
self.search_cards_into_table(&input.search, crate::search::SortMode::NoOrder)?;
let (pass_cost, fail_cost, learn_cost, initial_pass_rate) =
guard.col.storage.get_costs_for_retention()?;
let costs = guard.col.storage.get_costs_for_retention()?;
fn smoothing(obs: f32, default: f32, count: u32) -> f32 {
let alpha = count as f32 / (50.0 + count as f32);
obs * alpha + default * (1.0 - alpha)
}
let cost_success = smoothing(
costs.average_pass_time_ms / 1000.0,
DEFAULT_PASS_COST,
costs.pass_count,
);
let cost_failure = smoothing(
costs.average_fail_time_ms / 1000.0,
DEFAULT_FAIL_COST,
costs.fail_count,
);
let cost_learn = smoothing(
costs.average_learn_time_ms / 1000.0,
DEFAULT_LEARN_COST,
costs.learn_count,
);
let initial_pass_rate = smoothing(
costs.initial_pass_rate,
DEFAULT_INITIAL_PASS_RATE,
costs.pass_count,
);
let before = fsrs::expected_workload(
&input.w,
input.before,
LEARN_SPAN,
pass_cost,
fail_cost,
learn_cost,
cost_success,
cost_failure,
cost_learn,
initial_pass_rate,
TERMINATION_PROB,
)?;
@ -123,9 +153,9 @@ impl crate::services::DeckConfigService for Collection {
&input.w,
input.after,
LEARN_SPAN,
pass_cost,
fail_cost,
learn_cost,
cost_success,
cost_failure,
cost_learn,
initial_pass_rate,
TERMINATION_PROB,
)?;

View file

@ -13,6 +13,7 @@ WITH searched_revlogs AS (
SELECT AVG(time)
FROM searched_revlogs
WHERE ease > 1
AND type = 1
),
lapse_count AS (
SELECT COUNT(time) AS lapse_count
@ -56,9 +57,29 @@ initial_pass_rate AS (
) AS initial_pass_rate
FROM searched_revlogs
WHERE rank_num = 1
),
pass_cnt AS (
SELECT COUNT(*) AS cnt
FROM searched_revlogs
WHERE ease > 1
AND type = 1
),
fail_cnt AS (
SELECT COUNT(*) AS cnt
FROM searched_revlogs
WHERE ease = 1
AND type = 1
),
learn_cnt AS (
SELECT COUNT(*) AS cnt
FROM searched_revlogs
WHERE type = 0
)
SELECT *
FROM average_pass,
average_fail,
average_learn,
initial_pass_rate;
initial_pass_rate,
pass_cnt,
fail_cnt,
learn_cnt;

View file

@ -42,6 +42,17 @@ use crate::timestamp::TimestampMillis;
use crate::timestamp::TimestampSecs;
use crate::types::Usn;
#[derive(Debug, Clone, Default)]
pub struct RetentionCosts {
pub average_pass_time_ms: f32,
pub average_fail_time_ms: f32,
pub average_learn_time_ms: f32,
pub initial_pass_rate: f32,
pub pass_count: u32,
pub fail_count: u32,
pub learn_count: u32,
}
impl FromSql for CardType {
fn column_result(value: ValueRef<'_>) -> result::Result<Self, FromSqlError> {
if let ValueRef::Integer(i) = value {
@ -747,19 +758,22 @@ impl super::SqliteStorage {
.get(0)?)
}
pub(crate) fn get_costs_for_retention(&self) -> Result<(f32, f32, f32, f32)> {
pub(crate) fn get_costs_for_retention(&self) -> Result<RetentionCosts> {
let mut statement = self
.db
.prepare(include_str!("get_costs_for_retention.sql"))?;
let mut query = statement.query(params![])?;
let row = query.next()?.unwrap();
Ok((
row.get(0).unwrap_or(7000.),
row.get(1).unwrap_or(23_000.),
row.get(2).unwrap_or(30_000.),
row.get(3).unwrap_or(0.5),
))
Ok(RetentionCosts {
average_pass_time_ms: row.get(0).unwrap_or(7000.),
average_fail_time_ms: row.get(1).unwrap_or(23_000.),
average_learn_time_ms: row.get(2).unwrap_or(30_000.),
initial_pass_rate: row.get(3).unwrap_or(0.5),
pass_count: row.get(4).unwrap_or(0),
fail_count: row.get(5).unwrap_or(0),
learn_count: row.get(6).unwrap_or(0),
})
}
#[cfg(test)]