Break out main into more files for organization

This commit is contained in:
G2-Games 2024-10-24 23:29:53 -05:00
parent 2fac064c38
commit 46eaf5b4fd
3 changed files with 70 additions and 58 deletions

48
src/endpoints.rs Normal file
View file

@ -0,0 +1,48 @@
use std::sync::{Arc, RwLock};
use rocket::{get, http::RawStr, response::{status::NotFound, Redirect}, serde::{self, json::Json}, State};
use serde::Serialize;
use crate::{database::Database, get_id, settings::Settings};
/// An endpoint to obtain information about the server's capabilities
#[get("/info")]
pub fn server_info(settings: &State<Settings>) -> Json<ServerInfo> {
Json(ServerInfo {
max_filesize: settings.max_filesize,
max_duration: settings.duration.maximum.num_seconds() as u32,
default_duration: settings.duration.default.num_seconds() as u32,
allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(),
})
}
#[derive(Serialize, Debug)]
#[serde(crate = "rocket::serde")]
pub struct ServerInfo {
max_filesize: u64,
max_duration: u32,
default_duration: u32,
#[serde(skip_serializing_if = "Vec::is_empty")]
allowed_durations: Vec<u32>,
}
/// Look up the hash of a file to find it. This only returns the first
/// hit for a hash, so different filenames may not be found.
#[get("/f/<id>")]
pub fn lookup(
db: &State<Arc<RwLock<Database>>>,
id: &str
) -> Result<Redirect, NotFound<()>> {
for file in db.read().unwrap().files.values() {
if file.hash().to_hex()[0..10].to_string() == id {
let filename = get_id(
file.name(),
*file.hash()
);
let filename = RawStr::new(&filename).percent_encode().to_string();
return Ok(Redirect::to(format!("/files/{}", filename)))
}
}
Err(NotFound(()))
}

View file

@ -1,17 +1,21 @@
mod database;
mod strings;
mod settings;
mod endpoints;
mod utils;
use std::{fs, sync::{Arc, RwLock}};
use std::{fs, path::Path, sync::{Arc, RwLock}};
use blake3::Hash;
use chrono::{DateTime, TimeDelta, Utc};
use database::{clean_loop, Database, MochiFile};
use endpoints::{lookup, server_info};
use log::info;
use rocket::{
data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::{ContentType, RawStr}, post, response::{content::{RawCss, RawJavaScript}, status::NotFound, Redirect}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State
data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::ContentType, post, response::content::{RawCss, RawJavaScript}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State
};
use settings::Settings;
use strings::{parse_time_string, to_pretty_time};
use utils::{get_id, hash_file};
use uuid::Uuid;
use maud::{html, Markup, DOCTYPE, PreEscaped};
@ -216,61 +220,6 @@ struct ClientResponse {
pub expires: Option<DateTime<Utc>>,
}
/// Get a filename based on the file's hashed name
fn get_id(name: &str, hash: Hash) -> String {
hash.to_hex()[0..10].to_string() + "_" + name
}
/// Get the Blake3 hash of a file, without reading it all into memory, and also get the size
async fn hash_file<P: AsRef<Path>>(input: &P) -> Result<Hash, std::io::Error> {
let mut hasher = blake3::Hasher::new();
hasher.update_mmap_rayon(input)?;
Ok(hasher.finalize())
}
/// An endpoint to obtain information about the server's capabilities
#[get("/info")]
fn server_info(settings: &State<Settings>) -> Json<ServerInfo> {
Json(ServerInfo {
max_filesize: settings.max_filesize,
max_duration: settings.duration.maximum.num_seconds() as u32,
default_duration: settings.duration.default.num_seconds() as u32,
allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(),
})
}
#[derive(Serialize, Debug)]
#[serde(crate = "rocket::serde")]
struct ServerInfo {
max_filesize: u64,
max_duration: u32,
default_duration: u32,
#[serde(skip_serializing_if = "Vec::is_empty")]
allowed_durations: Vec<u32>,
}
/// Look up the hash of a file to find it. This only returns the first
/// hit for a hash, so different filenames may not be found.
#[get("/f/<id>")]
fn lookup(
db: &State<Arc<RwLock<Database>>>,
id: &str
) -> Result<Redirect, NotFound<()>> {
for file in db.read().unwrap().files.values() {
if file.hash().to_hex()[0..10].to_string() == id {
let filename = get_id(
file.name(),
*file.hash()
);
let filename = RawStr::new(&filename).percent_encode().to_string();
return Ok(Redirect::to(format!("/files/{}", filename)))
}
}
Err(NotFound(()))
}
#[rocket::main]
async fn main() {
// Get or create config file

15
src/utils.rs Normal file
View file

@ -0,0 +1,15 @@
use std::path::Path;
use blake3::Hash;
/// Get a filename based on the file's hashed name
pub fn get_id(name: &str, hash: Hash) -> String {
hash.to_hex()[0..10].to_string() + "_" + name
}
/// Get the Blake3 hash of a file, without reading it all into memory, and also get the size
pub async fn hash_file<P: AsRef<Path>>(input: &P) -> Result<Hash, std::io::Error> {
let mut hasher = blake3::Hasher::new();
hasher.update_mmap_rayon(input)?;
Ok(hasher.finalize())
}