fn main() {
    let args = std::env::args_os().collect::<Vec<_>>();

    if args.len() != 4 {
        eprintln!("usage: crc-flip infile crc32 md5");
        return;
    }

    let infile = std::fs::read(&args[1]).unwrap();
    assert!(!infile.is_empty());
    let target_crc32 = u32::from_str_radix(args[2].to_str().unwrap(), 16).unwrap();
    let target_md5: md5::Digest = args
        .get(3)
        .map(|s| {
            s.to_str()
                .and_then(|s| {
                    if s.len() != 32 || !s.is_ascii() {
                        return None;
                    }

                    let mut bs = [0u8; 16];
                    for i in (0..32).step_by(2) {
                        bs[i / 2] = u8::from_str_radix(&s[i..=i + 1], 16).ok()?;
                    }

                    Some(md5::Digest(bs))
                })
                .unwrap()
        })
        .unwrap();
    let actual_crc32 = crc32fast::hash(&infile);

    use rayon::prelude::*;

    const CHUNK_SIZE: usize = 8 * 1024 * 1024;

    let mut candidates: Vec<(usize, u8)> = (0..infile.len())
        .into_par_iter()
        .step_by(CHUNK_SIZE)
        .flat_map_iter(|chunk_start| {
            let infile = &infile;
            let chunk_end =
                std::cmp::min(chunk_start.checked_add(CHUNK_SIZE).unwrap(), infile.len());
            let (advance_one_byte, block_zero, mut rolling_bit_masks) =
                init_roll(infile.len(), chunk_end - 1);

            (chunk_start..chunk_end).rev().filter_map(move |i| {
                let solved_byte =
                    solve_byte(actual_crc32 ^ target_crc32, block_zero, &rolling_bit_masks)
                        .map(|byte_error| (i, *&infile[i] ^ byte_error));
                roll(&mut rolling_bit_masks, advance_one_byte);
                solved_byte
            })
        })
        .collect();

    candidates.sort();

    if !candidates.is_empty() {
        println!("{} candidate crc, checking md5", candidates.len());
        let mut md5_context = md5::Context::new();
        let mut prev_offset = 0;

        let mut to_finalize = vec![];
        for (offset, byte) in candidates {
            md5_context.consume(&infile[prev_offset..offset]);
            let mut local_context = md5_context.clone();
            local_context.consume(&[byte]);

            to_finalize.push((offset, byte, local_context));
            prev_offset = offset;
        }

        for (offset, byte) in to_finalize
            .into_par_iter()
            .filter_map(|(offset, byte, mut local_context)| {
                local_context.consume(&infile[offset + 1..]);
                if local_context.finalize() == target_md5 {
                    Some((offset, byte))
                } else {
                    None
                }
            })
            .collect::<Vec<_>>()
        {
            println!(
                "0x{offset:x}: 0x{orig_byte:02x} => 0x{byte:02x}",
                orig_byte = infile[offset]
            );
        }
    }
}

const INITIAL_CRC: u32 = !0;
const _FINAL_CRC_XOR: u32 = !0;
const IEEE_802_3_POLY: u32 = 0xEDB88320; // lsb-first

const fn update_one_bit(mut crc: u32, b: bool) -> u32 {
    if b {
        crc ^= 1;
    }

    let must_subtract = crc & 1 != 0;
    crc >>= 1;
    if must_subtract {
        crc ^ IEEE_802_3_POLY
    } else {
        crc
    }
}

const fn mult_mod(v0: u32, v1: u32) -> u32 {
    // Note: bit 0 of p is unused
    let mut p: u64 = 0;
    let mut i = 0;
    while i <= 31 {
        if v0 & (1u32 << i) != 0 {
            p ^= (v1 as u64) << (i + 1);
        }
        i += 1;
    }

    let mut i = 1;
    while i <= 31 {
        if p & (1u64 << i) != 0 {
            p ^= (IEEE_802_3_POLY as u64) << (i + 1);
        }
        i += 1;
    }
    (p >> 32) as u32
}

const fn compute_byte_powers() -> [u32; 64] {
    let mut powers = [0; 64];

    // Start with x^8, one followed by 1 byte of zeroes
    powers[0] = 0x00_80_00_00;
    let mut i = 1;
    while i < 64 {
        powers[i] = mult_mod(powers[i - 1], powers[i - 1]);
        i += 1;
    }

    powers
}

// `BYTE_POWERS[k] = (x ^ (8 * 2 ^ k)) MOD IEEE_802_3_POLY`
const BYTE_POWERS: [u32; 64] = compute_byte_powers();

const fn add_zeroes(mut crc: u32, mut block_size: u64) -> u32 {
    let mut power_i = 0;
    while block_size != 0 {
        if block_size & 1 != 0 {
            crc = mult_mod(crc, BYTE_POWERS[power_i]);
        }
        block_size >>= 1;
        power_i += 1;
    }

    crc
}

const fn compute_byte_table() -> [u32; 256] {
    let mut i = 0;
    let mut table = [0; 256];

    while i < 256 {
        let mut crc = 0;

        let mut j = 0;
        while j < 8 {
            crc = update_one_bit(crc, (i >> j) & 1 != 0);
            j += 1;
        }

        table[i] = crc;
        i += 1;
    }

    table
}

static BYTE_TABLE: [u32; 256] = compute_byte_table();

const fn compute_gray_table() -> [u8; 256] {
    let mut table = [0; 256];

    let mut j = 0;
    while j < 8 {
        // fill upper half of the table mid..end
        let mid = 1 << j;
        let end = 1 << (j + 1);

        let mut i = mid;
        while i < end {
            // flipped lower half, with high bit set
            table[i] = table[end - i - 1] | (1u8 << j);

            i += 1;
        }

        j += 1;
    }

    table
}

static GRAY_BYTES: [u8; 256] = compute_gray_table();

const fn compute_gray_flipped_bit() -> [usize; 256] {
    let mut table = [0; 256];

    const fn one_pos(b: u8) -> usize {
        7 - b.trailing_zeros() as usize
    }

    let mut i = 0;
    while i < 255 {
        let diff = GRAY_BYTES[i + 1] ^ GRAY_BYTES[i];
        table[i] = one_pos(diff);

        i += 1;
    }

    table[255] = one_pos(GRAY_BYTES[0] ^ GRAY_BYTES[255]);

    table
}

static GRAY_FLIPPED_BIT: [usize; 256] = compute_gray_flipped_bit();

fn init_roll(buf_len: usize, pos: usize) -> (u32, u32, [u32; 8]) {
    let bytes_before = pos;
    let bytes_after = buf_len.checked_sub(1).unwrap().checked_sub(pos).unwrap();
    let buf_len_u64 = u64::try_from(buf_len).unwrap();
    let bytes_before_u64 = u64::try_from(bytes_before).unwrap();
    let bytes_after_u64 = u64::try_from(bytes_after).unwrap();

    let before_block_zero = add_zeroes(INITIAL_CRC, bytes_before_u64);

    let block_zero = add_zeroes(INITIAL_CRC, buf_len_u64);
    let block_zero_plus_one_bit = update_one_bit(block_zero, false);
    let advance_one_bit = block_zero ^ block_zero_plus_one_bit;

    let mut rolling_bit_masks = [0; 8];
    // roll in a single bit
    rolling_bit_masks[0] = add_zeroes(
        BYTE_TABLE[((before_block_zero as u8) ^ 0x80) as usize] ^ (before_block_zero >> 8),
        bytes_after_u64,
    );
    for i in 1..8 {
        rolling_bit_masks[i] = update_one_bit(rolling_bit_masks[i - 1], false) ^ advance_one_bit;
    }

    let advance_one_byte = add_zeroes(block_zero, 1) ^ block_zero;

    (advance_one_byte, block_zero, rolling_bit_masks)
}

const fn solve_byte(target_crc: u32, block_zero: u32, rolling_bit_masks: &[u32; 8]) -> Option<u8> {
    let mut v_crc_00 = 0;
    let mut v_crc_10 = v_crc_00 ^ rolling_bit_masks[0] ^ block_zero;
    let mut v_crc_11 = v_crc_10 ^ rolling_bit_masks[1] ^ block_zero;
    let mut v_crc_01 = v_crc_00 ^ rolling_bit_masks[1] ^ block_zero;
    let mut j = 0;

    while v_crc_00 != target_crc
        && v_crc_10 != target_crc
        && v_crc_11 != target_crc
        && v_crc_01 != target_crc
        && j < 63
    {
        let bit_mask = rolling_bit_masks[GRAY_FLIPPED_BIT[j]] ^ block_zero;
        v_crc_00 ^= bit_mask;
        v_crc_01 ^= bit_mask;
        v_crc_11 ^= bit_mask;
        v_crc_10 ^= bit_mask;
        j += 1;
    }

    // TODO I think it isn't possible for there to be single byte collisions, so only
    // should ever match one case, verify this
    if v_crc_00 == target_crc {
        Some(GRAY_BYTES[j])
    } else if v_crc_10 == target_crc {
        Some(GRAY_BYTES[255 - j])
    } else if v_crc_11 == target_crc {
        Some(GRAY_BYTES[255 - (127 - j)])
    } else if v_crc_01 == target_crc {
        Some(GRAY_BYTES[127 - j])
    } else {
        None
    }
}

fn roll(rolling_bit_masks: &mut [u32; 8], advance_one_byte: u32) {
    for j in 0..8 {
        let bit_mask = rolling_bit_masks[j];
        rolling_bit_masks[j] =
            BYTE_TABLE[usize::from(bit_mask as u8)] ^ (bit_mask >> 8) ^ advance_one_byte;
    }
}

#[cfg(test)]
mod test {
    use rand::prelude::IteratorRandom;
    use rand_xoshiro::Xoshiro256StarStar;
    use rand_xoshiro::rand_core::RngCore;
    use rand_xoshiro::rand_core::SeedableRng;

    #[test]
    fn test_small_thorough() {
        use super::init_roll;
        use super::roll;
        use super::solve_byte;
        let mut rand = Xoshiro256StarStar::seed_from_u64(1);

        for buf_len in 1..=20 {
            let mut orig_buf = vec![0; buf_len];

            rand.fill_bytes(orig_buf.as_mut_slice());
            let crc_orig = crc32fast::hash(&orig_buf);

            for err_i in 0..buf_len {
                for err in 1..=u8::MAX {
                    let mut buf = orig_buf.clone();
                    buf[err_i] ^= err;
                    let crc_err = crc32fast::hash(&buf);

                    let (advance_one_byte, block_zero, mut rolling_bit_masks) =
                        init_roll(buf_len, buf_len - 1);

                    for i in (0..buf.len()).rev() {
                        let mut solved_buf = buf.clone();

                        if let Some(missing_byte) =
                            solve_byte(crc_orig ^ crc_err, block_zero, &rolling_bit_masks)
                                .map(|e| buf[i] ^ e)
                        {
                            solved_buf[i] = missing_byte;
                            if i == err_i {
                                assert_eq!(missing_byte, orig_buf[err_i]);
                            }
                            assert_eq!(crc32fast::hash(&solved_buf), crc_orig);
                        } else {
                            assert_ne!(i, err_i);
                            for non_solution_byte in 0..u8::MAX {
                                solved_buf[i] = non_solution_byte;
                                assert_ne!(crc32fast::hash(&solved_buf), crc_orig);
                            }
                        }

                        roll(&mut rolling_bit_masks, advance_one_byte);
                    }
                }
            }
        }
    }

    #[test]
    fn test_solve_byte() {
        use super::init_roll;
        use super::roll;
        use super::solve_byte;
        let mut rand = Xoshiro256StarStar::seed_from_u64(3);

        //for buf_len in (0x40 * 1024 * 1024)..=(0x40 * 1024 * 1024) {
        //for buf_len in 500_000_000..=500_000_000 {
        //for buf_len in 100_000_000..=100_000_000 {
        for buf_len in 3..=6000 {
            //{
            //let buf_len = 0x40 * 1024 * 1024;
            //let mut buf = vec![0; 0x20 * 1024 * 1024];
            //let mut buf = vec![0; 0x1 * 1024 * 1024];
            let mut buf = vec![0; buf_len];
            rand.fill_bytes(buf.as_mut_slice());

            let crc_all = crc32fast::hash(&buf);
            let err_i = (0..buf.len()).choose(&mut rand).unwrap();
            let err_b = buf[err_i] ^ (1..=u8::MAX).choose(&mut rand).unwrap();
            let crc_err = crc32fast::hash(
                (0..buf.len())
                    .map(|i| if i == err_i { err_b } else { buf[i] })
                    .collect::<Vec<u8>>()
                    .as_slice(),
            );
            let buf = std::sync::Arc::new(buf);

            let mut ts = vec![];
            const THREADS: usize = 3;
            for ti in 0..THREADS {
                let buf = buf.clone();
                let start_point = buf_len * ti / THREADS;
                let end_point = buf_len * (ti + 1) / THREADS;
                ts.push(std::thread::spawn(move || {
                    let (advance_one_byte, block_zero, mut rolling_bit_masks) =
                        init_roll(buf_len, end_point - 1);

                    for i in (start_point..end_point).rev() {
                        let missing_byte =
                            solve_byte(crc_all ^ crc_err, block_zero, &rolling_bit_masks)
                                .map(|e| buf[i] ^ e);

                        if let Some(missing_byte) = missing_byte {
                            let crc_filled = crc32fast::hash(
                                (0..buf.len())
                                    .map(|ii| if ii == i { missing_byte } else { buf[ii] })
                                    .collect::<Vec<u8>>()
                                    .as_slice(),
                            );

                            assert_eq!(crc_filled, crc_err, "{i}");
                            if i == err_i {
                                assert_eq!(missing_byte, err_b);
                            }
                            /*
                            eprintln!(
                                "check {flag} {i} {} => {}",
                                buf[i],
                                missing_byte,
                                flag = if i == err_i { "*" } else { " " }
                            );
                            */
                        } else {
                            assert_ne!(i, err_i);
                        }

                        roll(&mut rolling_bit_masks, advance_one_byte);
                    }
                }));
            }

            for ts in ts {
                ts.join().unwrap();
            }
        }
    }

    #[test]
    fn test_gray_flips() {
        let mut v = 0;
        for (i, flip) in super::GRAY_FLIPPED_BIT.iter().enumerate() {
            assert_eq!(super::GRAY_BYTES[i], v, "{i}");
            v ^= 0x80 >> flip;
        }
        assert_eq!(super::GRAY_BYTES[0], v);
    }
}
