-- |
-- Module      :  Cryptol.Utils.RecordMap
-- Copyright   :  (c) 2020 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
--
-- This module implements an "order insensitive" datastructure for
-- record types and values.  For most purposes, we want to deal with
-- record fields in a canonical order; but for user interaction
-- purposes, we generally want to display the fields in the order they
-- were specified by the user (in source files, at the REPL, etc.).

{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Safe #-}

module Cryptol.Utils.RecordMap
  ( RecordMap
  , displayOrder
  , canonicalFields
  , displayFields
  , recordElements
  , displayElements
  , fieldSet
  , recordFromFields
  , recordFromFieldsErr
  , recordFromFieldsWithDisplay
  , lookupField
  , adjustField
  , traverseRecordMap
  , mapWithFieldName
  , zipRecordsM
  , zipRecords
  , recordMapAccum
  ) where

import           Control.DeepSeq
import           Control.Monad.Except
import           Data.Functor.Identity
import           Data.Set (Set)
import           Data.Map (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Map.Merge.Strict as Map

import Cryptol.Utils.Panic

-- | An "order insensitive" datastructure.
--   The fields can be accessed either according
--   to a "canonical" order, or based on a
--   "display" order, which matches the order
--   in which the fields were originally specified.
data RecordMap a b =
  RecordMap
  { forall a b. RecordMap a b -> Map a b
recordMap :: !(Map a b)
  , forall a b. RecordMap a b -> [a]
_displayOrder :: [a]
  }
-- RecordMap Invariant:
--   The `displayOrder` field should contain exactly the
--   same set of field names as the keys of `recordMap`.
--   Moreover, each field name should occur at most once.

instance (Ord a, Eq b) => Eq (RecordMap a b) where
  RecordMap a b
a == :: RecordMap a b -> RecordMap a b -> Bool
== RecordMap a b
b = forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
a forall a. Eq a => a -> a -> Bool
== forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
b

instance (Ord a, Ord b) => Ord (RecordMap a b) where
  compare :: RecordMap a b -> RecordMap a b -> Ordering
compare RecordMap a b
a RecordMap a b
b = forall a. Ord a => a -> a -> Ordering
compare (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
a) (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
b)

instance (Show a, Ord a, Show b) => Show (RecordMap a b) where
  show :: RecordMap a b -> String
show = forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Show a, Ord a) => RecordMap a b -> [(a, b)]
displayFields

instance (NFData a, NFData b) => NFData (RecordMap a b) where
  rnf :: RecordMap a b -> ()
rnf = forall a. NFData a => a -> ()
rnf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. RecordMap a b -> [(a, b)]
canonicalFields 


-- | Return the fields in this record as a set.
fieldSet :: Ord a => RecordMap a b -> Set a
fieldSet :: forall a b. Ord a => RecordMap a b -> Set a
fieldSet RecordMap a b
r = forall k a. Map k a -> Set k
Map.keysSet (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r)

-- | The order in which the fields originally appeared.
displayOrder :: RecordMap a b -> [a]
displayOrder :: forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r = forall a b. RecordMap a b -> [a]
_displayOrder RecordMap a b
r

-- | Retrieve the elements of the record in canonical order
--   of the field names
recordElements :: RecordMap a b -> [b]
recordElements :: forall a b. RecordMap a b -> [b]
recordElements = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. RecordMap a b -> [(a, b)]
canonicalFields

-- | Return a list of field/value pairs in canonical order.
canonicalFields :: RecordMap a b -> [(a,b)]
canonicalFields :: forall a b. RecordMap a b -> [(a, b)]
canonicalFields = forall k a. Map k a -> [(k, a)]
Map.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. RecordMap a b -> Map a b
recordMap

-- | Retrieve the elements of the record in display order
--   of the field names.
displayElements :: (Show a, Ord a) => RecordMap a b -> [b]
displayElements :: forall a b. (Show a, Ord a) => RecordMap a b -> [b]
displayElements = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Show a, Ord a) => RecordMap a b -> [(a, b)]
displayFields

-- | Return a list of field/value pairs in display order.
displayFields :: (Show a, Ord a) => RecordMap a b -> [(a,b)]
displayFields :: forall a b. (Show a, Ord a) => RecordMap a b -> [(a, b)]
displayFields RecordMap a b
r = forall a b. (a -> b) -> [a] -> [b]
map a -> (a, b)
find (forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r)
  where
  find :: a -> (a, b)
find a
x =
    case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r) of
      Just b
v -> (a
x, b
v)
      Maybe b
Nothing ->
         forall a. HasCallStack => String -> [String] -> a
panic String
"displayFields"
               [String
"Could not find field in recordMap " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
x
               , String
"Display order: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r)
               , String
"Canonical order:" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst (forall a b. RecordMap a b -> [(a, b)]
canonicalFields RecordMap a b
r))
               ]

-- | Generate a record from a list of field/value pairs.
--   Precondition: each field identifier appears at most
--   once in the given list.
recordFromFields :: (Show a, Ord a) => [(a,b)] -> RecordMap a b
recordFromFields :: forall a b. (Show a, Ord a) => [(a, b)] -> RecordMap a b
recordFromFields [(a, b)]
xs =
  case forall a b.
(Show a, Ord a) =>
[(a, b)] -> Either (a, b) (RecordMap a b)
recordFromFieldsErr [(a, b)]
xs of
    Left (a
x,b
_) -> 
          forall a. HasCallStack => String -> [String] -> a
panic String
"recordFromFields"
                [String
"Repeated field value: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
x]
    Right RecordMap a b
r -> RecordMap a b
r

-- | Generate a record from a list of field/value pairs.
--   If any field name is repeated, the first repeated name/value
--   pair is returned.  Otherwise the constructed record is returned.
recordFromFieldsErr :: (Show a, Ord a) => [(a,b)] -> Either (a,b) (RecordMap a b)
recordFromFieldsErr :: forall a b.
(Show a, Ord a) =>
[(a, b)] -> Either (a, b) (RecordMap a b)
recordFromFieldsErr [(a, b)]
xs0 = forall {a}. Map a a -> [(a, a)] -> Either (a, a) (RecordMap a a)
loop forall a. Monoid a => a
mempty [(a, b)]
xs0
  where
  loop :: Map a a -> [(a, a)] -> Either (a, a) (RecordMap a a)
loop Map a a
m [] = forall a b. b -> Either a b
Right (forall a b. Map a b -> [a] -> RecordMap a b
RecordMap Map a a
m (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(a, b)]
xs0))
  loop Map a a
m ((a
x,a
v):[(a, a)]
xs)
    | Just a
_ <- forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x Map a a
m = forall a b. a -> Either a b
Left (a
x,a
v)
    | Bool
otherwise = Map a a -> [(a, a)] -> Either (a, a) (RecordMap a a)
loop (forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert a
x a
v Map a a
m) [(a, a)]
xs

-- | Generate a record from a list of field/value pairs,
--   and also provide the "display" order for the fields directly.
--   Precondition: each field identifier appears at most once in each
--   list, and a field name appears in the display order iff it appears
--   in the field list.
recordFromFieldsWithDisplay :: (Show a, Ord a) => [a] -> [(a,b)] -> RecordMap a b
recordFromFieldsWithDisplay :: forall a b. (Show a, Ord a) => [a] -> [(a, b)] -> RecordMap a b
recordFromFieldsWithDisplay [a]
dOrd [(a, b)]
fs = RecordMap a b
r { _displayOrder :: [a]
_displayOrder = [a]
dOrd }
  where
  r :: RecordMap a b
r = forall a b. (Show a, Ord a) => [(a, b)] -> RecordMap a b
recordFromFields [(a, b)]
fs

-- | Lookup the value of a field
lookupField :: Ord a => a -> RecordMap a b -> Maybe b
lookupField :: forall a b. Ord a => a -> RecordMap a b -> Maybe b
lookupField a
x RecordMap a b
m = forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
m)

-- | Update the value of a field by applying the given function.
--   If the field is not present in the record, return Nothing.
adjustField :: forall a b. Ord a => a -> (b -> b) -> RecordMap a b -> Maybe (RecordMap a b)
adjustField :: forall a b.
Ord a =>
a -> (b -> b) -> RecordMap a b -> Maybe (RecordMap a b)
adjustField a
x b -> b
f RecordMap a b
r = forall {b}. Map a b -> RecordMap a b
mkRec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) k a.
(Functor f, Ord k) =>
(Maybe a -> f (Maybe a)) -> k -> Map k a -> f (Map k a)
Map.alterF Maybe b -> Maybe (Maybe b)
f' a
x (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r)
  where
  mkRec :: Map a b -> RecordMap a b
mkRec Map a b
m = RecordMap a b
r{ recordMap :: Map a b
recordMap = Map a b
m }

  f' :: Maybe b -> Maybe (Maybe b)
  f' :: Maybe b -> Maybe (Maybe b)
f' Maybe b
Nothing = forall a. Maybe a
Nothing
  f' (Just b
v) = forall a. a -> Maybe a
Just (forall a. a -> Maybe a
Just (b -> b
f b
v))

-- | Traverse the elements of the given record map in canonical
--   order, applying the given action.
traverseRecordMap :: Applicative t =>
  (a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap :: forall (t :: * -> *) a b c.
Applicative t =>
(a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap a -> b -> t c
f RecordMap a b
r =
  (\Map a c
m -> forall a b. Map a b -> [a] -> RecordMap a b
RecordMap Map a c
m (forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
Map.traverseWithKey a -> b -> t c
f (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r)

-- | Apply the given function to each element of a record.
mapWithFieldName :: (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName :: forall a b c. (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName a -> b -> c
f = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b c.
Applicative t =>
(a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap (\a
a b
b -> forall a. a -> Identity a
Identity (a -> b -> c
f a
a b
b))

instance Functor (RecordMap a) where
  fmap :: forall a b. (a -> b) -> RecordMap a a -> RecordMap a b
fmap a -> b
f = forall a b c. (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName (\a
_ -> a -> b
f)

instance Foldable (RecordMap a) where
  foldMap :: forall m a. Monoid m => (a -> m) -> RecordMap a a -> m
foldMap a -> m
f = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. RecordMap a b -> [(a, b)]
canonicalFields

instance Traversable (RecordMap a) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> RecordMap a a -> f (RecordMap a b)
traverse a -> f b
f = forall (t :: * -> *) a b c.
Applicative t =>
(a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap (\a
_ -> a -> f b
f)

-- | The function recordMapAccum threads an accumulating argument through
--   the map in canonical order of fields.
recordMapAccum :: (a -> b -> (a,c)) -> a -> RecordMap k b -> (a, RecordMap k c)
recordMapAccum :: forall a b c k.
(a -> b -> (a, c)) -> a -> RecordMap k b -> (a, RecordMap k c)
recordMapAccum a -> b -> (a, c)
f a
z RecordMap k b
r = (a
a, RecordMap k b
r{ recordMap :: Map k c
recordMap = Map k c
m' })
  where
  (a
a, Map k c
m') = forall a b c k. (a -> b -> (a, c)) -> a -> Map k b -> (a, Map k c)
Map.mapAccum a -> b -> (a, c)
f a
z (forall a b. RecordMap a b -> Map a b
recordMap RecordMap k b
r)

-- | Zip together the fields of two records using the provided action.
--   If some field is present in one record, but not the other,
--   an @Either a a@ will be returned, indicating which record is missing
--   the field, and returning the name of the missing field.
--
--   The @displayOrder@ of the resulting record will be taken from the first
--   argument (rather arbitrarily).
zipRecordsM :: forall t a b c d. (Ord a, Monad t) =>
  (a -> b -> c -> t d) -> RecordMap a b -> RecordMap a c -> t (Either (Either a a) (RecordMap a d))
zipRecordsM :: forall (t :: * -> *) a b c d.
(Ord a, Monad t) =>
(a -> b -> c -> t d)
-> RecordMap a b
-> RecordMap a c
-> t (Either (Either a a) (RecordMap a d))
zipRecordsM a -> b -> c -> t d
f RecordMap a b
r1 RecordMap a c
r2 = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (forall {b}. Map a b -> RecordMap a b
mkRec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
zipMap (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a b
r1) (forall a b. RecordMap a b -> Map a b
recordMap RecordMap a c
r2))
  where
  mkRec :: Map a b -> RecordMap a b
mkRec Map a b
m = forall a b. Map a b -> [a] -> RecordMap a b
RecordMap Map a b
m (forall a b. RecordMap a b -> [a]
displayOrder RecordMap a b
r1)

  zipMap :: Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
  zipMap :: Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
zipMap = forall (f :: * -> *) k a c b.
(Applicative f, Ord k) =>
WhenMissing f k a c
-> WhenMissing f k b c
-> WhenMatched f k a b c
-> Map k a
-> Map k b
-> f (Map k c)
Map.mergeA forall {x} {y}. WhenMissing (ExceptT (Either a a) t) a x y
missingLeft forall {x} {y}. WhenMissing (ExceptT (Either a a) t) a x y
missingRight WhenMatched (ExceptT (Either a a) t) a b c d
matched
  missingLeft :: WhenMissing (ExceptT (Either a a) t) a x y
missingLeft  = forall (f :: * -> *) k x y.
Applicative f =>
(k -> x -> f y) -> WhenMissing f k x y
Map.traverseMissing (\a
a x
_b -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (forall a b. a -> Either a b
Left a
a))
  missingRight :: WhenMissing (ExceptT (Either a a) t) a x y
missingRight = forall (f :: * -> *) k x y.
Applicative f =>
(k -> x -> f y) -> WhenMissing f k x y
Map.traverseMissing (\a
a x
_c -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (forall a b. b -> Either a b
Right a
a))
  matched :: WhenMatched (ExceptT (Either a a) t) a b c d
matched = forall (f :: * -> *) k x y z.
Applicative f =>
(k -> x -> y -> f z) -> WhenMatched f k x y z
Map.zipWithAMatched (\a
a b
b c
c -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (a -> b -> c -> t d
f a
a b
b c
c))

-- | Pure version of `zipRecordsM`
zipRecords :: forall a b c d. Ord a =>
  (a -> b -> c -> d) -> RecordMap a b -> RecordMap a c -> Either (Either a a) (RecordMap a d)
zipRecords :: forall a b c d.
Ord a =>
(a -> b -> c -> d)
-> RecordMap a b
-> RecordMap a c
-> Either (Either a a) (RecordMap a d)
zipRecords a -> b -> c -> d
f RecordMap a b
r1 RecordMap a c
r2 = forall a. Identity a -> a
runIdentity (forall (t :: * -> *) a b c d.
(Ord a, Monad t) =>
(a -> b -> c -> t d)
-> RecordMap a b
-> RecordMap a c
-> t (Either (Either a a) (RecordMap a d))
zipRecordsM (\a
a b
b c
c -> forall a. a -> Identity a
Identity (a -> b -> c -> d
f a
a b
b c
c)) RecordMap a b
r1 RecordMap a c
r2)