Batch both max # of items processed and max # of items passed to fsrs

This commit is contained in:
Daniel Pechersky 2025-09-15 20:59:32 +07:00
parent dda1324872
commit f54e79c737
4 changed files with 140 additions and 94 deletions

7
Cargo.lock generated
View file

@ -124,6 +124,7 @@ dependencies = [
"once_cell", "once_cell",
"pbkdf2", "pbkdf2",
"percent-encoding-iri", "percent-encoding-iri",
"permutation",
"phf 0.11.3", "phf 0.11.3",
"pin-project", "pin-project",
"prettyplease", "prettyplease",
@ -4560,6 +4561,12 @@ name = "percent-encoding-iri"
version = "2.2.0" version = "2.2.0"
source = "git+https://github.com/ankitects/rust-url.git?rev=bb930b8d089f4d30d7d19c12e54e66191de47b88#bb930b8d089f4d30d7d19c12e54e66191de47b88" source = "git+https://github.com/ankitects/rust-url.git?rev=bb930b8d089f4d30d7d19c12e54e66191de47b88#bb930b8d089f4d30d7d19c12e54e66191de47b88"
[[package]]
name = "permutation"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df202b0b0f5b8e389955afd5f27b007b00fb948162953f1db9c70d2c7e3157d7"
[[package]] [[package]]
name = "pest" name = "pest"
version = "2.8.1" version = "2.8.1"

View file

@ -102,6 +102,7 @@ num_cpus = "1.17.0"
num_enum = "0.7.3" num_enum = "0.7.3"
once_cell = "1.21.3" once_cell = "1.21.3"
pbkdf2 = { version = "0.12", features = ["simple"] } pbkdf2 = { version = "0.12", features = ["simple"] }
permutation = "0.4.1"
phf = { version = "0.11.3", features = ["macros"] } phf = { version = "0.11.3", features = ["macros"] }
pin-project = "1.1.10" pin-project = "1.1.10"
prettyplease = "0.2.34" prettyplease = "0.2.34"

View file

@ -76,6 +76,7 @@ num_enum.workspace = true
once_cell.workspace = true once_cell.workspace = true
pbkdf2.workspace = true pbkdf2.workspace = true
percent-encoding-iri.workspace = true percent-encoding-iri.workspace = true
permutation.workspace = true
phf.workspace = true phf.workspace = true
pin-project.workspace = true pin-project.workspace = true
prost.workspace = true prost.workspace = true

View file

@ -58,6 +58,18 @@ pub(crate) struct UpdateMemoryStateEntry {
pub ignore_before: TimestampMillis, pub ignore_before: TimestampMillis,
} }
trait ChunkIntoVecs<T> {
fn chunk_into_vecs(&mut self, chunk_size: usize) -> impl Iterator<Item = Vec<T>>;
}
impl<T> ChunkIntoVecs<T> for Vec<T> {
fn chunk_into_vecs(&mut self, chunk_size: usize) -> impl Iterator<Item = Vec<T>> {
std::iter::from_fn(move || {
(!self.is_empty()).then(|| self.split_off(chunk_size.min(self.len())))
})
}
}
impl Collection { impl Collection {
/// For each provided set of params, locate cards with the provided search, /// For each provided set of params, locate cards with the provided search,
/// and update their memory state. /// and update their memory state.
@ -68,6 +80,9 @@ impl Collection {
&mut self, &mut self,
entries: Vec<UpdateMemoryStateEntry>, entries: Vec<UpdateMemoryStateEntry>,
) -> Result<()> { ) -> Result<()> {
const ITEM_CHUNK_SIZE: usize = 100_000;
const FSRS_CHUNK_SIZE: usize = 1000;
let timing = self.timing_today()?; let timing = self.timing_today()?;
let usn = self.usn()?; let usn = self.usn()?;
for UpdateMemoryStateEntry { for UpdateMemoryStateEntry {
@ -88,7 +103,7 @@ impl Collection {
let fsrs = FSRS::new(req.as_ref().map(|w| &w.params[..]).or(Some([].as_slice())))?; let fsrs = FSRS::new(req.as_ref().map(|w| &w.params[..]).or(Some([].as_slice())))?;
let decay = req.as_ref().map(|w| get_decay_from_params(&w.params)); let decay = req.as_ref().map(|w| get_decay_from_params(&w.params));
let historical_retention = req.as_ref().map(|w| w.historical_retention); let historical_retention = req.as_ref().map(|w| w.historical_retention);
let items = fsrs_items_for_memory_states( let mut items = fsrs_items_for_memory_states(
&fsrs, &fsrs,
revlog, revlog,
timing.next_day_at, timing.next_day_at,
@ -114,103 +129,125 @@ impl Collection {
let mut to_update = Vec::new(); let mut to_update = Vec::new();
let mut fsrs_items = Vec::new(); let mut fsrs_items = Vec::new();
let mut starting_states = Vec::new(); let mut starting_states = Vec::new();
for (idx, (card_id, item)) in items.into_iter().enumerate() { for (i, items) in items.chunk_into_vecs(ITEM_CHUNK_SIZE).enumerate() {
progress.update(true, |state| state.current_cards = idx as u32 + 1)?; progress.update(true, |state| {
let mut card = self.storage.get_card(card_id)?.or_not_found(card_id)?; let end_of_chunk_index = i * ITEM_CHUNK_SIZE + items.len();
let original = card.clone(); state.current_cards = end_of_chunk_index as u32 + 1
})?;
for (card_id, item) in items.into_iter() {
let mut card = self.storage.get_card(card_id)?.or_not_found(card_id)?;
let original = card.clone();
// Store decay and desired retention in the card so that add-ons, card info, // Store decay and desired retention in the card so that add-ons, card info,
// stats and browser search/sorts don't need to access the deck config. // stats and browser search/sorts don't need to access the deck config.
// Unlike memory states, scheduler doesn't use decay and dr stored in the card. // Unlike memory states, scheduler doesn't use decay and dr stored in the card.
let deck_id = card.original_or_current_deck_id(); let deck_id = card.original_or_current_deck_id();
let desired_retention = *req let desired_retention = *req
.deck_desired_retention .deck_desired_retention
.get(&deck_id) .get(&deck_id)
.unwrap_or(&preset_desired_retention); .unwrap_or(&preset_desired_retention);
card.desired_retention = Some(desired_retention); card.desired_retention = Some(desired_retention);
card.decay = decay; card.decay = decay;
if let Some(item) = item { if let Some(item) = item {
to_update.push((card, original)); to_update.push((card, original));
fsrs_items.push(item.item); fsrs_items.push(item.item);
starting_states.push(item.starting_state); starting_states.push(item.starting_state);
} else {
// clear memory states if item is None
card.memory_state = None;
self.update_card_inner(&mut card, original, usn)?;
}
}
let memory_states = fsrs.memory_state_batch(fsrs_items, starting_states)?;
for ((mut card, original), memory_state) in to_update.into_iter().zip(memory_states) {
card.memory_state = Some(memory_state.into());
'reschedule_card: {
// if rescheduling
let Some(reviews) = &last_revlog_info else {
break 'reschedule_card;
};
// and we have a last review time for the card
let Some(last_info) = reviews.get(&card.id) else {
break 'reschedule_card;
};
let Some(last_review) = &last_info.last_reviewed_at else {
break 'reschedule_card;
};
// or in (re)learning
if card.ctype != CardType::Review {
break 'reschedule_card;
};
let deck = self
.get_deck(card.original_or_current_deck_id())?
.or_not_found(card.original_or_current_deck_id())?;
let deckconfig_id = deck.config_id().unwrap();
// reschedule it
let days_elapsed = timing.next_day_at.elapsed_days_since(*last_review) as i32;
let original_interval = card.interval;
let interval = fsrs.next_interval(
Some(memory_state.stability),
card.desired_retention
.expect("We set desired retention above"),
0,
);
card.interval = rescheduler
.as_mut()
.and_then(|r| {
r.find_interval(
interval,
1,
req.max_interval,
days_elapsed as u32,
deckconfig_id,
get_fuzz_seed(&card, true),
)
})
.unwrap_or_else(|| {
with_review_fuzz(
card.get_fuzz_factor(true),
interval,
1,
req.max_interval,
)
});
let due = if card.original_due != 0 {
&mut card.original_due
} else { } else {
&mut card.due // clear memory states if item is None
}; card.memory_state = None;
let new_due = self.update_card_inner(&mut card, original, usn)?;
(timing.days_elapsed as i32) - days_elapsed + card.interval as i32; }
if let Some(rescheduler) = &mut rescheduler { }
rescheduler.update_due_cnt_per_day(*due, new_due, deckconfig_id);
// fsrs.memory_state_batch is O(nm) where n is the number of cards and m is the max review count between all items.
// Therefore we want to pass batches to fsrs.memory_state_batch where the review count is relatively even.
let mut p =
permutation::sort_unstable_by_key(&fsrs_items, |item| item.reviews.len());
p.apply_slice_in_place(&mut to_update);
p.apply_slice_in_place(&mut fsrs_items);
p.apply_slice_in_place(&mut starting_states);
for ((to_update, fsrs_items), starting_states) in to_update
.chunk_into_vecs(FSRS_CHUNK_SIZE)
.zip_eq(fsrs_items.chunk_into_vecs(FSRS_CHUNK_SIZE))
.zip_eq(starting_states.chunk_into_vecs(FSRS_CHUNK_SIZE))
{
let memory_states = fsrs.memory_state_batch(fsrs_items, starting_states)?;
for ((mut card, original), memory_state) in
to_update.into_iter().zip(memory_states)
{
card.memory_state = Some(memory_state.into());
'reschedule_card: {
// if rescheduling
let Some(reviews) = &last_revlog_info else {
break 'reschedule_card;
};
// and we have a last review time for the card
let Some(last_info) = reviews.get(&card.id) else {
break 'reschedule_card;
};
let Some(last_review) = &last_info.last_reviewed_at else {
break 'reschedule_card;
};
// or in (re)learning
if card.ctype != CardType::Review {
break 'reschedule_card;
};
let deck = self
.get_deck(card.original_or_current_deck_id())?
.or_not_found(card.original_or_current_deck_id())?;
let deckconfig_id = deck.config_id().unwrap();
// reschedule it
let days_elapsed =
timing.next_day_at.elapsed_days_since(*last_review) as i32;
let original_interval = card.interval;
let interval = fsrs.next_interval(
Some(memory_state.stability),
card.desired_retention
.expect("We set desired retention above"),
0,
);
card.interval = rescheduler
.as_mut()
.and_then(|r| {
r.find_interval(
interval,
1,
req.max_interval,
days_elapsed as u32,
deckconfig_id,
get_fuzz_seed(&card, true),
)
})
.unwrap_or_else(|| {
with_review_fuzz(
card.get_fuzz_factor(true),
interval,
1,
req.max_interval,
)
});
let due = if card.original_due != 0 {
&mut card.original_due
} else {
&mut card.due
};
let new_due =
(timing.days_elapsed as i32) - days_elapsed + card.interval as i32;
if let Some(rescheduler) = &mut rescheduler {
rescheduler.update_due_cnt_per_day(*due, new_due, deckconfig_id);
}
*due = new_due;
// Add a rescheduled revlog entry
self.log_rescheduled_review(&card, original_interval, usn)?;
}
self.update_card_inner(&mut card, original, usn)?;
} }
*due = new_due;
// Add a rescheduled revlog entry
self.log_rescheduled_review(&card, original_interval, usn)?;
} }
self.update_card_inner(&mut card, original, usn)?;
} }
} }
Ok(()) Ok(())