use std::arch::x86_64::*;

use crate::prefilter::x86_64::overlapping_load;

/// Converts a string to a u64 bitmask using AVX-512 intrinsics.
/// Each bit represents the existence of a character in the ASCII range [32, 90].
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512bw")]
pub unsafe fn string_to_bitmask_avx512(s: &[u8]) -> i64 {
    let mut mask: i64 = 0;

    let one = _mm512_set1_epi64(1);

    let len = s.len();
    for start in (0..len).step_by(16) {
        let chunk = unsafe { overlapping_load(s, start, len) };

        // Convert to uppercase: subtract 0x20 if in [a-z]
        let ge_a = unsafe { _mm_cmpge_epi8_mask(chunk, _mm_set1_epi8(b'a' as i8)) };
        let le_z = unsafe { _mm_cmple_epi8_mask(chunk, _mm_set1_epi8(b'z' as i8)) };
        let is_lower = ge_a & le_z;
        let chunk = unsafe { _mm_mask_sub_epi8(chunk, is_lower, chunk, _mm_set1_epi8(0x20)) };

        let chunk = _mm_subs_epi8(chunk, _mm_set1_epi8(33));

        // Zero-extend LO 8 bytes to 8 u64s
        let indices = _mm512_cvtepu8_epi64(chunk);
        // 1 << (char - 32) and reduction
        mask |= _mm512_reduce_or_epi64(_mm512_sllv_epi64(one, indices));

        // Zero-extend HI 8 bytes to 8 u64s
        let indices = _mm512_cvtepu8_epi64(_mm_slli_si128(chunk, 8));
        // 1 << (char - 32) and reduction
        mask |= _mm512_reduce_or_epi64(_mm512_sllv_epi64(one, indices));
    }

    mask
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_bitmask() {
        unsafe {
            if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
                // // Empty string
                // assert_eq!(string_to_bitmask_avx512(b""), 0);
                //
                // // Single char 'A' (65) -> bit 33
                // let m = string_to_bitmask_avx512(b"A");
                // assert_eq!(m, 1 << (65 - 32));
                //
                // // Lowercase converted to uppercase
                // let m_lower = string_to_bitmask_avx512(b"a");
                // let m_upper = string_to_bitmask_avx512(b"A");
                // assert_eq!(m_lower, m_upper);
                //
                // // Multiple chars
                // let m = string_to_bitmask_avx512(b"ABC");
                // assert_eq!(m, (1 << 33) | (1 << 34) | (1 << 35));

                // Longer than 8 bytes
                let m = string_to_bitmask_avx512(b"ABCDEFGHIJ");
                let expected = (1 << (b'A' - 33))
                    | (1 << (b'B' - 33))
                    | (1 << (b'C' - 33))
                    | (1 << (b'D' - 33))
                    | (1 << (b'E' - 33))
                    | (1 << (b'F' - 33))
                    | (1 << (b'G' - 33))
                    | (1 << (b'H' - 33))
                    | (1 << (b'I' - 33))
                    | (1 << (b'J' - 33));
                assert_eq!(m, expected);
            }
        }
    }
}
