diff --git a/src/database.rs b/src/database.rs index 716aac6..0cc168d 100644 --- a/src/database.rs +++ b/src/database.rs @@ -13,7 +13,7 @@ use ciborium::{from_reader, into_writer}; use log::{error, info, warn}; use rand::distributions::{Alphanumeric, DistString}; use rocket::{ - form::{self, FromFormField, ValueField}, serde::{Deserialize, Serialize}, tokio::{select, sync::mpsc::Receiver, time} + form::{self, FromFormField, ValueField}, serde::{Deserialize, Serialize} }; use serde_with::{serde_as, DisplayFromStr}; use uuid::Uuid; @@ -224,7 +224,7 @@ impl MochiFile { /// Clean the database. Removes files which are past their expiry /// [`chrono::DateTime`]. Also removes files which no longer exist on the disk. -fn clean_database(db: &Arc>, file_path: &Path) { +pub fn clean_database(db: &Arc>, file_path: &Path) { let mut database = db.write().unwrap(); // Add expired entries to the removal list @@ -263,23 +263,6 @@ fn clean_database(db: &Arc>, file_path: &Path) { drop(database); // Just to be sure } -/// A loop to clean the database periodically. -pub async fn clean_loop( - db: Arc>, - file_path: PathBuf, - mut shutdown_signal: Receiver<()>, - interval: TimeDelta, -) { - let mut interval = time::interval(interval.to_std().unwrap()); - - loop { - select! { - _ = interval.tick() => clean_database(&db, &file_path), - _ = shutdown_signal.recv() => break, - }; - } -} - /// A unique identifier for an entry in the database, 8 characters long, /// consists of ASCII alphanumeric characters (`a-z`, `A-Z`, and `0-9`). #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -358,28 +341,42 @@ impl<'r> FromFormField<'r> for Mmid { /// An in-memory database for partially uploaded chunks of files #[derive(Default, Debug)] pub struct Chunkbase { - chunks: HashMap, + chunks: HashMap, ChunkedInfo)>, } impl Chunkbase { - pub fn chunks(&self) -> &HashMap { + pub fn chunks(&self) -> &HashMap, ChunkedInfo)> { &self.chunks } - pub fn mut_chunks(&mut self) -> &mut HashMap { + pub fn mut_chunks(&mut self) -> &mut HashMap, ChunkedInfo)> { &mut self.chunks } /// Delete all temporary chunk files pub fn delete_all(&mut self) -> Result<(), io::Error> { - for chunk in &self.chunks { - fs::remove_file(&chunk.1.path)?; + for (_timeout, chunk) in self.chunks.values() { + fs::remove_file(&chunk.path)?; } self.chunks.clear(); Ok(()) } + + pub fn delete_timed_out(&mut self) -> Result<(), io::Error> { + let now = Utc::now(); + self.mut_chunks().retain(|_u, (t, c)| { + if *t <= now { + let _ = fs::remove_file(&c.path); + false + } else { + true + } + }); + + Ok(()) + } } /// Information about how to manage partially uploaded chunks of files diff --git a/src/lib.rs b/src/lib.rs index 948ba7b..19247f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use crate::{ settings::Settings, strings::to_pretty_time, }; -use chrono::Utc; +use chrono::{TimeDelta, Utc}; use database::{Chunkbase, ChunkedInfo, Mmid, MochiFile, Mochibase}; use maud::{html, Markup, PreEscaped}; use rocket::{ @@ -118,7 +118,7 @@ pub async fn chunked_upload_start( db.write() .unwrap() .mut_chunks() - .insert(uuid, file_info.into_inner()); + .insert(uuid, (Utc::now() + TimeDelta::seconds(30), file_info.into_inner())); Ok(Json(ChunkedResponse { status: true, @@ -136,7 +136,7 @@ pub async fn chunked_upload_continue( uuid: &str, offset: u64, ) -> Result<(), io::Error> { - let uuid = Uuid::parse_str(&uuid).map_err(|e| io::Error::other(e))?; + let uuid = Uuid::parse_str(uuid).map_err(io::Error::other)?; let data_stream = data.open((settings.chunk_size + 100).bytes()); let chunked_info = match chunk_db.read().unwrap().chunks().get(&uuid) { @@ -148,19 +148,19 @@ pub async fn chunked_upload_continue( .read(true) .write(true) .truncate(false) - .open(&chunked_info.path) + .open(&chunked_info.1.path) .await?; - if offset > chunked_info.size { + if offset > chunked_info.1.size { return Err(io::Error::new(ErrorKind::InvalidInput, "The seek position is larger than the file size")) } file.seek(io::SeekFrom::Start(offset)).await?; - data_stream.stream_to(&mut file).await?.written; + data_stream.stream_to(&mut file).await?; file.flush().await?; let position = file.stream_position().await?; - if position > chunked_info.size { + if position > chunked_info.1.size { chunk_db.write() .unwrap() .mut_chunks() @@ -180,7 +180,7 @@ pub async fn chunked_upload_finish( uuid: &str, ) -> Result, io::Error> { let now = Utc::now(); - let uuid = Uuid::parse_str(&uuid).map_err(|e| io::Error::other(e))?; + let uuid = Uuid::parse_str(uuid).map_err(io::Error::other)?; let chunked_info = match chunk_db.read().unwrap().chunks().get(&uuid) { Some(s) => s.clone(), None => return Err(io::Error::other("Invalid UUID")), @@ -193,22 +193,22 @@ pub async fn chunked_upload_finish( .remove(&uuid) .unwrap(); - if !chunked_info.path.try_exists().is_ok_and(|e| e) { + if !chunked_info.1.path.try_exists().is_ok_and(|e| e) { return Err(io::Error::other("File does not exist")) } // Get file hash let mut hasher = blake3::Hasher::new(); - hasher.update_mmap_rayon(&chunked_info.path).unwrap(); + hasher.update_mmap_rayon(&chunked_info.1.path).unwrap(); let hash = hasher.finalize(); let new_filename = settings.file_dir.join(hash.to_string()); // If the hash does not exist in the database, // move the file to the backend, else, delete it if main_db.read().unwrap().get_hash(&hash).is_none() { - std::fs::rename(&chunked_info.path, &new_filename).unwrap(); + std::fs::rename(&chunked_info.1.path, &new_filename).unwrap(); } else { - std::fs::remove_file(&chunked_info.path).unwrap(); + std::fs::remove_file(&chunked_info.1.path).unwrap(); } let mmid = Mmid::new_random(); @@ -216,11 +216,11 @@ pub async fn chunked_upload_finish( let constructed_file = MochiFile::new( mmid.clone(), - chunked_info.name, + chunked_info.1.name, file_type.media_type().to_string(), hash, now, - now + chunked_info.expire_duration + now + chunked_info.1.expire_duration ); main_db.write().unwrap().insert(&mmid, constructed_file.clone()); diff --git a/src/main.rs b/src/main.rs index 99af0e4..fb4151e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,15 @@ use std::{ - fs, - sync::{Arc, RwLock}, + fs, path::PathBuf, sync::{Arc, RwLock} }; use chrono::TimeDelta; use confetti_box::{ - database::{clean_loop, Chunkbase, Mochibase}, + database::{clean_database, Chunkbase, Mochibase}, endpoints, pages, resources, settings::Settings, }; use log::info; -use rocket::{data::ToByteUnit as _, routes, tokio}; +use rocket::{data::ToByteUnit as _, routes, tokio::{self, select, sync::broadcast::Receiver, time}}; #[rocket::main] async fn main() { @@ -45,12 +44,18 @@ async fn main() { let local_db = database.clone(); let local_chunk = chunkbase.clone(); - // Start monitoring thread, cleaning the database every 2 minutes - let (shutdown, rx) = tokio::sync::mpsc::channel(1); + let (shutdown, rx) = tokio::sync::broadcast::channel(1); + // Clean the database every 2 minutes tokio::spawn({ let cleaner_db = database.clone(); let file_path = config.file_dir.clone(); - async move { clean_loop(cleaner_db, file_path, rx, TimeDelta::minutes(2)).await } + async move { clean_loop(cleaner_db, file_path, rx).await } + }); + tokio::spawn({ + let cleaner_db = database.clone(); + let file_path = config.file_dir.clone(); + let rx2 = shutdown.subscribe(); + async move { clean_loop(cleaner_db, file_path, rx2).await } }); let rocket = rocket::build() @@ -92,7 +97,6 @@ async fn main() { info!("Stopping database cleaning thread..."); shutdown .send(()) - .await .expect("Failed to stop cleaner thread."); info!("Stopping database cleaning thread completed successfully."); @@ -112,3 +116,32 @@ async fn main() { .expect("Failed to delete chunks"); info!("Deleting chunk data completed successfully."); } + +/// A loop to clean the database periodically. +pub async fn clean_loop( + main_db: Arc>, + file_path: PathBuf, + mut shutdown_signal: Receiver<()>, +) { + let mut interval = time::interval(TimeDelta::minutes(2).to_std().unwrap()); + loop { + select! { + _ = interval.tick() => clean_database(&main_db, &file_path), + _ = shutdown_signal.recv() => break, + }; + } +} + +pub async fn clean_chunks( + chunk_db: Arc>, + mut shutdown_signal: Receiver<()>, +) { + let mut interval = time::interval(TimeDelta::seconds(30).to_std().unwrap()); + + loop { + select! { + _ = interval.tick() => {let _ = chunk_db.write().unwrap().delete_timed_out();}, + _ = shutdown_signal.recv() => break, + }; + } +} diff --git a/src/settings.rs b/src/settings.rs index a02d810..2f3592a 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -87,11 +87,11 @@ impl Settings { pub fn save(&self) -> Result<(), io::Error> { let out_path = &self.path.with_extension("new"); - let mut file = File::create(&out_path)?; + let mut file = File::create(out_path)?; file.write_all(&toml::to_string_pretty(self).unwrap().into_bytes())?; // Overwrite the original DB with - fs::rename(&out_path, &self.path).unwrap(); + fs::rename(out_path, &self.path).unwrap(); Ok(()) }