diff --git a/src/netmd/encryption.rs b/src/netmd/encryption.rs index 944870c..18c6b07 100644 --- a/src/netmd/encryption.rs +++ b/src/netmd/encryption.rs @@ -9,74 +9,12 @@ use super::interface::DataEncryptorInput; type DesEcbEnc = ecb::Decryptor; type DesCbcEnc = cbc::Encryptor; -pub fn threaded_encryptor( - input: DataEncryptorInput, -) -> UnboundedReceiver<(Vec, Vec, Vec)> { - let (tx, rx) = unbounded_channel::<(Vec, Vec, Vec)>(); - - thread::spawn(move || { - let mut iv = [0u8; 8]; - - // Create the random key - let mut random_key = [0u8; 8]; - rand::thread_rng().fill_bytes(&mut random_key); - - // Encrypt it with the kek - let mut encrypted_random_key = random_key; - if let Err(x) = DesEcbEnc::new(&input.kek.into()) - .decrypt_padded_mut::(&mut encrypted_random_key) - { - panic!("Cannot create main key {:?}", x) - }; - - let default_chunk_size = match input.chunk_size { - 0 => 0x00100000, - e => e, - }; - - let mut packet_count = 0u32; - let mut current_chunk_size; - - let mut input_data = input.data.clone(); - if (input_data.len() % input.frame_size) != 0 { - let padding_remaining = input.frame_size - (input_data.len() % input.frame_size); - input_data.extend(std::iter::repeat(0).take(padding_remaining)); - } - let input_data_length = input_data.len(); - - let mut offset: usize = 0; - while offset < input_data_length { - if packet_count > 0 { - current_chunk_size = default_chunk_size; - } else { - current_chunk_size = default_chunk_size - 24; - } - - current_chunk_size = std::cmp::min(current_chunk_size, input_data_length - offset); - - let this_data_chunk = &mut input_data[offset..offset + current_chunk_size]; - DesCbcEnc::new(&random_key.into(), &iv.into()) - .encrypt_padded_mut::(this_data_chunk, current_chunk_size) - .unwrap(); - - tx.send(( - encrypted_random_key.to_vec(), - iv.to_vec(), - this_data_chunk.to_vec(), - )) - .unwrap(); - - iv.copy_from_slice(&this_data_chunk[this_data_chunk.len() - 8..]); - - packet_count += 1; - offset += current_chunk_size; - } - }); - - rx +pub struct Encryptor { + channel: Option, Vec, Vec)>>, + state: Option, } -pub struct Encryptor { +struct EncryptorState { input_data: Vec, iv: [u8; 8], random_key: [u8; 8], @@ -89,6 +27,74 @@ pub struct Encryptor { } impl Encryptor { + pub fn new_threaded(input: DataEncryptorInput) -> Self { + let (tx, rx) = unbounded_channel::<(Vec, Vec, Vec)>(); + + thread::spawn(move || { + let mut iv = [0u8; 8]; + + // Create the random key + let mut random_key = [0u8; 8]; + rand::thread_rng().fill_bytes(&mut random_key); + + // Encrypt it with the kek + let mut encrypted_random_key = random_key; + if let Err(x) = DesEcbEnc::new(&input.kek.into()) + .decrypt_padded_mut::(&mut encrypted_random_key) + { + panic!("Cannot create main key {:?}", x) + }; + + let default_chunk_size = match input.chunk_size { + 0 => 0x00100000, + e => e, + }; + + let mut packet_count = 0u32; + let mut current_chunk_size; + + let mut input_data = input.data.clone(); + if (input_data.len() % input.frame_size) != 0 { + let padding_remaining = input.frame_size - (input_data.len() % input.frame_size); + input_data.extend(std::iter::repeat(0).take(padding_remaining)); + } + let input_data_length = input_data.len(); + + let mut offset: usize = 0; + while offset < input_data_length { + if packet_count > 0 { + current_chunk_size = default_chunk_size; + } else { + current_chunk_size = default_chunk_size - 24; + } + + current_chunk_size = std::cmp::min(current_chunk_size, input_data_length - offset); + + let this_data_chunk = &mut input_data[offset..offset + current_chunk_size]; + DesCbcEnc::new(&random_key.into(), &iv.into()) + .encrypt_padded_mut::(this_data_chunk, current_chunk_size) + .unwrap(); + + tx.send(( + encrypted_random_key.to_vec(), + iv.to_vec(), + this_data_chunk.to_vec(), + )) + .unwrap(); + + iv.copy_from_slice(&this_data_chunk[this_data_chunk.len() - 8..]); + + packet_count += 1; + offset += current_chunk_size; + } + }); + + Self { + channel: Some(rx), + state: None + } + } + pub fn new(input: DataEncryptorInput) -> Self { let iv = [0u8; 8]; @@ -120,58 +126,71 @@ impl Encryptor { let offset: usize = 0; - Self { - input_data, - iv, - random_key, - encrypted_random_key, - current_chunk_size, - offset, - default_chunk_size, - packet_count, - closed: false, + Encryptor { + channel: None, + state: Some(EncryptorState { + input_data, + iv, + random_key, + encrypted_random_key, + current_chunk_size, + offset, + default_chunk_size, + packet_count, + closed: false, + }) } } /// Get the next encrypted value pub async fn next(&mut self) -> Option<(Vec, Vec, Vec)> { - if self.closed { - return None - } + let output; - if self.packet_count > 0 { - self.current_chunk_size = self.default_chunk_size; + if let Some(state) = self.state.as_mut() { + if state.closed { + return None + } + + if state.packet_count > 0 { + state.current_chunk_size = state.default_chunk_size; + } else { + state.current_chunk_size = state.default_chunk_size - 24; + } + + state.current_chunk_size = std::cmp::min(state.current_chunk_size, state.input_data.len() - state.offset); + + let this_data_chunk = &mut state.input_data[state.offset..state.offset + state.current_chunk_size]; + DesCbcEnc::new(&state.random_key.into(), &state.iv.into()) + .encrypt_padded_mut::(this_data_chunk, state.current_chunk_size) + .unwrap(); + + output = Some(( + state.encrypted_random_key.to_vec(), + state.iv.to_vec(), + this_data_chunk.to_vec(), + )); + + state.iv.copy_from_slice(&this_data_chunk[this_data_chunk.len() - 8..]); + + state.packet_count += 1; + state.offset += state.current_chunk_size; + } else if let Some(channel) = self.channel.as_mut() { + output = channel.recv().await } else { - self.current_chunk_size = self.default_chunk_size - 24; + unreachable!("If you got here, this is bad!"); } - self.current_chunk_size = std::cmp::min(self.current_chunk_size, self.input_data.len() - self.offset); - - let this_data_chunk = &mut self.input_data[self.offset..self.offset + self.current_chunk_size]; - DesCbcEnc::new(&self.random_key.into(), &self.iv.into()) - .encrypt_padded_mut::(this_data_chunk, self.current_chunk_size) - .unwrap(); - - let output = ( - self.encrypted_random_key.to_vec(), - self.iv.to_vec(), - this_data_chunk.to_vec(), - ); - - self.iv.copy_from_slice(&this_data_chunk[this_data_chunk.len() - 8..]); - - self.packet_count += 1; - self.offset += self.current_chunk_size; - - Some(output) + output } /// Call close to return none from subsequent calls pub fn close(&mut self) { - self.closed = true; + if let Some(state) = self.state.as_mut() { + state.closed = true; + } else if let Some(channel) = self.channel.as_mut() { + channel.close() + } else { + unreachable!("If you got here, this is bad!"); + } } } - -pub fn encryptor() { - -} diff --git a/src/netmd/interface.rs b/src/netmd/interface.rs index 7be62ec..b707ea7 100644 --- a/src/netmd/interface.rs +++ b/src/netmd/interface.rs @@ -17,7 +17,7 @@ use std::time::Duration; use thiserror::Error; use super::base::NetMD; -use super::encryption::{threaded_encryptor, EncryptorState}; +use super::encryption::Encryptor; use super::utils::{cross_sleep, to_sjis}; /// An action to take on the player @@ -1690,11 +1690,7 @@ impl NetMDInterface { discformat: u8, frames: u32, pkt_size: u32, - // key, iv, data - #[cfg(not(target_family = "wasm"))] - mut packets: UnboundedReceiver<(Vec, Vec, Vec)>, - #[cfg(target_family = "wasm")] - mut packets: EncryptorState, + mut packets: Encryptor, hex_session_key: &[u8], progress_callback: F, ) -> Result<(u16, Vec, Vec), InterfaceError> @@ -1733,7 +1729,7 @@ impl NetMDInterface { let mut written_bytes = 0; let mut packet_count = 0; - while let Some((key, iv, data)) = packets.recv().await { + while let Some((key, iv, data)) = packets.next().await { let binpack = if packet_count == 0 { let packed_length: Vec = pkt_size.to_be_bytes().to_vec(); [vec![0, 0, 0, 0], packed_length, key, iv, data].concat() @@ -1921,8 +1917,8 @@ impl MDTrack { } #[cfg(not(target_family = "wasm"))] - pub fn get_encrypting_iterator(&mut self) -> UnboundedReceiver<(Vec, Vec, Vec)> { - threaded_encryptor(DataEncryptorInput { + pub fn get_encrypting_iterator(&mut self) -> Encryptor { + Encryptor::new_threaded(DataEncryptorInput { kek: self.get_kek(), frame_size: self.frame_size(), chunk_size: self.chunk_size(), @@ -1931,8 +1927,8 @@ impl MDTrack { } #[cfg(target_family = "wasm")] - pub fn get_encrypting_iterator(&mut self) -> EncryptorState { - EncryptorState::new(DataEncryptorInput { + pub fn get_encrypting_iterator(&mut self) -> Encryptor { + Encryptor::new(DataEncryptorInput { kek: self.get_kek(), frame_size: self.frame_size(), chunk_size: self.chunk_size(),