Compare commits

..

2 commits

Author SHA1 Message Date
833eeb6158 Improved encryption implementation 2024-08-15 00:35:01 -05:00
9172a524d3 Initial working WASM encryption 2024-08-15 00:04:55 -05:00
4 changed files with 158 additions and 139 deletions

View file

@ -3,4 +3,4 @@ rustflags = ["--cfg=web_sys_unstable_apis"]
# Enable for testing WASM-only stuff # Enable for testing WASM-only stuff
[build] [build]
target = "wasm32-unknown-unknown" #target = "wasm32-unknown-unknown"

View file

@ -45,16 +45,13 @@ getrandom = { version = "0.2", features = ["js"] }
des = "0.8" des = "0.8"
cbc = "0.1" cbc = "0.1"
ecb = "0.1" ecb = "0.1"
tokio = { version = "1.36", features = ["sync"] }
g2-unicode-jp = "0.4.1" g2-unicode-jp = "0.4.1"
thiserror = "1.0.57" thiserror = "1.0.57"
phf = { version = "0.11.2", features = ["phf_macros", "macros"] } phf = { version = "0.11.2", features = ["phf_macros", "macros"] }
byteorder = "1.5.0" byteorder = "1.5.0"
log = "0.4.22" log = "0.4.22"
[target.'cfg(not(target_family = "wasm"))'.dependencies]
tokio = { version = "1.36", features = ["sync"] }
[target.'cfg(target_family = "wasm")'.dependencies] [target.'cfg(target_family = "wasm")'.dependencies]
gloo = { version = "0.11", features = ["futures", "worker"] } gloo = { version = "0.11.0", features = ["futures", "worker"] }
futures = "0.3" futures = "0.3.30"
serde = { version = "1.0", features = ["derive"] }

View file

@ -1,28 +1,36 @@
use cbc::cipher::block_padding::NoPadding; use cbc::cipher::block_padding::NoPadding;
use cbc::cipher::{BlockDecryptMut, BlockEncryptMut, KeyInit, KeyIvInit}; use cbc::cipher::{BlockDecryptMut, BlockEncryptMut, KeyInit, KeyIvInit};
use rand::RngCore; use rand::RngCore;
use super::interface::DataEncryptorInput;
#[cfg(not(target_family = "wasm"))]
use std::thread; use std::thread;
#[cfg(not(target_family = "wasm"))]
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
#[cfg(target_family = "wasm")] use super::interface::DataEncryptorInput;
use futures::{SinkExt, StreamExt};
#[cfg(target_family = "wasm")]
use gloo::worker::reactor::{reactor, ReactorBridge, ReactorScope};
type DesEcbEnc = ecb::Decryptor<des::Des>; type DesEcbEnc = ecb::Decryptor<des::Des>;
type DesCbcEnc = cbc::Encryptor<des::Des>; type DesCbcEnc = cbc::Encryptor<des::Des>;
#[cfg(not(target_family = "wasm"))] pub struct Encryptor {
pub fn new_thread_encryptor( channel: Option<UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)>>,
input: DataEncryptorInput, state: Option<EncryptorState>,
) -> UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)> { }
struct EncryptorState {
input_data: Vec<u8>,
iv: [u8; 8],
random_key: [u8; 8],
encrypted_random_key: [u8; 8],
default_chunk_size: usize,
current_chunk_size: usize,
offset: usize,
packet_count: usize,
closed: bool,
}
impl Encryptor {
pub fn new_threaded(input: DataEncryptorInput) -> Self {
let (tx, rx) = unbounded_channel::<(Vec<u8>, Vec<u8>, Vec<u8>)>(); let (tx, rx) = unbounded_channel::<(Vec<u8>, Vec<u8>, Vec<u8>)>();
let _ = thread::spawn(move || { thread::spawn(move || {
let mut iv = [0u8; 8]; let mut iv = [0u8; 8];
// Create the random key // Create the random key
@ -81,31 +89,14 @@ pub fn new_thread_encryptor(
} }
}); });
rx Self {
channel: Some(rx),
state: None
}
} }
pub fn new(input: DataEncryptorInput) -> Self {
#[cfg(target_family = "wasm")] let iv = [0u8; 8];
pub fn web_worker_encryptor(
input: DataEncryptorInput,
) -> ReactorBridge<WebThread> {
use gloo::worker::Spawnable;
let bridge = WebThread::spawner().spawn("...");
bridge.send_input(input);
bridge
}
#[cfg(target_family = "wasm")]
#[reactor]
pub async fn WebThread(
mut scope: ReactorScope<DataEncryptorInput, (Vec<u8>, Vec<u8>, Vec<u8>)>
) {
// Get the initial input data
let input = scope.next().await.unwrap();
let mut iv = [0u8; 8];
// Create the random key // Create the random key
let mut random_key = [0u8; 8]; let mut random_key = [0u8; 8];
@ -124,40 +115,82 @@ pub async fn WebThread(
e => e, e => e,
}; };
let mut packet_count = 0u32; let packet_count = 0;
let mut current_chunk_size; let current_chunk_size = 0;
let mut input_data = input.data.clone(); let mut input_data = input.data.clone();
if (input_data.len() % input.frame_size) != 0 { if (input_data.len() % input.frame_size) != 0 {
let padding_remaining = input.frame_size - (input_data.len() % input.frame_size); let padding_remaining = input.frame_size - (input_data.len() % input.frame_size);
input_data.extend(std::iter::repeat(0).take(padding_remaining)); input_data.extend(std::iter::repeat(0).take(padding_remaining));
} }
let input_data_length = input_data.len();
let mut offset: usize = 0; let offset: usize = 0;
while offset < input_data_length {
if packet_count > 0 { Encryptor {
current_chunk_size = default_chunk_size; channel: None,
} else { state: Some(EncryptorState {
current_chunk_size = default_chunk_size - 24; input_data,
iv,
random_key,
encrypted_random_key,
current_chunk_size,
offset,
default_chunk_size,
packet_count,
closed: false,
})
}
} }
current_chunk_size = std::cmp::min(current_chunk_size, input_data_length - offset); /// Get the next encrypted value
pub async fn next(&mut self) -> Option<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let output;
let this_data_chunk = &mut input_data[offset..offset + current_chunk_size]; if let Some(state) = self.state.as_mut() {
DesCbcEnc::new(&random_key.into(), &iv.into()) if state.closed {
.encrypt_padded_mut::<NoPadding>(this_data_chunk, current_chunk_size) return None
}
if state.packet_count > 0 {
state.current_chunk_size = state.default_chunk_size;
} else {
state.current_chunk_size = state.default_chunk_size - 24;
}
state.current_chunk_size = std::cmp::min(state.current_chunk_size, state.input_data.len() - state.offset);
let this_data_chunk = &mut state.input_data[state.offset..state.offset + state.current_chunk_size];
DesCbcEnc::new(&state.random_key.into(), &state.iv.into())
.encrypt_padded_mut::<NoPadding>(this_data_chunk, state.current_chunk_size)
.unwrap(); .unwrap();
scope.send(( output = Some((
encrypted_random_key.to_vec(), state.encrypted_random_key.to_vec(),
iv.to_vec(), state.iv.to_vec(),
this_data_chunk.to_vec(), this_data_chunk.to_vec(),
)).await.unwrap(); ));
iv.copy_from_slice(&this_data_chunk[this_data_chunk.len() - 8..]); state.iv.copy_from_slice(&this_data_chunk[this_data_chunk.len() - 8..]);
packet_count += 1; state.packet_count += 1;
offset += current_chunk_size; state.offset += state.current_chunk_size;
} else if let Some(channel) = self.channel.as_mut() {
output = channel.recv().await
} else {
unreachable!("If you got here, this is bad!");
}
output
}
/// Call close to return none from subsequent calls
pub fn close(&mut self) {
if let Some(state) = self.state.as_mut() {
state.closed = true;
} else if let Some(channel) = self.channel.as_mut() {
channel.close()
} else {
unreachable!("If you got here, this is bad!");
}
} }
} }

View file

@ -17,6 +17,7 @@ use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use super::base::NetMD; use super::base::NetMD;
use super::encryption::Encryptor;
use super::utils::{cross_sleep, to_sjis}; use super::utils::{cross_sleep, to_sjis};
/// An action to take on the player /// An action to take on the player
@ -1689,11 +1690,7 @@ impl NetMDInterface {
discformat: u8, discformat: u8,
frames: u32, frames: u32,
pkt_size: u32, pkt_size: u32,
// key, iv, data mut packets: Encryptor,
#[cfg(not(target_family = "wasm"))]
mut packets: tokio::sync::mpsc::UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)>,
#[cfg(target_family = "wasm")]
mut packets: gloo::worker::reactor::ReactorBridge<super::encryption::WebThread>,
hex_session_key: &[u8], hex_session_key: &[u8],
progress_callback: F, progress_callback: F,
) -> Result<(u16, Vec<u8>, Vec<u8>), InterfaceError> ) -> Result<(u16, Vec<u8>, Vec<u8>), InterfaceError>
@ -1732,15 +1729,7 @@ impl NetMDInterface {
let mut written_bytes = 0; let mut written_bytes = 0;
let mut packet_count = 0; let mut packet_count = 0;
while let Some((key, iv, data)) = { while let Some((key, iv, data)) = packets.next().await {
{
#[cfg(not(target_family = "wasm"))]
packets.recv().await
}
#[cfg(target_family = "wasm")]
futures::StreamExt::next(&mut packets).await
} {
let binpack = if packet_count == 0 { let binpack = if packet_count == 0 {
let packed_length: Vec<u8> = pkt_size.to_be_bytes().to_vec(); let packed_length: Vec<u8> = pkt_size.to_be_bytes().to_vec();
[vec![0, 0, 0, 0], packed_length, key, iv, data].concat() [vec![0, 0, 0, 0], packed_length, key, iv, data].concat()
@ -1752,6 +1741,7 @@ impl NetMDInterface {
packet_count += 1; packet_count += 1;
(progress_callback)(total_bytes, written_bytes); (progress_callback)(total_bytes, written_bytes);
if total_bytes == written_bytes { if total_bytes == written_bytes {
packets.close();
break; break;
} }
} }
@ -1874,7 +1864,6 @@ pub struct MDTrack {
pub full_width_title: Option<String>, pub full_width_title: Option<String>,
} }
#[cfg_attr(target_family = "wasm", derive(serde::Serialize, serde::Deserialize))]
pub struct DataEncryptorInput { pub struct DataEncryptorInput {
pub kek: [u8; 8], pub kek: [u8; 8],
pub frame_size: usize, pub frame_size: usize,
@ -1928,8 +1917,8 @@ impl MDTrack {
} }
#[cfg(not(target_family = "wasm"))] #[cfg(not(target_family = "wasm"))]
pub fn get_encrypting_iterator(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<(Vec<u8>, Vec<u8>, Vec<u8>)> { pub fn get_encrypting_iterator(&mut self) -> Encryptor {
super::encryption::new_thread_encryptor(DataEncryptorInput { Encryptor::new_threaded(DataEncryptorInput {
kek: self.get_kek(), kek: self.get_kek(),
frame_size: self.frame_size(), frame_size: self.frame_size(),
chunk_size: self.chunk_size(), chunk_size: self.chunk_size(),
@ -1938,8 +1927,8 @@ impl MDTrack {
} }
#[cfg(target_family = "wasm")] #[cfg(target_family = "wasm")]
pub fn get_encrypting_iterator(&mut self) -> gloo::worker::reactor::ReactorBridge<super::encryption::WebThread> { pub fn get_encrypting_iterator(&mut self) -> Encryptor {
super::encryption::web_worker_encryptor(DataEncryptorInput { Encryptor::new(DataEncryptorInput {
kek: self.get_kek(), kek: self.get_kek(),
frame_size: self.frame_size(), frame_size: self.frame_size(),
chunk_size: self.chunk_size(), chunk_size: self.chunk_size(),