use hashbrown::HashMap;

fn main() {
    let args = std::env::args_os().collect::<Vec<_>>();

    if args.len() != 4 {
        eprintln!("usage: crc-flip-two-bytes infile crc32 md5");
        return;
    }

    let infile = std::fs::read(&args[1]).unwrap();
    assert!(!infile.is_empty());
    assert_eq!(infile.len() % 2, 0);
    let target_crc32 = u32::from_str_radix(args[2].to_str().unwrap(), 16).unwrap();
    let target_md5: md5::Digest = args
        .get(3)
        .map(|s| {
            s.to_str()
                .and_then(|s| {
                    if s.len() != 32 || !s.is_ascii() {
                        return None;
                    }

                    let mut bs = [0u8; 16];
                    for i in (0..32).step_by(2) {
                        bs[i / 2] = u8::from_str_radix(&s[i..=i + 1], 16).ok()?;
                    }

                    Some(md5::Digest(bs))
                })
                .unwrap()
        })
        .unwrap();
    let actual_crc32 = crc32fast::hash(&infile);
    let syndrome = actual_crc32 ^ target_crc32;

    let infile_len = infile.len();

    use rayon::prelude::*;

    const CHUNK_SIZE: usize = 8 * 1024 * 1024;

    let bytes_at_start_hash = init_u16_at_start(infile_len);

    let candidate_err: Vec<(usize, u16)> = (0..infile_len - 1)
        .into_par_iter()
        .step_by(CHUNK_SIZE)
        .flat_map_iter(|chunk_start| {
            let chunk_end = std::cmp::min(chunk_start.checked_add(CHUNK_SIZE).unwrap(), infile_len);
            let mut crc = advance_zeroes(syndrome, u64::try_from(chunk_start).unwrap());
            let bytes_at_start_hash = &bytes_at_start_hash;

            (chunk_start..chunk_end).filter_map(move |offset| {
                let result = if offset % 2 == 0 {
                    bytes_at_start_hash.get(&crc).copied()
                } else {
                    None
                };
                crc = update_one_byte(crc, 0);
                result.map(|byte| (offset, byte))
            })
        })
        .collect();

    assert!(candidate_err.is_sorted());

    if !candidate_err.is_empty() {
        println!("checking md5 for {} crc matches", candidate_err.len());
        let mut md5_context = md5::Context::new();
        let mut prev_offset = 0;

        let mut to_finalize = vec![];
        for (offset, u16_err) in candidate_err {
            let byte0 = (u16_err as u8) ^ infile[offset];
            let byte1 = ((u16_err >> 8) as u8) ^ infile[offset + 1];
            md5_context.consume(&infile[prev_offset..offset]);
            let mut local_context = md5_context.clone();
            local_context.consume([byte0, byte1]);

            to_finalize.push((offset, [byte0, byte1], local_context));
            prev_offset = offset;
        }

        for (offset, bytes) in to_finalize
            .into_par_iter()
            .filter_map(|(offset, bytes, mut local_context)| {
                local_context.consume(&infile[offset + 2..]);
                if local_context.finalize() == target_md5 {
                    Some((offset, bytes))
                } else {
                    None
                }
            })
            .collect::<Vec<_>>()
        {
            println!(
                "0x{offset:x}: 0x{:02x} 0x{:02x} => 0x{:02x} 0x{:02x}",
                infile[offset],
                infile[offset + 1],
                bytes[0],
                bytes[1],
            );
        }
    }
}

const _INITIAL_CRC: u32 = !0;
const _FINAL_CRC_XOR: u32 = !0;
const IEEE_802_3_POLY: u32 = 0xEDB88320; // lsb-first

const fn update_one_bit(mut crc: u32, b: bool) -> u32 {
    if b {
        crc ^= 1;
    }

    let must_subtract = crc & 1 != 0;
    crc >>= 1;
    if must_subtract {
        crc ^ IEEE_802_3_POLY
    } else {
        crc
    }
}

const fn mult_mod(v0: u32, v1: u32) -> u32 {
    // Note: bit 0 of p is unused
    let mut p: u64 = 0;
    let mut i = 0;
    while i <= 31 {
        if v0 & (1u32 << i) != 0 {
            p ^= (v1 as u64) << (i + 1);
        }
        i += 1;
    }

    let mut i = 1;
    while i <= 31 {
        if p & (1u64 << i) != 0 {
            p ^= (IEEE_802_3_POLY as u64) << (i + 1);
        }
        i += 1;
    }
    (p >> 32) as u32
}

const fn compute_byte_powers() -> [u32; 64] {
    let mut powers = [0; 64];

    // Start with x^8, one followed by 1 byte of zeroes
    powers[0] = 0x00_80_00_00;
    let mut i = 1;
    while i < 64 {
        powers[i] = mult_mod(powers[i - 1], powers[i - 1]);
        i += 1;
    }

    powers
}

// `BYTE_POWERS[k] = (x ^ (8 * 2 ^ k)) MOD IEEE_802_3_POLY`
const BYTE_POWERS: [u32; 64] = compute_byte_powers();

const fn advance_zeroes(mut crc: u32, mut block_size: u64) -> u32 {
    let mut power_i = 0;
    while block_size != 0 {
        if block_size & 1 != 0 {
            crc = mult_mod(crc, BYTE_POWERS[power_i]);
        }
        block_size >>= 1;
        power_i += 1;
    }

    crc
}

const fn compute_byte_table() -> [u32; 256] {
    let mut i = 0;
    let mut table = [0; 256];

    while i < 256 {
        let mut crc = 0;

        let mut j = 0;
        while j < 8 {
            crc = update_one_bit(crc, (i >> j) & 1 != 0);
            j += 1;
        }

        table[i] = crc;
        i += 1;
    }

    table
}

static BYTE_TABLE: [u32; 256] = compute_byte_table();

const fn update_one_byte(crc: u32, b: u8) -> u32 {
    BYTE_TABLE[(crc as u8 ^ b) as usize] ^ crc >> 8
}

fn init_roll(buf_len: usize, pos: usize) -> [u32; 8] {
    let bytes_after = buf_len.checked_sub(1).unwrap().checked_sub(pos).unwrap();
    let bytes_after_u64 = u64::try_from(bytes_after).unwrap();

    let mut rolling_bit_masks = [0; 8];
    // roll in a single bit
    rolling_bit_masks[0] = advance_zeroes(update_one_bit(0, true), bytes_after_u64);
    for i in 1..8 {
        rolling_bit_masks[i] = update_one_bit(rolling_bit_masks[i - 1], false);
    }

    rolling_bit_masks
}

fn init_bytes_at_start(buf_len: usize) -> HashMap<u32, u8> {
    let bits_at_start: [u32; 8] = init_roll(buf_len, 0);
    let bytes_at_start_hash: HashMap<u32, u8> = (u8::MIN..=u8::MAX)
        .map(|b| {
            let mut v = 0;
            for (k, bit_crc) in bits_at_start.into_iter().enumerate() {
                if (0x80 >> k) & b != 0 {
                    v ^= bit_crc;
                }
            }
            (v, b)
        })
        .collect();
    assert_eq!(bytes_at_start_hash.len(), 256);

    bytes_at_start_hash
}

fn init_u16_at_start(buf_len: usize) -> HashMap<u32, u16> {
    let bits_at_start: [u32; 8] = init_roll(buf_len, 0);
    let bits_at_start2: [u32; 8] = init_roll(buf_len, 1);
    let hash: HashMap<u32, u16> = (u16::MIN..=u16::MAX)
        .map(|i| {
            let mut v = 0;
            for (k, bit_crc) in bits_at_start.into_iter().enumerate() {
                if (0x80 >> k) & i != 0 {
                    v ^= bit_crc;
                }
            }
            for (k, bit_crc) in bits_at_start2.into_iter().enumerate() {
                if (0x8000 >> k) & i != 0 {
                    v ^= bit_crc;
                }
            }
            (v, i)
        })
        .collect();
    assert_eq!(hash.len(), 0x10000);

    hash
}

#[cfg(test)]
mod test {
    use rand::prelude::IteratorRandom;
    use rand_xoshiro::Xoshiro256StarStar;
    use rand_xoshiro::rand_core::RngCore;
    use rand_xoshiro::rand_core::SeedableRng;

    use super::advance_zeroes;
    use super::init_bytes_at_start;
    use super::update_one_byte;

    #[test]
    fn test_small_thorough() {
        let mut rand = Xoshiro256StarStar::seed_from_u64(1);

        for buf_len in 1..=20 {
            let mut orig_buf = vec![0; buf_len];

            rand.fill_bytes(orig_buf.as_mut_slice());
            let crc_orig = crc32fast::hash(&orig_buf);

            let bytes_at_start_hash = init_bytes_at_start(buf_len);

            for err_i in 0..buf_len {
                for err in 1..=u8::MAX {
                    let mut buf = orig_buf.clone();
                    buf[err_i] ^= err;
                    let crc_err = crc32fast::hash(&buf);
                    let syndrome = crc_orig ^ crc_err;

                    let mut crc = syndrome;

                    for i in 0..buf.len() {
                        let mut solved_buf = buf.clone();

                        if let Some(&byte_diff) = bytes_at_start_hash.get(&crc) {
                            let missing_byte = byte_diff ^ buf[i];
                            solved_buf[i] = missing_byte;
                            if i == err_i {
                                assert_eq!(missing_byte, orig_buf[err_i]);
                            }
                            assert_eq!(crc32fast::hash(&solved_buf), crc_orig);
                        } else {
                            assert_ne!(i, err_i);
                            for non_solution_byte in 0..u8::MAX {
                                solved_buf[i] = non_solution_byte;
                                assert_ne!(crc32fast::hash(&solved_buf), crc_orig);
                            }
                        }

                        crc = update_one_byte(crc, 0);
                    }
                }
            }
        }
    }

    #[test]
    fn test_solve_byte() {
        let mut rand = Xoshiro256StarStar::seed_from_u64(3);

        //for buf_len in (0x40 * 1024 * 1024)..=(0x40 * 1024 * 1024) {
        //for buf_len in 500_000_000..=500_000_000 {
        //for buf_len in 100_000_000..=100_000_000 {
        for buf_len in 3..=6000 {
            let mut buf = vec![0; buf_len];
            rand.fill_bytes(buf.as_mut_slice());

            let bytes_at_start_hash = init_bytes_at_start(buf_len);

            let crc_all = crc32fast::hash(&buf);
            let err_i = (0..buf.len()).choose(&mut rand).unwrap();
            let err_b = buf[err_i] ^ (1..=u8::MAX).choose(&mut rand).unwrap();
            let crc_err = crc32fast::hash(
                (0..buf.len())
                    .map(|i| if i == err_i { err_b } else { buf[i] })
                    .collect::<Vec<u8>>()
                    .as_slice(),
            );
            let syndrome = crc_all ^ crc_err;

            let mut ts = vec![];
            const THREADS: usize = 3;
            for ti in 0..THREADS {
                let start_point = buf_len * ti / THREADS;
                let end_point = buf_len * (ti + 1) / THREADS;
                let buf = buf.clone();
                let bytes_at_start_hash = bytes_at_start_hash.clone();
                ts.push(std::thread::spawn(move || {
                    let mut crc = advance_zeroes(syndrome, u64::try_from(start_point).unwrap());

                    for i in start_point..end_point {
                        let byte_diff = bytes_at_start_hash.get(&crc);

                        if let Some(&byte_diff) = byte_diff {
                            let missing_byte = buf[i] ^ byte_diff;
                            let crc_filled = crc32fast::hash(
                                (0..buf.len())
                                    .map(|ii| if ii == i { missing_byte } else { buf[ii] })
                                    .collect::<Vec<u8>>()
                                    .as_slice(),
                            );

                            assert_eq!(crc_filled, crc_err, "{i}");
                            if i == err_i {
                                assert_eq!(missing_byte, err_b);
                            }
                            /*
                            eprintln!(
                                "check {flag} {i} {} => {}",
                                buf[i],
                                missing_byte,
                                flag = if i == err_i { "*" } else { " " }
                            );
                            */
                        } else {
                            assert_ne!(i, err_i);
                        }

                        crc = update_one_byte(crc, 0);
                    }
                }));
            }

            for ts in ts {
                ts.join().unwrap();
            }
        }
    }
}
