Ran cargo fmt

This commit is contained in:
G2-Games 2024-10-27 02:03:55 -05:00
parent 159300a0d0
commit e05f0373b3
6 changed files with 159 additions and 83 deletions

View file

@ -1,19 +1,26 @@
use std::{collections::HashMap, fs::{self, File}, path::{Path, PathBuf}, sync::{Arc, RwLock}}; use std::{
collections::HashMap,
fs::{self, File},
path::{Path, PathBuf},
sync::{Arc, RwLock},
};
use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode}; use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode};
use chrono::{DateTime, TimeDelta, Utc};
use blake3::Hash; use blake3::Hash;
use chrono::{DateTime, TimeDelta, Utc};
use log::{info, warn}; use log::{info, warn};
use rocket::{serde::{Deserialize, Serialize}, tokio::{select, sync::mpsc::Receiver, time}}; use rocket::{
serde::{Deserialize, Serialize},
tokio::{select, sync::mpsc::Receiver, time},
};
const BINCODE_CFG: Configuration = bincode::config::standard(); const BINCODE_CFG: Configuration = bincode::config::standard();
#[derive(Debug, Clone)] #[derive(Debug, Clone, Decode, Encode)]
#[derive(Decode, Encode)]
pub struct Database { pub struct Database {
path: PathBuf, path: PathBuf,
#[bincode(with_serde)] #[bincode(with_serde)]
pub files: HashMap<MochiKey, MochiFile> pub files: HashMap<MochiKey, MochiFile>,
} }
impl Database { impl Database {
@ -22,7 +29,7 @@ impl Database {
let output = Self { let output = Self {
path: path.as_ref().to_path_buf(), path: path.as_ref().to_path_buf(),
files: HashMap::new() files: HashMap::new(),
}; };
encode_into_std_write(&output, &mut file, BINCODE_CFG).expect("Could not write database!"); encode_into_std_write(&output, &mut file, BINCODE_CFG).expect("Could not write database!");
@ -49,9 +56,7 @@ impl Database {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Decode, Encode, Deserialize, Serialize)]
#[derive(Decode, Encode)]
#[derive(Deserialize, Serialize)]
#[serde(crate = "rocket::serde")] #[serde(crate = "rocket::serde")]
pub struct MochiFile { pub struct MochiFile {
/// The original name of the file /// The original name of the file
@ -79,7 +84,7 @@ impl MochiFile {
name: &str, name: &str,
hash: Hash, hash: Hash,
filename: PathBuf, filename: PathBuf,
expire_duration: TimeDelta expire_duration: TimeDelta,
) -> Self { ) -> Self {
let current = Utc::now(); let current = Utc::now();
let expiry = current + expire_duration; let expiry = current + expire_duration;
@ -104,7 +109,7 @@ impl MochiFile {
pub fn get_key(&self) -> MochiKey { pub fn get_key(&self) -> MochiKey {
MochiKey { MochiKey {
name: self.name.clone(), name: self.name.clone(),
hash: self.hash hash: self.hash,
} }
} }
@ -122,9 +127,7 @@ impl MochiFile {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Decode, Encode, Deserialize, Serialize)]
#[derive(Decode, Encode)]
#[derive(Deserialize, Serialize)]
#[serde(crate = "rocket::serde")] #[serde(crate = "rocket::serde")]
pub struct MochiKey { pub struct MochiKey {
name: String, name: String,
@ -136,7 +139,10 @@ pub struct MochiKey {
/// [`chrono::DateTime`]. Also removes files which no longer exist on the disk. /// [`chrono::DateTime`]. Also removes files which no longer exist on the disk.
fn clean_database(db: &Arc<RwLock<Database>>) { fn clean_database(db: &Arc<RwLock<Database>>) {
let mut database = db.write().unwrap(); let mut database = db.write().unwrap();
let files_to_remove: Vec<_> = database.files.iter().filter_map(|e| { let files_to_remove: Vec<_> = database
.files
.iter()
.filter_map(|e| {
if e.1.expired() { if e.1.expired() {
// Check if the entry has expired // Check if the entry has expired
Some((e.0.clone(), e.1.clone())) Some((e.0.clone(), e.1.clone()))
@ -146,7 +152,8 @@ fn clean_database(db: &Arc<RwLock<Database>>) {
} else { } else {
None None
} }
}).collect(); })
.collect();
let mut expired = 0; let mut expired = 0;
let mut missing = 0; let mut missing = 0;
@ -166,7 +173,10 @@ fn clean_database(db: &Arc<RwLock<Database>>) {
database.files.remove(&file.0); database.files.remove(&file.0);
} }
info!("{} expired and {} missing items cleared from database", expired, missing); info!(
"{} expired and {} missing items cleared from database",
expired, missing
);
database.save(); database.save();
} }

View file

@ -1,6 +1,12 @@
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use rocket::{get, http::RawStr, response::{status::NotFound, Redirect}, serde::{self, json::Json}, State}; use rocket::{
get,
http::RawStr,
response::{status::NotFound, Redirect},
serde::{self, json::Json},
State,
};
use serde::Serialize; use serde::Serialize;
use crate::{database::Database, get_id, settings::Settings}; use crate::{database::Database, get_id, settings::Settings};
@ -12,7 +18,13 @@ pub fn server_info(settings: &State<Settings>) -> Json<ServerInfo> {
max_filesize: settings.max_filesize, max_filesize: settings.max_filesize,
max_duration: settings.duration.maximum.num_seconds() as u32, max_duration: settings.duration.maximum.num_seconds() as u32,
default_duration: settings.duration.default.num_seconds() as u32, default_duration: settings.duration.default.num_seconds() as u32,
allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(), allowed_durations: settings
.duration
.allowed
.clone()
.into_iter()
.map(|t| t.num_seconds() as u32)
.collect(),
}) })
} }
@ -29,18 +41,12 @@ pub struct ServerInfo {
/// Look up the hash of a file to find it. This only returns the first /// Look up the hash of a file to find it. This only returns the first
/// hit for a hash, so different filenames may not be found. /// hit for a hash, so different filenames may not be found.
#[get("/f/<id>")] #[get("/f/<id>")]
pub fn lookup( pub fn lookup(db: &State<Arc<RwLock<Database>>>, id: &str) -> Result<Redirect, NotFound<()>> {
db: &State<Arc<RwLock<Database>>>,
id: &str
) -> Result<Redirect, NotFound<()>> {
for file in db.read().unwrap().files.values() { for file in db.read().unwrap().files.values() {
if file.hash().to_hex()[0..10].to_string() == id { if file.hash().to_hex()[0..10].to_string() == id {
let filename = get_id( let filename = get_id(file.name(), *file.hash());
file.name(),
*file.hash()
);
let filename = RawStr::new(&filename).percent_encode().to_string(); let filename = RawStr::new(&filename).percent_encode().to_string();
return Ok(Redirect::to(format!("/files/{}", filename))) return Ok(Redirect::to(format!("/files/{}", filename)));
} }
} }

View file

@ -1,23 +1,35 @@
mod database; mod database;
mod strings;
mod settings;
mod endpoints; mod endpoints;
mod settings;
mod strings;
mod utils; mod utils;
use std::{fs, sync::{Arc, RwLock}}; use std::{
fs,
sync::{Arc, RwLock},
};
use chrono::{DateTime, TimeDelta, Utc}; use chrono::{DateTime, TimeDelta, Utc};
use database::{clean_loop, Database, MochiFile}; use database::{clean_loop, Database, MochiFile};
use endpoints::{lookup, server_info}; use endpoints::{lookup, server_info};
use log::info; use log::info;
use maud::{html, Markup, PreEscaped, DOCTYPE};
use rocket::{ use rocket::{
data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::ContentType, post, response::content::{RawCss, RawJavaScript}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State data::{Limits, ToByteUnit},
form::Form,
fs::{FileServer, Options, TempFile},
get,
http::ContentType,
post,
response::content::{RawCss, RawJavaScript},
routes,
serde::{json::Json, Serialize},
tokio, Config, FromForm, State,
}; };
use settings::Settings; use settings::Settings;
use strings::{parse_time_string, to_pretty_time}; use strings::{parse_time_string, to_pretty_time};
use utils::{get_id, hash_file}; use utils::{get_id, hash_file};
use uuid::Uuid; use uuid::Uuid;
use maud::{html, Markup, DOCTYPE, PreEscaped};
fn head(page_title: &str) -> Markup { fn head(page_title: &str) -> Markup {
html! { html! {
@ -119,20 +131,23 @@ async fn handle_upload(
let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) { let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) {
if t > settings.duration.maximum { if t > settings.duration.maximum {
return Ok(Json(ClientResponse::failure("Duration larger than maximum"))) return Ok(Json(ClientResponse::failure(
"Duration larger than maximum",
)));
} }
if settings.duration.restrict_to_allowed && !settings.duration.allowed.contains(&t) { if settings.duration.restrict_to_allowed && !settings.duration.allowed.contains(&t) {
return Ok(Json(ClientResponse::failure("Duration not allowed"))) return Ok(Json(ClientResponse::failure("Duration not allowed")));
} }
t t
} else { } else {
return Ok(Json(ClientResponse::failure("Duration invalid"))) return Ok(Json(ClientResponse::failure("Duration invalid")));
}; };
// TODO: Properly sanitize this... // TODO: Properly sanitize this...
let raw_name = &*file_data.file let raw_name = &*file_data
.file
.raw_name() .raw_name()
.unwrap() .unwrap()
.dangerous_unsafe_unsanitized_raw() .dangerous_unsafe_unsanitized_raw()
@ -145,21 +160,18 @@ async fn handle_upload(
file_data.file.persist_to(&temp_filename).await?; file_data.file.persist_to(&temp_filename).await?;
let hash = hash_file(&temp_filename).await?; let hash = hash_file(&temp_filename).await?;
let filename = get_id( let filename = get_id(raw_name, hash);
raw_name,
hash
);
out_path.push(filename.clone()); out_path.push(filename.clone());
let constructed_file = MochiFile::new_with_expiry( let constructed_file =
raw_name, MochiFile::new_with_expiry(raw_name, hash, out_path.clone(), expire_time);
hash,
out_path.clone(),
expire_time
);
if !settings.overwrite if !settings.overwrite
&& db.read().unwrap().files.contains_key(&constructed_file.get_key()) && db
.read()
.unwrap()
.files
.contains_key(&constructed_file.get_key())
{ {
info!("Key already in DB, NOT ADDING"); info!("Key already in DB, NOT ADDING");
return Ok(Json(ClientResponse { return Ok(Json(ClientResponse {
@ -169,14 +181,16 @@ async fn handle_upload(
url: filename, url: filename,
hash: hash.to_hex()[0..10].to_string(), hash: hash.to_hex()[0..10].to_string(),
expires: Some(constructed_file.get_expiry()), expires: Some(constructed_file.get_expiry()),
..Default::default() }));
}))
} }
// Move it to the new proper place // Move it to the new proper place
std::fs::rename(temp_filename, out_path)?; std::fs::rename(temp_filename, out_path)?;
db.write().unwrap().files.insert(constructed_file.get_key(), constructed_file.clone()); db.write()
.unwrap()
.files
.insert(constructed_file.get_key(), constructed_file.clone());
db.write().unwrap().save(); db.write().unwrap().save();
Ok(Json(ClientResponse { Ok(Json(ClientResponse {
@ -221,8 +235,7 @@ impl ClientResponse {
#[rocket::main] #[rocket::main]
async fn main() { async fn main() {
// Get or create config file // Get or create config file
let config = Settings::open(&"./settings.toml") let config = Settings::open(&"./settings.toml").expect("Could not open settings file");
.expect("Could not open settings file");
if !config.temp_dir.try_exists().is_ok_and(|e| e) { if !config.temp_dir.try_exists().is_ok_and(|e| e) {
fs::create_dir_all(config.temp_dir.clone()).expect("Failed to create temp directory"); fs::create_dir_all(config.temp_dir.clone()).expect("Failed to create temp directory");
@ -256,11 +269,22 @@ async fn main() {
let rocket = rocket::build() let rocket = rocket::build()
.mount( .mount(
config.server.root_path.clone() + "/", config.server.root_path.clone() + "/",
routes![home, handle_upload, form_handler_js, stylesheet, server_info, favicon, lookup] routes![
home,
handle_upload,
form_handler_js,
stylesheet,
server_info,
favicon,
lookup
],
) )
.mount( .mount(
config.server.root_path.clone() + "/files", config.server.root_path.clone() + "/files",
FileServer::new(config.file_dir.clone(), Options::Missing | Options::NormalizeDirs) FileServer::new(
config.file_dir.clone(),
Options::Missing | Options::NormalizeDirs,
),
) )
.manage(database) .manage(database)
.manage(config) .manage(config)
@ -272,7 +296,10 @@ async fn main() {
rocket.expect("Server failed to shutdown gracefully"); rocket.expect("Server failed to shutdown gracefully");
info!("Stopping database cleaning thread"); info!("Stopping database cleaning thread");
shutdown.send(()).await.expect("Failed to stop cleaner thread"); shutdown
.send(())
.await
.expect("Failed to stop cleaner thread");
info!("Saving database on shutdown..."); info!("Saving database on shutdown...");
local_db.write().unwrap().save(); local_db.write().unwrap().save();

View file

@ -1,9 +1,13 @@
use std::{fs::{self, File}, io::{self, Read, Write}, path::{Path, PathBuf}}; use std::{
fs::{self, File},
io::{self, Read, Write},
path::{Path, PathBuf},
};
use chrono::TimeDelta; use chrono::TimeDelta;
use serde_with::serde_as;
use rocket::serde::{Deserialize, Serialize};
use rocket::data::ToByteUnit; use rocket::data::ToByteUnit;
use rocket::serde::{Deserialize, Serialize};
use serde_with::serde_as;
/// A response to the client from the server /// A response to the client from the server
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
@ -103,7 +107,7 @@ impl Default for ServerSettings {
Self { Self {
address: "127.0.0.1".into(), address: "127.0.0.1".into(),
root_path: "/".into(), root_path: "/".into(),
port: 8950 port: 8950,
} }
} }
} }

View file

@ -4,13 +4,13 @@ use chrono::TimeDelta;
pub fn parse_time_string(string: &str) -> Result<TimeDelta, Box<dyn Error>> { pub fn parse_time_string(string: &str) -> Result<TimeDelta, Box<dyn Error>> {
if string.len() > 7 { if string.len() > 7 {
return Err("Not valid time string".into()) return Err("Not valid time string".into());
} }
let unit = string.chars().last(); let unit = string.chars().last();
let multiplier = if let Some(u) = unit { let multiplier = if let Some(u) = unit {
if !u.is_ascii_alphabetic() { if !u.is_ascii_alphabetic() {
return Err("Not valid time string".into()) return Err("Not valid time string".into());
} }
match u { match u {
@ -21,13 +21,13 @@ pub fn parse_time_string(string: &str) -> Result<TimeDelta, Box<dyn Error>> {
_ => return Err("Not valid time string".into()), _ => return Err("Not valid time string".into()),
} }
} else { } else {
return Err("Not valid time string".into()) return Err("Not valid time string".into());
}; };
let time = if let Ok(n) = string[..string.len() - 1].parse::<i32>() { let time = if let Ok(n) = string[..string.len() - 1].parse::<i32>() {
n n
} else { } else {
return Err("Not valid time string".into()) return Err("Not valid time string".into());
}; };
let final_time = multiplier * time; let final_time = multiplier * time;
@ -41,10 +41,39 @@ pub fn to_pretty_time(seconds: u32) -> String {
let mins = ((seconds as f32 - (hour * 3600.0) - (days * 86400.0)) / 60.0).floor(); let mins = ((seconds as f32 - (hour * 3600.0) - (days * 86400.0)) / 60.0).floor();
let secs = seconds as f32 - (hour * 3600.0) - (mins * 60.0) - (days * 86400.0); let secs = seconds as f32 - (hour * 3600.0) - (mins * 60.0) - (days * 86400.0);
let days = if days == 0.0 {"".to_string()} else if days == 1.0 {days.to_string() + "<br>day"} else {days.to_string() + "<br>days"}; let days = if days == 0.0 {
let hour = if hour == 0.0 {"".to_string()} else if hour == 1.0 {hour.to_string() + "<br>hour"} else {hour.to_string() + "<br>hours"}; "".to_string()
let mins = if mins == 0.0 {"".to_string()} else if mins == 1.0 {mins.to_string() + "<br>minute"} else {mins.to_string() + "<br>minutes"}; } else if days == 1.0 {
let secs = if secs == 0.0 {"".to_string()} else if secs == 1.0 {secs.to_string() + "<br>second"} else {secs.to_string() + "<br>seconds"}; days.to_string() + "<br>day"
} else {
days.to_string() + "<br>days"
};
(days + " " + &hour + " " + &mins + " " + &secs).trim().to_string() let hour = if hour == 0.0 {
"".to_string()
} else if hour == 1.0 {
hour.to_string() + "<br>hour"
} else {
hour.to_string() + "<br>hours"
};
let mins = if mins == 0.0 {
"".to_string()
} else if mins == 1.0 {
mins.to_string() + "<br>minute"
} else {
mins.to_string() + "<br>minutes"
};
let secs = if secs == 0.0 {
"".to_string()
} else if secs == 1.0 {
secs.to_string() + "<br>second"
} else {
secs.to_string() + "<br>seconds"
};
(days + " " + &hour + " " + &mins + " " + &secs)
.trim()
.to_string()
} }

View file

@ -1,5 +1,5 @@
use std::path::Path;
use blake3::Hash; use blake3::Hash;
use std::path::Path;
/// Get a filename based on the file's hashed name /// Get a filename based on the file's hashed name
pub fn get_id(name: &str, hash: Hash) -> String { pub fn get_id(name: &str, hash: Hash) -> String {