diff --git a/Cargo.lock b/Cargo.lock index bcf02d6..ddc10fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -236,6 +236,7 @@ dependencies = [ "bincode", "blake3", "chrono", + "file-format", "log", "maud", "rand", @@ -380,9 +381,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encoding_rs" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] @@ -423,6 +424,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "file-format" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ffe3a660c3a1b10e96f304a9413d673b2118d62e4520f7ddf4a4faccfe8b9b9" + [[package]] name = "fnv" version = "1.0.7" @@ -983,9 +990,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -1133,9 +1140,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 4002a73..dfdcb2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" bincode = { version = "2.0.0-rc.3", features = ["serde"] } blake3 = { version = "1.5.4", features = ["mmap", "rayon", "serde"] } chrono = { version = "0.4.38", features = ["serde"] } +file-format = { version = "0.25.0", features = ["reader"] } log = "0.4" maud = { version = "0.26", features = ["rocket"] } rand = "0.8.5" diff --git a/src/database.rs b/src/database.rs index 1408e8a..84f13dd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -5,6 +5,7 @@ use std::{ use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode}; use blake3::Hash; use chrono::{DateTime, TimeDelta, Utc}; +use file_format::FileFormat; use log::{info, warn}; use rand::distributions::{Alphanumeric, DistString}; use rocket::{ @@ -146,6 +147,9 @@ pub struct MochiFile { /// The original name of the file name: String, + /// The format the file is, for serving + extension: String, + /// The Blake3 hash of the file #[bincode(with_serde)] hash: Hash, @@ -164,6 +168,7 @@ impl MochiFile { pub fn new_with_expiry( mmid: Mmid, name: String, + extension: &str, hash: Hash, expire_duration: TimeDelta, ) -> Self { @@ -173,6 +178,7 @@ impl MochiFile { Self { mmid, name, + extension: extension.to_string(), hash, upload_datetime: current, expiry_datetime: expiry, @@ -199,6 +205,10 @@ impl MochiFile { pub fn mmid(&self) -> &Mmid { &self.mmid } + + pub fn extension(&self) -> &String { + &self.extension + } } /// Clean the database. Removes files which are past their expiry @@ -283,8 +293,18 @@ impl Mmid { } } -impl From<&str> for Mmid { - fn from(value: &str) -> Self { - Self(value.to_owned()) +impl TryFrom<&str> for Mmid { + type Error = (); + + fn try_from(value: &str) -> Result { + if value.len() != 8 { + return Err(()) + } + + if value.chars().any(|c| !c.is_ascii_alphanumeric()) { + return Err(()) + } + + Ok(Self(value.to_owned())) } } diff --git a/src/endpoints.rs b/src/endpoints.rs index 572bb44..0ae142f 100644 --- a/src/endpoints.rs +++ b/src/endpoints.rs @@ -1,11 +1,11 @@ use std::sync::{Arc, RwLock}; use rocket::{ - fs::NamedFile, get, serde::{self, json::Json}, State + fs::NamedFile, get, http::ContentType, serde::{self, json::Json}, tokio::fs::File, State }; use serde::Serialize; -use crate::{database::Database, settings::Settings}; +use crate::{database::{Database, Mmid}, settings::Settings}; /// An endpoint to obtain information about the server's capabilities #[get("/info")] @@ -34,20 +34,28 @@ pub struct ServerInfo { allowed_durations: Vec, } -/// Look up the hash of a file to find it. This only returns the first -/// hit for a hash, so different filenames may not be found. -#[get("/f/")] +/// Look up the [`Mmid`] of a file to find it. +#[get("/f/")] pub async fn lookup( db: &State>>, settings: &State, - id: &str -) -> Option { - dbg!(db.read().unwrap()); - let entry = if let Some(e) = db.read().unwrap().get(&id.into()).cloned() { + mmid: &str +) -> Option<(ContentType, NamedFile)> { + let mmid: Mmid = match mmid.try_into() { + Ok(v) => v, + Err(_) => return None, + }; + + let entry = if let Some(e) = db.read().unwrap().get(&mmid).cloned() { e } else { return None }; - NamedFile::open(settings.file_dir.join(entry.hash().to_string())).await.ok() + let file = NamedFile::open(settings.file_dir.join(entry.hash().to_string())).await.ok()?; + + Some(( + ContentType::from_extension(entry.extension()).unwrap_or(ContentType::Binary), + file + )) } diff --git a/src/main.rs b/src/main.rs index 78b7322..8165d2e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ mod endpoints; mod settings; mod strings; mod utils; -mod file_server; use std::{ fs, @@ -127,9 +126,7 @@ async fn handle_upload( db: &State>>, settings: &State, ) -> Result, std::io::Error> { - let mut temp_dir = settings.temp_dir.clone(); - let mut out_path = settings.file_dir.clone(); - + // Ensure the expiry time is valid, if not return an error let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) { if settings.duration.restrict_to_allowed && !settings.duration.allowed.contains(&t) { return Ok(Json(ClientResponse::failure("Duration not allowed"))); @@ -144,7 +141,6 @@ async fn handle_upload( return Ok(Json(ClientResponse::failure("Duration invalid"))); }; - // TODO: Properly sanitize this... let raw_name = file_data .file .raw_name() @@ -153,20 +149,31 @@ async fn handle_upload( .as_str() .to_string(); - // Get temp path and hash it - temp_dir.push(Uuid::new_v4().to_string()); - let temp_filename = temp_dir; + // Get temp path for the file + let temp_filename = settings.temp_dir.join(Uuid::new_v4().to_string()); file_data.file.persist_to(&temp_filename).await?; + + // Get hash and random identifier + let file_mmid = Mmid::new(); let file_hash = hash_file(&temp_filename).await?; - let file_mmid = Mmid::new(); - out_path.push(file_hash.to_string()); + // Process filetype + let file_type = file_format::FileFormat::from_file(&temp_filename)?; let constructed_file = - MochiFile::new_with_expiry(file_mmid.clone(), raw_name, file_hash, expire_time); + MochiFile::new_with_expiry( + file_mmid.clone(), + raw_name, + file_type.extension(), + file_hash, + expire_time + ); // Move it to the new proper place - std::fs::rename(temp_filename, out_path)?; + std::fs::rename( + temp_filename, + settings.file_dir.join(file_hash.to_string()) + )?; db.write() .unwrap() @@ -177,7 +184,7 @@ async fn handle_upload( status: true, name: constructed_file.name().clone(), mmid: Some(file_mmid), - hash: file_hash.to_hex()[0..10].to_string(), + hash: file_hash.to_string(), expires: Some(constructed_file.expiry()), ..Default::default() })) @@ -244,7 +251,7 @@ async fn main() { tokio::spawn({ let cleaner_db = database.clone(); let file_path = config.file_dir.clone(); - async move { clean_loop(cleaner_db, file_path, rx, TimeDelta::seconds(10)).await } + async move { clean_loop(cleaner_db, file_path, rx, TimeDelta::minutes(2)).await } }); let rocket = rocket::build() diff --git a/web/request.js b/web/request.js index c931545..8618f13 100644 --- a/web/request.js +++ b/web/request.js @@ -42,13 +42,11 @@ function getDroppedFiles(evt) { }); } - console.log(files); return files; } async function fileSend(files, duration, maxSize) { for (const file of files) { - console.log(file); const [linkRow, progressBar, progressText] = addNewToList(file.name); if (file.size > maxSize) { makeErrored(progressBar, progressText, linkRow, TOO_LARGE_TEXT);