From 442619270023595edd7c01a4ece547089f7ee015 Mon Sep 17 00:00:00 2001 From: G2-Games Date: Sat, 22 Mar 2025 17:11:47 -0500 Subject: [PATCH] Changed imu to use websockets --- confetti-cli/Cargo.toml | 4 + confetti-cli/src/main.rs | 208 +++++++++++++++++++++------------------ 2 files changed, 116 insertions(+), 96 deletions(-) diff --git a/confetti-cli/Cargo.toml b/confetti-cli/Cargo.toml index 092c9be..a3579bd 100644 --- a/confetti-cli/Cargo.toml +++ b/confetti-cli/Cargo.toml @@ -18,9 +18,11 @@ workspace = true [dependencies] anyhow = "1.0" +base64 = "0.22.1" chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.5", features = ["derive", "unicode"] } directories = "6.0" +futures-util = "0.3.31" indicatif = { version = "0.17", features = ["improved_unicode"] } owo-colors = { version = "4.1", features = ["supports-colors"] } reqwest = { version = "0.12", features = ["json", "stream"] } @@ -28,6 +30,8 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0" tokio = { version = "1.41", features = ["fs", "macros", "rt-multi-thread"] } +tokio-tungstenite = { version = "0.26.2", features = ["native-tls"] } tokio-util = { version = "0.7", features = ["codec"] } toml = "0.8" +url = { version = "2.5.4", features = ["serde"] } uuid = { version = "1.11", features = ["serde", "v4"] } diff --git a/confetti-cli/src/main.rs b/confetti-cli/src/main.rs index a61cd3f..8036b91 100644 --- a/confetti-cli/src/main.rs +++ b/confetti-cli/src/main.rs @@ -1,13 +1,17 @@ use std::{error::Error, fs, io::{self, Read, Write}, path::{Path, PathBuf}}; +use base64::{prelude::BASE64_URL_SAFE, Engine}; use chrono::{DateTime, Datelike, Local, Month, TimeDelta, Timelike, Utc}; +use futures_util::{stream::FusedStream as _, SinkExt as _, StreamExt as _}; use indicatif::{ProgressBar, ProgressStyle}; use owo_colors::OwoColorize; use reqwest::Client; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}, task::JoinSet}; +use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}, join, task::JoinSet}; +use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest as _, Message}}; +use url::Url; use uuid::Uuid; use clap::{arg, builder::{styling::RgbColor, Styles}, Parser, Subcommand}; use anyhow::{anyhow, bail, Context as _, Result}; @@ -82,17 +86,16 @@ async fn main() -> Result<()> { match &cli.command { Commands::Upload { files, duration } => { - if config.url.is_empty() { + let Some(url) = config.url.clone() else { exit_error( format!("URL is empty"), Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())), None, ); - } + }; get_info_if_expired(&mut config).await?; - let client = Client::new(); let duration = match parse_time_string(&duration) { Ok(d) => d, Err(e) => return Err(anyhow!("Invalid duration: {e}")), @@ -125,8 +128,7 @@ async fn main() -> Result<()> { let response = upload_file( name.into_owned(), &path, - &client, - &config.url, + &url, duration, &config.login ).await.with_context(|| "Failed to upload").unwrap(); @@ -141,11 +143,19 @@ async fn main() -> Result<()> { println!( "{:>8} {}, {} (in {})\n{:>8} {}", "Expires:".truecolor(174,196,223).bold(), date, time, pretty_time_long(duration.num_seconds()), - "URL:".truecolor(174,196,223).bold(), (config.url.clone() + "/f/" + &response.mmid.0).underline() + "URL:".truecolor(174,196,223).bold(), (url.to_string() + "/f/" + &response.mmid.0).underline() ); } } Commands::Download { mmids, out_directory } => { + let Some(url) = config.url else { + exit_error( + format!("URL is empty"), + Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())), + None, + ); + }; + let out_directory = if let Some(dir) = out_directory { dir } else { @@ -167,7 +177,6 @@ async fn main() -> Result<()> { } }; - let url = &config.url; for mmid in mmids { let mmid = if mmid.len() != 8 { if mmid.contains(format!("{url}/f/").as_str()) { @@ -203,10 +212,10 @@ async fn main() -> Result<()> { }; let mut file_res = if let Some(login) = &config.login { - client.get(format!("{}/f/{mmid}", config.url)) + client.get(format!("{}/f/{mmid}", url)) .basic_auth(&login.user, Some(&login.pass)) } else { - client.get(format!("{}/f/{mmid}", config.url)) + client.get(format!("{}/f/{mmid}", url)) } .send() .await @@ -305,11 +314,13 @@ async fn main() -> Result<()> { url }; - if !url.starts_with("https://") && !url.starts_with("http://") { - config.url = ("https://".to_owned() + url).to_string(); + let new_url = if !url.starts_with("https://") && !url.starts_with("http://") { + ("https://".to_owned() + url).to_string() } else { - config.url = url.to_string(); - } + url.to_string() + }; + + config.url = Some(Url::parse(&new_url)?); config.save().unwrap(); println!("URL set to \"{url}\""); @@ -356,7 +367,7 @@ async fn main() -> Result<()> { #[derive(Error, Debug)] enum UploadError { #[error("request provided was invalid: {0}")] - InvalidRequest(String), + WebSocketFailed(String), #[error("error on reqwest transaction: {0}")] Reqwest(#[from] reqwest::Error), @@ -365,104 +376,101 @@ enum UploadError { async fn upload_file>( name: String, path: &P, - client: &Client, - url: &String, + url: &Url, duration: TimeDelta, login: &Option, ) -> Result { let mut file = File::open(path).await.unwrap(); - let size = file.metadata().await.unwrap().len() as u64; + let file_size = file.metadata().await.unwrap().len(); - let ChunkedResponse {status, message, uuid, chunk_size} = { - client.post(format!("{url}/upload/chunked/")) - .json( - &ChunkedInfo { - name: name.clone(), - size, - expire_duration: duration.num_seconds() as u64, - } - ) - .basic_auth(&login.as_ref().unwrap().user, login.as_ref().unwrap().pass.clone().into()) - .send() - .await? - .json() - .await? - }; - - if !status { - return Err(UploadError::InvalidRequest(message)); + // Construct the URL + let mut url = url.clone(); + if url.scheme() == "http" { + url.set_scheme("ws").unwrap(); + } else if url.scheme() == "https" { + url.set_scheme("wss").unwrap(); } - let mut i = 0; - let post_url = format!("{url}/upload/chunked/{}", uuid.unwrap()); - let mut request_set = JoinSet::new(); + url.set_path("/upload/websocket"); + url.set_query(Some(&format!("name={}&size={}&duration={}", name, file_size, duration.num_seconds()))); + + let mut request = url.to_string().into_client_request().unwrap(); + + if let Some(l) = login { + request.headers_mut().insert( + "Authorization", + format!("Basic {}", BASE64_URL_SAFE.encode(format!("{}:{}", l.user, l.pass))).parse().unwrap() + ); + } + + let (stream, _response) = connect_async(request).await.map_err(|e| UploadError::WebSocketFailed(e.to_string()))?; + let (mut write, mut read) = stream.split(); + + // Upload the file in chunks + let upload_task = async move { + let mut chunk = vec![0u8; 20_000]; + loop { + let read_len = file.read(&mut chunk).await.unwrap(); + if read_len == 0 { + break + } + + write.send(Message::binary(chunk[..read_len].to_vec())).await.unwrap(); + } + + // Close the stream because sending is over + write.send(Message::binary(b"".as_slice())).await.unwrap(); + write.flush().await.unwrap(); + + write + }; + let bar = ProgressBar::new(100); bar.set_style(ProgressStyle::with_template( &format!("{} {{bar:40.cyan/blue}} {{pos:>3}}% {{msg}}", name) ).unwrap()); - loop { - // Read the next chunk into a buffer - let mut chunk = vec![0u8; chunk_size.unwrap() as usize]; - let bytes_read = fill_buffer(&mut chunk, &mut file).await.unwrap(); - if bytes_read == 0 { - break; - } - let chunk = chunk[..bytes_read].to_owned(); - request_set.spawn({ - let post_url = post_url.clone(); - let user = login.as_ref().unwrap().user.clone(); - let pass = login.as_ref().unwrap().pass.clone(); - // Reuse the client for all the threads - let client = Client::clone(client); + // Get the progress of the file upload + let progress_task = async move { + let final_json = loop { + let Some(p) = read.next().await else { + break String::new() + }; - async move { - client.post(&post_url) - .query(&[("chunk", i)]) - .basic_auth(&user, pass.into()) - .body(chunk) - .send() - .await + let p = p.unwrap(); + + // Got the final json information, return that + if p.is_text() { + break p.into_text().unwrap().to_string() } - }); - i += 1; + // Get the progress information + let prog = p.into_data(); + let prog = u64::from_le_bytes(prog.to_vec().try_into().unwrap()); + let percent = f64::trunc((prog as f64 / file_size as f64) * 100.0); + if percent <= 100. { + bar.set_position(percent as u64); + } + }; - // Limit the number of concurrent uploads to 5 - if request_set.len() >= 5 { - bar.set_message(""); - request_set.join_next().await; - bar.set_message("⏳"); - } + (read, final_json, bar) + }; - let percent = f64::trunc(((i as f64 * chunk_size.unwrap() as f64) / size as f64) * 100.0); - if percent <= 100. { - bar.set_position(percent as u64); - } + // Wait for both of the tasks to finish + let (read, write) = join!(progress_task, upload_task); + let (read, final_json, bar) = read; + let mut stream = write.reunite(read).unwrap(); + + let file_info: MochiFile = serde_json::from_str(&final_json).unwrap(); + + // If the websocket isn't closed, do that + if !stream.is_terminated() { + stream.close(None).await.unwrap(); } - // Wait for all remaining uploads to finish - loop { - if let Some(t) = request_set.join_next().await { - match t { - Ok(_) => (), - Err(_) => todo!(), - } - } else { - break - } - } bar.finish_and_clear(); - println!("[{}] - \"{}\"", "✓".bright_green(), name); - Ok( - client.get(format!("{url}/upload/chunked/{}?finish", uuid.unwrap())) - .basic_auth(&login.as_ref().unwrap().user, login.as_ref().unwrap().pass.clone().into()) - .send() - .await.unwrap() - .json::() - .await? - ) + Ok(file_info) } async fn get_info_if_expired(config: &mut Config) -> Result<()> { @@ -482,7 +490,13 @@ async fn get_info_if_expired(config: &mut Config) -> Result<()> { } async fn get_info(config: &Config) -> Result { - let url = config.url.clone(); + let Some(url) = config.url.clone() else { + exit_error( + format!("URL is empty"), + Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())), + None, + ); + }; let client = Client::new(); let get_info = client.get(format!("{url}/info")); @@ -520,9 +534,11 @@ async fn fill_buffer(buffer: &mut [u8], mut stream: S) let mut bytes_read = 0; while bytes_read < buffer.len() { let len = stream.read(&mut buffer[bytes_read..]).await?; + if len == 0 { break; } + bytes_read += len; } Ok(bytes_read) @@ -596,7 +612,7 @@ struct Login { #[derive(Deserialize, Serialize, Debug, Default)] #[serde(default)] struct Config { - url: String, + url: Option, login: Option, /// The time when the info was last fetched info_fetch: Option>, @@ -611,7 +627,7 @@ impl Config { str } else { let c = Config { - url: String::new(), + url: None, login: None, info_fetch: None, info: None, @@ -644,7 +660,7 @@ impl Config { if buf.is_empty() { let c = Config { - url: String::new(), + url: None, login: None, info: None, info_fetch: None,