mirror of
https://github.com/ankitects/anki.git
synced 2026-01-13 22:13:58 -05:00
implement smoothing & a separate struct for return value
This commit is contained in:
parent
e3afbd93e2
commit
9c955a2bbb
3 changed files with 81 additions and 16 deletions
|
|
@ -103,19 +103,49 @@ impl crate::services::DeckConfigService for Collection {
|
|||
) -> Result<anki_proto::deck_config::GetRetentionWorkloadResponse> {
|
||||
const LEARN_SPAN: usize = 100_000_000;
|
||||
const TERMINATION_PROB: f32 = 0.01;
|
||||
// the default values are from https://github.com/open-spaced-repetition/Anki-button-usage/blob/881009015c2a85ac911021d76d0aacb124849937/analysis.ipynb
|
||||
const DEFAULT_LEARN_COST: f32 = 19.4698;
|
||||
const DEFAULT_PASS_COST: f32 = 7.8454;
|
||||
const DEFAULT_FAIL_COST: f32 = 23.185;
|
||||
const DEFAULT_INITIAL_PASS_RATE: f32 = 0.7645;
|
||||
|
||||
let guard =
|
||||
self.search_cards_into_table(&input.search, crate::search::SortMode::NoOrder)?;
|
||||
let (pass_cost, fail_cost, learn_cost, initial_pass_rate) =
|
||||
guard.col.storage.get_costs_for_retention()?;
|
||||
let costs = guard.col.storage.get_costs_for_retention()?;
|
||||
|
||||
fn smoothing(obs: f32, default: f32, count: u32) -> f32 {
|
||||
let alpha = count as f32 / (50.0 + count as f32);
|
||||
obs * alpha + default * (1.0 - alpha)
|
||||
}
|
||||
|
||||
let cost_success = smoothing(
|
||||
costs.average_pass_time_ms / 1000.0,
|
||||
DEFAULT_PASS_COST,
|
||||
costs.pass_count,
|
||||
);
|
||||
let cost_failure = smoothing(
|
||||
costs.average_fail_time_ms / 1000.0,
|
||||
DEFAULT_FAIL_COST,
|
||||
costs.fail_count,
|
||||
);
|
||||
let cost_learn = smoothing(
|
||||
costs.average_learn_time_ms / 1000.0,
|
||||
DEFAULT_LEARN_COST,
|
||||
costs.learn_count,
|
||||
);
|
||||
let initial_pass_rate = smoothing(
|
||||
costs.initial_pass_rate,
|
||||
DEFAULT_INITIAL_PASS_RATE,
|
||||
costs.pass_count,
|
||||
);
|
||||
|
||||
let before = fsrs::expected_workload(
|
||||
&input.w,
|
||||
input.before,
|
||||
LEARN_SPAN,
|
||||
pass_cost,
|
||||
fail_cost,
|
||||
learn_cost,
|
||||
cost_success,
|
||||
cost_failure,
|
||||
cost_learn,
|
||||
initial_pass_rate,
|
||||
TERMINATION_PROB,
|
||||
)?;
|
||||
|
|
@ -123,9 +153,9 @@ impl crate::services::DeckConfigService for Collection {
|
|||
&input.w,
|
||||
input.after,
|
||||
LEARN_SPAN,
|
||||
pass_cost,
|
||||
fail_cost,
|
||||
learn_cost,
|
||||
cost_success,
|
||||
cost_failure,
|
||||
cost_learn,
|
||||
initial_pass_rate,
|
||||
TERMINATION_PROB,
|
||||
)?;
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ WITH searched_revlogs AS (
|
|||
SELECT AVG(time)
|
||||
FROM searched_revlogs
|
||||
WHERE ease > 1
|
||||
AND type = 1
|
||||
),
|
||||
lapse_count AS (
|
||||
SELECT COUNT(time) AS lapse_count
|
||||
|
|
@ -56,9 +57,29 @@ initial_pass_rate AS (
|
|||
) AS initial_pass_rate
|
||||
FROM searched_revlogs
|
||||
WHERE rank_num = 1
|
||||
),
|
||||
pass_cnt AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM searched_revlogs
|
||||
WHERE ease > 1
|
||||
AND type = 1
|
||||
),
|
||||
fail_cnt AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM searched_revlogs
|
||||
WHERE ease = 1
|
||||
AND type = 1
|
||||
),
|
||||
learn_cnt AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM searched_revlogs
|
||||
WHERE type = 0
|
||||
)
|
||||
SELECT *
|
||||
FROM average_pass,
|
||||
average_fail,
|
||||
average_learn,
|
||||
initial_pass_rate;
|
||||
initial_pass_rate,
|
||||
pass_cnt,
|
||||
fail_cnt,
|
||||
learn_cnt;
|
||||
|
|
@ -42,6 +42,17 @@ use crate::timestamp::TimestampMillis;
|
|||
use crate::timestamp::TimestampSecs;
|
||||
use crate::types::Usn;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetentionCosts {
|
||||
pub average_pass_time_ms: f32,
|
||||
pub average_fail_time_ms: f32,
|
||||
pub average_learn_time_ms: f32,
|
||||
pub initial_pass_rate: f32,
|
||||
pub pass_count: u32,
|
||||
pub fail_count: u32,
|
||||
pub learn_count: u32,
|
||||
}
|
||||
|
||||
impl FromSql for CardType {
|
||||
fn column_result(value: ValueRef<'_>) -> result::Result<Self, FromSqlError> {
|
||||
if let ValueRef::Integer(i) = value {
|
||||
|
|
@ -747,19 +758,22 @@ impl super::SqliteStorage {
|
|||
.get(0)?)
|
||||
}
|
||||
|
||||
pub(crate) fn get_costs_for_retention(&self) -> Result<(f32, f32, f32, f32)> {
|
||||
pub(crate) fn get_costs_for_retention(&self) -> Result<RetentionCosts> {
|
||||
let mut statement = self
|
||||
.db
|
||||
.prepare(include_str!("get_costs_for_retention.sql"))?;
|
||||
let mut query = statement.query(params![])?;
|
||||
let row = query.next()?.unwrap();
|
||||
|
||||
Ok((
|
||||
row.get(0).unwrap_or(7000.),
|
||||
row.get(1).unwrap_or(23_000.),
|
||||
row.get(2).unwrap_or(30_000.),
|
||||
row.get(3).unwrap_or(0.5),
|
||||
))
|
||||
Ok(RetentionCosts {
|
||||
average_pass_time_ms: row.get(0).unwrap_or(7000.),
|
||||
average_fail_time_ms: row.get(1).unwrap_or(23_000.),
|
||||
average_learn_time_ms: row.get(2).unwrap_or(30_000.),
|
||||
initial_pass_rate: row.get(3).unwrap_or(0.5),
|
||||
pass_count: row.get(4).unwrap_or(0),
|
||||
fail_count: row.get(5).unwrap_or(0),
|
||||
learn_count: row.get(6).unwrap_or(0),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
Loading…
Reference in a new issue