Feat/FSRS Simulator (#3257)

* test using existed cards

* plot new and review

* convert learning cards & use line chart

* allow draw multiple simulations in the same chart

* support hide simulation

* convert x axis to Date

* convert y from second to minute

* support clear last simulation

* remove unused import

* rename

* add hover/tooltip

* fallback to default parameters

* update default value and maximum of deckSize

* add "processing..."

* fix mistake
This commit is contained in:
Jarrett Ye 2024-08-22 16:34:19 +08:00 committed by GitHub
parent e92aaa4478
commit 8ed9f49bdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 457 additions and 22 deletions

View file

@ -625,6 +625,7 @@ exposed_backend_list = [
"set_wants_abort",
"evaluate_weights",
"get_optimal_retention_parameters",
"simulate_fsrs_review",
]

View file

@ -5,8 +5,10 @@ use anki_proto::scheduler::SimulateFsrsReviewRequest;
use anki_proto::scheduler::SimulateFsrsReviewResponse;
use fsrs::simulate;
use fsrs::SimulatorConfig;
use fsrs::DEFAULT_PARAMETERS;
use itertools::Itertools;
use crate::card::CardQueue;
use crate::prelude::*;
use crate::search::SortMode;
@ -22,9 +24,15 @@ impl Collection {
.get_revlog_entries_for_searched_cards_in_card_order()?;
let cards = guard.col.storage.all_searched_cards()?;
drop(guard);
let days_elapsed = self.timing_today().unwrap().days_elapsed as i32;
let converted_cards = cards
.into_iter()
.filter(|c| c.queue != CardQueue::Suspended && c.queue != CardQueue::PreviewRepeat)
.filter_map(|c| Card::convert(c, days_elapsed, req.days_to_simulate))
.collect_vec();
let p = self.get_optimal_retention_parameters(revlogs)?;
let config = SimulatorConfig {
deck_size: req.deck_size as usize,
deck_size: req.deck_size as usize + converted_cards.len(),
learn_span: req.days_to_simulate as usize,
max_cost_perday: f32::MAX,
max_ivl: req.max_interval as f32,
@ -40,7 +48,19 @@ impl Collection {
learn_limit: req.new_limit as usize,
review_limit: req.review_limit as usize,
};
let days_elapsed = self.timing_today().unwrap().days_elapsed as i32;
let parameters = if req.weights.is_empty() {
DEFAULT_PARAMETERS.to_vec()
} else if req.weights.len() != 19 {
if req.weights.len() == 17 {
let mut parameters = req.weights.to_vec();
parameters.extend_from_slice(&[0.0, 0.0]);
parameters
} else {
return Err(AnkiError::FsrsWeightsInvalid);
}
} else {
req.weights.to_vec()
};
let (
accumulated_knowledge_acquisition,
daily_review_count,
@ -48,15 +68,10 @@ impl Collection {
daily_time_cost,
) = simulate(
&config,
&req.weights,
&parameters,
req.desired_retention,
None,
Some(
cards
.into_iter()
.filter_map(|c| Card::convert(c, days_elapsed))
.collect_vec(),
),
Some(converted_cards),
);
Ok(SimulateFsrsReviewResponse {
accumulated_knowledge_acquisition: accumulated_knowledge_acquisition.to_vec(),
@ -68,9 +83,10 @@ impl Collection {
}
impl Card {
fn convert(card: Card, days_elapsed: i32) -> Option<fsrs::Card> {
fn convert(card: Card, days_elapsed: i32, day_to_simulate: u32) -> Option<fsrs::Card> {
match card.memory_state {
Some(state) => {
Some(state) => match card.queue {
CardQueue::DayLearn | CardQueue::Review => {
let due = card.original_or_current_due();
let relative_due = due - days_elapsed;
Some(fsrs::Card {
@ -80,7 +96,29 @@ impl Card {
due: relative_due as f32,
})
}
None => None,
CardQueue::New => Some(fsrs::Card {
difficulty: 1e-10,
stability: 1e-10,
last_date: 0.0,
due: day_to_simulate as f32,
}),
CardQueue::Learn | CardQueue::SchedBuried | CardQueue::UserBuried => {
Some(fsrs::Card {
difficulty: state.difficulty,
stability: state.stability,
last_date: 0.0,
due: 0.0,
})
}
CardQueue::PreviewRepeat => None,
CardQueue::Suspended => None,
},
None => Some(fsrs::Card {
difficulty: 1e-10,
stability: 1e-10,
last_date: 0.0,
due: day_to_simulate as f32,
}),
}
}
}

View file

@ -7,10 +7,15 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
ComputeRetentionProgress,
type ComputeWeightsProgress,
} from "@generated/anki/collection_pb";
import { ComputeOptimalRetentionRequest } from "@generated/anki/scheduler_pb";
import {
ComputeOptimalRetentionRequest,
SimulateFsrsReviewRequest,
type SimulateFsrsReviewResponse,
} from "@generated/anki/scheduler_pb";
import {
computeFsrsWeights,
computeOptimalRetention,
simulateFsrsReview,
evaluateWeights,
setWantsAbort,
} from "@generated/backend";
@ -28,6 +33,14 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
import Warning from "./Warning.svelte";
import WeightsInputRow from "./WeightsInputRow.svelte";
import WeightsSearchRow from "./WeightsSearchRow.svelte";
import { renderSimulationChart, type Point } from "../graphs/simulator";
import Graph from "../graphs/Graph.svelte";
import HoverColumns from "../graphs/HoverColumns.svelte";
import CumulativeOverlay from "../graphs/CumulativeOverlay.svelte";
import AxisTicks from "../graphs/AxisTicks.svelte";
import NoDataOverlay from "../graphs/NoDataOverlay.svelte";
import TableData from "../graphs/TableData.svelte";
import { defaultGraphBounds, type TableDatum } from "../graphs/graph-helpers";
export let state: DeckOptionsState;
export let openHelpModal: (String) => void;
@ -68,6 +81,17 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
optimalRetentionRequest.daysToSimulate = 3650;
}
const simulateFsrsRequest = new SimulateFsrsReviewRequest({
weights: $config.fsrsWeights,
desiredRetention: $config.desiredRetention,
deckSize: 0,
daysToSimulate: 365,
newLimit: $config.newPerDay,
reviewLimit: $config.reviewsPerDay,
maxInterval: $config.maximumReviewInterval,
search: `preset:"${state.getCurrentName()}" -is:suspended`,
});
function getRetentionWarning(retention: number): string {
const decay = -0.5;
const factor = 0.9 ** (1 / decay) - 1;
@ -256,6 +280,69 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
}
return tr.deckConfigPredictedOptimalRetention({ num: retention.toFixed(2) });
}
let tableData: TableDatum[] = [] as any;
const bounds = defaultGraphBounds();
let svg = null as HTMLElement | SVGElement | null;
const title = tr.statisticsReviewsTitle();
let simulationNumber = 0;
let points: Point[] = [];
function movingAverage(y: number[], windowSize: number): number[] {
const result: number[] = [];
for (let i = 0; i < y.length; i++) {
let sum = 0;
let count = 0;
for (let j = Math.max(0, i - windowSize + 1); j <= i; j++) {
sum += y[j];
count++;
}
result.push(sum / count);
}
return result;
}
$: simulateProgressString = "";
async function simulateFsrs(): Promise<void> {
let resp: SimulateFsrsReviewResponse | undefined;
simulationNumber += 1;
try {
await runWithBackendProgress(
async () => {
simulateFsrsRequest.weights = $config.fsrsWeights;
simulateFsrsRequest.desiredRetention = $config.desiredRetention;
simulateFsrsRequest.search = `preset:"${state.getCurrentName()}" -is:suspended`;
simulateProgressString = "processing...";
resp = await simulateFsrsReview(simulateFsrsRequest);
},
() => {},
);
} finally {
if (resp) {
simulateProgressString = "";
const dailyTimeCost = movingAverage(
resp.dailyTimeCost,
Math.round(simulateFsrsRequest.daysToSimulate / 50),
);
points = points.concat(
dailyTimeCost.map((v, i) => ({
x: i,
y: v,
label: simulationNumber,
})),
);
tableData = renderSimulationChart(svg as SVGElement, bounds, points);
}
}
}
function clearSimulation(): void {
points = points.filter((p) => p.label !== simulationNumber);
simulationNumber = Math.max(0, simulationNumber - 1);
tableData = renderSimulationChart(svg as SVGElement, bounds, points);
}
</script>
<SpinBoxFloatRow
@ -377,5 +464,94 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
</details>
</div>
<div class="m-2">
<details>
<summary>FSRS simulator (experimental)</summary>
<SpinBoxRow
bind:value={simulateFsrsRequest.daysToSimulate}
defaultValue={365}
min={1}
max={3650}
>
<SettingTitle on:click={() => openHelpModal("simulateFsrsReview")}>
Days to simulate
</SettingTitle>
</SpinBoxRow>
<SpinBoxRow
bind:value={simulateFsrsRequest.deckSize}
defaultValue={0}
min={1}
max={100000}
>
<SettingTitle on:click={() => openHelpModal("simulateFsrsReview")}>
Additional new cards to simulate
</SettingTitle>
</SpinBoxRow>
<SpinBoxRow
bind:value={simulateFsrsRequest.newLimit}
defaultValue={defaults.newPerDay}
min={0}
max={1000}
>
<SettingTitle on:click={() => openHelpModal("simulateFsrsReview")}>
New cards/day
</SettingTitle>
</SpinBoxRow>
<SpinBoxRow
bind:value={simulateFsrsRequest.reviewLimit}
defaultValue={defaults.reviewsPerDay}
min={0}
max={1000}
>
<SettingTitle on:click={() => openHelpModal("simulateFsrsReview")}>
Maximum reviews/day
</SettingTitle>
</SpinBoxRow>
<SpinBoxRow
bind:value={simulateFsrsRequest.maxInterval}
defaultValue={defaults.maximumReviewInterval}
min={1}
max={36500}
>
<SettingTitle on:click={() => openHelpModal("simulateFsrsReview")}>
Maximum interval
</SettingTitle>
</SpinBoxRow>
<button
class="btn {computing ? 'btn-warning' : 'btn-primary'}"
disabled={computing}
on:click={() => simulateFsrs()}
>
{"Simulate"}
</button>
<button
class="btn {computing ? 'btn-warning' : 'btn-primary'}"
disabled={computing}
on:click={() => clearSimulation()}
>
{"Clear last simulation"}
</button>
<div>{simulateProgressString}</div>
<Graph {title}>
<svg bind:this={svg} viewBox={`0 0 ${bounds.width} ${bounds.height}`}>
<CumulativeOverlay />
<HoverColumns />
<AxisTicks {bounds} />
<NoDataOverlay {bounds} />
</svg>
<TableData {tableData} />
</Graph>
</details>
</div>
<style>
</style>

View file

@ -0,0 +1,220 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
import { localizedNumber } from "@tslib/i18n";
import {
axisBottom,
axisLeft,
bisector,
line,
max,
pointer,
rollup,
scaleLinear,
scaleTime,
schemeCategory10,
select,
timeFormat,
} from "d3";
import type { GraphBounds, TableDatum } from "./graph-helpers";
import { setDataAvailable } from "./graph-helpers";
import { hideTooltip, showTooltip } from "./tooltip";
export interface Point {
x: number;
y: number;
label: number;
}
export function renderSimulationChart(
svgElem: SVGElement,
bounds: GraphBounds,
data: Point[],
): TableDatum[] {
const svg = select(svgElem);
svg.selectAll(".lines").remove();
svg.selectAll(".hover-columns").remove();
svg.selectAll(".focus-line").remove();
svg.selectAll(".legend").remove();
if (data.length == 0) {
setDataAvailable(svg, false);
return [];
}
const trans = svg.transition().duration(600) as any;
// Prepare data
const today = new Date();
const convertedData = data.map(d => ({
...d,
date: new Date(today.getTime() + d.x * 24 * 60 * 60 * 1000),
yMinutes: d.y / 60,
}));
const xMin = today;
const xMax = max(convertedData, d => d.date);
const x = scaleTime()
.domain([xMin, xMax!])
.range([bounds.marginLeft, bounds.width - bounds.marginRight]);
const formatDate = timeFormat("%Y-%m-%d");
svg.select<SVGGElement>(".x-ticks")
.call((selection) =>
selection.transition(trans).call(
axisBottom(x)
.ticks(7)
.tickFormat((d: any) => formatDate(d))
.tickSizeOuter(0),
)
)
.attr("direction", "ltr");
// y scale
const yTickFormat = (n: number): string => {
if (Math.round(n) != n) {
return "";
} else {
return localizedNumber(n);
}
};
const yMax = max(convertedData, d => d.yMinutes)!;
const y = scaleLinear()
.range([bounds.height - bounds.marginBottom, bounds.marginTop])
.domain([0, yMax])
.nice();
svg.select<SVGGElement>(".y-ticks")
.call((selection) =>
selection.transition(trans).call(
axisLeft(y)
.ticks(bounds.height / 50)
.tickSizeOuter(0)
.tickFormat(yTickFormat as any),
)
)
.attr("direction", "ltr");
svg.select(".y-ticks")
.append("text")
.attr("class", "y-axis-title")
.attr("transform", "rotate(-90)")
.attr("y", 0 - bounds.marginLeft)
.attr("x", 0 - (bounds.height / 2))
.attr("dy", "1em")
.attr("fill", "currentColor")
.style("text-anchor", "middle")
.text("Review Time per day (minutes)");
// x lines
const points = convertedData.map((d) => [x(d.date), y(d.yMinutes), d.label]);
const groups = rollup(points, v => Object.assign(v, { z: v[0][2] }), d => d[2]);
const color = schemeCategory10;
svg.append("g")
.attr("class", "lines")
.attr("fill", "none")
.attr("stroke-width", 1.5)
.attr("stroke-linejoin", "round")
.attr("stroke-linecap", "round")
.selectAll("path")
.data(Array.from(groups.entries()))
.join("path")
.style("mix-blend-mode", "multiply")
.attr("stroke", (d, i) => color[i % color.length])
.attr("d", d => line()(d[1].map(p => [p[0], p[1]])))
.attr("data-group", d => d[0]);
const focusLine = svg.append("line")
.attr("class", "focus-line")
.attr("y1", bounds.marginTop)
.attr("y2", bounds.height - bounds.marginBottom)
.attr("stroke", "black")
.attr("stroke-width", 1)
.style("opacity", 0);
const LongestGroupData = Array.from(groups.values()).reduce((a, b) => a.length > b.length ? a : b);
const barWidth = bounds.width / LongestGroupData.length;
// hover/tooltip
svg.append("g")
.attr("class", "hover-columns")
.selectAll("rect")
.data(LongestGroupData)
.join("rect")
.attr("x", d => d[0] - barWidth / 2)
.attr("y", bounds.marginTop)
.attr("width", barWidth)
.attr("height", bounds.height - bounds.marginTop - bounds.marginBottom)
.attr("fill", "transparent")
.on("mousemove", mousemove)
.on("mouseout", hideTooltip);
function mousemove(event: MouseEvent, d: any): void {
pointer(event, document.body);
const date = x.invert(d[0]);
const groupData: { [key: string]: number } = {};
groups.forEach((groupPoints, key) => {
const bisect = bisector((d: number[]) => x.invert(d[0])).left;
const index = bisect(groupPoints, date);
const dataPoint = groupPoints[index - 1] || groupPoints[index];
if (dataPoint) {
groupData[key] = y.invert(dataPoint[1]);
}
});
focusLine.attr("x1", d[0]).attr("x2", d[0]).style("opacity", 1);
let tooltipContent = `Date: ${timeFormat("%Y-%m-%d")(date)}<br>`;
for (const [key, value] of Object.entries(groupData)) {
tooltipContent += `Simulation ${key}: ${value.toFixed(2)} minutes<br>`;
}
showTooltip(tooltipContent, event.pageX, event.pageY);
}
const legend = svg.append("g")
.attr("class", "legend")
.attr("font-family", "sans-serif")
.attr("font-size", 10)
.attr("text-anchor", "start")
.selectAll("g")
.data(Array.from(groups.keys()))
.join("g")
.attr("transform", (d, i) => `translate(0,${i * 20})`)
.attr("cursor", "pointer")
.on("click", (event, d) => toggleGroup(event, d));
legend.append("rect")
.attr("x", bounds.width - bounds.marginRight + 10)
.attr("width", 19)
.attr("height", 19)
.attr("fill", (d, i) => color[i % color.length]);
legend.append("text")
.attr("x", bounds.width - bounds.marginRight + 34)
.attr("y", 9.5)
.attr("dy", "0.32em")
.text(d => `Simulation ${d}`);
const toggleGroup = (event: MouseEvent, d: number) => {
const group = d;
const path = svg.select(`path[data-group="${group}"]`);
const hidden = path.classed("hidden");
const target = event.currentTarget as HTMLElement;
path.classed("hidden", !hidden);
path.style("display", () => hidden ? null : "none");
select(target).select("rect")
.style("opacity", hidden ? 1 : 0.5);
};
setDataAvailable(svg, true);
const tableData: TableDatum[] = [];
return tableData;
}