refact(fs): file transfer refactor

1. Hide `files` in `new_write()`.
2. Remove duplicted code. - `resolve_entry_path()`.
3. Remove unseless `inline`.
4. Add comments.
5. Reduce `#[cfg]`.

Signed-off-by: fufesou <linlong1266@gmail.com>
This commit is contained in:
fufesou
2026-04-09 14:48:52 +08:00
parent 4b6eb8f909
commit 9633ad27ff

138
src/fs.rs
View File

@@ -398,7 +398,7 @@ pub struct TransferJob {
pub is_resume: bool,
pub file_num: i32,
#[serde(skip_serializing)]
pub files: Vec<FileEntry>,
files: Vec<FileEntry>,
pub conn_id: i32, // server only
#[serde(skip_serializing)]
@@ -457,25 +457,15 @@ fn is_compressed_file(name: &str) -> bool {
compressed_exts.contains(&ext)
}
#[inline]
fn validate_file_name_no_traversal(name: &str) -> ResultType<()> {
pub fn validate_file_name_no_traversal(name: &str) -> ResultType<()> {
if name.bytes().any(|b| b == 0) {
bail!("file name contains null bytes");
}
#[cfg(windows)]
if name
.split(|c| c == '/' || c == '\\')
let has_traversal = name
.split(|c: char| c == '/' || (cfg!(windows) && c == '\\'))
.filter(|s| !s.is_empty())
.any(|component| component == "..")
{
bail!("path traversal detected in file name");
}
#[cfg(not(windows))]
if name
.split('/')
.filter(|s| !s.is_empty())
.any(|component| component == "..")
{
.any(|component| component == "..");
if has_traversal {
bail!("path traversal detected in file name");
}
#[cfg(windows)]
@@ -497,8 +487,9 @@ fn validate_file_name_no_traversal(name: &str) -> ResultType<()> {
Ok(())
}
#[inline]
fn validate_transfer_file_names(files: &[FileEntry]) -> ResultType<()> {
// Single-file transfer may use an empty relative name, because
// the destination file path is carried by transfer metadata.
if files.len() == 1 && files.first().map_or(false, |f| f.name.is_empty()) {
return Ok(());
}
@@ -511,7 +502,6 @@ fn validate_transfer_file_names(files: &[FileEntry]) -> ResultType<()> {
Ok(())
}
#[inline]
fn validate_no_symlink_components(base: &PathBuf, name: &str) -> ResultType<()> {
if name.is_empty() {
return Ok(());
@@ -521,6 +511,8 @@ fn validate_no_symlink_components(base: &PathBuf, name: &str) -> ResultType<()>
match component {
std::path::Component::Normal(seg) => {
current.push(seg);
// Best-effort guard: path-based checks are inherently TOCTOU-prone
// if local filesystem state changes between validation and write.
if let Ok(meta) = std::fs::symlink_metadata(&current) {
if meta.file_type().is_symlink() {
bail!("symlink path component is not allowed");
@@ -536,7 +528,6 @@ fn validate_no_symlink_components(base: &PathBuf, name: &str) -> ResultType<()>
Ok(())
}
#[inline]
fn join_validated_path(base: &PathBuf, name: &str) -> ResultType<PathBuf> {
validate_file_name_no_traversal(name)?;
validate_no_symlink_components(base, name)?;
@@ -553,11 +544,9 @@ impl TransferJob {
file_num: i32,
show_hidden: bool,
is_remote: bool,
files: Vec<FileEntry>,
enable_overwrite_detection: bool,
) -> Self {
log::info!("new write {}", data_source);
let total_size = files.iter().map(|x| x.size).sum();
Self {
id,
r#type,
@@ -566,13 +555,18 @@ impl TransferJob {
file_num,
show_hidden,
is_remote,
files,
total_size,
files: Vec::new(),
total_size: 0,
enable_overwrite_detection,
..Default::default()
}
}
pub fn with_files(mut self, files: Vec<FileEntry>) -> ResultType<Self> {
self.set_files(files)?;
Ok(self)
}
pub fn new_read(
id: i32,
r#type: JobType,
@@ -631,6 +625,7 @@ impl TransferJob {
validate_no_symlink_components(base, &file.name)?;
}
}
self.total_size = files.iter().map(|x| x.size).sum();
self.files = files;
Ok(())
}
@@ -666,6 +661,20 @@ impl TransferJob {
self.file_num
}
fn resolve_entry_path(&self, base: &PathBuf, name: &str) -> Option<PathBuf> {
if self.r#type == JobType::Generic {
match join_validated_path(base, name) {
Ok(path) => Some(path),
Err(err) => {
log::error!("Invalid file name in transfer job {}: {}", self.id, err);
None
}
}
} else {
Some(Self::join(base, name))
}
}
pub fn modify_time(&self) {
if self.r#type == JobType::Printer {
return;
@@ -674,16 +683,8 @@ impl TransferJob {
let file_num = self.file_num as usize;
if file_num < self.files.len() {
let entry = &self.files[file_num];
let path = if self.r#type == JobType::Generic {
match join_validated_path(p, &entry.name) {
Ok(path) => path,
Err(err) => {
log::error!("Invalid file name in transfer job {}: {}", self.id, err);
return;
}
}
} else {
Self::join(p, &entry.name)
let Some(path) = self.resolve_entry_path(p, &entry.name) else {
return;
};
let download_path = format!("{}.download", get_string(&path));
let digest_path = format!("{}.digest", get_string(&path));
@@ -706,16 +707,8 @@ impl TransferJob {
let file_num = self.file_num as usize;
if file_num < self.files.len() {
let entry = &self.files[file_num];
let path = if self.r#type == JobType::Generic {
match join_validated_path(p, &entry.name) {
Ok(path) => path,
Err(err) => {
log::error!("Invalid file name in transfer job {}: {}", self.id, err);
return;
}
}
} else {
Self::join(p, &entry.name)
let Some(path) = self.resolve_entry_path(p, &entry.name) else {
return;
};
let download_path = format!("{}.download", get_string(&path));
let digest_path = format!("{}.digest", get_string(&path));
@@ -1082,16 +1075,8 @@ impl TransferJob {
async fn set_stream_offset(&mut self, file_num: usize, offset: u64) {
if let DataSource::FilePath(p) = &self.data_source {
let entry = &self.files[file_num];
let path = if self.r#type == JobType::Generic {
match join_validated_path(p, &entry.name) {
Ok(path) => path,
Err(err) => {
log::error!("Invalid file name in transfer job {}: {}", self.id, err);
return;
}
}
} else {
Self::join(p, &entry.name)
let Some(path) = self.resolve_entry_path(p, &entry.name) else {
return;
};
let file_path = get_string(&path);
let download_path = format!("{}.download", &file_path);
@@ -1529,13 +1514,12 @@ mod tests {
0,
false,
true,
Vec::new(),
false,
)
}
fn new_write_job(id: i32, download_dir: PathBuf, name: &str) -> TransferJob {
TransferJob::new_write(
fn new_write_job(id: i32, download_dir: PathBuf, name: &str) -> ResultType<TransferJob> {
let job = TransferJob::new_write(
id,
JobType::Generic,
"/fake/remote".to_string(),
@@ -1543,17 +1527,10 @@ mod tests {
0,
false,
true,
vec![new_file_entry(name)],
false,
)
}
fn make_test_block(id: i32, payload: &[u8]) -> FileTransferBlock {
let mut block = FileTransferBlock::new();
block.id = id;
block.file_num = 0;
block.data = payload.to_vec().into();
block
.with_files(vec![new_file_entry(name)])?;
Ok(job)
}
fn assert_err_contains(err: anyhow::Error, expected: &str) {
@@ -1565,17 +1542,13 @@ mod tests {
);
}
#[tokio::test]
async fn path_traversal_e2e_write_rejects_relative_escape() {
#[test]
fn path_traversal_e2e_write_rejects_relative_escape() {
let tmp_root = unique_temp_dir("rustdesk_e2e_relative");
let downloads = tmp_root.join("downloads");
std::fs::create_dir_all(&downloads).expect("create downloads dir");
let mut job = new_write_job(1, downloads, "../traversal_proof.txt");
let block = make_test_block(1, b"malicious payload");
let err = job
.write(block)
.await
let err = new_write_job(1, downloads, "../traversal_proof.txt")
.expect_err("relative path traversal must be rejected");
assert_err_contains(err, "path traversal");
assert!(!tmp_root.join("traversal_proof.txt").exists());
@@ -1583,18 +1556,14 @@ mod tests {
let _ = std::fs::remove_dir_all(&tmp_root);
}
#[tokio::test]
async fn path_traversal_e2e_write_rejects_absolute_path() {
#[test]
fn path_traversal_e2e_write_rejects_absolute_path() {
let tmp_root = unique_temp_dir("rustdesk_e2e_absolute");
let downloads = tmp_root.join("downloads");
let absolute_target = tmp_root.join("fake_ssh").join("authorized_keys");
std::fs::create_dir_all(&downloads).expect("create downloads dir");
let mut job = new_write_job(2, downloads, &absolute_target.to_string_lossy());
let block = make_test_block(2, b"ssh key payload");
let err = job
.write(block)
.await
let err = new_write_job(2, downloads, &absolute_target.to_string_lossy())
.expect_err("absolute path must be rejected");
assert_err_contains(err, "absolute path");
assert!(!absolute_target.exists());
@@ -1602,8 +1571,8 @@ mod tests {
let _ = std::fs::remove_dir_all(&tmp_root);
}
#[tokio::test]
async fn path_traversal_e2e_write_rejects_symlink_escape() {
#[test]
fn path_traversal_e2e_write_rejects_symlink_escape() {
let tmp_root = unique_temp_dir("rustdesk_e2e_symlink");
let downloads = tmp_root.join("downloads");
let outside = tmp_root.join("outside");
@@ -1633,11 +1602,7 @@ mod tests {
}
}
let mut job = new_write_job(3, downloads, "link/escape.txt");
let block = make_test_block(3, b"symlink escape payload");
let err = job
.write(block)
.await
let err = new_write_job(3, downloads, "link/escape.txt")
.expect_err("symlink traversal must be rejected");
assert_err_contains(err, "symlink");
assert!(!escaped_target.exists());
@@ -1744,7 +1709,6 @@ mod tests {
0,
false,
true,
Vec::new(),
false,
);
let err = job