use std::collections::HashSet;
use std::fs;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let input = fs::read_to_string("input")?;

    // Part 1
    let groups = parse_groups(&input)?;
    println!("{}", count_answers(&groups));

    // Part 2
    let groups = parse_group_individuals(&input)?;
    println!("{}", count_individual_answers(&groups));

    Ok(())
}

fn parse_groups(input: &str) -> Result<Vec<HashSet<char>>, String> {
    input
        .split("\n\n")
        .map(|group| {
            group
                .chars()
                .filter(|c| *c != '\n')
                .map(|c| {
                    if c.is_alphabetic() {
                        Ok(c)
                    } else {
                        Err(format!("Invalid answer: {}", c))
                    }
                })
                .collect()
        })
        .collect()
}

fn count_answers(groups: &Vec<HashSet<char>>) -> usize {
    groups.iter().map(|group| group.iter().count()).sum()
}

fn parse_group_individuals(input: &str) -> Result<Vec<Vec<HashSet<char>>>, String> {
    input
        .split("\n\n")
        .map(|group| {
            group
                .lines()
                .map(|individual| {
                    individual
                        .chars()
                        .map(|c| {
                            if c.is_alphabetic() {
                                Ok(c)
                            } else {
                                Err(format!("Invalid answer: {}", c))
                            }
                        })
                        .collect()
                })
                .collect()
        })
        .collect()
}

fn count_individual_answers(groups: &Vec<Vec<HashSet<char>>>) -> usize {
    groups
        .iter()
        .map(|group| {
            let mut iter = group.into_iter().cloned();
            let first = iter.next().expect("Must have at least one element");

            iter.fold(first, |cumulative, entry| {
                cumulative.intersection(&entry).copied().collect()
            })
            .len()
        })
        .sum()
}

#[cfg(test)]
mod tests {
    use super::*;
    use indoc::indoc;

    #[test]
    fn test_simple() -> Result<(), Box<dyn std::error::Error>> {
        let input = indoc!(
            "abc

             a
             b
             c

             ab
             ac

             a
             a
             a
             a

             b
            "
        );

        let groups = parse_groups(input)?;
        let counts = count_answers(&groups);
        assert_eq!(counts, 11);
        Ok(())
    }

    #[test]
    fn test_simple2() -> Result<(), Box<dyn std::error::Error>> {
        let input = indoc!(
            "abc

             a
             b
             c

             ab
             ac

             a
             a
             a
             a

             b
            "
        );

        let groups = parse_group_individuals(input)?;
        let counts = count_individual_answers(&groups);
        assert_eq!(counts, 6);
        Ok(())
    }
}