diff --git a/.gitignore b/.gitignore index 355dd8d..328d2b5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ /target /tmp /src/files +temp_files *.mochi +settings.toml diff --git a/Cargo.lock b/Cargo.lock index 5cc2997..c555fbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -740,6 +740,7 @@ dependencies = [ "log", "maud", "rocket", + "toml", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 132eaa6..a83d903 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ chrono = { version = "0.4.38", features = ["serde"] } log = "0.4" maud = { version = "0.26", features = ["rocket"] } rocket = { version = "0.5", features = ["json"] } +toml = "0.8.19" uuid = { version = "1.11.0", features = ["v4"] } [profile.production] diff --git a/src/database.rs b/src/database.rs index 6db93a1..c072ad7 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,9 +1,10 @@ -use std::{collections::HashMap, fs::{self, File}, path::{Path, PathBuf}}; +use std::{collections::HashMap, fs::{self, File}, path::{Path, PathBuf}, sync::{Arc, RwLock}, time::Duration}; use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode}; use chrono::{DateTime, TimeDelta, Utc}; use blake3::Hash; -use rocket::serde::{Deserialize, Serialize}; +use log::{info, warn}; +use rocket::{serde::{Deserialize, Serialize}, tokio::{select, sync::mpsc::Receiver, time}}; const BINCODE_CFG: Configuration = bincode::config::standard(); @@ -131,3 +132,57 @@ pub struct MochiKey { #[bincode(with_serde)] hash: Hash, } + +/// Clean the database. Removes files which are past their expiry +/// [`chrono::DateTime`]. Also removes files which no longer exist on the disk. +fn clean_database(db: &Arc>) { + let mut database = db.write().unwrap(); + let files_to_remove: Vec<_> = database.files.iter().filter_map(|e| { + if e.1.expired() { + // Check if the entry has expired + Some((e.0.clone(), e.1.clone())) + } else if !e.1.path().try_exists().is_ok_and(|r| r) { + // Check if the entry exists + Some((e.0.clone(), e.1.clone())) + } else { + None + } + }).collect(); + + let mut expired = 0; + let mut missing = 0; + for file in &files_to_remove { + let path = file.1.path(); + // If the path does not exist, there's no reason to try to remove it. + if path.try_exists().is_ok_and(|r| r) { + match fs::remove_file(path) { + Ok(_) => (), + Err(e) => warn!("Failed to delete path at {:?}: {e}", path), + } + expired += 1; + } else { + missing += 1 + } + + database.files.remove(&file.0); + } + + info!("{} expired and {} missing items cleared from database", expired, missing); + database.save(); +} + +/// A loop to clean the database periodically. +pub async fn clean_loop( + db: Arc>, + mut shutdown_signal: Receiver<()>, + interval: Duration, +) { + let mut interval = time::interval(interval); + + loop { + select! { + _ = interval.tick() => clean_database(&db), + _ = shutdown_signal.recv() => break, + }; + } +} diff --git a/src/main.rs b/src/main.rs index 556eb4e..4596d0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,19 @@ mod database; +mod time_string; +mod settings; -use std::{fs, path::{Path, PathBuf}, sync::{Arc, RwLock}, time::Duration}; +use std::{path::{Path, PathBuf}, sync::{Arc, RwLock}, time::Duration}; use blake3::Hash; use chrono::{DateTime, TimeDelta, Utc}; -use database::{Database, MochiFile}; -use log::{debug, info, warn}; -use maud::{html, Markup, DOCTYPE, PreEscaped}; +use database::{clean_loop, Database, MochiFile}; +use log::info; use rocket::{ - form::Form, fs::{FileServer, Options, TempFile}, get, post, routes, serde::{json::Json, Serialize}, tokio::{self, fs::File, io::AsyncReadExt, select, spawn, time}, FromForm, State + data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, post, response::content::{RawCss, RawJavaScript}, routes, serde::{json::Json, Serialize}, tokio::{self, fs::File, io::AsyncReadExt}, Config, FromForm, State }; +use settings::Settings; +use time_string::parse_time_string; use uuid::Uuid; +use maud::{html, Markup, DOCTYPE}; fn head(page_title: &str) -> Markup { html! { @@ -18,12 +22,22 @@ fn head(page_title: &str) -> Markup { meta name="viewport" content="width=device-width, initial-scale=1"; title { (page_title) } // Javascript stuff for client side handling - script { (PreEscaped(include_str!("static/request.js"))) } - // CSS for styling the sheets - style { (PreEscaped(include_str!("static/main.css"))) } + script src="request.js" { } } } +/// Stylesheet +#[get("/main.css")] +fn stylesheet() -> RawCss<&'static str> { + RawCss(include_str!("static/main.css")) +} + +/// Upload handler javascript +#[get("/request.js")] +fn form_handler_js() -> RawJavaScript<&'static str> { + RawJavaScript(include_str!("static/request.js")) +} + #[get("/")] fn home() -> Markup { html! { @@ -36,6 +50,8 @@ fn home() -> Markup { "Upload File" } input id="fileInput" type="file" name="fileUpload" onchange="formSubmit(this.parentNode)" style="display:none;"; + br; + input type="text" name="duration" minlength="2" maxlength="4"; } div class="progress-box" { progress id="uploadProgress" value="0" max="100" {} @@ -55,6 +71,9 @@ fn home() -> Markup { struct Upload<'r> { #[field(name = "fileUpload")] file: TempFile<'r>, + + #[field(name = "duration")] + expire_time: String, } /// Handle a file upload and store it @@ -62,27 +81,42 @@ struct Upload<'r> { async fn handle_upload( mut file_data: Form>, db: &State>> -) -> Result, std::io::Error> { +) -> Result, std::io::Error> { let mut out_path = PathBuf::from("files/"); + let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) { + if t < TimeDelta::days(365) { + t + } else { + TimeDelta::days(365) + } + } else { + return Ok(Json(ClientResponse { + status: false, + response: "Invalid duration", + ..Default::default() + })) + }; + // Get temp path and hash it let temp_filename = "temp_files/".to_owned() + &Uuid::new_v4().to_string(); file_data.file.persist_to(&temp_filename).await?; let hash = hash_file(&temp_filename).await?; - let filename = get_filename( - // TODO: Properly sanitize this... - file_data.file.raw_name().unwrap().dangerous_unsafe_unsanitized_raw().as_str(), + // TODO: Properly sanitize this... + let raw_name = file_data.file.raw_name().unwrap().dangerous_unsafe_unsanitized_raw().as_str(); + let filename = get_id( + raw_name, hash.0 ); out_path.push(filename.clone()); let constructed_file = MochiFile::new_with_expiry( - file_data.file.raw_name().unwrap().dangerous_unsafe_unsanitized_raw().as_str(), + raw_name, hash.1, hash.0, out_path.clone(), - TimeDelta::hours(24) + expire_time ); // Move it to the new proper place @@ -91,23 +125,35 @@ async fn handle_upload( db.write().unwrap().files.insert(constructed_file.get_key(), constructed_file.clone()); db.write().unwrap().save(); - let location = FileLocation { - name: constructed_file.name().clone(), + Ok(Json(ClientResponse { status: true, + name: Some(constructed_file.name().clone()), url: Some("files/".to_string() + &filename), - expires: constructed_file.get_expiry(), - }; - - Ok(Json(location)) + expires: Some(constructed_file.get_expiry()), + ..Default::default() + })) } -#[derive(Serialize)] +/// A response to the client from the server +#[derive(Serialize, Default, Debug)] #[serde(crate = "rocket::serde")] -struct FileLocation { - pub name: String, +struct ClientResponse { + /// Success or failure pub status: bool, + + pub response: &'static str, + + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub url: Option, - pub expires: DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires: Option>, +} + +/// Get a filename based on the file's hashed name +fn get_id(name: &str, hash: Hash) -> String { + hash.to_hex()[0..10].to_string() + "_" + name } /// Get the Blake3 hash of a file, without reading it all into memory, and also get the size @@ -127,63 +173,65 @@ async fn hash_file>(input: &P) -> Result<(Hash, usize), std::io:: Ok((hasher.finalize(), total)) } -/// Get a random filename for use as the uploaded file's name -fn get_filename(name: &str, hash: Hash) -> String { - hash.to_hex()[0..10].to_string() + "_" + name +/// An endpoint to obtain information about the server's capabilities +#[get("/info")] +fn server_info(settings: &State) -> Json { + Json(ServerInfo { + max_filesize: settings.max_filesize, + max_duration: settings.max_duration, + }) } -fn clean_database(db: &Arc>) { - info!("Cleaning database"); - let mut expired_list = Vec::new(); - - let mut database = db.write().unwrap(); - for file in &database.files { - if file.1.expired() { - expired_list.push((file.0.clone(), file.1.clone())); - } - } - - for file in expired_list { - let path = file.1.path(); - if path.exists() { - match fs::remove_file(path) { - Ok(_) => (), - Err(_) => warn!("Failed to delete path at {:?}", path), - } - } - debug!("Deleted file: {}", file.1.name()); - database.files.remove(&file.0); - } - - database.save(); +#[derive(Serialize, Debug)] +#[serde(crate = "rocket::serde")] +struct ServerInfo { + max_filesize: u64, + max_duration: u32, } #[rocket::main] async fn main() { + // Get or create config file + let config = Settings::open(&"./settings.toml") + .expect("Could not open settings file"); + + // Set rocket configuration settings + let rocket_config = Config { + address: config.server.address.parse().expect("IP address invalid"), + port: config.server.port, + temp_dir: config.temp_dir.clone().into(), + limits: Limits::default() + .limit("data-form", config.max_filesize.bytes()) + .limit("file", config.max_filesize.bytes()), + ..Default::default() + }; + let database = Arc::new(RwLock::new(Database::open(&"database.mochi"))); let local_db = database.clone(); // Start monitoring thread - let (shutdown, mut rx) = tokio::sync::mpsc::channel(1); - let cleaner_db = database.clone(); - spawn(async move { - let mut interval = time::interval(Duration::from_secs(120)); - - loop { - select! { - _ = interval.tick() => clean_database(&cleaner_db), - _ = rx.recv() => break, - }; - } + let (shutdown, rx) = tokio::sync::mpsc::channel(1); + tokio::spawn({ + let cleaner_db = database.clone(); + async move { clean_loop(cleaner_db, rx, Duration::from_secs(120)).await } }); let rocket = rocket::build() + .mount( + config.root_path.clone() + "/", + routes![home, handle_upload, form_handler_js, stylesheet, server_info] + ) + .mount( + config.root_path.clone() + "/files", + FileServer::new("files/", Options::Missing | Options::NormalizeDirs) + ) .manage(database) - .mount("/", routes![home, handle_upload]) - .mount("/files", FileServer::new("files/", Options::Missing | Options::NormalizeDirs)) + .manage(config) + .configure(rocket_config) .launch() .await; + // Ensure the server gracefully shuts down rocket.expect("Server failed to shutdown gracefully"); info!("Stopping database cleaning thread"); diff --git a/src/settings.rs b/src/settings.rs new file mode 100644 index 0000000..5105a93 --- /dev/null +++ b/src/settings.rs @@ -0,0 +1,90 @@ +use std::{fs::{self, File}, io::{self, Read, Write}, path::{Path, PathBuf}}; + +use rocket::serde::{Deserialize, Serialize}; + +/// A response to the client from the server +#[derive(Deserialize, Serialize, Debug)] +#[serde(crate = "rocket::serde")] +pub struct Settings { + /// Maximum filesize in bytes + pub max_filesize: u64, + + /// Maximum file lifetime, seconds + pub max_duration: u32, + + /// The path to the root directory of the program, ex `/filehost/` + pub root_path: String, + + /// The path to the database file + pub database_path: PathBuf, + + /// Temporary directory for stuff + pub temp_dir: PathBuf, + + pub server: ServerSettings, + + #[serde(skip)] + path: PathBuf, +} + +impl Default for Settings { + fn default() -> Self { + Self { + max_filesize: 128_000_000, // 128MB + max_duration: 86_400, // 1 day + root_path: "/".into(), + server: ServerSettings::default(), + path: "./settings.toml".into(), + database_path: "./database.mochi".into(), + temp_dir: std::env::temp_dir() + } + } +} + +impl Settings { + pub fn open>(path: &P) -> Result { + let mut input_str = String::new(); + if !path.as_ref().exists() { + let new_self = Self { + path: path.as_ref().to_path_buf(), + ..Default::default() + }; + new_self.save()?; + return Ok(new_self); + } else { + File::open(path).unwrap().read_to_string(&mut input_str)?; + } + + let mut parsed_settings: Self = toml::from_str(&input_str).unwrap(); + parsed_settings.path = path.as_ref().to_path_buf(); + + Ok(parsed_settings) + } + + pub fn save(&self) -> Result<(), io::Error> { + let mut out_path = self.path.clone(); + out_path.set_extension(".bkp"); + let mut file = File::create(&out_path).expect("Could not save!"); + file.write_all(&toml::to_string_pretty(self).unwrap().into_bytes())?; + + fs::rename(out_path, &self.path).unwrap(); + + Ok(()) + } +} + +#[derive(Deserialize, Serialize, Debug)] +#[serde(crate = "rocket::serde")] +pub struct ServerSettings { + pub address: String, + pub port: u16, +} + +impl Default for ServerSettings { + fn default() -> Self { + Self { + address: "127.0.0.1".into(), + port: 8955 + } + } +} diff --git a/src/static/main.css b/src/static/main.css index 9d97277..858278a 100644 --- a/src/static/main.css +++ b/src/static/main.css @@ -1,5 +1,4 @@ @import url('https://g2games.dev/assets/fonts/fonts.css'); -@import url('https://g2games.dev/assets/main-style.css'); .main-wrapper { margin: auto; diff --git a/src/static/request.js b/src/static/request.js index 03ba5f1..dc96421 100644 --- a/src/static/request.js +++ b/src/static/request.js @@ -1,11 +1,30 @@ -let progressBar = null; -let progressValue = null; -let statusNotifier = null; +let progressBar; +let progressValue; +let statusNotifier; +let uploadedFilesDisplay; let uploadInProgress = false; -let uploadedFilesDisplay = null; + +const TOO_LARGE_TEXT = "File is too large!"; +const ERROR_TEXT = "An error occured!"; + +let MAX_FILESIZE; +let MAX_DURATION; async function formSubmit(form) { + if (uploadInProgress) { + return; + } + + // Get file size and don't upload if it's too large + let file_upload = document.getElementById("fileInput"); + let file = file_upload.files[0]; + if (file.size > MAX_FILESIZE) { + progressValue.textContent = TOO_LARGE_TEXT; + console.error("Provided file is too large", file.size, "bytes; max", MAX_FILESIZE, "bytes"); + return; + } + let url = "/upload"; let request = new XMLHttpRequest(); request.open('POST', url, true); @@ -19,7 +38,7 @@ async function formSubmit(form) { try { request.send(new FormData(form)); } catch (e) { - console.log(e); + console.error("An error occured while uploading", e); } // Reset the form data since we've successfully submitted it @@ -28,24 +47,27 @@ async function formSubmit(form) { function networkErrorHandler(_err) { uploadInProgress = false; - console.log("An error occured while uploading"); + console.error("A network error occured while uploading"); progressValue.textContent = "A network error occured!"; } function uploadComplete(response) { let target = response.target; - console.log(target); if (target.status === 200) { const response = JSON.parse(target.responseText); - console.log(response); if (response.status) { progressValue.textContent = "Success"; addToList(response.name, response.url); + } else { + console.error("Error uploading", response) + progressValue.textContent = response.response; } } else if (target.status === 413) { - progressValue.textContent = "File too large!"; + progressValue.textContent = TOO_LARGE_TEXT; + } else { + progressValue.textContent = ERROR_TEXT; } uploadInProgress = false; @@ -67,18 +89,18 @@ function uploadProgress(progress) { } } -function attachFormSubmitEvent(formId) { - if (uploadInProgress) { - return; - } - - document.getElementById(formId).addEventListener("submit", formSubmit); +async function getServerCapabilities() { + let capabilities = await fetch("info").then((response) => response.json()); + MAX_FILESIZE = capabilities.max_filesize; + MAX_DURATION = capabilities.max_duration; } document.addEventListener("DOMContentLoaded", function(_event){ - attachFormSubmitEvent("uploadForm"); + document.getElementById("uploadForm").addEventListener("submit", formSubmit); progressBar = document.getElementById("uploadProgress"); progressValue = document.getElementById("uploadProgressValue"); statusNotifier = document.getElementById("uploadStatus"); uploadedFilesDisplay = document.getElementById("uploadedFilesDisplay"); -}) +}); + +getServerCapabilities(); diff --git a/src/time_string.rs b/src/time_string.rs new file mode 100644 index 0000000..9549f46 --- /dev/null +++ b/src/time_string.rs @@ -0,0 +1,36 @@ +use std::error::Error; + +use chrono::TimeDelta; + +pub fn parse_time_string(string: &str) -> Result> { + if string.len() > 5 { + return Err("Not valid time string".into()) + } + + let unit = string.chars().last(); + let multiplier = if let Some(u) = unit { + if !u.is_ascii_alphabetic() { + return Err("Not valid time string".into()) + } + + match u { + 'D' | 'd' => TimeDelta::days(1), + 'H' | 'h' => TimeDelta::hours(1), + 'M' | 'm' => TimeDelta::minutes(1), + 'S' | 's' => TimeDelta::seconds(1), + _ => return Err("Not valid time string".into()), + } + } else { + return Err("Not valid time string".into()) + }; + + let time = if let Ok(n) = string[..string.len() - 1].parse::() { + n + } else { + return Err("Not valid time string".into()) + }; + + let final_time = multiplier * time; + + Ok(final_time) +}