Feat/Use cached workload values (#4208)

* Feat/Use cached workload values

* Fix: Calculation when unchanged

* Modify constants

* Cache clearing logic

* Use function params

* use https://github.com/open-spaced-repetition/fsrs-rs/pull/352

* Revert "use https://github.com/open-spaced-repetition/fsrs-rs/pull/352"

This reverts commit 72efcf230c.

* Reapply "use https://github.com/open-spaced-repetition/fsrs-rs/pull/352"

This reverts commit 49eab2969f.

* ./check

* bump fsrs
This commit is contained in:
Luc Mcgrady 2025-07-28 10:00:16 +01:00 committed by GitHub
parent 1af3c58d40
commit c947690aeb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 50 additions and 198 deletions

4
Cargo.lock generated
View file

@ -2214,9 +2214,9 @@ dependencies = [
[[package]]
name = "fsrs"
version = "4.1.1"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1f3a8c3df2c324ebab71461178fe8c1fe2d7373cf603f312b652befd026f06d"
checksum = "f590cfcbe25079bb54a39900f45e6e308935bd6067249ce00d265b280465cde2"
dependencies = [
"burn",
"itertools 0.14.0",

View file

@ -33,8 +33,9 @@ git = "https://github.com/ankitects/linkcheck.git"
rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca"
[workspace.dependencies.fsrs]
version = "4.1.1"
version = "5.0.0"
# git = "https://github.com/open-spaced-repetition/fsrs-rs.git"
# branch = "Refactor/expected_workload_via_dp"
# rev = "a7f7efc10f0a26b14ee348cc7402155685f2a24f"
# path = "../open-spaced-repetition/fsrs-rs"

View file

@ -1450,7 +1450,7 @@
},
{
"name": "fsrs",
"version": "4.1.1",
"version": "5.0.0",
"authors": "Open Spaced Repetition",
"repository": "https://github.com/open-spaced-repetition/fsrs-rs",
"license": "BSD-3-Clause",

View file

@ -40,12 +40,10 @@ message DeckConfigId {
message GetRetentionWorkloadRequest {
repeated float w = 1;
string search = 2;
float before = 3;
float after = 4;
}
message GetRetentionWorkloadResponse {
float factor = 1;
map<uint32, float> costs = 1;
}
message GetIgnoredBeforeCountRequest {

View file

@ -1,5 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::collections::HashMap;
use anki_proto::generic;
use crate::collection::Collection;
@ -101,68 +103,26 @@ impl crate::services::DeckConfigService for Collection {
&mut self,
input: anki_proto::deck_config::GetRetentionWorkloadRequest,
) -> Result<anki_proto::deck_config::GetRetentionWorkloadResponse> {
const LEARN_SPAN: usize = 100_000_000;
const TERMINATION_PROB: f32 = 0.001;
// the default values are from https://github.com/open-spaced-repetition/Anki-button-usage/blob/881009015c2a85ac911021d76d0aacb124849937/analysis.ipynb
const DEFAULT_LEARN_COST: f32 = 19.4698;
const DEFAULT_PASS_COST: f32 = 7.8454;
const DEFAULT_FAIL_COST: f32 = 23.185;
const DEFAULT_INITIAL_PASS_RATE: f32 = 0.7645;
let guard =
self.search_cards_into_table(&input.search, crate::search::SortMode::NoOrder)?;
let costs = guard.col.storage.get_costs_for_retention()?;
fn smoothing(obs: f32, default: f32, count: u32) -> f32 {
let alpha = count as f32 / (50.0 + count as f32);
obs * alpha + default * (1.0 - alpha)
}
let revlogs = guard
.col
.storage
.get_revlog_entries_for_searched_cards_in_card_order()?;
let cost_success = smoothing(
costs.average_pass_time_ms / 1000.0,
DEFAULT_PASS_COST,
costs.pass_count,
);
let cost_failure = smoothing(
costs.average_fail_time_ms / 1000.0,
DEFAULT_FAIL_COST,
costs.fail_count,
);
let cost_learn = smoothing(
costs.average_learn_time_ms / 1000.0,
DEFAULT_LEARN_COST,
costs.learn_count,
);
let initial_pass_rate = smoothing(
costs.initial_pass_rate,
DEFAULT_INITIAL_PASS_RATE,
costs.pass_count,
);
let config = guard.col.get_optimal_retention_parameters(revlogs)?;
let before = fsrs::expected_workload(
&input.w,
input.before,
LEARN_SPAN,
cost_success,
cost_failure,
cost_learn,
initial_pass_rate,
TERMINATION_PROB,
)?;
let after = fsrs::expected_workload(
&input.w,
input.after,
LEARN_SPAN,
cost_success,
cost_failure,
cost_learn,
initial_pass_rate,
TERMINATION_PROB,
)?;
let costs = (70u32..=99u32)
.map(|dr| {
Ok((
dr,
fsrs::expected_workload(&input.w, dr as f32 / 100., &config)?,
))
})
.collect::<Result<HashMap<_, _>>>()?;
Ok(anki_proto::deck_config::GetRetentionWorkloadResponse {
factor: after / before,
})
Ok(anki_proto::deck_config::GetRetentionWorkloadResponse { costs })
}
}

View file

@ -1,85 +0,0 @@
WITH searched_revlogs AS (
SELECT *,
RANK() OVER (
PARTITION BY cid
ORDER BY id ASC
) AS rank_num
FROM revlog
WHERE ease > 0
AND cid IN search_cids
ORDER BY id DESC -- Use the last 10_000 reviews
LIMIT 10000
), average_pass AS (
SELECT AVG(time)
FROM searched_revlogs
WHERE ease > 1
AND type = 1
),
lapse_count AS (
SELECT COUNT(time) AS lapse_count
FROM searched_revlogs
WHERE ease = 1
AND type = 1
),
fail_sum AS (
SELECT SUM(time) AS total_fail_time
FROM searched_revlogs
WHERE (
ease = 1
AND type = 1
)
OR type = 2
),
-- (sum(Relearning) + sum(Lapses)) / count(Lapses)
average_fail AS (
SELECT total_fail_time * 1.0 / NULLIF(lapse_count, 0) AS avg_fail_time
FROM fail_sum,
lapse_count
),
-- Can lead to cards with partial learn histories skewing the time
summed_learns AS (
SELECT cid,
SUM(time) AS total_time
FROM searched_revlogs
WHERE searched_revlogs.type = 0
GROUP BY cid
),
average_learn AS (
SELECT AVG(total_time) AS avg_learn_time
FROM summed_learns
),
initial_pass_rate AS (
SELECT AVG(
CASE
WHEN ease > 1 THEN 1.0
ELSE 0.0
END
) AS initial_pass_rate
FROM searched_revlogs
WHERE rank_num = 1
),
pass_cnt AS (
SELECT COUNT(*) AS cnt
FROM searched_revlogs
WHERE ease > 1
AND type = 1
),
fail_cnt AS (
SELECT COUNT(*) AS cnt
FROM searched_revlogs
WHERE ease = 1
AND type = 1
),
learn_cnt AS (
SELECT COUNT(*) AS cnt
FROM searched_revlogs
WHERE type = 0
)
SELECT *
FROM average_pass,
average_fail,
average_learn,
initial_pass_rate,
pass_cnt,
fail_cnt,
learn_cnt;

View file

@ -42,17 +42,6 @@ use crate::timestamp::TimestampMillis;
use crate::timestamp::TimestampSecs;
use crate::types::Usn;
#[derive(Debug, Clone, Default)]
pub struct RetentionCosts {
pub average_pass_time_ms: f32,
pub average_fail_time_ms: f32,
pub average_learn_time_ms: f32,
pub initial_pass_rate: f32,
pub pass_count: u32,
pub fail_count: u32,
pub learn_count: u32,
}
impl FromSql for CardType {
fn column_result(value: ValueRef<'_>) -> result::Result<Self, FromSqlError> {
if let ValueRef::Integer(i) = value {
@ -759,24 +748,6 @@ impl super::SqliteStorage {
.get(0)?)
}
pub(crate) fn get_costs_for_retention(&self) -> Result<RetentionCosts> {
let mut statement = self
.db
.prepare(include_str!("get_costs_for_retention.sql"))?;
let mut query = statement.query(params![])?;
let row = query.next()?.unwrap();
Ok(RetentionCosts {
average_pass_time_ms: row.get(0).unwrap_or(7000.),
average_fail_time_ms: row.get(1).unwrap_or(23_000.),
average_learn_time_ms: row.get(2).unwrap_or(30_000.),
initial_pass_rate: row.get(3).unwrap_or(0.5),
pass_count: row.get(4).unwrap_or(0),
fail_count: row.get(5).unwrap_or(0),
learn_count: row.get(6).unwrap_or(0),
})
}
#[cfg(test)]
pub(crate) fn get_all_cards(&self) -> Vec<Card> {
self.db

View file

@ -29,6 +29,7 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
import SimulatorModal from "./SimulatorModal.svelte";
import {
GetRetentionWorkloadRequest,
type GetRetentionWorkloadResponse,
UpdateDeckConfigsMode,
} from "@generated/anki/deck_config_pb";
import type Modal from "bootstrap/js/dist/modal";
@ -69,19 +70,9 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
$: roundedRetention = Number(effectiveDesiredRetention.toFixed(2));
$: desiredRetentionWarning = getRetentionLongShortWarning(roundedRetention);
let timeoutId: ReturnType<typeof setTimeout> | undefined = undefined;
const WORKLOAD_UPDATE_DELAY_MS = 100;
let desiredRetentionChangeInfo = "";
$: {
clearTimeout(timeoutId);
if (showDesiredRetentionTooltip) {
timeoutId = setTimeout(() => {
getRetentionChangeInfo(roundedRetention, fsrsParams($config));
}, WORKLOAD_UPDATE_DELAY_MS);
} else {
desiredRetentionChangeInfo = "";
}
$: if (showDesiredRetentionTooltip) {
getRetentionChangeInfo(roundedRetention, fsrsParams($config));
}
$: retentionWarningClass = getRetentionWarningClass(roundedRetention);
@ -137,21 +128,37 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
}
}
let retentionWorloadInfo: undefined | Promise<GetRetentionWorkloadResponse> =
undefined;
let lastParams = [...fsrsParams($config)];
async function getRetentionChangeInfo(retention: number, params: number[]) {
if (+startingDesiredRetention == roundedRetention) {
desiredRetentionChangeInfo = tr.deckConfigWorkloadFactorUnchanged();
return;
}
const request = new GetRetentionWorkloadRequest({
w: params,
search: defaultparamSearch,
before: +startingDesiredRetention,
after: retention,
});
const resp = await getRetentionWorkload(request);
if (
// If the cache is empty and a request has not yet been made to fill it
!retentionWorloadInfo ||
// If the parameters have been changed
lastParams.toString() !== params.toString()
) {
const request = new GetRetentionWorkloadRequest({
w: params,
search: defaultparamSearch,
});
lastParams = [...params];
retentionWorloadInfo = getRetentionWorkload(request);
}
const previous = +startingDesiredRetention * 100;
const after = retention * 100;
const resp = await retentionWorloadInfo;
const factor = resp.costs[after] / resp.costs[previous];
desiredRetentionChangeInfo = tr.deckConfigWorkloadFactorChange({
factor: resp.factor.toFixed(2),
previousDr: (+startingDesiredRetention * 100).toString(),
factor: factor.toFixed(2),
previousDr: previous.toString(),
});
}