diff --git a/src/database.rs b/src/database.rs index 10fb93e..ff94105 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,19 +1,26 @@ -use std::{collections::HashMap, fs::{self, File}, path::{Path, PathBuf}, sync::{Arc, RwLock}}; +use std::{ + collections::HashMap, + fs::{self, File}, + path::{Path, PathBuf}, + sync::{Arc, RwLock}, +}; use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode}; -use chrono::{DateTime, TimeDelta, Utc}; use blake3::Hash; +use chrono::{DateTime, TimeDelta, Utc}; use log::{info, warn}; -use rocket::{serde::{Deserialize, Serialize}, tokio::{select, sync::mpsc::Receiver, time}}; +use rocket::{ + serde::{Deserialize, Serialize}, + tokio::{select, sync::mpsc::Receiver, time}, +}; const BINCODE_CFG: Configuration = bincode::config::standard(); -#[derive(Debug, Clone)] -#[derive(Decode, Encode)] +#[derive(Debug, Clone, Decode, Encode)] pub struct Database { path: PathBuf, #[bincode(with_serde)] - pub files: HashMap + pub files: HashMap, } impl Database { @@ -22,7 +29,7 @@ impl Database { let output = Self { path: path.as_ref().to_path_buf(), - files: HashMap::new() + files: HashMap::new(), }; encode_into_std_write(&output, &mut file, BINCODE_CFG).expect("Could not write database!"); @@ -49,9 +56,7 @@ impl Database { } } -#[derive(Debug, Clone)] -#[derive(Decode, Encode)] -#[derive(Deserialize, Serialize)] +#[derive(Debug, Clone, Decode, Encode, Deserialize, Serialize)] #[serde(crate = "rocket::serde")] pub struct MochiFile { /// The original name of the file @@ -79,7 +84,7 @@ impl MochiFile { name: &str, hash: Hash, filename: PathBuf, - expire_duration: TimeDelta + expire_duration: TimeDelta, ) -> Self { let current = Utc::now(); let expiry = current + expire_duration; @@ -104,7 +109,7 @@ impl MochiFile { pub fn get_key(&self) -> MochiKey { MochiKey { name: self.name.clone(), - hash: self.hash + hash: self.hash, } } @@ -122,9 +127,7 @@ impl MochiFile { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[derive(Decode, Encode)] -#[derive(Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Decode, Encode, Deserialize, Serialize)] #[serde(crate = "rocket::serde")] pub struct MochiKey { name: String, @@ -136,17 +139,21 @@ pub struct MochiKey { /// [`chrono::DateTime`]. Also removes files which no longer exist on the disk. fn clean_database(db: &Arc>) { let mut database = db.write().unwrap(); - let files_to_remove: Vec<_> = database.files.iter().filter_map(|e| { - if e.1.expired() { - // Check if the entry has expired - Some((e.0.clone(), e.1.clone())) - } else if !e.1.path().try_exists().is_ok_and(|r| r) { - // Check if the entry exists - Some((e.0.clone(), e.1.clone())) - } else { - None - } - }).collect(); + let files_to_remove: Vec<_> = database + .files + .iter() + .filter_map(|e| { + if e.1.expired() { + // Check if the entry has expired + Some((e.0.clone(), e.1.clone())) + } else if !e.1.path().try_exists().is_ok_and(|r| r) { + // Check if the entry exists + Some((e.0.clone(), e.1.clone())) + } else { + None + } + }) + .collect(); let mut expired = 0; let mut missing = 0; @@ -166,7 +173,10 @@ fn clean_database(db: &Arc>) { database.files.remove(&file.0); } - info!("{} expired and {} missing items cleared from database", expired, missing); + info!( + "{} expired and {} missing items cleared from database", + expired, missing + ); database.save(); } diff --git a/src/endpoints.rs b/src/endpoints.rs index 6effbfc..54293e0 100644 --- a/src/endpoints.rs +++ b/src/endpoints.rs @@ -1,6 +1,12 @@ use std::sync::{Arc, RwLock}; -use rocket::{get, http::RawStr, response::{status::NotFound, Redirect}, serde::{self, json::Json}, State}; +use rocket::{ + get, + http::RawStr, + response::{status::NotFound, Redirect}, + serde::{self, json::Json}, + State, +}; use serde::Serialize; use crate::{database::Database, get_id, settings::Settings}; @@ -12,7 +18,13 @@ pub fn server_info(settings: &State) -> Json { max_filesize: settings.max_filesize, max_duration: settings.duration.maximum.num_seconds() as u32, default_duration: settings.duration.default.num_seconds() as u32, - allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(), + allowed_durations: settings + .duration + .allowed + .clone() + .into_iter() + .map(|t| t.num_seconds() as u32) + .collect(), }) } @@ -29,18 +41,12 @@ pub struct ServerInfo { /// Look up the hash of a file to find it. This only returns the first /// hit for a hash, so different filenames may not be found. #[get("/f/")] -pub fn lookup( - db: &State>>, - id: &str -) -> Result> { +pub fn lookup(db: &State>>, id: &str) -> Result> { for file in db.read().unwrap().files.values() { if file.hash().to_hex()[0..10].to_string() == id { - let filename = get_id( - file.name(), - *file.hash() - ); + let filename = get_id(file.name(), *file.hash()); let filename = RawStr::new(&filename).percent_encode().to_string(); - return Ok(Redirect::to(format!("/files/{}", filename))) + return Ok(Redirect::to(format!("/files/{}", filename))); } } diff --git a/src/main.rs b/src/main.rs index c4adff4..a60aeba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,23 +1,35 @@ mod database; -mod strings; -mod settings; mod endpoints; +mod settings; +mod strings; mod utils; -use std::{fs, sync::{Arc, RwLock}}; +use std::{ + fs, + sync::{Arc, RwLock}, +}; use chrono::{DateTime, TimeDelta, Utc}; use database::{clean_loop, Database, MochiFile}; use endpoints::{lookup, server_info}; use log::info; +use maud::{html, Markup, PreEscaped, DOCTYPE}; use rocket::{ - data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::ContentType, post, response::content::{RawCss, RawJavaScript}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State + data::{Limits, ToByteUnit}, + form::Form, + fs::{FileServer, Options, TempFile}, + get, + http::ContentType, + post, + response::content::{RawCss, RawJavaScript}, + routes, + serde::{json::Json, Serialize}, + tokio, Config, FromForm, State, }; use settings::Settings; use strings::{parse_time_string, to_pretty_time}; use utils::{get_id, hash_file}; use uuid::Uuid; -use maud::{html, Markup, DOCTYPE, PreEscaped}; fn head(page_title: &str) -> Markup { html! { @@ -119,20 +131,23 @@ async fn handle_upload( let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) { if t > settings.duration.maximum { - return Ok(Json(ClientResponse::failure("Duration larger than maximum"))) + return Ok(Json(ClientResponse::failure( + "Duration larger than maximum", + ))); } if settings.duration.restrict_to_allowed && !settings.duration.allowed.contains(&t) { - return Ok(Json(ClientResponse::failure("Duration not allowed"))) + return Ok(Json(ClientResponse::failure("Duration not allowed"))); } t } else { - return Ok(Json(ClientResponse::failure("Duration invalid"))) + return Ok(Json(ClientResponse::failure("Duration invalid"))); }; // TODO: Properly sanitize this... - let raw_name = &*file_data.file + let raw_name = &*file_data + .file .raw_name() .unwrap() .dangerous_unsafe_unsanitized_raw() @@ -145,21 +160,18 @@ async fn handle_upload( file_data.file.persist_to(&temp_filename).await?; let hash = hash_file(&temp_filename).await?; - let filename = get_id( - raw_name, - hash - ); + let filename = get_id(raw_name, hash); out_path.push(filename.clone()); - let constructed_file = MochiFile::new_with_expiry( - raw_name, - hash, - out_path.clone(), - expire_time - ); + let constructed_file = + MochiFile::new_with_expiry(raw_name, hash, out_path.clone(), expire_time); if !settings.overwrite - && db.read().unwrap().files.contains_key(&constructed_file.get_key()) + && db + .read() + .unwrap() + .files + .contains_key(&constructed_file.get_key()) { info!("Key already in DB, NOT ADDING"); return Ok(Json(ClientResponse { @@ -169,14 +181,16 @@ async fn handle_upload( url: filename, hash: hash.to_hex()[0..10].to_string(), expires: Some(constructed_file.get_expiry()), - ..Default::default() - })) + })); } // Move it to the new proper place std::fs::rename(temp_filename, out_path)?; - db.write().unwrap().files.insert(constructed_file.get_key(), constructed_file.clone()); + db.write() + .unwrap() + .files + .insert(constructed_file.get_key(), constructed_file.clone()); db.write().unwrap().save(); Ok(Json(ClientResponse { @@ -221,8 +235,7 @@ impl ClientResponse { #[rocket::main] async fn main() { // Get or create config file - let config = Settings::open(&"./settings.toml") - .expect("Could not open settings file"); + let config = Settings::open(&"./settings.toml").expect("Could not open settings file"); if !config.temp_dir.try_exists().is_ok_and(|e| e) { fs::create_dir_all(config.temp_dir.clone()).expect("Failed to create temp directory"); @@ -256,11 +269,22 @@ async fn main() { let rocket = rocket::build() .mount( config.server.root_path.clone() + "/", - routes![home, handle_upload, form_handler_js, stylesheet, server_info, favicon, lookup] + routes![ + home, + handle_upload, + form_handler_js, + stylesheet, + server_info, + favicon, + lookup + ], ) .mount( config.server.root_path.clone() + "/files", - FileServer::new(config.file_dir.clone(), Options::Missing | Options::NormalizeDirs) + FileServer::new( + config.file_dir.clone(), + Options::Missing | Options::NormalizeDirs, + ), ) .manage(database) .manage(config) @@ -272,7 +296,10 @@ async fn main() { rocket.expect("Server failed to shutdown gracefully"); info!("Stopping database cleaning thread"); - shutdown.send(()).await.expect("Failed to stop cleaner thread"); + shutdown + .send(()) + .await + .expect("Failed to stop cleaner thread"); info!("Saving database on shutdown..."); local_db.write().unwrap().save(); diff --git a/src/settings.rs b/src/settings.rs index f918eeb..33eda2b 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,9 +1,13 @@ -use std::{fs::{self, File}, io::{self, Read, Write}, path::{Path, PathBuf}}; +use std::{ + fs::{self, File}, + io::{self, Read, Write}, + path::{Path, PathBuf}, +}; use chrono::TimeDelta; -use serde_with::serde_as; -use rocket::serde::{Deserialize, Serialize}; use rocket::data::ToByteUnit; +use rocket::serde::{Deserialize, Serialize}; +use serde_with::serde_as; /// A response to the client from the server #[derive(Deserialize, Serialize, Debug)] @@ -44,7 +48,7 @@ pub struct Settings { impl Default for Settings { fn default() -> Self { Self { - max_filesize: 1.megabytes().into(), // 128 MB + max_filesize: 1.megabytes().into(), // 128 MB overwrite: true, duration: DurationSettings::default(), server: ServerSettings::default(), @@ -103,7 +107,7 @@ impl Default for ServerSettings { Self { address: "127.0.0.1".into(), root_path: "/".into(), - port: 8950 + port: 8950, } } } @@ -134,8 +138,8 @@ pub struct DurationSettings { impl Default for DurationSettings { fn default() -> Self { Self { - maximum: TimeDelta::days(3), // 72 hours - default: TimeDelta::hours(6), // 6 hours + maximum: TimeDelta::days(3), // 72 hours + default: TimeDelta::hours(6), // 6 hours // 1 hour, 6 hours, 24 hours, and 48 hours allowed: vec![ TimeDelta::hours(1), diff --git a/src/strings.rs b/src/strings.rs index e0033fd..98f8a6d 100644 --- a/src/strings.rs +++ b/src/strings.rs @@ -4,13 +4,13 @@ use chrono::TimeDelta; pub fn parse_time_string(string: &str) -> Result> { if string.len() > 7 { - return Err("Not valid time string".into()) + return Err("Not valid time string".into()); } let unit = string.chars().last(); let multiplier = if let Some(u) = unit { if !u.is_ascii_alphabetic() { - return Err("Not valid time string".into()) + return Err("Not valid time string".into()); } match u { @@ -21,13 +21,13 @@ pub fn parse_time_string(string: &str) -> Result> { _ => return Err("Not valid time string".into()), } } else { - return Err("Not valid time string".into()) + return Err("Not valid time string".into()); }; let time = if let Ok(n) = string[..string.len() - 1].parse::() { n } else { - return Err("Not valid time string".into()) + return Err("Not valid time string".into()); }; let final_time = multiplier * time; @@ -41,10 +41,39 @@ pub fn to_pretty_time(seconds: u32) -> String { let mins = ((seconds as f32 - (hour * 3600.0) - (days * 86400.0)) / 60.0).floor(); let secs = seconds as f32 - (hour * 3600.0) - (mins * 60.0) - (days * 86400.0); - let days = if days == 0.0 {"".to_string()} else if days == 1.0 {days.to_string() + "
day"} else {days.to_string() + "
days"}; - let hour = if hour == 0.0 {"".to_string()} else if hour == 1.0 {hour.to_string() + "
hour"} else {hour.to_string() + "
hours"}; - let mins = if mins == 0.0 {"".to_string()} else if mins == 1.0 {mins.to_string() + "
minute"} else {mins.to_string() + "
minutes"}; - let secs = if secs == 0.0 {"".to_string()} else if secs == 1.0 {secs.to_string() + "
second"} else {secs.to_string() + "
seconds"}; + let days = if days == 0.0 { + "".to_string() + } else if days == 1.0 { + days.to_string() + "
day" + } else { + days.to_string() + "
days" + }; - (days + " " + &hour + " " + &mins + " " + &secs).trim().to_string() + let hour = if hour == 0.0 { + "".to_string() + } else if hour == 1.0 { + hour.to_string() + "
hour" + } else { + hour.to_string() + "
hours" + }; + + let mins = if mins == 0.0 { + "".to_string() + } else if mins == 1.0 { + mins.to_string() + "
minute" + } else { + mins.to_string() + "
minutes" + }; + + let secs = if secs == 0.0 { + "".to_string() + } else if secs == 1.0 { + secs.to_string() + "
second" + } else { + secs.to_string() + "
seconds" + }; + + (days + " " + &hour + " " + &mins + " " + &secs) + .trim() + .to_string() } diff --git a/src/utils.rs b/src/utils.rs index 7a0faae..4d2cf65 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ -use std::path::Path; use blake3::Hash; +use std::path::Path; /// Get a filename based on the file's hashed name pub fn get_id(name: &str, hash: Hash) -> String {