diff --git a/Cargo.lock b/Cargo.lock index 8320264..bcf02d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,6 +238,7 @@ dependencies = [ "chrono", "log", "maud", + "rand", "rocket", "serde", "serde_with", diff --git a/Cargo.toml b/Cargo.toml index c6c9828..4002a73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ blake3 = { version = "1.5.4", features = ["mmap", "rayon", "serde"] } chrono = { version = "0.4.38", features = ["serde"] } log = "0.4" maud = { version = "0.26", features = ["rocket"] } +rand = "0.8.5" rocket = { version = "0.5", features = ["json"] } serde = { version = "1.0.213", features = ["derive"] } serde_with = { version = "3.11.0", features = ["chrono_0_4"] } diff --git a/src/database.rs b/src/database.rs index ff94105..1408e8a 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,26 +1,32 @@ use std::{ - collections::HashMap, - fs::{self, File}, - path::{Path, PathBuf}, - sync::{Arc, RwLock}, + collections::{hash_map::Values, HashMap, HashSet}, fs::{self, File}, path::{Path, PathBuf}, sync::{Arc, RwLock} }; use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode}; use blake3::Hash; use chrono::{DateTime, TimeDelta, Utc}; use log::{info, warn}; +use rand::distributions::{Alphanumeric, DistString}; use rocket::{ serde::{Deserialize, Serialize}, tokio::{select, sync::mpsc::Receiver, time}, }; +use crate::settings::Settings; + const BINCODE_CFG: Configuration = bincode::config::standard(); #[derive(Debug, Clone, Decode, Encode)] pub struct Database { path: PathBuf, + + /// Every hash in the database along with the [`Mmid`]s associated with them #[bincode(with_serde)] - pub files: HashMap, + hashes: HashMap>, + + /// All entries in the database + #[bincode(with_serde)] + entries: HashMap, } impl Database { @@ -29,7 +35,8 @@ impl Database { let output = Self { path: path.as_ref().to_path_buf(), - files: HashMap::new(), + entries: HashMap::new(), + hashes: HashMap::new(), }; encode_into_std_write(&output, &mut file, BINCODE_CFG).expect("Could not write database!"); @@ -37,6 +44,7 @@ impl Database { output } + /// Open the database from a path, **or create it if it does not exist** pub fn open>(path: &P) -> Self { if !path.as_ref().exists() { Self::new(path) @@ -46,6 +54,7 @@ impl Database { } } + /// Save the database to its file pub fn save(&self) { let mut out_path = self.path.clone(); out_path.set_extension(".bkp"); @@ -54,18 +63,90 @@ impl Database { fs::rename(out_path, &self.path).unwrap(); } + + /// Insert a [`MochiFile`] into the database. + /// + /// If the database already contained this value, then `false` is returned. + pub fn insert(&mut self, entry: MochiFile) -> bool { + if let Some(s) = self.hashes.get_mut(&entry.hash) { + // If the database already contains the hash, make sure the file is unique + if !s.insert(entry.mmid.clone()) { + return false; + } + } else { + // If the database does not contain the hash, create a new set for it + self.hashes.insert(entry.hash, HashSet::from([entry.mmid.clone()])); + } + + self.entries.insert(entry.mmid.clone(), entry.clone()); + + true + } + + /// Remove an [`Mmid`] from the database entirely. + /// + /// If the database did not contain this value, then `false` is returned. + pub fn remove_mmid(&mut self, mmid: &Mmid) -> bool { + let hash = if let Some(h) = self.entries.get(mmid).and_then(|e| Some(e.hash)) { + self.entries.remove(mmid); + h + } else { + return false + }; + + if let Some(s) = self.hashes.get_mut(&hash) { + s.remove(mmid); + } + + true + } + + /// Remove a hash from the database entirely. + /// + /// Will not remove (returns [`Some(false)`] if hash contains references. + pub fn remove_hash(&mut self, hash: &Hash) -> Option { + if let Some(s) = self.hashes.get(hash) { + if s.is_empty() { + self.hashes.remove(hash); + return Some(true) + } else { + return Some(false) + } + } else { + return None + } + } + + /// Checks if a hash contained in the database contains no more [`Mmid`]s. + pub fn is_hash_empty(&self, hash: &Hash) -> Option { + if let Some(s) = self.hashes.get(hash) { + Some(s.is_empty()) + } else { + None + } + } + + /// Get an entry by its [`Mmid`]. Returns [`None`] if the value does not exist. + pub fn get(&self, mmid: &Mmid) -> Option<&MochiFile> { + self.entries.get(mmid) + } + + pub fn entries(&self) -> Values<'_, Mmid, MochiFile> { + self.entries.values() + } } +/// An entry in the database storing metadata about a file #[derive(Debug, Clone, Decode, Encode, Deserialize, Serialize)] #[serde(crate = "rocket::serde")] pub struct MochiFile { + /// A unique identifier describing this file + mmid: Mmid, + /// The original name of the file name: String, - /// The location on disk (for deletion and management) - filename: PathBuf, - - /// The hashed contents of the file as a Blake3 hash + /// The Blake3 hash of the file #[bincode(with_serde)] hash: Hash, @@ -81,17 +162,17 @@ pub struct MochiFile { impl MochiFile { /// Create a new file that expires in `expiry`. pub fn new_with_expiry( - name: &str, + mmid: Mmid, + name: String, hash: Hash, - filename: PathBuf, expire_duration: TimeDelta, ) -> Self { let current = Utc::now(); let expiry = current + expire_duration; Self { - name: name.to_string(), - filename, + mmid, + name, hash, upload_datetime: current, expiry_datetime: expiry, @@ -102,22 +183,11 @@ impl MochiFile { &self.name } - pub fn path(&self) -> &PathBuf { - &self.filename - } - - pub fn get_key(&self) -> MochiKey { - MochiKey { - name: self.name.clone(), - hash: self.hash, - } - } - - pub fn get_expiry(&self) -> DateTime { + pub fn expiry(&self) -> DateTime { self.expiry_datetime } - pub fn expired(&self) -> bool { + pub fn is_expired(&self) -> bool { let datetime = Utc::now(); datetime > self.expiry_datetime } @@ -125,64 +195,58 @@ impl MochiFile { pub fn hash(&self) -> &Hash { &self.hash } -} -#[derive(Debug, Clone, PartialEq, Eq, Hash, Decode, Encode, Deserialize, Serialize)] -#[serde(crate = "rocket::serde")] -pub struct MochiKey { - name: String, - #[bincode(with_serde)] - hash: Hash, + pub fn mmid(&self) -> &Mmid { + &self.mmid + } } /// Clean the database. Removes files which are past their expiry /// [`chrono::DateTime`]. Also removes files which no longer exist on the disk. -fn clean_database(db: &Arc>) { +fn clean_database( + db: &Arc>, + file_path: &PathBuf, +) { let mut database = db.write().unwrap(); + + // Add expired entries to the removal list let files_to_remove: Vec<_> = database - .files - .iter() + .entries() .filter_map(|e| { - if e.1.expired() { - // Check if the entry has expired - Some((e.0.clone(), e.1.clone())) - } else if !e.1.path().try_exists().is_ok_and(|r| r) { - // Check if the entry exists - Some((e.0.clone(), e.1.clone())) + if e.is_expired() { + Some((e.mmid().clone(), e.hash().clone())) } else { None } }) .collect(); - let mut expired = 0; - let mut missing = 0; - for file in &files_to_remove { - let path = file.1.path(); - // If the path does not exist, there's no reason to try to remove it. - if path.try_exists().is_ok_and(|r| r) { - match fs::remove_file(path) { - Ok(_) => (), - Err(e) => warn!("Failed to delete path at {:?}: {e}", path), - } - expired += 1; - } else { - missing += 1 - } + let mut removed_files = 0; + let mut removed_entries = 0; + for e in &files_to_remove { - database.files.remove(&file.0); + if database.remove_mmid(&e.0) { + removed_entries += 1; + } + if database.is_hash_empty(&e.1).is_some_and(|b| b) { + database.remove_hash(&e.1); + if let Err(e) = fs::remove_file(file_path.join(e.1.to_string())) { + warn!("Failed to remove expired hash: {}", e); + } else { + removed_files += 1; + } + } } - info!( - "{} expired and {} missing items cleared from database", - expired, missing - ); + info!("Cleaned database. Removed {removed_entries} expired entries. Removed {removed_files} no longer referenced files."); + database.save(); } /// A loop to clean the database periodically. pub async fn clean_loop( db: Arc>, + file_path: PathBuf, mut shutdown_signal: Receiver<()>, interval: TimeDelta, ) { @@ -190,8 +254,37 @@ pub async fn clean_loop( loop { select! { - _ = interval.tick() => clean_database(&db), + _ = interval.tick() => clean_database(&db, &file_path), _ = shutdown_signal.recv() => break, }; } } + +/// A unique identifier for an entry in the database, 8 characters long, +/// consists of ASCII alphanumeric characters (`a-z`, `A-Z`, and `0-9`). +#[derive(Debug, PartialEq, Eq, Clone, Decode, Encode, Hash)] +#[derive(Deserialize, Serialize)] +pub struct Mmid(String); + +impl Mmid { + /// Create a new random MMID + pub fn new() -> Self { + let string = Alphanumeric.sample_string(&mut rand::thread_rng(), 8); + + Self(string) + } + + pub fn as_str(&self) -> &str { + &self.0 + } + + pub fn to_string(&self) -> String { + self.0.to_owned() + } +} + +impl From<&str> for Mmid { + fn from(value: &str) -> Self { + Self(value.to_owned()) + } +} diff --git a/src/endpoints.rs b/src/endpoints.rs index 54293e0..572bb44 100644 --- a/src/endpoints.rs +++ b/src/endpoints.rs @@ -1,15 +1,11 @@ use std::sync::{Arc, RwLock}; use rocket::{ - get, - http::RawStr, - response::{status::NotFound, Redirect}, - serde::{self, json::Json}, - State, + fs::NamedFile, get, serde::{self, json::Json}, State }; use serde::Serialize; -use crate::{database::Database, get_id, settings::Settings}; +use crate::{database::Database, settings::Settings}; /// An endpoint to obtain information about the server's capabilities #[get("/info")] @@ -41,14 +37,17 @@ pub struct ServerInfo { /// Look up the hash of a file to find it. This only returns the first /// hit for a hash, so different filenames may not be found. #[get("/f/")] -pub fn lookup(db: &State>>, id: &str) -> Result> { - for file in db.read().unwrap().files.values() { - if file.hash().to_hex()[0..10].to_string() == id { - let filename = get_id(file.name(), *file.hash()); - let filename = RawStr::new(&filename).percent_encode().to_string(); - return Ok(Redirect::to(format!("/files/{}", filename))); - } - } +pub async fn lookup( + db: &State>>, + settings: &State, + id: &str +) -> Option { + dbg!(db.read().unwrap()); + let entry = if let Some(e) = db.read().unwrap().get(&id.into()).cloned() { + e + } else { + return None + }; - Err(NotFound(())) + NamedFile::open(settings.file_dir.join(entry.hash().to_string())).await.ok() } diff --git a/src/file_server.rs b/src/file_server.rs new file mode 100644 index 0000000..b7a28f3 --- /dev/null +++ b/src/file_server.rs @@ -0,0 +1,14 @@ +use std::{path::PathBuf, sync::{Arc, RwLock}}; +use rocket::{fs::NamedFile, get, State}; + +use crate::database::Database; + +#[get("/")] +async fn files( + db: &State>>, + ident: PathBuf +) -> Option { + //let file = NamedFile::open(Path::new("static/").join(file)).await.ok(); + + todo!() +} diff --git a/src/main.rs b/src/main.rs index a60aeba..78b7322 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod endpoints; mod settings; mod strings; mod utils; +mod file_server; use std::{ fs, @@ -10,7 +11,7 @@ use std::{ }; use chrono::{DateTime, TimeDelta, Utc}; -use database::{clean_loop, Database, MochiFile}; +use database::{clean_loop, Database, Mmid, MochiFile}; use endpoints::{lookup, server_info}; use log::info; use maud::{html, Markup, PreEscaped, DOCTYPE}; @@ -28,7 +29,7 @@ use rocket::{ }; use settings::Settings; use strings::{parse_time_string, to_pretty_time}; -use utils::{get_id, hash_file}; +use utils::hash_file; use uuid::Uuid; fn head(page_title: &str) -> Markup { @@ -86,9 +87,9 @@ fn home(settings: &State) -> Markup { } form #uploadForm { // It's stupid how these can't be styled so they're just hidden here... - input id="fileInput" type="file" name="fileUpload" multiple + input #fileInput type="file" name="fileUpload" multiple onchange="formSubmit(this.parentNode)" data-max-filesize=(settings.max_filesize) style="display:none;"; - input id="fileDuration" type="text" name="duration" minlength="2" + input #fileDuration type="text" name="duration" minlength="2" maxlength="7" value=(settings.duration.default.num_seconds().to_string() + "s") style="display:none;"; } hr; @@ -130,23 +131,21 @@ async fn handle_upload( let mut out_path = settings.file_dir.clone(); let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) { - if t > settings.duration.maximum { - return Ok(Json(ClientResponse::failure( - "Duration larger than maximum", - ))); - } - if settings.duration.restrict_to_allowed && !settings.duration.allowed.contains(&t) { return Ok(Json(ClientResponse::failure("Duration not allowed"))); } + if t > settings.duration.maximum { + return Ok(Json(ClientResponse::failure("Duration larger than max"))); + } + t } else { return Ok(Json(ClientResponse::failure("Duration invalid"))); }; // TODO: Properly sanitize this... - let raw_name = &*file_data + let raw_name = file_data .file .raw_name() .unwrap() @@ -158,47 +157,28 @@ async fn handle_upload( temp_dir.push(Uuid::new_v4().to_string()); let temp_filename = temp_dir; file_data.file.persist_to(&temp_filename).await?; - let hash = hash_file(&temp_filename).await?; + let file_hash = hash_file(&temp_filename).await?; - let filename = get_id(raw_name, hash); - out_path.push(filename.clone()); + let file_mmid = Mmid::new(); + out_path.push(file_hash.to_string()); let constructed_file = - MochiFile::new_with_expiry(raw_name, hash, out_path.clone(), expire_time); - - if !settings.overwrite - && db - .read() - .unwrap() - .files - .contains_key(&constructed_file.get_key()) - { - info!("Key already in DB, NOT ADDING"); - return Ok(Json(ClientResponse { - status: true, - response: "File already exists", - name: constructed_file.name().clone(), - url: filename, - hash: hash.to_hex()[0..10].to_string(), - expires: Some(constructed_file.get_expiry()), - })); - } + MochiFile::new_with_expiry(file_mmid.clone(), raw_name, file_hash, expire_time); // Move it to the new proper place std::fs::rename(temp_filename, out_path)?; db.write() .unwrap() - .files - .insert(constructed_file.get_key(), constructed_file.clone()); + .insert(constructed_file.clone()); db.write().unwrap().save(); Ok(Json(ClientResponse { status: true, name: constructed_file.name().clone(), - url: filename, - hash: hash.to_hex()[0..10].to_string(), - expires: Some(constructed_file.get_expiry()), + mmid: Some(file_mmid), + hash: file_hash.to_hex()[0..10].to_string(), + expires: Some(constructed_file.expiry()), ..Default::default() })) } @@ -214,8 +194,8 @@ struct ClientResponse { #[serde(skip_serializing_if = "str::is_empty")] pub name: String, - #[serde(skip_serializing_if = "str::is_empty")] - pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub mmid: Option, #[serde(skip_serializing_if = "str::is_empty")] pub hash: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -263,7 +243,8 @@ async fn main() { let (shutdown, rx) = tokio::sync::mpsc::channel(1); tokio::spawn({ let cleaner_db = database.clone(); - async move { clean_loop(cleaner_db, rx, TimeDelta::minutes(2)).await } + let file_path = config.file_dir.clone(); + async move { clean_loop(cleaner_db, file_path, rx, TimeDelta::seconds(10)).await } }); let rocket = rocket::build() diff --git a/src/utils.rs b/src/utils.rs index 4d2cf65..dfb202b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,11 +1,6 @@ use blake3::Hash; use std::path::Path; -/// Get a filename based on the file's hashed name -pub fn get_id(name: &str, hash: Hash) -> String { - hash.to_hex()[0..10].to_string() + "_" + name -} - /// Get the Blake3 hash of a file, without reading it all into memory, and also get the size pub async fn hash_file>(input: &P) -> Result { let mut hasher = blake3::Hasher::new(); diff --git a/web/request.js b/web/request.js index 59a898b..c931545 100644 --- a/web/request.js +++ b/web/request.js @@ -90,11 +90,11 @@ function makeErrored(progressBar, progressText, linkRow, errorMessage) { linkRow.style.background = "#ffb2ae"; } -function makeFinished(progressBar, progressText, linkRow, linkAddress, hash) { +function makeFinished(progressBar, progressText, linkRow, MMID, _hash) { progressText.textContent = ""; const link = progressText.appendChild(document.createElement("a")); - link.textContent = hash; - link.href = "/files/" + linkAddress; + link.textContent = MMID; + link.href = "/f/" + MMID; link.target = "_blank"; let button = linkRow.appendChild(document.createElement("button")); @@ -105,7 +105,7 @@ function makeFinished(progressBar, progressText, linkRow, linkAddress, hash) { clearTimeout(buttonTimeout) } navigator.clipboard.writeText( - encodeURI(window.location.protocol + "//" + window.location.host + "/files/" + linkAddress) + encodeURI(window.location.protocol + "//" + window.location.host + "/f/" + MMID) ) button.textContent = "✅"; buttonTimeout = setTimeout(function() { @@ -143,7 +143,7 @@ function uploadComplete(response, progressBar, progressText, linkRow) { if (response.status) { console.log("Successfully uploaded file", response); - makeFinished(progressBar, progressText, linkRow, response.url, response.hash); + makeFinished(progressBar, progressText, linkRow, response.mmid, response.hash); } else { console.error("Error uploading", response); makeErrored(progressBar, progressText, linkRow, response.response); @@ -179,6 +179,8 @@ async function initEverything() { if (this.classList.contains("selected")) { return } + document.getElementById("uploadForm").elements["duration"].value + = this.dataset.durationSeconds + "s"; let selected = this.parentNode.getElementsByClassName("selected"); selected[0].classList.remove("selected"); this.classList.add("selected");