{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Cryptol.Utils.RecordMap
( RecordMap
, displayOrder
, canonicalFields
, displayFields
, recordElements
, fieldSet
, recordFromFields
, recordFromFieldsErr
, recordFromFieldsWithDisplay
, lookupField
, adjustField
, traverseRecordMap
, mapWithFieldName
, zipRecordsM
, zipRecords
, recordMapAccum
) where
import Control.DeepSeq
import Control.Monad.Except
import Data.Functor.Identity
import Data.Set (Set)
import Data.Map (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Map.Merge.Strict as Map
import Cryptol.Utils.Panic
data RecordMap a b =
RecordMap
{ recordMap :: !(Map a b)
, _displayOrder :: [a]
}
instance (Ord a, Eq b) => Eq (RecordMap a b) where
a == b = recordMap a == recordMap b
instance (Ord a, Ord b) => Ord (RecordMap a b) where
compare a b = compare (recordMap a) (recordMap b)
instance (Show a, Ord a, Show b) => Show (RecordMap a b) where
show = show . displayFields
instance (NFData a, NFData b) => NFData (RecordMap a b) where
rnf = rnf . canonicalFields
fieldSet :: Ord a => RecordMap a b -> Set a
fieldSet r = Map.keysSet (recordMap r)
displayOrder :: RecordMap a b -> [a]
displayOrder r = _displayOrder r
recordElements :: RecordMap a b -> [b]
recordElements = map snd . canonicalFields
canonicalFields :: RecordMap a b -> [(a,b)]
canonicalFields = Map.toList . recordMap
displayFields :: (Show a, Ord a) => RecordMap a b -> [(a,b)]
displayFields r = map find (displayOrder r)
where
find x =
case Map.lookup x (recordMap r) of
Just v -> (x, v)
Nothing ->
panic "displayFields"
["Could not find field in recordMap " ++ show x
, "Display order: " ++ show (displayOrder r)
, "Canonical order:" ++ show (map fst (canonicalFields r))
]
recordFromFields :: (Show a, Ord a) => [(a,b)] -> RecordMap a b
recordFromFields xs =
case recordFromFieldsErr xs of
Left (x,_) ->
panic "recordFromFields"
["Repeated field value: " ++ show x]
Right r -> r
recordFromFieldsErr :: (Show a, Ord a) => [(a,b)] -> Either (a,b) (RecordMap a b)
recordFromFieldsErr xs0 = loop mempty xs0
where
loop m [] = Right (RecordMap m (map fst xs0))
loop m ((x,v):xs)
| Just _ <- Map.lookup x m = Left (x,v)
| otherwise = loop (Map.insert x v m) xs
recordFromFieldsWithDisplay :: (Show a, Ord a) => [a] -> [(a,b)] -> RecordMap a b
recordFromFieldsWithDisplay dOrd fs = r { _displayOrder = dOrd }
where
r = recordFromFields fs
lookupField :: Ord a => a -> RecordMap a b -> Maybe b
lookupField x m = Map.lookup x (recordMap m)
adjustField :: forall a b. Ord a => a -> (b -> b) -> RecordMap a b -> Maybe (RecordMap a b)
adjustField x f r = mkRec <$> Map.alterF f' x (recordMap r)
where
mkRec m = r{ recordMap = m }
f' :: Maybe b -> Maybe (Maybe b)
f' Nothing = Nothing
f' (Just v) = Just (Just (f v))
traverseRecordMap :: Applicative t =>
(a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap f r =
(\m -> RecordMap m (displayOrder r)) <$> Map.traverseWithKey f (recordMap r)
mapWithFieldName :: (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName f = runIdentity . traverseRecordMap (\a b -> Identity (f a b))
instance Functor (RecordMap a) where
fmap f = mapWithFieldName (\_ -> f)
instance Foldable (RecordMap a) where
foldMap f = foldMap (f . snd) . canonicalFields
instance Traversable (RecordMap a) where
traverse f = traverseRecordMap (\_ -> f)
recordMapAccum :: (a -> b -> (a,c)) -> a -> RecordMap k b -> (a, RecordMap k c)
recordMapAccum f z r = (a, r{ recordMap = m' })
where
(a, m') = Map.mapAccum f z (recordMap r)
zipRecordsM :: forall t a b c d. (Ord a, Monad t) =>
(a -> b -> c -> t d) -> RecordMap a b -> RecordMap a c -> t (Either (Either a a) (RecordMap a d))
zipRecordsM f r1 r2 = runExceptT (mkRec <$> zipMap (recordMap r1) (recordMap r2))
where
mkRec m = RecordMap m (displayOrder r1)
zipMap :: Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
zipMap = Map.mergeA missingLeft missingRight matched
missingLeft = Map.traverseMissing (\a _b -> throwError (Left a))
missingRight = Map.traverseMissing (\a _c -> throwError (Right a))
matched = Map.zipWithAMatched (\a b c -> lift (f a b c))
zipRecords :: forall a b c d. Ord a =>
(a -> b -> c -> d) -> RecordMap a b -> RecordMap a c -> Either (Either a a) (RecordMap a d)
zipRecords f r1 r2 = runIdentity (zipRecordsM (\a b c -> Identity (f a b c)) r1 r2)