adventofcode-2021/12/cavesearcher.hs

123 lines
3.1 KiB
Haskell

import Data.Char (isLower)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Parsing (splitByString)
main :: IO ()
main = do
input <- getContents
let
caves = parseCaves input
(putStrLn . show . solution1) caves
(putStrLn . show . solution2) caves
solution1 :: Map.Map String [String] -> Int
solution1 = length . cavePaths
solution2 :: Map.Map String [String] -> Int
solution2 = length . cavePaths2
parseCaves :: String -> Map.Map String [String]
parseCaves = Map.fromListWith (++) . flatmap addReverse . map tupleFromList . map (splitByString "-") . lines
where
addReverse :: (String, [String]) -> [(String, [String])]
addReverse t@(a, b) = [t, (head b, [a])]
tupleFromList :: [String] -> (String, [String])
tupleFromList [] = undefined
tupleFromList (x:xs)
-- Since the second part is always expected to be a cave label
-- or "end", this should never be > 1
| length xs > 1 = undefined
| otherwise = (x, xs)
cavePaths :: Map.Map String [String] -> [[String]]
cavePaths caveMap = followSingle caveMap Set.empty "start"
cavePaths2 :: Map.Map String [String] -> [[String]]
cavePaths2 caveMap = followOneRepeat caveMap Set.empty "start"
followSingle :: Map.Map String [String] -> Set.Set String -> String -> [[String]]
followSingle caveMap visited node
| node == "end" = [[node]]
| otherwise =
let
v' = Set.insert node visited
in
map ((:) node) (flatmap (followSingle caveMap v') adjacent)
where
adjacent :: [String]
adjacent = filter (not . visitedSmall) (Map.findWithDefault [] node caveMap)
visitedSmall :: String -> Bool
visitedSmall n = all isLower n && n `elem` visited
followOneRepeat :: Map.Map String [String] -> Set.Set String -> String -> [[String]]
followOneRepeat caveMap visited node
| node == "end" = [[node]]
| all isLower node && node `elem` visited =
followSingle caveMap visited node
| otherwise =
let
v' = Set.insert node visited
in
map ((:) node) (flatmap (followOneRepeat caveMap v') adjacent)
where
adjacent :: [String]
adjacent = filter (/= "start") (Map.findWithDefault [] node caveMap)
flatmap :: (t -> [a]) -> [t] -> [a]
flatmap _ [] = []
flatmap f (x:xs) = f x ++ flatmap f xs
-- Tests
testInput1 = unlines [
"start-A",
"start-b",
"A-c",
"A-b",
"b-d",
"A-end",
"b-end"
]
testInput2 = unlines [
"dc-end",
"HN-start",
"start-kj",
"dc-start",
"dc-HN",
"LN-dc",
"HN-end",
"kj-sa",
"kj-HN",
"kj-dc"]
testInput3 = unlines [
"fs-end",
"he-DX",
"fs-he",
"start-DX",
"pj-DX",
"end-zg",
"zg-sl",
"zg-pj",
"pj-he",
"RW-he",
"fs-DX",
"pj-RW",
"zg-RW",
"start-pj",
"he-WI",
"zg-he",
"pj-fs",
"start-RW"]
parsedTestInput1 = parseCaves testInput1
test1 = cavePaths parsedTestInput1
test2 = cavePaths2 parsedTestInput1
test3 = cavePaths2 (parseCaves testInput2)
test4 = cavePaths2 (parseCaves testInput3)
printPaths :: [[String]] -> IO ()
printPaths = putStr . unlines . (map (foldr1 (\c a -> c++"->"++a)))