import Data.Char (isLower) import qualified Data.Map.Strict as Map import qualified Data.Set as Set import Parsing (splitByString) main :: IO () main = do input <- getContents let caves = parseCaves input (putStrLn . show . solution1) caves (putStrLn . show . solution2) caves solution1 :: Map.Map String [String] -> Int solution1 = length . cavePaths solution2 :: Map.Map String [String] -> Int solution2 = length . cavePaths2 parseCaves :: String -> Map.Map String [String] parseCaves = Map.fromListWith (++) . flatmap addReverse . map tupleFromList . map (splitByString "-") . lines where addReverse :: (String, [String]) -> [(String, [String])] addReverse t@(a, b) = [t, (head b, [a])] tupleFromList :: [String] -> (String, [String]) tupleFromList [] = undefined tupleFromList (x:xs) -- Since the second part is always expected to be a cave label -- or "end", this should never be > 1 | length xs > 1 = undefined | otherwise = (x, xs) cavePaths :: Map.Map String [String] -> [[String]] cavePaths caveMap = followSingle caveMap Set.empty "start" cavePaths2 :: Map.Map String [String] -> [[String]] cavePaths2 caveMap = followOneRepeat caveMap Set.empty "start" followSingle :: Map.Map String [String] -> Set.Set String -> String -> [[String]] followSingle caveMap visited node | node == "end" = [[node]] | otherwise = let v' = Set.insert node visited in map ((:) node) (flatmap (followSingle caveMap v') adjacent) where adjacent :: [String] adjacent = filter (not . visitedSmall) (Map.findWithDefault [] node caveMap) visitedSmall :: String -> Bool visitedSmall n = all isLower n && n `elem` visited followOneRepeat :: Map.Map String [String] -> Set.Set String -> String -> [[String]] followOneRepeat caveMap visited node | node == "end" = [[node]] | all isLower node && node `elem` visited = followSingle caveMap visited node | otherwise = let v' = Set.insert node visited in map ((:) node) (flatmap (followOneRepeat caveMap v') adjacent) where adjacent :: [String] adjacent = filter (/= "start") (Map.findWithDefault [] node caveMap) flatmap :: (t -> [a]) -> [t] -> [a] flatmap _ [] = [] flatmap f (x:xs) = f x ++ flatmap f xs -- Tests testInput1 = unlines [ "start-A", "start-b", "A-c", "A-b", "b-d", "A-end", "b-end" ] testInput2 = unlines [ "dc-end", "HN-start", "start-kj", "dc-start", "dc-HN", "LN-dc", "HN-end", "kj-sa", "kj-HN", "kj-dc"] testInput3 = unlines [ "fs-end", "he-DX", "fs-he", "start-DX", "pj-DX", "end-zg", "zg-sl", "zg-pj", "pj-he", "RW-he", "fs-DX", "pj-RW", "zg-RW", "start-pj", "he-WI", "zg-he", "pj-fs", "start-RW"] parsedTestInput1 = parseCaves testInput1 test1 = cavePaths parsedTestInput1 test2 = cavePaths2 parsedTestInput1 test3 = cavePaths2 (parseCaves testInput2) test4 = cavePaths2 (parseCaves testInput3) printPaths :: [[String]] -> IO () printPaths = putStr . unlines . (map (foldr1 (\c a -> c++"->"++a)))