Changed imu to use websockets

This commit is contained in:
G2-Games 2025-03-22 17:11:47 -05:00
parent 80a11fbb0b
commit 4426192700
2 changed files with 116 additions and 96 deletions

View file

@ -18,9 +18,11 @@ workspace = true
[dependencies] [dependencies]
anyhow = "1.0" anyhow = "1.0"
base64 = "0.22.1"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.5", features = ["derive", "unicode"] } clap = { version = "4.5", features = ["derive", "unicode"] }
directories = "6.0" directories = "6.0"
futures-util = "0.3.31"
indicatif = { version = "0.17", features = ["improved_unicode"] } indicatif = { version = "0.17", features = ["improved_unicode"] }
owo-colors = { version = "4.1", features = ["supports-colors"] } owo-colors = { version = "4.1", features = ["supports-colors"] }
reqwest = { version = "0.12", features = ["json", "stream"] } reqwest = { version = "0.12", features = ["json", "stream"] }
@ -28,6 +30,8 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
thiserror = "1.0" thiserror = "1.0"
tokio = { version = "1.41", features = ["fs", "macros", "rt-multi-thread"] } tokio = { version = "1.41", features = ["fs", "macros", "rt-multi-thread"] }
tokio-tungstenite = { version = "0.26.2", features = ["native-tls"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["codec"] }
toml = "0.8" toml = "0.8"
url = { version = "2.5.4", features = ["serde"] }
uuid = { version = "1.11", features = ["serde", "v4"] } uuid = { version = "1.11", features = ["serde", "v4"] }

View file

@ -1,13 +1,17 @@
use std::{error::Error, fs, io::{self, Read, Write}, path::{Path, PathBuf}}; use std::{error::Error, fs, io::{self, Read, Write}, path::{Path, PathBuf}};
use base64::{prelude::BASE64_URL_SAFE, Engine};
use chrono::{DateTime, Datelike, Local, Month, TimeDelta, Timelike, Utc}; use chrono::{DateTime, Datelike, Local, Month, TimeDelta, Timelike, Utc};
use futures_util::{stream::FusedStream as _, SinkExt as _, StreamExt as _};
use indicatif::{ProgressBar, ProgressStyle}; use indicatif::{ProgressBar, ProgressStyle};
use owo_colors::OwoColorize; use owo_colors::OwoColorize;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}, task::JoinSet}; use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}, join, task::JoinSet};
use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest as _, Message}};
use url::Url;
use uuid::Uuid; use uuid::Uuid;
use clap::{arg, builder::{styling::RgbColor, Styles}, Parser, Subcommand}; use clap::{arg, builder::{styling::RgbColor, Styles}, Parser, Subcommand};
use anyhow::{anyhow, bail, Context as _, Result}; use anyhow::{anyhow, bail, Context as _, Result};
@ -82,17 +86,16 @@ async fn main() -> Result<()> {
match &cli.command { match &cli.command {
Commands::Upload { files, duration } => { Commands::Upload { files, duration } => {
if config.url.is_empty() { let Some(url) = config.url.clone() else {
exit_error( exit_error(
format!("URL is empty"), format!("URL is empty"),
Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())), Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())),
None, None,
); );
} };
get_info_if_expired(&mut config).await?; get_info_if_expired(&mut config).await?;
let client = Client::new();
let duration = match parse_time_string(&duration) { let duration = match parse_time_string(&duration) {
Ok(d) => d, Ok(d) => d,
Err(e) => return Err(anyhow!("Invalid duration: {e}")), Err(e) => return Err(anyhow!("Invalid duration: {e}")),
@ -125,8 +128,7 @@ async fn main() -> Result<()> {
let response = upload_file( let response = upload_file(
name.into_owned(), name.into_owned(),
&path, &path,
&client, &url,
&config.url,
duration, duration,
&config.login &config.login
).await.with_context(|| "Failed to upload").unwrap(); ).await.with_context(|| "Failed to upload").unwrap();
@ -141,11 +143,19 @@ async fn main() -> Result<()> {
println!( println!(
"{:>8} {}, {} (in {})\n{:>8} {}", "{:>8} {}, {} (in {})\n{:>8} {}",
"Expires:".truecolor(174,196,223).bold(), date, time, pretty_time_long(duration.num_seconds()), "Expires:".truecolor(174,196,223).bold(), date, time, pretty_time_long(duration.num_seconds()),
"URL:".truecolor(174,196,223).bold(), (config.url.clone() + "/f/" + &response.mmid.0).underline() "URL:".truecolor(174,196,223).bold(), (url.to_string() + "/f/" + &response.mmid.0).underline()
); );
} }
} }
Commands::Download { mmids, out_directory } => { Commands::Download { mmids, out_directory } => {
let Some(url) = config.url else {
exit_error(
format!("URL is empty"),
Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())),
None,
);
};
let out_directory = if let Some(dir) = out_directory { let out_directory = if let Some(dir) = out_directory {
dir dir
} else { } else {
@ -167,7 +177,6 @@ async fn main() -> Result<()> {
} }
}; };
let url = &config.url;
for mmid in mmids { for mmid in mmids {
let mmid = if mmid.len() != 8 { let mmid = if mmid.len() != 8 {
if mmid.contains(format!("{url}/f/").as_str()) { if mmid.contains(format!("{url}/f/").as_str()) {
@ -203,10 +212,10 @@ async fn main() -> Result<()> {
}; };
let mut file_res = if let Some(login) = &config.login { let mut file_res = if let Some(login) = &config.login {
client.get(format!("{}/f/{mmid}", config.url)) client.get(format!("{}/f/{mmid}", url))
.basic_auth(&login.user, Some(&login.pass)) .basic_auth(&login.user, Some(&login.pass))
} else { } else {
client.get(format!("{}/f/{mmid}", config.url)) client.get(format!("{}/f/{mmid}", url))
} }
.send() .send()
.await .await
@ -305,11 +314,13 @@ async fn main() -> Result<()> {
url url
}; };
if !url.starts_with("https://") && !url.starts_with("http://") { let new_url = if !url.starts_with("https://") && !url.starts_with("http://") {
config.url = ("https://".to_owned() + url).to_string(); ("https://".to_owned() + url).to_string()
} else { } else {
config.url = url.to_string(); url.to_string()
} };
config.url = Some(Url::parse(&new_url)?);
config.save().unwrap(); config.save().unwrap();
println!("URL set to \"{url}\""); println!("URL set to \"{url}\"");
@ -356,7 +367,7 @@ async fn main() -> Result<()> {
#[derive(Error, Debug)] #[derive(Error, Debug)]
enum UploadError { enum UploadError {
#[error("request provided was invalid: {0}")] #[error("request provided was invalid: {0}")]
InvalidRequest(String), WebSocketFailed(String),
#[error("error on reqwest transaction: {0}")] #[error("error on reqwest transaction: {0}")]
Reqwest(#[from] reqwest::Error), Reqwest(#[from] reqwest::Error),
@ -365,104 +376,101 @@ enum UploadError {
async fn upload_file<P: AsRef<Path>>( async fn upload_file<P: AsRef<Path>>(
name: String, name: String,
path: &P, path: &P,
client: &Client, url: &Url,
url: &String,
duration: TimeDelta, duration: TimeDelta,
login: &Option<Login>, login: &Option<Login>,
) -> Result<MochiFile, UploadError> { ) -> Result<MochiFile, UploadError> {
let mut file = File::open(path).await.unwrap(); let mut file = File::open(path).await.unwrap();
let size = file.metadata().await.unwrap().len() as u64; let file_size = file.metadata().await.unwrap().len();
let ChunkedResponse {status, message, uuid, chunk_size} = { // Construct the URL
client.post(format!("{url}/upload/chunked/")) let mut url = url.clone();
.json( if url.scheme() == "http" {
&ChunkedInfo { url.set_scheme("ws").unwrap();
name: name.clone(), } else if url.scheme() == "https" {
size, url.set_scheme("wss").unwrap();
expire_duration: duration.num_seconds() as u64,
}
)
.basic_auth(&login.as_ref().unwrap().user, login.as_ref().unwrap().pass.clone().into())
.send()
.await?
.json()
.await?
};
if !status {
return Err(UploadError::InvalidRequest(message));
} }
let mut i = 0; url.set_path("/upload/websocket");
let post_url = format!("{url}/upload/chunked/{}", uuid.unwrap()); url.set_query(Some(&format!("name={}&size={}&duration={}", name, file_size, duration.num_seconds())));
let mut request_set = JoinSet::new();
let mut request = url.to_string().into_client_request().unwrap();
if let Some(l) = login {
request.headers_mut().insert(
"Authorization",
format!("Basic {}", BASE64_URL_SAFE.encode(format!("{}:{}", l.user, l.pass))).parse().unwrap()
);
}
let (stream, _response) = connect_async(request).await.map_err(|e| UploadError::WebSocketFailed(e.to_string()))?;
let (mut write, mut read) = stream.split();
// Upload the file in chunks
let upload_task = async move {
let mut chunk = vec![0u8; 20_000];
loop {
let read_len = file.read(&mut chunk).await.unwrap();
if read_len == 0 {
break
}
write.send(Message::binary(chunk[..read_len].to_vec())).await.unwrap();
}
// Close the stream because sending is over
write.send(Message::binary(b"".as_slice())).await.unwrap();
write.flush().await.unwrap();
write
};
let bar = ProgressBar::new(100); let bar = ProgressBar::new(100);
bar.set_style(ProgressStyle::with_template( bar.set_style(ProgressStyle::with_template(
&format!("{} {{bar:40.cyan/blue}} {{pos:>3}}% {{msg}}", name) &format!("{} {{bar:40.cyan/blue}} {{pos:>3}}% {{msg}}", name)
).unwrap()); ).unwrap());
loop {
// Read the next chunk into a buffer
let mut chunk = vec![0u8; chunk_size.unwrap() as usize];
let bytes_read = fill_buffer(&mut chunk, &mut file).await.unwrap();
if bytes_read == 0 {
break;
}
let chunk = chunk[..bytes_read].to_owned();
request_set.spawn({ // Get the progress of the file upload
let post_url = post_url.clone(); let progress_task = async move {
let user = login.as_ref().unwrap().user.clone(); let final_json = loop {
let pass = login.as_ref().unwrap().pass.clone(); let Some(p) = read.next().await else {
// Reuse the client for all the threads break String::new()
let client = Client::clone(client); };
async move { let p = p.unwrap();
client.post(&post_url)
.query(&[("chunk", i)]) // Got the final json information, return that
.basic_auth(&user, pass.into()) if p.is_text() {
.body(chunk) break p.into_text().unwrap().to_string()
.send()
.await
} }
});
i += 1; // Get the progress information
let prog = p.into_data();
let prog = u64::from_le_bytes(prog.to_vec().try_into().unwrap());
let percent = f64::trunc((prog as f64 / file_size as f64) * 100.0);
if percent <= 100. {
bar.set_position(percent as u64);
}
};
// Limit the number of concurrent uploads to 5 (read, final_json, bar)
if request_set.len() >= 5 { };
bar.set_message("");
request_set.join_next().await;
bar.set_message("");
}
let percent = f64::trunc(((i as f64 * chunk_size.unwrap() as f64) / size as f64) * 100.0); // Wait for both of the tasks to finish
if percent <= 100. { let (read, write) = join!(progress_task, upload_task);
bar.set_position(percent as u64); let (read, final_json, bar) = read;
} let mut stream = write.reunite(read).unwrap();
let file_info: MochiFile = serde_json::from_str(&final_json).unwrap();
// If the websocket isn't closed, do that
if !stream.is_terminated() {
stream.close(None).await.unwrap();
} }
// Wait for all remaining uploads to finish
loop {
if let Some(t) = request_set.join_next().await {
match t {
Ok(_) => (),
Err(_) => todo!(),
}
} else {
break
}
}
bar.finish_and_clear(); bar.finish_and_clear();
println!("[{}] - \"{}\"", "".bright_green(), name);
Ok( Ok(file_info)
client.get(format!("{url}/upload/chunked/{}?finish", uuid.unwrap()))
.basic_auth(&login.as_ref().unwrap().user, login.as_ref().unwrap().pass.clone().into())
.send()
.await.unwrap()
.json::<MochiFile>()
.await?
)
} }
async fn get_info_if_expired(config: &mut Config) -> Result<()> { async fn get_info_if_expired(config: &mut Config) -> Result<()> {
@ -482,7 +490,13 @@ async fn get_info_if_expired(config: &mut Config) -> Result<()> {
} }
async fn get_info(config: &Config) -> Result<ServerInfo> { async fn get_info(config: &Config) -> Result<ServerInfo> {
let url = config.url.clone(); let Some(url) = config.url.clone() else {
exit_error(
format!("URL is empty"),
Some(format!("Please set it using the {} command", "set".truecolor(246,199,219).bold())),
None,
);
};
let client = Client::new(); let client = Client::new();
let get_info = client.get(format!("{url}/info")); let get_info = client.get(format!("{url}/info"));
@ -520,9 +534,11 @@ async fn fill_buffer<S: AsyncReadExt + Unpin>(buffer: &mut [u8], mut stream: S)
let mut bytes_read = 0; let mut bytes_read = 0;
while bytes_read < buffer.len() { while bytes_read < buffer.len() {
let len = stream.read(&mut buffer[bytes_read..]).await?; let len = stream.read(&mut buffer[bytes_read..]).await?;
if len == 0 { if len == 0 {
break; break;
} }
bytes_read += len; bytes_read += len;
} }
Ok(bytes_read) Ok(bytes_read)
@ -596,7 +612,7 @@ struct Login {
#[derive(Deserialize, Serialize, Debug, Default)] #[derive(Deserialize, Serialize, Debug, Default)]
#[serde(default)] #[serde(default)]
struct Config { struct Config {
url: String, url: Option<Url>,
login: Option<Login>, login: Option<Login>,
/// The time when the info was last fetched /// The time when the info was last fetched
info_fetch: Option<DateTime<Utc>>, info_fetch: Option<DateTime<Utc>>,
@ -611,7 +627,7 @@ impl Config {
str str
} else { } else {
let c = Config { let c = Config {
url: String::new(), url: None,
login: None, login: None,
info_fetch: None, info_fetch: None,
info: None, info: None,
@ -644,7 +660,7 @@ impl Config {
if buf.is_empty() { if buf.is_empty() {
let c = Config { let c = Config {
url: String::new(), url: None,
login: None, login: None,
info: None, info: None,
info_fetch: None, info_fetch: None,