diff --git a/rslib/src/import_export/gather.rs b/rslib/src/import_export/gather.rs index 38251d3bb..90c775159 100644 --- a/rslib/src/import_export/gather.rs +++ b/rslib/src/import_export/gather.rs @@ -11,15 +11,13 @@ use itertools::Itertools; use crate::{ card::{CardQueue, CardType}, decks::NormalDeck, + io::filename_is_safe, latex::extract_latex, prelude::*, revlog::RevlogEntry, search::{Negated, SearchNode, SortMode}, storage::ids_to_string, - text::{ - extract_media_refs, extract_underscored_css_imports, extract_underscored_references, - is_remote_filename, - }, + text::{extract_media_refs, extract_underscored_css_imports, extract_underscored_references}, }; #[derive(Debug, Default)] @@ -49,20 +47,7 @@ fn optional_deck_search(deck_id: Option) -> SearchNode { } } -fn is_local_base_name(name: &str) -> bool { - !is_remote_filename(name) && Path::new(name).parent().is_none() -} - impl ExportData { - /* - pub(super) fn new(, media_folder: Option) -> Self { - Self { - with_scheduling, - media_folder, - ..Default::default() - } - } - */ pub(super) fn gather_data( &mut self, col: &mut Collection, @@ -87,7 +72,7 @@ impl ExportData { pub(super) fn gather_media_paths(&mut self, media_folder: &Path) { let mut inserter = |name: &str| { - if is_local_base_name(name) { + if filename_is_safe(name) { self.media_paths.insert(media_folder.join(name)); } }; diff --git a/rslib/src/import_export/package/colpkg/import.rs b/rslib/src/import_export/package/colpkg/import.rs index 7664637a4..78f923638 100644 --- a/rslib/src/import_export/package/colpkg/import.rs +++ b/rslib/src/import_export/package/colpkg/import.rs @@ -6,7 +6,7 @@ use std::{ collections::HashMap, fs::{self, File}, io::{self, Read, Write}, - path::{Component, Path, PathBuf}, + path::{Path, PathBuf}, }; use prost::Message; @@ -21,7 +21,7 @@ use crate::{ package::{MediaEntries, MediaEntry, Meta}, ImportProgress, }, - io::{atomic_rename, tempfile_in_parent_of}, + io::{atomic_rename, filename_is_safe, tempfile_in_parent_of}, media::files::normalize_filename, prelude::*, }; @@ -149,7 +149,9 @@ fn restore_media_file(meta: &Meta, zip_file: &mut ZipFile, path: &Path) -> Resul impl MediaEntry { fn safe_normalized_file_path(&self, meta: &Meta, media_folder: &Path) -> Result { - check_filename_safe(&self.name)?; + if !filename_is_safe(&self.name) { + return Err(AnkiError::ImportError(ImportError::Corrupt)); + } let normalized = maybe_normalizing(&self.name, meta.strict_media_checks())?; Ok(media_folder.join(normalized.as_ref())) } @@ -179,20 +181,6 @@ fn maybe_normalizing(name: &str, strict: bool) -> Result> { } } -/// Return an error if name contains any path separators. -fn check_filename_safe(name: &str) -> Result<()> { - let mut components = Path::new(name).components(); - let first_element_normal = components - .next() - .map(|component| matches!(component, Component::Normal(_))) - .unwrap_or_default(); - if !first_element_normal || components.next().is_some() { - Err(AnkiError::ImportError(ImportError::Corrupt)) - } else { - Ok(()) - } -} - fn extract_media_entries(meta: &Meta, archive: &mut ZipArchive) -> Result> { let mut file = archive.by_name("media")?; let mut buf = Vec::new(); @@ -251,22 +239,6 @@ fn copy_collection( mod test { use super::*; - #[test] - fn path_traversal() { - assert!(check_filename_safe("foo").is_ok(),); - - assert!(check_filename_safe("..").is_err()); - assert!(check_filename_safe("foo/bar").is_err()); - assert!(check_filename_safe("/foo").is_err()); - assert!(check_filename_safe("../foo").is_err()); - - if cfg!(windows) { - assert!(check_filename_safe("foo\\bar").is_err()); - assert!(check_filename_safe("c:\\foo").is_err()); - assert!(check_filename_safe("\\foo").is_err()); - } - } - #[test] fn normalization() { assert_eq!(&maybe_normalizing("con", false).unwrap(), "con_"); diff --git a/rslib/src/io.rs b/rslib/src/io.rs index 8a68ffdb3..fc86c5db6 100644 --- a/rslib/src/io.rs +++ b/rslib/src/io.rs @@ -1,7 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use std::path::Path; +use std::path::{Component, Path}; use tempfile::NamedTempFile; @@ -42,6 +42,17 @@ pub(crate) fn read_dir_files(path: impl AsRef) -> std::io::Result bool { + let mut components = Path::new(name).components(); + let first_element_normal = components + .next() + .map(|component| matches!(component, Component::Normal(_))) + .unwrap_or_default(); + + first_element_normal && components.next().is_none() +} + pub(crate) struct ReadDirFiles(std::fs::ReadDir); impl Iterator for ReadDirFiles { @@ -60,3 +71,24 @@ impl Iterator for ReadDirFiles { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn path_traversal() { + assert!(filename_is_safe("foo")); + + assert!(!filename_is_safe("..")); + assert!(!filename_is_safe("foo/bar")); + assert!(!filename_is_safe("/foo")); + assert!(!filename_is_safe("../foo")); + + if cfg!(windows) { + assert!(!filename_is_safe("foo\\bar")); + assert!(!filename_is_safe("c:\\foo")); + assert!(!filename_is_safe("\\foo")); + } + } +} diff --git a/rslib/src/text.rs b/rslib/src/text.rs index c15ac7edf..f4496837c 100644 --- a/rslib/src/text.rs +++ b/rslib/src/text.rs @@ -437,10 +437,6 @@ lazy_static! { pub(crate) static ref REMOTE_FILENAME: Regex = Regex::new("(?i)^https?://").unwrap(); } -pub(crate) fn is_remote_filename(name: &str) -> bool { - REMOTE_FILENAME.is_match(name) -} - /// IRI-encode unescaped local paths in HTML fragment. pub(crate) fn encode_iri_paths(unescaped_html: &str) -> Cow { transform_html_paths(unescaped_html, |fname| {