Store desired retention in card data

If we want to be able to factor the desired retention into a sort based
on relative overdueness, having the values accessible on the card makes
things easier.
This commit is contained in:
Damien Elmes 2023-09-23 15:12:48 +10:00
parent c78de23cf9
commit 03edb7bf9e
10 changed files with 50 additions and 20 deletions

View file

@ -50,6 +50,7 @@ message Card {
uint32 flags = 17; uint32 flags = 17;
optional uint32 original_position = 18; optional uint32 original_position = 18;
optional FsrsMemoryState memory_state = 20; optional FsrsMemoryState memory_state = 20;
optional float desired_retention = 21;
string custom_data = 19; string custom_data = 19;
} }

View file

@ -46,6 +46,7 @@ class Card(DeprecatedNamesMixin):
queue: CardQueue queue: CardQueue
type: CardType type: CardType
memory_state: FSRSMemoryState | None memory_state: FSRSMemoryState | None
desired_retention: float | None
def __init__( def __init__(
self, self,
@ -96,6 +97,9 @@ class Card(DeprecatedNamesMixin):
) )
self.custom_data = card.custom_data self.custom_data = card.custom_data
self.memory_state = card.memory_state if card.HasField("memory_state") else None self.memory_state = card.memory_state if card.HasField("memory_state") else None
self.desired_retention = (
card.desired_retention if card.HasField("desired_retention") else None
)
def _to_backend_card(self) -> cards_pb2.Card: def _to_backend_card(self) -> cards_pb2.Card:
# mtime & usn are set by backend # mtime & usn are set by backend
@ -118,6 +122,7 @@ class Card(DeprecatedNamesMixin):
original_position=self.original_position, original_position=self.original_position,
custom_data=self.custom_data, custom_data=self.custom_data,
memory_state=self.memory_state, memory_state=self.memory_state,
desired_retention=self.desired_retention,
) )
def flush(self) -> None: def flush(self) -> None:

View file

@ -94,6 +94,7 @@ pub struct Card {
/// The position in the new queue before leaving it. /// The position in the new queue before leaving it.
pub(crate) original_position: Option<u32>, pub(crate) original_position: Option<u32>,
pub(crate) memory_state: Option<FsrsMemoryState>, pub(crate) memory_state: Option<FsrsMemoryState>,
pub(crate) desired_retention: Option<f32>,
/// JSON object or empty; exposed through the reviewer for persisting custom /// JSON object or empty; exposed through the reviewer for persisting custom
/// state /// state
pub(crate) custom_data: String, pub(crate) custom_data: String,
@ -143,6 +144,7 @@ impl Default for Card {
flags: 0, flags: 0,
original_position: None, original_position: None,
memory_state: None, memory_state: None,
desired_retention: None,
custom_data: String::new(), custom_data: String::new(),
} }
} }

View file

@ -101,6 +101,7 @@ impl TryFrom<anki_proto::cards::Card> for Card {
flags: c.flags as u8, flags: c.flags as u8,
original_position: c.original_position, original_position: c.original_position,
memory_state: c.memory_state.map(Into::into), memory_state: c.memory_state.map(Into::into),
desired_retention: c.desired_retention,
custom_data: c.custom_data, custom_data: c.custom_data,
}) })
} }
@ -128,6 +129,7 @@ impl From<Card> for anki_proto::cards::Card {
flags: c.flags as u32, flags: c.flags as u32,
original_position: c.original_position.map(Into::into), original_position: c.original_position.map(Into::into),
memory_state: c.memory_state.map(Into::into), memory_state: c.memory_state.map(Into::into),
desired_retention: c.desired_retention,
custom_data: c.custom_data, custom_data: c.custom_data,
} }
} }

View file

@ -15,7 +15,7 @@ use anki_proto::decks::deck::normal::DayLimit;
use crate::config::StringKey; use crate::config::StringKey;
use crate::decks::NormalDeck; use crate::decks::NormalDeck;
use crate::prelude::*; use crate::prelude::*;
use crate::scheduler::fsrs::weights::Weights; use crate::scheduler::fsrs::memory_state::WeightsAndDesiredRetention;
use crate::search::JoinSearches; use crate::search::JoinSearches;
use crate::search::SearchNode; use crate::search::SearchNode;
@ -216,19 +216,20 @@ impl Collection {
} }
if !decks_needing_memory_recompute.is_empty() { if !decks_needing_memory_recompute.is_empty() {
let input: Vec<(Option<Weights>, Vec<SearchNode>)> = decks_needing_memory_recompute let input: Vec<(Option<WeightsAndDesiredRetention>, Vec<SearchNode>)> =
.into_iter() decks_needing_memory_recompute
.map(|(conf_id, search)| { .into_iter()
let weights = configs_after_update.get(&conf_id).and_then(|c| { .map(|(conf_id, search)| {
if input.fsrs { let weights = configs_after_update.get(&conf_id).and_then(|c| {
Some(c.inner.fsrs_weights.clone()) if input.fsrs {
} else { Some((c.inner.fsrs_weights.clone(), c.inner.desired_retention))
None } else {
} None
}); }
Ok((weights, search)) });
}) Ok((weights, search))
.collect::<Result<_>>()?; })
.collect::<Result<_>>()?;
self.update_memory_state(input)?; self.update_memory_state(input)?;
} }

View file

@ -68,6 +68,8 @@ struct CardStateUpdater {
fuzz_seed: Option<u64>, fuzz_seed: Option<u64>,
/// Set if FSRS is enabled. /// Set if FSRS is enabled.
fsrs_next_states: Option<NextStates>, fsrs_next_states: Option<NextStates>,
/// Set if FSRS is enabled.
desired_retention: Option<f32>,
} }
impl CardStateUpdater { impl CardStateUpdater {
@ -159,6 +161,7 @@ impl CardStateUpdater {
) -> RevlogEntryPartial { ) -> RevlogEntryPartial {
self.card.reps += 1; self.card.reps += 1;
self.card.original_due = 0; self.card.original_due = 0;
self.card.desired_retention = self.desired_retention;
let revlog = match next { let revlog = match next {
NormalState::New(next) => self.apply_new_state(current, next), NormalState::New(next) => self.apply_new_state(current, next),
@ -351,7 +354,8 @@ impl Collection {
.get_deck(card.deck_id)? .get_deck(card.deck_id)?
.or_not_found(card.deck_id)?; .or_not_found(card.deck_id)?;
let config = self.home_deck_config(deck.config_id(), card.original_deck_id)?; let config = self.home_deck_config(deck.config_id(), card.original_deck_id)?;
let fsrs_next_states = if self.get_config_bool(BoolKey::Fsrs) { let fsrs_enabled = self.get_config_bool(BoolKey::Fsrs);
let fsrs_next_states = if fsrs_enabled {
let fsrs = FSRS::new(Some(&config.inner.fsrs_weights))?; let fsrs = FSRS::new(Some(&config.inner.fsrs_weights))?;
let memory_state = if let Some(state) = card.memory_state { let memory_state = if let Some(state) = card.memory_state {
Some(MemoryState::from(state)) Some(MemoryState::from(state))
@ -373,7 +377,7 @@ impl Collection {
} else { } else {
None None
}; };
let desired_retention = fsrs_enabled.then_some(config.inner.desired_retention);
Ok(CardStateUpdater { Ok(CardStateUpdater {
fuzz_seed: get_fuzz_seed(&card), fuzz_seed: get_fuzz_seed(&card),
card, card,
@ -382,6 +386,7 @@ impl Collection {
timing, timing,
now: TimestampSecs::now(), now: TimestampSecs::now(),
fsrs_next_states, fsrs_next_states,
desired_retention,
}) })
} }

View file

@ -17,6 +17,8 @@ pub struct ComputeMemoryProgress {
pub total_cards: u32, pub total_cards: u32,
} }
pub(crate) type WeightsAndDesiredRetention = (Weights, f32);
impl Collection { impl Collection {
/// For each provided set of weights, locate cards with the provided search, /// For each provided set of weights, locate cards with the provided search,
/// and update their memory state. /// and update their memory state.
@ -25,27 +27,30 @@ impl Collection {
/// memory state should be removed. /// memory state should be removed.
pub(crate) fn update_memory_state( pub(crate) fn update_memory_state(
&mut self, &mut self,
entries: Vec<(Option<Weights>, Vec<SearchNode>)>, entries: Vec<(Option<WeightsAndDesiredRetention>, Vec<SearchNode>)>,
) -> Result<()> { ) -> Result<()> {
let timing = self.timing_today()?; let timing = self.timing_today()?;
let usn = self.usn()?; let usn = self.usn()?;
for (weights, search) in entries { for (weights_and_desired_retention, search) in entries {
let search = SearchBuilder::any(search.into_iter()) let search = SearchBuilder::any(search.into_iter())
.and(SearchNode::State(StateKind::New).negated()); .and(SearchNode::State(StateKind::New).negated());
let revlog = self.revlog_for_srs(search)?; let revlog = self.revlog_for_srs(search)?;
let items = fsrs_items_for_memory_state(revlog, timing.next_day_at); let items = fsrs_items_for_memory_state(revlog, timing.next_day_at);
let fsrs = FSRS::new(weights.as_deref())?; let desired_retention = weights_and_desired_retention.as_ref().map(|w| w.1);
let fsrs = FSRS::new(weights_and_desired_retention.as_ref().map(|w| &w.0[..]))?;
let mut progress = self.new_progress_handler::<ComputeMemoryProgress>(); let mut progress = self.new_progress_handler::<ComputeMemoryProgress>();
progress.update(false, |s| s.total_cards = items.len() as u32)?; progress.update(false, |s| s.total_cards = items.len() as u32)?;
for (idx, (card_id, item)) in items.into_iter().enumerate() { for (idx, (card_id, item)) in items.into_iter().enumerate() {
progress.update(true, |state| state.current_cards = idx as u32 + 1)?; progress.update(true, |state| state.current_cards = idx as u32 + 1)?;
let mut card = self.storage.get_card(card_id)?.or_not_found(card_id)?; let mut card = self.storage.get_card(card_id)?.or_not_found(card_id)?;
let original = card.clone(); let original = card.clone();
if weights.is_some() { if weights_and_desired_retention.is_some() {
let state = fsrs.memory_state(item); let state = fsrs.memory_state(item);
card.memory_state = Some(state.into()); card.memory_state = Some(state.into());
card.desired_retention = desired_retention;
} else { } else {
card.memory_state = None; card.memory_state = None;
card.desired_retention = None;
} }
self.update_card_inner(&mut card, original, usn)?; self.update_card_inner(&mut card, original, usn)?;
} }

View file

@ -38,6 +38,12 @@ pub(crate) struct CardData {
deserialize_with = "default_on_invalid" deserialize_with = "default_on_invalid"
)] )]
pub(crate) fsrs_difficulty: Option<f32>, pub(crate) fsrs_difficulty: Option<f32>,
#[serde(
rename = "dr",
skip_serializing_if = "Option::is_none",
deserialize_with = "default_on_invalid"
)]
pub(crate) fsrs_desired_retention: Option<f32>,
/// A string representation of a JSON object storing optional data /// A string representation of a JSON object storing optional data
/// associated with the card, so v3 custom scheduling code can persist /// associated with the card, so v3 custom scheduling code can persist
@ -52,6 +58,7 @@ impl CardData {
original_position: card.original_position, original_position: card.original_position,
fsrs_stability: card.memory_state.as_ref().map(|m| m.stability), fsrs_stability: card.memory_state.as_ref().map(|m| m.stability),
fsrs_difficulty: card.memory_state.as_ref().map(|m| m.difficulty), fsrs_difficulty: card.memory_state.as_ref().map(|m| m.difficulty),
fsrs_desired_retention: card.desired_retention,
custom_data: card.custom_data.clone(), custom_data: card.custom_data.clone(),
} }
} }

View file

@ -81,6 +81,7 @@ fn row_to_card(row: &Row) -> result::Result<Card, rusqlite::Error> {
flags: row.get(16)?, flags: row.get(16)?,
original_position: data.original_position, original_position: data.original_position,
memory_state: data.memory_state(), memory_state: data.memory_state(),
desired_retention: data.fsrs_desired_retention,
custom_data: data.custom_data, custom_data: data.custom_data,
}) })
} }

View file

@ -331,6 +331,7 @@ impl From<CardEntry> for Card {
flags: e.flags, flags: e.flags,
original_position: data.original_position, original_position: data.original_position,
memory_state: data.memory_state(), memory_state: data.memory_state(),
desired_retention: data.fsrs_desired_retention,
custom_data: data.custom_data, custom_data: data.custom_data,
} }
} }