-- |
-- Module      :  Cryptol.Utils.RecordMap
-- Copyright   :  (c) 2020 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
--
-- This module implements an "order insensitive" datastructure for
-- record types and values.  For most purposes, we want to deal with
-- record fields in a canonical order; but for user interaction
-- purposes, we generally want to display the fields in the order they
-- were specified by the user (in source files, at the REPL, etc.).

{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Cryptol.Utils.RecordMap
  ( RecordMap
  , displayOrder
  , canonicalFields
  , displayFields
  , recordElements
  , fieldSet
  , recordFromFields
  , recordFromFieldsErr
  , recordFromFieldsWithDisplay
  , lookupField
  , adjustField
  , traverseRecordMap
  , mapWithFieldName
  , zipRecordsM
  , zipRecords
  , recordMapAccum
  ) where

import           Control.DeepSeq
import           Control.Monad.Except
import           Data.Functor.Identity
import           Data.Set (Set)
import           Data.Map (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Map.Merge.Strict as Map

import Cryptol.Utils.Panic

-- | An "order insensitive" datastructure.
--   The fields can be accessed either according
--   to a "canonical" order, or based on a
--   "display" order, which matches the order
--   in which the fields were originally specified.
data RecordMap a b =
  RecordMap
  { RecordMap a b -> Map a b
recordMap :: !(Map a b)
  , RecordMap a b -> [a]
_displayOrder :: [a]
  }
-- RecordMap Invariant:
--   The `displayOrder` field should contain exactly the
--   same set of field names as the keys of `recordMap`.
--   Moreover, each field name should occur at most once.

instance (Ord a, Eq b) => Eq (RecordMap a b) where
  RecordMap a b
a == :: RecordMap a b -> RecordMap a b -> Bool
== RecordMap a b
b = RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
a Map a b -> Map a b -> Bool
forall a. Eq a => a -> a -> Bool
== RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
b

instance (Ord a, Ord b) => Ord (RecordMap a b) where
  compare :: RecordMap a b -> RecordMap a b -> Ordering
compare RecordMap a b
a RecordMap a b
b = Map a b -> Map a b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
a) (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
b)

instance (Show a, Ord a, Show b) => Show (RecordMap a b) where
  show :: RecordMap a b -> String
show = [(a, b)] -> String
forall a. Show a => a -> String
show ([(a, b)] -> String)
-> (RecordMap a b -> [(a, b)]) -> RecordMap a b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordMap a b -> [(a, b)]
forall a b. (Show a, Ord a) => RecordMap a b -> [(a, b)]
displayFields

instance (NFData a, NFData b) => NFData (RecordMap a b) where
  rnf :: RecordMap a b -> ()
rnf = [(a, b)] -> ()
forall a. NFData a => a -> ()
rnf ([(a, b)] -> ())
-> (RecordMap a b -> [(a, b)]) -> RecordMap a b -> ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordMap a b -> [(a, b)]
forall a b. RecordMap a b -> [(a, b)]
canonicalFields 


-- | Return the fields in this record as a set.
fieldSet :: Ord a => RecordMap a b -> Set a
fieldSet :: RecordMap a b -> Set a
fieldSet RecordMap a b
r = Map a b -> Set a
forall k a. Map k a -> Set k
Map.keysSet (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r)

-- | The order in which the fields originally appeared.
displayOrder :: RecordMap a b -> [a]
displayOrder :: RecordMap a b -> [a]
displayOrder RecordMap a b
r = RecordMap a b -> [a]
forall a b. RecordMap a b -> [a]
_displayOrder RecordMap a b
r

-- | Retrieve the elements of the record in canonical order
--   of the field names
recordElements :: RecordMap a b -> [b]
recordElements :: RecordMap a b -> [b]
recordElements = ((a, b) -> b) -> [(a, b)] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (a, b) -> b
forall a b. (a, b) -> b
snd ([(a, b)] -> [b])
-> (RecordMap a b -> [(a, b)]) -> RecordMap a b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordMap a b -> [(a, b)]
forall a b. RecordMap a b -> [(a, b)]
canonicalFields

-- | Return a list of field/value pairs in canonical order.
canonicalFields :: RecordMap a b -> [(a,b)]
canonicalFields :: RecordMap a b -> [(a, b)]
canonicalFields = Map a b -> [(a, b)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map a b -> [(a, b)])
-> (RecordMap a b -> Map a b) -> RecordMap a b -> [(a, b)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap

-- | Return a list of field/value pairs in display order.
displayFields :: (Show a, Ord a) => RecordMap a b -> [(a,b)]
displayFields :: RecordMap a b -> [(a, b)]
displayFields RecordMap a b
r = (a -> (a, b)) -> [a] -> [(a, b)]
forall a b. (a -> b) -> [a] -> [b]
map a -> (a, b)
find (RecordMap a b -> [a]
forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r)
  where
  find :: a -> (a, b)
find a
x =
    case a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r) of
      Just b
v -> (a
x, b
v)
      Maybe b
Nothing ->
         String -> [String] -> (a, b)
forall a. HasCallStack => String -> [String] -> a
panic String
"displayFields"
               [String
"Could not find field in recordMap " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x
               , String
"Display order: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [a] -> String
forall a. Show a => a -> String
show (RecordMap a b -> [a]
forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r)
               , String
"Canonical order:" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [a] -> String
forall a. Show a => a -> String
show (((a, b) -> a) -> [(a, b)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, b) -> a
forall a b. (a, b) -> a
fst (RecordMap a b -> [(a, b)]
forall a b. RecordMap a b -> [(a, b)]
canonicalFields RecordMap a b
r))
               ]

-- | Generate a record from a list of field/value pairs.
--   Precondition: each field identifier appears at most
--   once in the given list.
recordFromFields :: (Show a, Ord a) => [(a,b)] -> RecordMap a b
recordFromFields :: [(a, b)] -> RecordMap a b
recordFromFields [(a, b)]
xs =
  case [(a, b)] -> Either (a, b) (RecordMap a b)
forall a b.
(Show a, Ord a) =>
[(a, b)] -> Either (a, b) (RecordMap a b)
recordFromFieldsErr [(a, b)]
xs of
    Left (a
x,b
_) -> 
          String -> [String] -> RecordMap a b
forall a. HasCallStack => String -> [String] -> a
panic String
"recordFromFields"
                [String
"Repeated field value: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x]
    Right RecordMap a b
r -> RecordMap a b
r

-- | Generate a record from a list of field/value pairs.
--   If any field name is repeated, the first repeated name/value
--   pair is returned.  Otherwise the constructed record is returned.
recordFromFieldsErr :: (Show a, Ord a) => [(a,b)] -> Either (a,b) (RecordMap a b)
recordFromFieldsErr :: [(a, b)] -> Either (a, b) (RecordMap a b)
recordFromFieldsErr [(a, b)]
xs0 = Map a b -> [(a, b)] -> Either (a, b) (RecordMap a b)
forall a. Map a a -> [(a, a)] -> Either (a, a) (RecordMap a a)
loop Map a b
forall a. Monoid a => a
mempty [(a, b)]
xs0
  where
  loop :: Map a a -> [(a, a)] -> Either (a, a) (RecordMap a a)
loop Map a a
m [] = RecordMap a a -> Either (a, a) (RecordMap a a)
forall a b. b -> Either a b
Right (Map a a -> [a] -> RecordMap a a
forall a b. Map a b -> [a] -> RecordMap a b
RecordMap Map a a
m (((a, b) -> a) -> [(a, b)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, b) -> a
forall a b. (a, b) -> a
fst [(a, b)]
xs0))
  loop Map a a
m ((a
x,a
v):[(a, a)]
xs)
    | Just a
_ <- a -> Map a a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x Map a a
m = (a, a) -> Either (a, a) (RecordMap a a)
forall a b. a -> Either a b
Left (a
x,a
v)
    | Bool
otherwise = Map a a -> [(a, a)] -> Either (a, a) (RecordMap a a)
loop (a -> a -> Map a a -> Map a a
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert a
x a
v Map a a
m) [(a, a)]
xs

-- | Generate a record from a list of field/value pairs,
--   and also provide the "display" order for the fields directly.
--   Precondition: each field identifier appears at most once in each
--   list, and a field name appears in the display order iff it appears
--   in the field list.
recordFromFieldsWithDisplay :: (Show a, Ord a) => [a] -> [(a,b)] -> RecordMap a b
recordFromFieldsWithDisplay :: [a] -> [(a, b)] -> RecordMap a b
recordFromFieldsWithDisplay [a]
dOrd [(a, b)]
fs = RecordMap a b
r { _displayOrder :: [a]
_displayOrder = [a]
dOrd }
  where
  r :: RecordMap a b
r = [(a, b)] -> RecordMap a b
forall a b. (Show a, Ord a) => [(a, b)] -> RecordMap a b
recordFromFields [(a, b)]
fs

-- | Lookup the value of a field
lookupField :: Ord a => a -> RecordMap a b -> Maybe b
lookupField :: a -> RecordMap a b -> Maybe b
lookupField a
x RecordMap a b
m = a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
m)

-- | Update the value of a field by applying the given function.
--   If the field is not present in the record, return Nothing.
adjustField :: forall a b. Ord a => a -> (b -> b) -> RecordMap a b -> Maybe (RecordMap a b)
adjustField :: a -> (b -> b) -> RecordMap a b -> Maybe (RecordMap a b)
adjustField a
x b -> b
f RecordMap a b
r = Map a b -> RecordMap a b
forall b. Map a b -> RecordMap a b
mkRec (Map a b -> RecordMap a b)
-> Maybe (Map a b) -> Maybe (RecordMap a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe b -> Maybe (Maybe b)) -> a -> Map a b -> Maybe (Map a b)
forall (f :: * -> *) k a.
(Functor f, Ord k) =>
(Maybe a -> f (Maybe a)) -> k -> Map k a -> f (Map k a)
Map.alterF Maybe b -> Maybe (Maybe b)
f' a
x (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r)
  where
  mkRec :: Map a b -> RecordMap a b
mkRec Map a b
m = RecordMap a b
r{ recordMap :: Map a b
recordMap = Map a b
m }

  f' :: Maybe b -> Maybe (Maybe b)
  f' :: Maybe b -> Maybe (Maybe b)
f' Maybe b
Nothing = Maybe (Maybe b)
forall a. Maybe a
Nothing
  f' (Just b
v) = Maybe b -> Maybe (Maybe b)
forall a. a -> Maybe a
Just (b -> Maybe b
forall a. a -> Maybe a
Just (b -> b
f b
v))

-- | Traverse the elements of the given record map in canonical
--   order, applying the given action.
traverseRecordMap :: Applicative t =>
  (a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap :: (a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap a -> b -> t c
f RecordMap a b
r =
  (\Map a c
m -> Map a c -> [a] -> RecordMap a c
forall a b. Map a b -> [a] -> RecordMap a b
RecordMap Map a c
m (RecordMap a b -> [a]
forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r)) (Map a c -> RecordMap a c) -> t (Map a c) -> t (RecordMap a c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> b -> t c) -> Map a b -> t (Map a c)
forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
Map.traverseWithKey a -> b -> t c
f (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r)

-- | Apply the given function to each element of a record.
mapWithFieldName :: (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName :: (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName a -> b -> c
f = Identity (RecordMap a c) -> RecordMap a c
forall a. Identity a -> a
runIdentity (Identity (RecordMap a c) -> RecordMap a c)
-> (RecordMap a b -> Identity (RecordMap a c))
-> RecordMap a b
-> RecordMap a c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b -> Identity c) -> RecordMap a b -> Identity (RecordMap a c)
forall (t :: * -> *) a b c.
Applicative t =>
(a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap (\a
a b
b -> c -> Identity c
forall a. a -> Identity a
Identity (a -> b -> c
f a
a b
b))

instance Functor (RecordMap a) where
  fmap :: (a -> b) -> RecordMap a a -> RecordMap a b
fmap a -> b
f = (a -> a -> b) -> RecordMap a a -> RecordMap a b
forall a b c. (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName (\a
_ -> a -> b
f)

instance Foldable (RecordMap a) where
  foldMap :: (a -> m) -> RecordMap a a -> m
foldMap a -> m
f = ((a, a) -> m) -> [(a, a)] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (a -> m
f (a -> m) -> ((a, a) -> a) -> (a, a) -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, a) -> a
forall a b. (a, b) -> b
snd) ([(a, a)] -> m)
-> (RecordMap a a -> [(a, a)]) -> RecordMap a a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordMap a a -> [(a, a)]
forall a b. RecordMap a b -> [(a, b)]
canonicalFields

instance Traversable (RecordMap a) where
  traverse :: (a -> f b) -> RecordMap a a -> f (RecordMap a b)
traverse a -> f b
f = (a -> a -> f b) -> RecordMap a a -> f (RecordMap a b)
forall (t :: * -> *) a b c.
Applicative t =>
(a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap (\a
_ -> a -> f b
f)

-- | The function recordMapAccum threads an accumulating argument through
--   the map in canonical order of fields.
recordMapAccum :: (a -> b -> (a,c)) -> a -> RecordMap k b -> (a, RecordMap k c)
recordMapAccum :: (a -> b -> (a, c)) -> a -> RecordMap k b -> (a, RecordMap k c)
recordMapAccum a -> b -> (a, c)
f a
z RecordMap k b
r = (a
a, RecordMap k b
r{ recordMap :: Map k c
recordMap = Map k c
m' })
  where
  (a
a, Map k c
m') = (a -> b -> (a, c)) -> a -> Map k b -> (a, Map k c)
forall a b c k. (a -> b -> (a, c)) -> a -> Map k b -> (a, Map k c)
Map.mapAccum a -> b -> (a, c)
f a
z (RecordMap k b -> Map k b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap k b
r)

-- | Zip together the fields of two records using the provided action.
--   If some field is present in one record, but not the other,
--   an @Either a a@ will be returned, indicating which record is missing
--   the field, and returning the name of the missing field.
--
--   The @displayOrder@ of the resulting record will be taken from the first
--   argument (rather arbitrarily).
zipRecordsM :: forall t a b c d. (Ord a, Monad t) =>
  (a -> b -> c -> t d) -> RecordMap a b -> RecordMap a c -> t (Either (Either a a) (RecordMap a d))
zipRecordsM :: (a -> b -> c -> t d)
-> RecordMap a b
-> RecordMap a c
-> t (Either (Either a a) (RecordMap a d))
zipRecordsM a -> b -> c -> t d
f RecordMap a b
r1 RecordMap a c
r2 = ExceptT (Either a a) t (RecordMap a d)
-> t (Either (Either a a) (RecordMap a d))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (Map a d -> RecordMap a d
forall b. Map a b -> RecordMap a b
mkRec (Map a d -> RecordMap a d)
-> ExceptT (Either a a) t (Map a d)
-> ExceptT (Either a a) t (RecordMap a d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
zipMap (RecordMap a b -> Map a b
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r1) (RecordMap a c -> Map a c
forall a b. RecordMap a b -> Map a b
recordMap RecordMap a c
r2))
  where
  mkRec :: Map a b -> RecordMap a b
mkRec Map a b
m = Map a b -> [a] -> RecordMap a b
forall a b. Map a b -> [a] -> RecordMap a b
RecordMap Map a b
m (RecordMap a b -> [a]
forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r1)

  zipMap :: Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
  zipMap :: Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
zipMap = WhenMissing (ExceptT (Either a a) t) a b d
-> WhenMissing (ExceptT (Either a a) t) a c d
-> WhenMatched (ExceptT (Either a a) t) a b c d
-> Map a b
-> Map a c
-> ExceptT (Either a a) t (Map a d)
forall (f :: * -> *) k a c b.
(Applicative f, Ord k) =>
WhenMissing f k a c
-> WhenMissing f k b c
-> WhenMatched f k a b c
-> Map k a
-> Map k b
-> f (Map k c)
Map.mergeA WhenMissing (ExceptT (Either a a) t) a b d
forall x y. WhenMissing (ExceptT (Either a a) t) a x y
missingLeft WhenMissing (ExceptT (Either a a) t) a c d
forall x y. WhenMissing (ExceptT (Either a a) t) a x y
missingRight WhenMatched (ExceptT (Either a a) t) a b c d
matched
  missingLeft :: WhenMissing (ExceptT (Either a a) t) a x y
missingLeft  = (a -> x -> ExceptT (Either a a) t y)
-> WhenMissing (ExceptT (Either a a) t) a x y
forall (f :: * -> *) k x y.
Applicative f =>
(k -> x -> f y) -> WhenMissing f k x y
Map.traverseMissing (\a
a x
_b -> Either a a -> ExceptT (Either a a) t y
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (a -> Either a a
forall a b. a -> Either a b
Left a
a))
  missingRight :: WhenMissing (ExceptT (Either a a) t) a x y
missingRight = (a -> x -> ExceptT (Either a a) t y)
-> WhenMissing (ExceptT (Either a a) t) a x y
forall (f :: * -> *) k x y.
Applicative f =>
(k -> x -> f y) -> WhenMissing f k x y
Map.traverseMissing (\a
a x
_c -> Either a a -> ExceptT (Either a a) t y
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (a -> Either a a
forall a b. b -> Either a b
Right a
a))
  matched :: WhenMatched (ExceptT (Either a a) t) a b c d
matched = (a -> b -> c -> ExceptT (Either a a) t d)
-> WhenMatched (ExceptT (Either a a) t) a b c d
forall (f :: * -> *) k x y z.
Applicative f =>
(k -> x -> y -> f z) -> WhenMatched f k x y z
Map.zipWithAMatched (\a
a b
b c
c -> t d -> ExceptT (Either a a) t d
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (a -> b -> c -> t d
f a
a b
b c
c))

-- | Pure version of `zipRecordsM`
zipRecords :: forall a b c d. Ord a =>
  (a -> b -> c -> d) -> RecordMap a b -> RecordMap a c -> Either (Either a a) (RecordMap a d)
zipRecords :: (a -> b -> c -> d)
-> RecordMap a b
-> RecordMap a c
-> Either (Either a a) (RecordMap a d)
zipRecords a -> b -> c -> d
f RecordMap a b
r1 RecordMap a c
r2 = Identity (Either (Either a a) (RecordMap a d))
-> Either (Either a a) (RecordMap a d)
forall a. Identity a -> a
runIdentity ((a -> b -> c -> Identity d)
-> RecordMap a b
-> RecordMap a c
-> Identity (Either (Either a a) (RecordMap a d))
forall (t :: * -> *) a b c d.
(Ord a, Monad t) =>
(a -> b -> c -> t d)
-> RecordMap a b
-> RecordMap a c
-> t (Either (Either a a) (RecordMap a d))
zipRecordsM (\a
a b
b c
c -> d -> Identity d
forall a. a -> Identity a
Identity (a -> b -> c -> d
f a
a b
b c
c)) RecordMap a b
r1 RecordMap a c
r2)