diff --git a/cz/src/binio.rs b/cz/src/binio.rs index 334d0e4..02acfd5 100644 --- a/cz/src/binio.rs +++ b/cz/src/binio.rs @@ -1,80 +1,70 @@ -pub struct BitIo { - data: Vec, +use std::io::{self, Read, Write}; + +use byteorder::{ReadBytesExt, WriteBytesExt}; + +/// A simple way to write individual bits to an input implementing [Write]. +pub struct BitWriter<'a, O: Write + WriteBytesExt> { + output: &'a mut O, + + current_byte: u8, + byte_offset: usize, bit_offset: usize, byte_size: usize, } -impl BitIo { - /// Create a new BitIO reader and writer over some data - pub fn new(data: Vec) -> Self { +impl<'a, O: Write + WriteBytesExt> BitWriter<'a, O> { + /// Create a new BitWriter wrapper around something which + /// implements [Write]. + pub fn new(output: &'a mut O) -> Self { Self { - data, + output, + + current_byte: 0, + byte_offset: 0, bit_offset: 0, + byte_size: 0, } } - /// Get the byte offset of the reader - pub fn byte_offset(&self) -> usize { - self.byte_offset - } - - /// Get the byte size of the reader + /// Get the number of whole bytes written to the stream. pub fn byte_size(&self) -> usize { self.byte_size } - /// Get the current bytes up to `byte_size` in the reader - pub fn bytes(&self) -> Vec { - self.data[..self.byte_size].to_vec() + /// Get the bit offset within the current byte. + pub fn bit_offset(&self) -> u8 { + self.bit_offset as u8 } - /// Read some bits from the buffer - pub fn read_bit(&mut self, bit_len: usize) -> u64 { - if bit_len > 8 * 8 { - panic!("Cannot read more than 64 bits") - } - - if bit_len % 8 == 0 && self.bit_offset == 0 { - return self.read(bit_len / 8); - } - - let mut result = 0; - for i in 0..bit_len { - let bit_value = ((self.data[self.byte_offset] as usize >> self.bit_offset) & 1) as u64; - self.bit_offset += 1; - - if self.bit_offset == 8 { - self.byte_offset += 1; - self.bit_offset = 0; - } - - result |= bit_value << i; - } - - result + /// Check if the stream is aligned to a byte. + pub fn aligned(&self) -> bool { + self.bit_offset() == 0 } - /// Read some bytes from the buffer - pub fn read(&mut self, byte_len: usize) -> u64 { - if byte_len > 8 { - panic!("Cannot read more than 8 bytes") - } + /// Align the writer to the nearest byte by padding with zero bits. + /// + /// Returns the number of zero bits + pub fn flush(&mut self) -> Result { + self.byte_offset += 1; - let mut padded_slice = [0u8; 8]; - padded_slice.copy_from_slice(&self.data[self.byte_offset..self.byte_offset + byte_len]); - self.byte_offset += byte_len; + // Write out the current byte unfinished + self.output.write_u8(self.current_byte).unwrap(); + self.current_byte = 0; + self.bit_offset = 0; - u64::from_le_bytes(padded_slice) + Ok(8 - self.bit_offset) } - /// Write some bits to the buffer + /// Write some bits to the output. pub fn write_bit(&mut self, data: u64, bit_len: usize) { - if bit_len > 8 * 8 { - panic!("Cannot write more than 64 bits"); + if bit_len > 64 { + panic!("Cannot write more than 64 bits at once."); + } else if bit_len == 0 { + panic!("Must write 1 or more bits.") } if bit_len % 8 == 0 && self.bit_offset == 0 { @@ -85,32 +75,115 @@ impl BitIo { for i in 0..bit_len { let bit_value = (data >> i) & 1; - self.data[self.byte_offset] &= !(1 << self.bit_offset); + self.current_byte &= !(1 << self.bit_offset); - self.data[self.byte_offset] |= (bit_value << self.bit_offset) as u8; + self.current_byte |= (bit_value << self.bit_offset) as u8; self.bit_offset += 1; - if self.bit_offset == 8 { + if self.bit_offset >= 8 { self.byte_offset += 1; self.bit_offset = 0; + + self.output.write_u8(self.current_byte).unwrap(); + self.current_byte = 0; } } self.byte_size = self.byte_offset + (self.bit_offset + 7) / 8; } + /// Write some bytes to the output. pub fn write(&mut self, data: u64, byte_len: usize) { if byte_len > 8 { - panic!("Cannot write more than 8 bytes") + panic!("Cannot write more than 8 bytes at once.") + } else if byte_len == 0 { + panic!("Must write 1 or more bytes.") } - let mut padded_slice = [0u8; 8]; - padded_slice.copy_from_slice(&data.to_le_bytes()); - - self.data[self.byte_offset..self.byte_offset + byte_len] - .copy_from_slice(&padded_slice[..byte_len]); + self.output + .write_all(&data.to_le_bytes()[..byte_len]) + .unwrap(); self.byte_offset += byte_len; self.byte_size = self.byte_offset + (self.bit_offset + 7) / 8; } } + +/// A simple way to read individual bits from an input implementing [Read]. +pub struct BitReader<'a, I: Read + ReadBytesExt> { + input: &'a mut I, + + current_byte: Option, + + byte_offset: usize, + bit_offset: usize, +} + +impl<'a, I: Read + ReadBytesExt> BitReader<'a, I> { + /// Create a new BitReader wrapper around something which + /// implements [Write]. + pub fn new(input: &'a mut I) -> Self { + let first = input.read_u8().unwrap(); + Self { + input, + + current_byte: Some(first), + + byte_offset: 0, + bit_offset: 0, + } + } + + /// Get the number of whole bytes read from the stream. + pub fn byte_offset(&self) -> usize { + self.byte_offset + } + + /// Read some bits from the input. + pub fn read_bit(&mut self, bit_len: usize) -> u64 { + if bit_len > 64 { + panic!("Cannot read more than 64 bits at once.") + } else if bit_len == 0 { + panic!("Must read 1 or more bits.") + } + + if bit_len % 8 == 0 && self.bit_offset == 0 { + return self.read(bit_len / 8); + } + + let mut result = 0; + for i in 0..bit_len { + let bit_value = ((self.current_byte.unwrap() as usize >> self.bit_offset) & 1) as u64; + self.bit_offset += 1; + + if self.bit_offset == 8 { + self.byte_offset += 1; + self.bit_offset = 0; + + self.current_byte = Some(self.input.read_u8().unwrap()); + } + + result |= bit_value << i; + } + + result + } + + /// Read some bytes from the input. + pub fn read(&mut self, byte_len: usize) -> u64 { + if byte_len > 8 { + panic!("Cannot read more than 8 bytes at once.") + } else if byte_len == 0 { + panic!("Must read 1 or more bytes") + } + + let mut padded_slice = vec![0u8; byte_len]; + self.input.read_exact(&mut padded_slice).unwrap(); + self.byte_offset += byte_len; + + let extra_length = padded_slice.len() - byte_len; + padded_slice.extend_from_slice(&vec![0u8; extra_length]); + + u64::from_le_bytes(padded_slice.try_into().unwrap()) + } +} diff --git a/cz/src/compression.rs b/cz/src/compression.rs index 79d0da7..2f0cae2 100644 --- a/cz/src/compression.rs +++ b/cz/src/compression.rs @@ -1,10 +1,10 @@ use byteorder::{ReadBytesExt, WriteBytesExt, LE}; use std::{ collections::HashMap, - io::{Read, Seek, Write}, + io::{Cursor, Read, Seek, Write}, }; -use crate::binio::BitIo; +use crate::binio::{BitReader, BitWriter}; use crate::common::CzError; /// The size of compressed data in each chunk @@ -163,7 +163,7 @@ pub fn decompress2( } fn decompress_lzw2(input_data: &[u8], size: usize) -> Vec { - let mut data = input_data.to_vec(); + let mut data = Cursor::new(input_data); let mut dictionary = HashMap::new(); for i in 0..256 { dictionary.insert(i as u64, vec![i as u8]); @@ -172,12 +172,15 @@ fn decompress_lzw2(input_data: &[u8], size: usize) -> Vec { let mut result = Vec::with_capacity(size); let data_size = input_data.len(); - data.extend_from_slice(&[0, 0]); - let mut bit_io = BitIo::new(data); + let mut bit_io = BitReader::new(&mut data); let mut w = dictionary.get(&0).unwrap().clone(); let mut element; loop { + if bit_io.byte_offset() >= data_size - 1 { + break; + } + let flag = bit_io.read_bit(1); if flag == 0 { element = bit_io.read_bit(15); @@ -185,10 +188,6 @@ fn decompress_lzw2(input_data: &[u8], size: usize) -> Vec { element = bit_io.read_bit(18); } - if bit_io.byte_offset() > data_size { - break; - } - let mut entry; if let Some(x) = dictionary.get(&element) { // If the element was already in the dict, get it @@ -197,7 +196,7 @@ fn decompress_lzw2(input_data: &[u8], size: usize) -> Vec { entry = w.clone(); entry.push(w[0]) } else { - panic!("Bad compressed element: {}", element) + panic!("Bad compressed element {} at offset {}", element, bit_io.byte_offset()) } //println!("{}", element); @@ -363,8 +362,9 @@ fn compress_lzw2(data: &[u8], last: Vec) -> (usize, Vec, Vec) { element = last } - let mut bit_io = BitIo::new(vec![0u8; 0xF0000]); - let write_bit = |bit_io: &mut BitIo, code: u64| { + let mut output_buf = Vec::new(); + let mut bit_io = BitWriter::new(&mut output_buf); + let write_bit = |bit_io: &mut BitWriter>, code: u64| { if code > 0x7FFF { bit_io.write_bit(1, 1); bit_io.write_bit(code, 18); @@ -402,13 +402,18 @@ fn compress_lzw2(data: &[u8], last: Vec) -> (usize, Vec, Vec) { write_bit(&mut bit_io, *dictionary.get(&vec![c]).unwrap()); } } - return (count, bit_io.bytes(), Vec::new()); + + bit_io.flush().unwrap(); + return (count, output_buf, Vec::new()); } else if bit_io.byte_size() < 0x87BDF { if !last_element.is_empty() { write_bit(&mut bit_io, *dictionary.get(&last_element).unwrap()); } - return (count, bit_io.bytes(), Vec::new()); + + bit_io.flush().unwrap(); + return (count, output_buf, Vec::new()); } - (count, bit_io.bytes(), last_element) + bit_io.flush().unwrap(); + (count, output_buf, last_element) } diff --git a/cz/tests/round_trip.rs b/cz/tests/round_trip.rs index 33a4638..2816a77 100644 --- a/cz/tests/round_trip.rs +++ b/cz/tests/round_trip.rs @@ -49,3 +49,66 @@ fn cz1_round_trip() { assert_eq!(original_cz.as_raw(), decoded_cz.as_raw()); } } + +#[test] +fn cz2_round_trip() { + let mut i = 0; + for image in TEST_IMAGES { + let original_cz = DynamicCz::from_raw( + CzVersion::CZ2, + image.0, + image.1, + image.2.to_vec() + ); + + let mut cz_bytes = Vec::new(); + original_cz.encode(&mut cz_bytes).unwrap(); + + let mut cz_bytes = Cursor::new(cz_bytes); + let decoded_cz = DynamicCz::decode(&mut cz_bytes).unwrap(); + + assert_eq!(original_cz.as_raw(), decoded_cz.as_raw()); + + i += 1; + } +} + +#[test] +fn cz3_round_trip() { + for image in TEST_IMAGES { + let original_cz = DynamicCz::from_raw( + CzVersion::CZ3, + image.0, + image.1, + image.2.to_vec() + ); + + let mut cz_bytes = Vec::new(); + original_cz.encode(&mut cz_bytes).unwrap(); + + let mut cz_bytes = Cursor::new(cz_bytes); + let decoded_cz = DynamicCz::decode(&mut cz_bytes).unwrap(); + + assert_eq!(original_cz.as_raw(), decoded_cz.as_raw()); + } +} + +#[test] +fn cz4_round_trip() { + for image in TEST_IMAGES { + let original_cz = DynamicCz::from_raw( + CzVersion::CZ4, + image.0, + image.1, + image.2.to_vec() + ); + + let mut cz_bytes = Vec::new(); + original_cz.encode(&mut cz_bytes).unwrap(); + + let mut cz_bytes = Cursor::new(cz_bytes); + let decoded_cz = DynamicCz::decode(&mut cz_bytes).unwrap(); + + assert_eq!(original_cz.as_raw(), decoded_cz.as_raw()); + } +}