From 17a1c809d8dae048a861abcfa1df4efcf58ce421 Mon Sep 17 00:00:00 2001
From: asivery <asivery@protonmail.com>
Date: Sat, 9 Mar 2024 20:52:18 +0100
Subject: [PATCH] Add track upload

---
 Cargo.lock                          |   1 +
 minidisc-rs/Cargo.toml              |   1 +
 minidisc-rs/src/netmd/base.rs       |   2 +-
 minidisc-rs/src/netmd/commands.rs   |  11 ++-
 minidisc-rs/src/netmd/encryption.rs |  67 +++++++++++++
 minidisc-rs/src/netmd/interface.rs  | 145 +++++++++++++++++++---------
 minidisc-rs/src/netmd/mod.rs        |   3 +-
 7 files changed, 179 insertions(+), 51 deletions(-)
 create mode 100644 minidisc-rs/src/netmd/encryption.rs

diff --git a/Cargo.lock b/Cargo.lock
index d845a1a..bd88186 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -760,6 +760,7 @@ dependencies = [
  "once_cell",
  "rand",
  "regex",
+ "tokio",
  "unicode-jp",
  "unicode-normalization",
  "web-sys",
diff --git a/minidisc-rs/Cargo.toml b/minidisc-rs/Cargo.toml
index de83235..5ea1d68 100644
--- a/minidisc-rs/Cargo.toml
+++ b/minidisc-rs/Cargo.toml
@@ -28,6 +28,7 @@ getrandom = { version = "0.2.12", features = ["js"] }
 des = "0.8.1"
 cbc = "0.1.2"
 ecb = "0.1.2"
+tokio = { version = "1.35.1", features = ["sync"] }
 
 [dependencies.unicode-jp]
 # Relying on this fork for now as it has up-to-date deps
diff --git a/minidisc-rs/src/netmd/base.rs b/minidisc-rs/src/netmd/base.rs
index a1acc16..a9f18c1 100644
--- a/minidisc-rs/src/netmd/base.rs
+++ b/minidisc-rs/src/netmd/base.rs
@@ -254,7 +254,7 @@ impl NetMD {
 
             // Back off while trying again
             let sleep_time = Self::READ_REPLY_RETRY_INTERVAL
-                * (u32::pow(2, attempt / 10) - 1);
+                * (u32::pow(2, attempt) - 1);
 
             cross_sleep(sleep_time).await;
         }
diff --git a/minidisc-rs/src/netmd/commands.rs b/minidisc-rs/src/netmd/commands.rs
index 28b1afb..22350e2 100644
--- a/minidisc-rs/src/netmd/commands.rs
+++ b/minidisc-rs/src/netmd/commands.rs
@@ -3,7 +3,7 @@ use std::error::Error;
 use num_derive::FromPrimitive;
 use num_traits::FromPrimitive;
 
-use super::interface::{NetMDInterface, MDTrack};
+use super::interface::{NetMDInterface, MDTrack, MDSession};
 use super::utils::cross_sleep;
 
 #[derive(FromPrimitive)]
@@ -72,8 +72,13 @@ pub async fn prepare_download(interface: &mut NetMDInterface) -> Result<(), Box<
     Ok(())
 }
 
-pub async fn download(interface: &mut NetMDInterface, _track: MDTrack) -> Result<(), Box<dyn Error>>{
+pub async fn download<F>(interface: &mut NetMDInterface, track: MDTrack, progress_callback: F) -> Result<(u16, Vec<u8>, Vec<u8>), Box<dyn Error>> where F: Fn(usize, usize){
     prepare_download(interface).await?;
+    let mut session = MDSession::new(interface);
+    session.init().await?;
+    let result = session.download_track(track, progress_callback, None).await?;
+    session.close().await?;
+    interface.release().await?;
 
-    Ok(())
+    Ok(result)
 }
diff --git a/minidisc-rs/src/netmd/encryption.rs b/minidisc-rs/src/netmd/encryption.rs
new file mode 100644
index 0000000..947cd22
--- /dev/null
+++ b/minidisc-rs/src/netmd/encryption.rs
@@ -0,0 +1,67 @@
+use std::thread;
+use cbc::cipher::block_padding::NoPadding;
+use cbc::cipher::{KeyInit, BlockDecryptMut, KeyIvInit, BlockEncryptMut};
+use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
+use rand::RngCore;
+
+use super::interface::DataEncryptorInput;
+
+type DesEcbEnc = ecb::Decryptor<des::Des>;
+type DesCbcEnc = cbc::Encryptor<des::Des>;
+
+pub fn new_thread_encryptor(_input: DataEncryptorInput) -> UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)> {
+    let (tx, rx) = unbounded_channel::<(Vec<u8>, Vec<u8>, Vec<u8>)>();
+    let input = Box::from(_input);
+    thread::spawn(move || {
+        let mut iv = [0u8; 8];
+
+        // Create the random key
+        let mut random_key = [0u8; 8];
+        rand::thread_rng().fill_bytes(&mut random_key);
+
+        // Encrypt it with the kek
+        let mut encrypted_random_key = random_key.clone();
+        match DesEcbEnc::new(&input.kek.into()).decrypt_padded_mut::<NoPadding>(&mut encrypted_random_key){
+            Err(x) => panic!("Cannot create main key {:?}", x),
+            Ok(_) => {}
+        };
+
+        let default_chunk_size = match input.chunk_size{
+            0 => 0x00100000,
+            e => e
+        };
+
+        let mut packet_count = 0u32;
+        let mut current_chunk_size;
+
+        let mut input_data = input.data.clone();
+        if (input_data.len() % input.frame_size) != 0 {
+            let padding_remaining = input.frame_size - (input_data.len() % input.frame_size);
+            input_data.extend(std::iter::repeat(0).take(padding_remaining));
+        }
+        let input_data_length = input_data.len();
+        
+        let mut offset: usize = 0;
+        while offset < input_data_length {
+            if packet_count > 0 {
+                current_chunk_size = default_chunk_size;
+            } else {
+                current_chunk_size = default_chunk_size - 24;
+            }
+
+            current_chunk_size = std::cmp::min(current_chunk_size, input_data_length - offset);
+
+            let this_data_chunk = &mut input_data[offset..offset+current_chunk_size];
+            DesCbcEnc::new(&random_key.into(), &iv.into()).encrypt_padded_mut::<NoPadding>(this_data_chunk, current_chunk_size).unwrap();
+
+            tx.send((encrypted_random_key.to_vec(), iv.to_vec(), this_data_chunk.to_vec())).unwrap();
+
+            iv.copy_from_slice(&this_data_chunk[this_data_chunk.len()-8..]);
+
+            packet_count += 1;
+            offset += current_chunk_size;
+        }
+    });
+
+    rx
+}
diff --git a/minidisc-rs/src/netmd/interface.rs b/minidisc-rs/src/netmd/interface.rs
index 598c329..9911b96 100644
--- a/minidisc-rs/src/netmd/interface.rs
+++ b/minidisc-rs/src/netmd/interface.rs
@@ -8,10 +8,11 @@ use crate::netmd::utils::{
 use cbc::cipher::block_padding::NoPadding;
 use cbc::cipher::{KeyIvInit, BlockEncryptMut, BlockDecryptMut, KeyInit};
 use encoding_rs::SHIFT_JIS;
+use num_derive::FromPrimitive;
 use rand::RngCore;
+use tokio::sync::mpsc::UnboundedReceiver;
 use std::collections::HashMap;
 use std::error::Error;
-use std::str::FromStr;
 
 use lazy_static::lazy_static;
 
@@ -31,7 +32,7 @@ enum Track {
     Restart = 0x0001,
 }
 
-#[derive(Debug)]
+#[derive(Debug, Clone, Copy, FromPrimitive)]
 pub enum DiscFormat {
     LP4 = 0,
     LP2 = 2,
@@ -39,7 +40,7 @@ pub enum DiscFormat {
     SPStereo = 6,
 }
 
-#[derive(Clone, Hash, Eq, PartialEq)]
+#[derive(Clone, Hash, Eq, PartialEq, FromPrimitive)]
 pub enum WireFormat {
     Pcm = 0x00,
     L105kbps = 0x90,
@@ -391,6 +392,7 @@ impl NetMDInterface {
                     let sleep_time = Self::INTERIM_RESPONSE_RETRY_INTERVAL
                         * (u32::pow(2, current_attempt as u32) - 1);
 
+
                     cross_sleep(sleep_time).await;
 
                     current_attempt += 1;
@@ -459,7 +461,7 @@ impl NetMDInterface {
         Ok(())
     }
 
-    async fn release(&mut self) -> Result<(), Box<dyn Error>> {
+    pub async fn release(&mut self) -> Result<(), Box<dyn Error>> {
         let mut query = format_query("ff 0100 ffff ffff ffff ffff ffff ffff".to_string(), vec![])?;
 
         let reply = self.send_query(&mut query, false, false).await?;
@@ -1431,8 +1433,8 @@ impl NetMDInterface {
 
     pub async fn setup_download(
         &mut self,
-        contentid: Vec<u8>,
-        keyenckey: Vec<u8>,
+        contentid: &[u8],
+        keyenckey: &[u8],
         hex_session_key: &[u8],
     ) -> Result<(), Box<dyn Error>> {
         if contentid.len() != 20 {
@@ -1441,11 +1443,11 @@ impl NetMDInterface {
         if keyenckey.len() != 8 {
             return Err("Supplied Key Encryption Key length wrong".into());
         }
-        if hex_session_key.len() != 16 {
+        if hex_session_key.len() != 8 {
             return Err("Supplied Session Key length wrong".into());
         }
 
-        let mut message = [vec![1, 1, 1, 1], contentid, keyenckey].concat();
+        let mut message = [vec![1, 1, 1, 1], contentid.to_vec(), keyenckey.to_vec()].concat();
         DesCbcEnc::new(hex_session_key.into(), &[0u8; 8].into()).encrypt_padded_mut::<NoPadding>(message.as_mut_slice(), 32).unwrap();
 
         let mut query = format_query(
@@ -1465,7 +1467,7 @@ impl NetMDInterface {
         track_number: u16,
         hex_session_key: &[u8],
     ) -> Result<(), Box<dyn Error>> {
-        if hex_session_key.len() != 16 {
+        if hex_session_key.len() != 8 {
             return Err("Supplied Session Key length wrong".into());
         }
 
@@ -1473,7 +1475,7 @@ impl NetMDInterface {
         DesEcbEnc::new(hex_session_key.into()).encrypt_padded_mut::<NoPadding>(&mut message, 8).unwrap();
 
         let mut query = format_query(
-            "1800 080046 f0030103 22 ff 0000 %*".to_string(),
+            "1800 080046 f0030103 48 ff 00 1001 %w %*".to_string(),
             vec![
                 QueryValue::Number(track_number as i64),
                 QueryValue::Array(Vec::from(message)),
@@ -1482,29 +1484,30 @@ impl NetMDInterface {
 
         let reply = self.send_query(&mut query, false, false).await?;
 
-        scan_query(reply, "1800 080046 f0030103 22 00 0000".to_string())?;
+        scan_query(reply, "1800 080046 f0030103 48 00 00 1001 %?%?".to_string())?;
 
         Ok(())
     }
 
-    pub async fn send_track(
+    pub async fn send_track<F>(
         &mut self,
         wireformat: u8,
         discformat: u8,
-        frames: i32,
+        frames: u32,
         pkt_size: u32,
         // key   // iv    // data
-        packets: Vec<(Vec<u8>, Vec<u8>, Vec<u8>)>,
+        mut packets: UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)>,
         hex_session_key: &[u8],
-    ) -> Result<(i64, String, String), Box<dyn Error>> {
-        if hex_session_key.len() != 16 {
+        progress_callback: F
+    ) -> Result<(u16, Vec<u8>, Vec<u8>), Box<dyn Error>> where F: Fn(usize, usize) {
+        if hex_session_key.len() != 8 {
             return Err("Supplied Session Key length wrong".into());
         }
 
         // Sharps are slow
         cross_sleep(200).await;
 
-        let total_bytes = pkt_size + 24; //framesizedict[wireformat] * frames + pktcount * 24;
+        let total_bytes: usize = (pkt_size + 24) as usize; //framesizedict[wireformat] * frames + pktcount * 24;
 
         let mut query = format_query(
             "1800 080046 f0030103 28 ff 000100 1001 ffff 00 %b %b %d %d".to_string(),
@@ -1520,20 +1523,29 @@ impl NetMDInterface {
             reply,
             "1800 080046 f0030103 28 00 000100 1001 %?%? 00 %*".to_string(),
         )?;
+        self.net_md_device.poll().await?;
 
         // Sharps are slow
         cross_sleep(200).await;
 
         let mut _written_bytes = 0;
-        for (packet_count, (key, iv, data)) in packets.into_iter().enumerate() {
+        let mut packet_count = 0;
+        
+        while let Some((key, iv, data)) = packets.recv().await {
             let binpack = if packet_count == 0 {
-                let packed_length: Vec<u8> = pkt_size.to_le_bytes().to_vec();
-                [vec![0, 0, 0, 0], packed_length, key, iv, data.clone()].concat()
+                let packed_length: Vec<u8> = pkt_size.to_be_bytes().to_vec();
+                [vec![0, 0, 0, 0], packed_length, key, iv, data].concat()
             } else {
-                data.clone()
+                data
             };
             self.net_md_device.write_bulk(&binpack).await?;
-            _written_bytes += data.len();
+            _written_bytes += binpack.len();
+            packet_count += 1;
+            (progress_callback)(total_bytes, _written_bytes);
+            if total_bytes == _written_bytes.try_into().unwrap() {
+                packets.close();
+                break;
+            }
         }
 
         reply = self.read_reply(false).await?;
@@ -1546,12 +1558,10 @@ impl NetMDInterface {
         let mut encrypted_data = res[1].to_vec().unwrap();
         DesCbcDec::new(hex_session_key.into(), &[0u8; 8].into()).decrypt_padded_mut::<NoPadding>(&mut encrypted_data).unwrap();
 
-        let reply_data = String::from_utf8(encrypted_data)?;
+        let part1 = encrypted_data[0..8].to_vec();
+        let part2 = encrypted_data[12..32].to_vec();
 
-        let part1 = String::from_str(&reply_data[0..8]).unwrap();
-        let part2 = String::from_str(&reply_data[12..32]).unwrap();
-
-        Ok((res[0].to_i64().unwrap(), part1, part2))
+        Ok((res[0].to_i64().unwrap() as u16, part1, part2))
     }
 
     pub async fn track_uuid(&mut self, track: u16) -> Result<String, Box<dyn Error>> {
@@ -1644,22 +1654,20 @@ impl EKBOpenSource {
     }
 }
 
-#[derive(Clone)]
 pub struct MDTrack {
-    title: String,
-    format: WireFormat,
-    data: Vec<u8>,
-    chunk_size: i32,
-    full_width_title: Option<String>,
-    encrypt_packets_iterator: EncryptPacketsIterator,
+    pub title: String,
+    pub format: WireFormat,
+    pub data: Vec<u8>,
+    pub chunk_size: usize,
+    pub full_width_title: Option<String>,
+    pub encrypt_packets_iterator: Box<dyn Fn(DataEncryptorInput) -> UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)>>,
 }
 
-#[derive(Clone)]
-pub struct EncryptPacketsIterator {
-    kek: Vec<u8>,
-    frame_size: i32,
-    data: Vec<u8>,
-    chunk_size: i32,
+pub struct DataEncryptorInput {
+    pub kek: [u8; 8],
+    pub frame_size: usize,
+    pub data: Vec<u8>,
+    pub chunk_size: usize,
 }
 
 impl MDTrack {
@@ -1683,7 +1691,7 @@ impl MDTrack {
         *FRAME_SIZE.get(&self.format).unwrap()
     }
 
-    pub fn chunk_size(self) -> i32 {
+    pub fn chunk_size(&self) -> usize {
         self.chunk_size
     }
 
@@ -1696,25 +1704,34 @@ impl MDTrack {
         len
     }
 
-    pub fn content_id() -> [u8; 20] {
+    pub fn content_id(&self) -> [u8; 20] {
         [
             0x01, 0x0f, 0x50, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0xa2, 0x8d, 0x3e, 0x1a,
             0x3b, 0x0c, 0x44, 0xaf, 0x2f, 0xa0,
         ]
     }
 
-    pub fn get_kek() -> [u8; 8] {
+    pub fn get_kek(&self) -> [u8; 8] {
         [0x14, 0xe3, 0x83, 0x4e, 0xe2, 0xd3, 0xcc, 0xa5]
     }
+
+    pub fn get_encrypting_iterator(&mut self) -> UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)>{
+        (self.encrypt_packets_iterator)(DataEncryptorInput {
+            kek: self.get_kek().clone(),
+            frame_size: self.frame_size(),
+            chunk_size: self.chunk_size(),
+            data: std::mem::take(&mut self.data)
+        })
+    }
 }
 
-pub struct MDSession {
-    pub md: NetMDInterface,
+pub struct MDSession<'a> {
+    pub md: &'a mut NetMDInterface,
     pub ekb_object: EKBOpenSource,
     pub hex_session_key: Option<Vec<u8>>,
 }
 
-impl MDSession {
+impl<'a> MDSession<'a> {
     pub async fn init(&mut self) -> Result<(), Box<dyn Error>>{
         self.md.enter_secure_session().await?;
         self.md.leaf_id().await?;
@@ -1739,4 +1756,40 @@ impl MDSession {
 
         Ok(())
     }
+
+    pub async fn download_track<F>(&mut self, mut track: MDTrack, progress_callback: F, disc_format: Option<DiscFormat>) -> Result<(u16, Vec<u8>, Vec<u8>), Box<dyn Error>> where F: Fn(usize, usize) {
+        if let None = self.hex_session_key{
+            return Err("Cannot download a track using a non-init()'ed session!".into());
+
+        }
+        self.md.setup_download(&track.content_id(), &track.get_kek(), &self.hex_session_key.as_ref().unwrap()).await?;
+        let data_format = track.data_format();
+        let final_disc_format = disc_format.unwrap_or(*DISC_FOR_WIRE.get(&data_format).unwrap());
+
+        let (track_index, uuid, ccid) = self.md.send_track(
+            data_format as u8,
+            final_disc_format as u8,
+            track.frame_count() as u32,
+            track.total_size() as u32,
+            track.get_encrypting_iterator(),
+            self.hex_session_key.as_ref().unwrap().as_slice(),
+            progress_callback
+        ).await?;
+
+        self.md.set_track_title(track_index, &track.title, false).await?;
+        if let Some(full_width) = track.full_width_title {
+            self.md.set_track_title(track_index, &full_width, true).await?;
+        }
+        self.md.commit_track(track_index, &self.hex_session_key.as_ref().unwrap()).await?;
+
+        Ok((track_index, uuid, ccid))
+    }
+
+    pub fn new(md: &'a mut NetMDInterface) -> Self {
+        MDSession {
+            md,
+            ekb_object: EKBOpenSource {},
+            hex_session_key: None,
+        }
+    }
 }
diff --git a/minidisc-rs/src/netmd/mod.rs b/minidisc-rs/src/netmd/mod.rs
index 234810e..84fed51 100644
--- a/minidisc-rs/src/netmd/mod.rs
+++ b/minidisc-rs/src/netmd/mod.rs
@@ -9,4 +9,5 @@ pub mod interface;
 mod mappings;
 mod query_utils;
 mod utils;
-mod commands;
+pub mod commands;
+pub mod encryption;