Ran cargo fmt

This commit is contained in:
G2-Games 2024-10-27 02:03:55 -05:00
parent 159300a0d0
commit e05f0373b3
6 changed files with 159 additions and 83 deletions

View file

@ -1,19 +1,26 @@
use std::{collections::HashMap, fs::{self, File}, path::{Path, PathBuf}, sync::{Arc, RwLock}};
use std::{
collections::HashMap,
fs::{self, File},
path::{Path, PathBuf},
sync::{Arc, RwLock},
};
use bincode::{config::Configuration, decode_from_std_read, encode_into_std_write, Decode, Encode};
use chrono::{DateTime, TimeDelta, Utc};
use blake3::Hash;
use chrono::{DateTime, TimeDelta, Utc};
use log::{info, warn};
use rocket::{serde::{Deserialize, Serialize}, tokio::{select, sync::mpsc::Receiver, time}};
use rocket::{
serde::{Deserialize, Serialize},
tokio::{select, sync::mpsc::Receiver, time},
};
const BINCODE_CFG: Configuration = bincode::config::standard();
#[derive(Debug, Clone)]
#[derive(Decode, Encode)]
#[derive(Debug, Clone, Decode, Encode)]
pub struct Database {
path: PathBuf,
#[bincode(with_serde)]
pub files: HashMap<MochiKey, MochiFile>
pub files: HashMap<MochiKey, MochiFile>,
}
impl Database {
@ -22,7 +29,7 @@ impl Database {
let output = Self {
path: path.as_ref().to_path_buf(),
files: HashMap::new()
files: HashMap::new(),
};
encode_into_std_write(&output, &mut file, BINCODE_CFG).expect("Could not write database!");
@ -49,9 +56,7 @@ impl Database {
}
}
#[derive(Debug, Clone)]
#[derive(Decode, Encode)]
#[derive(Deserialize, Serialize)]
#[derive(Debug, Clone, Decode, Encode, Deserialize, Serialize)]
#[serde(crate = "rocket::serde")]
pub struct MochiFile {
/// The original name of the file
@ -79,7 +84,7 @@ impl MochiFile {
name: &str,
hash: Hash,
filename: PathBuf,
expire_duration: TimeDelta
expire_duration: TimeDelta,
) -> Self {
let current = Utc::now();
let expiry = current + expire_duration;
@ -104,7 +109,7 @@ impl MochiFile {
pub fn get_key(&self) -> MochiKey {
MochiKey {
name: self.name.clone(),
hash: self.hash
hash: self.hash,
}
}
@ -122,9 +127,7 @@ impl MochiFile {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Decode, Encode)]
#[derive(Deserialize, Serialize)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Decode, Encode, Deserialize, Serialize)]
#[serde(crate = "rocket::serde")]
pub struct MochiKey {
name: String,
@ -136,17 +139,21 @@ pub struct MochiKey {
/// [`chrono::DateTime`]. Also removes files which no longer exist on the disk.
fn clean_database(db: &Arc<RwLock<Database>>) {
let mut database = db.write().unwrap();
let files_to_remove: Vec<_> = database.files.iter().filter_map(|e| {
if e.1.expired() {
// Check if the entry has expired
Some((e.0.clone(), e.1.clone()))
} else if !e.1.path().try_exists().is_ok_and(|r| r) {
// Check if the entry exists
Some((e.0.clone(), e.1.clone()))
} else {
None
}
}).collect();
let files_to_remove: Vec<_> = database
.files
.iter()
.filter_map(|e| {
if e.1.expired() {
// Check if the entry has expired
Some((e.0.clone(), e.1.clone()))
} else if !e.1.path().try_exists().is_ok_and(|r| r) {
// Check if the entry exists
Some((e.0.clone(), e.1.clone()))
} else {
None
}
})
.collect();
let mut expired = 0;
let mut missing = 0;
@ -166,7 +173,10 @@ fn clean_database(db: &Arc<RwLock<Database>>) {
database.files.remove(&file.0);
}
info!("{} expired and {} missing items cleared from database", expired, missing);
info!(
"{} expired and {} missing items cleared from database",
expired, missing
);
database.save();
}

View file

@ -1,6 +1,12 @@
use std::sync::{Arc, RwLock};
use rocket::{get, http::RawStr, response::{status::NotFound, Redirect}, serde::{self, json::Json}, State};
use rocket::{
get,
http::RawStr,
response::{status::NotFound, Redirect},
serde::{self, json::Json},
State,
};
use serde::Serialize;
use crate::{database::Database, get_id, settings::Settings};
@ -12,7 +18,13 @@ pub fn server_info(settings: &State<Settings>) -> Json<ServerInfo> {
max_filesize: settings.max_filesize,
max_duration: settings.duration.maximum.num_seconds() as u32,
default_duration: settings.duration.default.num_seconds() as u32,
allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(),
allowed_durations: settings
.duration
.allowed
.clone()
.into_iter()
.map(|t| t.num_seconds() as u32)
.collect(),
})
}
@ -29,18 +41,12 @@ pub struct ServerInfo {
/// Look up the hash of a file to find it. This only returns the first
/// hit for a hash, so different filenames may not be found.
#[get("/f/<id>")]
pub fn lookup(
db: &State<Arc<RwLock<Database>>>,
id: &str
) -> Result<Redirect, NotFound<()>> {
pub fn lookup(db: &State<Arc<RwLock<Database>>>, id: &str) -> Result<Redirect, NotFound<()>> {
for file in db.read().unwrap().files.values() {
if file.hash().to_hex()[0..10].to_string() == id {
let filename = get_id(
file.name(),
*file.hash()
);
let filename = get_id(file.name(), *file.hash());
let filename = RawStr::new(&filename).percent_encode().to_string();
return Ok(Redirect::to(format!("/files/{}", filename)))
return Ok(Redirect::to(format!("/files/{}", filename)));
}
}

View file

@ -1,23 +1,35 @@
mod database;
mod strings;
mod settings;
mod endpoints;
mod settings;
mod strings;
mod utils;
use std::{fs, sync::{Arc, RwLock}};
use std::{
fs,
sync::{Arc, RwLock},
};
use chrono::{DateTime, TimeDelta, Utc};
use database::{clean_loop, Database, MochiFile};
use endpoints::{lookup, server_info};
use log::info;
use maud::{html, Markup, PreEscaped, DOCTYPE};
use rocket::{
data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::ContentType, post, response::content::{RawCss, RawJavaScript}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State
data::{Limits, ToByteUnit},
form::Form,
fs::{FileServer, Options, TempFile},
get,
http::ContentType,
post,
response::content::{RawCss, RawJavaScript},
routes,
serde::{json::Json, Serialize},
tokio, Config, FromForm, State,
};
use settings::Settings;
use strings::{parse_time_string, to_pretty_time};
use utils::{get_id, hash_file};
use uuid::Uuid;
use maud::{html, Markup, DOCTYPE, PreEscaped};
fn head(page_title: &str) -> Markup {
html! {
@ -119,20 +131,23 @@ async fn handle_upload(
let expire_time = if let Ok(t) = parse_time_string(&file_data.expire_time) {
if t > settings.duration.maximum {
return Ok(Json(ClientResponse::failure("Duration larger than maximum")))
return Ok(Json(ClientResponse::failure(
"Duration larger than maximum",
)));
}
if settings.duration.restrict_to_allowed && !settings.duration.allowed.contains(&t) {
return Ok(Json(ClientResponse::failure("Duration not allowed")))
return Ok(Json(ClientResponse::failure("Duration not allowed")));
}
t
} else {
return Ok(Json(ClientResponse::failure("Duration invalid")))
return Ok(Json(ClientResponse::failure("Duration invalid")));
};
// TODO: Properly sanitize this...
let raw_name = &*file_data.file
let raw_name = &*file_data
.file
.raw_name()
.unwrap()
.dangerous_unsafe_unsanitized_raw()
@ -145,21 +160,18 @@ async fn handle_upload(
file_data.file.persist_to(&temp_filename).await?;
let hash = hash_file(&temp_filename).await?;
let filename = get_id(
raw_name,
hash
);
let filename = get_id(raw_name, hash);
out_path.push(filename.clone());
let constructed_file = MochiFile::new_with_expiry(
raw_name,
hash,
out_path.clone(),
expire_time
);
let constructed_file =
MochiFile::new_with_expiry(raw_name, hash, out_path.clone(), expire_time);
if !settings.overwrite
&& db.read().unwrap().files.contains_key(&constructed_file.get_key())
&& db
.read()
.unwrap()
.files
.contains_key(&constructed_file.get_key())
{
info!("Key already in DB, NOT ADDING");
return Ok(Json(ClientResponse {
@ -169,14 +181,16 @@ async fn handle_upload(
url: filename,
hash: hash.to_hex()[0..10].to_string(),
expires: Some(constructed_file.get_expiry()),
..Default::default()
}))
}));
}
// Move it to the new proper place
std::fs::rename(temp_filename, out_path)?;
db.write().unwrap().files.insert(constructed_file.get_key(), constructed_file.clone());
db.write()
.unwrap()
.files
.insert(constructed_file.get_key(), constructed_file.clone());
db.write().unwrap().save();
Ok(Json(ClientResponse {
@ -221,8 +235,7 @@ impl ClientResponse {
#[rocket::main]
async fn main() {
// Get or create config file
let config = Settings::open(&"./settings.toml")
.expect("Could not open settings file");
let config = Settings::open(&"./settings.toml").expect("Could not open settings file");
if !config.temp_dir.try_exists().is_ok_and(|e| e) {
fs::create_dir_all(config.temp_dir.clone()).expect("Failed to create temp directory");
@ -256,11 +269,22 @@ async fn main() {
let rocket = rocket::build()
.mount(
config.server.root_path.clone() + "/",
routes![home, handle_upload, form_handler_js, stylesheet, server_info, favicon, lookup]
routes![
home,
handle_upload,
form_handler_js,
stylesheet,
server_info,
favicon,
lookup
],
)
.mount(
config.server.root_path.clone() + "/files",
FileServer::new(config.file_dir.clone(), Options::Missing | Options::NormalizeDirs)
FileServer::new(
config.file_dir.clone(),
Options::Missing | Options::NormalizeDirs,
),
)
.manage(database)
.manage(config)
@ -272,7 +296,10 @@ async fn main() {
rocket.expect("Server failed to shutdown gracefully");
info!("Stopping database cleaning thread");
shutdown.send(()).await.expect("Failed to stop cleaner thread");
shutdown
.send(())
.await
.expect("Failed to stop cleaner thread");
info!("Saving database on shutdown...");
local_db.write().unwrap().save();

View file

@ -1,9 +1,13 @@
use std::{fs::{self, File}, io::{self, Read, Write}, path::{Path, PathBuf}};
use std::{
fs::{self, File},
io::{self, Read, Write},
path::{Path, PathBuf},
};
use chrono::TimeDelta;
use serde_with::serde_as;
use rocket::serde::{Deserialize, Serialize};
use rocket::data::ToByteUnit;
use rocket::serde::{Deserialize, Serialize};
use serde_with::serde_as;
/// A response to the client from the server
#[derive(Deserialize, Serialize, Debug)]
@ -44,7 +48,7 @@ pub struct Settings {
impl Default for Settings {
fn default() -> Self {
Self {
max_filesize: 1.megabytes().into(), // 128 MB
max_filesize: 1.megabytes().into(), // 128 MB
overwrite: true,
duration: DurationSettings::default(),
server: ServerSettings::default(),
@ -103,7 +107,7 @@ impl Default for ServerSettings {
Self {
address: "127.0.0.1".into(),
root_path: "/".into(),
port: 8950
port: 8950,
}
}
}
@ -134,8 +138,8 @@ pub struct DurationSettings {
impl Default for DurationSettings {
fn default() -> Self {
Self {
maximum: TimeDelta::days(3), // 72 hours
default: TimeDelta::hours(6), // 6 hours
maximum: TimeDelta::days(3), // 72 hours
default: TimeDelta::hours(6), // 6 hours
// 1 hour, 6 hours, 24 hours, and 48 hours
allowed: vec![
TimeDelta::hours(1),

View file

@ -4,13 +4,13 @@ use chrono::TimeDelta;
pub fn parse_time_string(string: &str) -> Result<TimeDelta, Box<dyn Error>> {
if string.len() > 7 {
return Err("Not valid time string".into())
return Err("Not valid time string".into());
}
let unit = string.chars().last();
let multiplier = if let Some(u) = unit {
if !u.is_ascii_alphabetic() {
return Err("Not valid time string".into())
return Err("Not valid time string".into());
}
match u {
@ -21,13 +21,13 @@ pub fn parse_time_string(string: &str) -> Result<TimeDelta, Box<dyn Error>> {
_ => return Err("Not valid time string".into()),
}
} else {
return Err("Not valid time string".into())
return Err("Not valid time string".into());
};
let time = if let Ok(n) = string[..string.len() - 1].parse::<i32>() {
n
} else {
return Err("Not valid time string".into())
return Err("Not valid time string".into());
};
let final_time = multiplier * time;
@ -41,10 +41,39 @@ pub fn to_pretty_time(seconds: u32) -> String {
let mins = ((seconds as f32 - (hour * 3600.0) - (days * 86400.0)) / 60.0).floor();
let secs = seconds as f32 - (hour * 3600.0) - (mins * 60.0) - (days * 86400.0);
let days = if days == 0.0 {"".to_string()} else if days == 1.0 {days.to_string() + "<br>day"} else {days.to_string() + "<br>days"};
let hour = if hour == 0.0 {"".to_string()} else if hour == 1.0 {hour.to_string() + "<br>hour"} else {hour.to_string() + "<br>hours"};
let mins = if mins == 0.0 {"".to_string()} else if mins == 1.0 {mins.to_string() + "<br>minute"} else {mins.to_string() + "<br>minutes"};
let secs = if secs == 0.0 {"".to_string()} else if secs == 1.0 {secs.to_string() + "<br>second"} else {secs.to_string() + "<br>seconds"};
let days = if days == 0.0 {
"".to_string()
} else if days == 1.0 {
days.to_string() + "<br>day"
} else {
days.to_string() + "<br>days"
};
(days + " " + &hour + " " + &mins + " " + &secs).trim().to_string()
let hour = if hour == 0.0 {
"".to_string()
} else if hour == 1.0 {
hour.to_string() + "<br>hour"
} else {
hour.to_string() + "<br>hours"
};
let mins = if mins == 0.0 {
"".to_string()
} else if mins == 1.0 {
mins.to_string() + "<br>minute"
} else {
mins.to_string() + "<br>minutes"
};
let secs = if secs == 0.0 {
"".to_string()
} else if secs == 1.0 {
secs.to_string() + "<br>second"
} else {
secs.to_string() + "<br>seconds"
};
(days + " " + &hour + " " + &mins + " " + &secs)
.trim()
.to_string()
}

View file

@ -1,5 +1,5 @@
use std::path::Path;
use blake3::Hash;
use std::path::Path;
/// Get a filename based on the file's hashed name
pub fn get_id(name: &str, hash: Hash) -> String {