-
João Magalhães authored
Also simplified error to string conversion.
João Magalhães authoredAlso simplified error to string conversion.
huffman.rs 6.31 KiB
use boytacean_common::error::Error;
use std::{
cmp::Ordering,
collections::BinaryHeap,
io::{Cursor, Read},
mem::size_of,
};
#[derive(Debug, Eq, PartialEq)]
struct Node {
frequency: u32,
character: Option<u8>,
left: Option<Box<Node>>,
right: Option<Box<Node>>,
}
impl Ord for Node {
fn cmp(&self, other: &Self) -> Ordering {
other.frequency.cmp(&self.frequency)
}
}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub fn encode_huffman(data: &[u8]) -> Result<Vec<u8>, Error> {
let frequency_map = build_frequency(data);
let tree = build_tree(&frequency_map)
.ok_or(Error::CustomError(String::from("Failed to build tree")))?;
let mut codes = vec![Vec::new(); 256];
build_codes(&tree, Vec::new(), &mut codes);
let encoded_tree = encode_tree(&tree);
let encoded_data = encode_data(data, &codes);
let tree_length = encoded_tree.len() as u32;
let data_length = data.len() as u64;
let mut result = Vec::new();
result.extend(tree_length.to_be_bytes());
result.extend(encoded_tree);
result.extend(data_length.to_be_bytes());
result.extend(encoded_data);
Ok(result)
}
pub fn decode_huffman(data: &[u8]) -> Result<Vec<u8>, Error> {
let mut reader = Cursor::new(data);
let mut buffer = [0x00; size_of::<u32>()];
reader.read_exact(&mut buffer)?;
let tree_length = u32::from_be_bytes(buffer);
let mut buffer = vec![0; tree_length as usize];
reader.read_exact(&mut buffer)?;
let tree = decode_tree(&mut buffer.as_slice());
let mut buffer = [0x00; size_of::<u64>()];
reader.read_exact(&mut buffer)?;
let data_length = u64::from_be_bytes(buffer);
let mut buffer =
vec![0; data.len() - size_of::<u32>() - tree_length as usize - size_of::<u64>()];
reader.read_exact(&mut buffer)?;
let result = decode_data(&buffer, &tree, data_length);
Ok(result)
}
fn build_frequency(data: &[u8]) -> [u32; 256] {
let mut frequency_map = [0_u32; 256];
for &byte in data {
frequency_map[byte as usize] += 1;
}
frequency_map
}
fn build_tree(frequency_map: &[u32; 256]) -> Option<Box<Node>> {
let mut heap: BinaryHeap<Box<Node>> = BinaryHeap::new();
for (byte, &frequency) in frequency_map.iter().enumerate() {
if frequency == 0 {
continue;
}
heap.push(Box::new(Node {
frequency,
character: Some(byte as u8),
left: None,
right: None,
}));
}
while heap.len() > 1 {
let left = heap.pop().unwrap();
let right = heap.pop().unwrap();
let merged = Box::new(Node {
frequency: left.frequency + right.frequency,
character: None,
left: Some(left),
right: Some(right),
});
heap.push(merged);
}
heap.pop()
}
fn build_codes(node: &Node, prefix: Vec<u8>, codes: &mut [Vec<u8>]) {
if let Some(character) = node.character {
codes[character as usize] = prefix;
} else {
if let Some(ref left) = node.left {
let mut left_prefix = prefix.clone();
left_prefix.push(0);
build_codes(left, left_prefix, codes);
}
if let Some(ref right) = node.right {
let mut right_prefix = prefix;
right_prefix.push(1);
build_codes(right, right_prefix, codes);
}
}
}
fn encode_data(data: &[u8], codes: &[Vec<u8>]) -> Vec<u8> {
let mut bit_buffer = Vec::new();
let mut current_byte = 0u8;
let mut bit_count = 0;
for &byte in data {
let code = &codes[byte as usize];
for &bit in code {
current_byte <<= 1;
if bit == 1 {
current_byte |= 1;
}
bit_count += 1;
if bit_count == 8 {
bit_buffer.push(current_byte);
current_byte = 0;
bit_count = 0;
}
}
}
if bit_count > 0 {
current_byte <<= 8 - bit_count;
bit_buffer.push(current_byte);
}
bit_buffer
}
fn decode_data(encoded: &[u8], root: &Node, data_length: u64) -> Vec<u8> {
let mut decoded = Vec::new();
let mut current_node = root;
let mut bit_index = 0;
for &byte in encoded {
if decoded.len() as u64 == data_length {
break;
}
for bit_offset in (0..8).rev() {
let bit = (byte >> bit_offset) & 1;
current_node = if bit == 0 {
current_node.left.as_deref().unwrap()
} else {
current_node.right.as_deref().unwrap()
};
if let Some(character) = current_node.character {
decoded.push(character);
current_node = root;
}
if decoded.len() as u64 == data_length {
break;
}
bit_index += 1;
if bit_index == encoded.len() * 8 {
break;
}
}
}
decoded
}
fn encode_tree(node: &Node) -> Vec<u8> {
let mut result = Vec::new();
if let Some(character) = node.character {
result.push(1);
result.push(character);
} else {
result.push(0);
if let Some(ref left) = node.left {
result.extend(encode_tree(left));
}
if let Some(ref right) = node.right {
result.extend(encode_tree(right));
}
}
result
}
fn decode_tree(data: &mut &[u8]) -> Box<Node> {
let mut node = Box::new(Node {
frequency: 0,
character: None,
left: None,
right: None,
});
if data[0] == 1 {
node.character = Some(data[1]);
*data = &data[2..];
} else {
*data = &data[1..];
node.left = Some(decode_tree(data));
node.right = Some(decode_tree(data));
}
node
}
#[cfg(test)]
mod tests {
use super::{decode_huffman, encode_huffman};
#[test]
fn test_huffman_encoding() {
let data = b"this is an example for huffman encoding, huffman encoding, huffman encoding";
let encoded = encode_huffman(data).unwrap();
let decoded = decode_huffman(&encoded).unwrap();
assert_eq!(data.to_vec(), decoded);
assert_eq!(encoded.len(), 109);
assert_eq!(decoded.len(), 75);
}
}