-- |
-- Module      :  Cryptol.Utils.RecordMap
-- Copyright   :  (c) 2020 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
--
-- This module implements an "order insensitive" datastructure for
-- record types and values.  For most purposes, we want to deal with
-- record fields in a canonical order; but for user interaction
-- purposes, we generally want to display the fields in the order they
-- were specified by the user (in source files, at the REPL, etc.).

{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Cryptol.Utils.RecordMap
  ( RecordMap
  , displayOrder
  , canonicalFields
  , displayFields
  , recordElements
  , fieldSet
  , recordFromFields
  , recordFromFieldsErr
  , recordFromFieldsWithDisplay
  , lookupField
  , adjustField
  , traverseRecordMap
  , mapWithFieldName
  , zipRecordsM
  , zipRecords
  , recordMapAccum
  ) where

import           Control.DeepSeq
import           Control.Monad.Except
import           Data.Functor.Identity
import           Data.Set (Set)
import           Data.Map (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Map.Merge.Strict as Map

import Cryptol.Utils.Panic

-- | An "order insensitive" datastructure.
--   The fields can be accessed either according
--   to a "canonical" order, or based on a
--   "display" order, which matches the order
--   in which the fields were originally specified.
data RecordMap a b =
  RecordMap
  { recordMap :: !(Map a b)
  , _displayOrder :: [a]
  }
-- RecordMap Invariant:
--   The `displayOrder` field should contain exactly the
--   same set of field names as the keys of `recordMap`.
--   Moreover, each field name should occur at most once.

instance (Ord a, Eq b) => Eq (RecordMap a b) where
  a == b = recordMap a == recordMap b

instance (Ord a, Ord b) => Ord (RecordMap a b) where
  compare a b = compare (recordMap a) (recordMap b)

instance (Show a, Ord a, Show b) => Show (RecordMap a b) where
  show = show . displayFields

instance (NFData a, NFData b) => NFData (RecordMap a b) where
  rnf = rnf . canonicalFields


-- | Return the fields in this record as a set.
fieldSet :: Ord a => RecordMap a b -> Set a
fieldSet r = Map.keysSet (recordMap r)

-- | The order in which the fields originally appeared.
displayOrder :: RecordMap a b -> [a]
displayOrder r = _displayOrder r

-- | Retrieve the elements of the record in canonical order
--   of the field names
recordElements :: RecordMap a b -> [b]
recordElements = map snd . canonicalFields

-- | Return a list of field/value pairs in canonical order.
canonicalFields :: RecordMap a b -> [(a,b)]
canonicalFields = Map.toList . recordMap

-- | Return a list of field/value pairs in display order.
displayFields :: (Show a, Ord a) => RecordMap a b -> [(a,b)]
displayFields r = map find (displayOrder r)
  where
  find x =
    case Map.lookup x (recordMap r) of
      Just v -> (x, v)
      Nothing ->
         panic "displayFields"
               ["Could not find field in recordMap " ++ show x
               , "Display order: " ++ show (displayOrder r)
               , "Canonical order:" ++ show (map fst (canonicalFields r))
               ]

-- | Generate a record from a list of field/value pairs.
--   Precondition: each field identifier appears at most
--   once in the given list.
recordFromFields :: (Show a, Ord a) => [(a,b)] -> RecordMap a b
recordFromFields xs =
  case recordFromFieldsErr xs of
    Left (x,_) ->
          panic "recordFromFields"
                ["Repeated field value: " ++ show x]
    Right r -> r

-- | Generate a record from a list of field/value pairs.
--   If any field name is repeated, the first repeated name/value
--   pair is returned.  Otherwise the constructed record is returned.
recordFromFieldsErr :: (Show a, Ord a) => [(a,b)] -> Either (a,b) (RecordMap a b)
recordFromFieldsErr xs0 = loop mempty xs0
  where
  loop m [] = Right (RecordMap m (map fst xs0))
  loop m ((x,v):xs)
    | Just _ <- Map.lookup x m = Left (x,v)
    | otherwise = loop (Map.insert x v m) xs

-- | Generate a record from a list of field/value pairs,
--   and also provide the "display" order for the fields directly.
--   Precondition: each field identifier appears at most once in each
--   list, and a field name appears in the display order iff it appears
--   in the field list.
recordFromFieldsWithDisplay :: (Show a, Ord a) => [a] -> [(a,b)] -> RecordMap a b
recordFromFieldsWithDisplay dOrd fs = r { _displayOrder = dOrd }
  where
  r = recordFromFields fs

-- | Lookup the value of a field
lookupField :: Ord a => a -> RecordMap a b -> Maybe b
lookupField x m = Map.lookup x (recordMap m)

-- | Update the value of a field by applying the given function.
--   If the field is not present in the record, return Nothing.
adjustField :: forall a b. Ord a => a -> (b -> b) -> RecordMap a b -> Maybe (RecordMap a b)
adjustField x f r = mkRec <$> Map.alterF f' x (recordMap r)
  where
  mkRec m = r{ recordMap = m }

  f' :: Maybe b -> Maybe (Maybe b)
  f' Nothing = Nothing
  f' (Just v) = Just (Just (f v))

-- | Traverse the elements of the given record map in canonical
--   order, applying the given action.
traverseRecordMap :: Applicative t =>
  (a -> b -> t c) -> RecordMap a b -> t (RecordMap a c)
traverseRecordMap f r =
  (\m -> RecordMap m (displayOrder r)) <$> Map.traverseWithKey f (recordMap r)

-- | Apply the given function to each element of a record.
mapWithFieldName :: (a -> b -> c) -> RecordMap a b -> RecordMap a c
mapWithFieldName f = runIdentity . traverseRecordMap (\a b -> Identity (f a b))

instance Functor (RecordMap a) where
  fmap f = mapWithFieldName (\_ -> f)

instance Foldable (RecordMap a) where
  foldMap f = foldMap (f . snd) . canonicalFields

instance Traversable (RecordMap a) where
  traverse f = traverseRecordMap (\_ -> f)

-- | The function recordMapAccum threads an accumulating argument through
--   the map in canonical order of fields.
recordMapAccum :: (a -> b -> (a,c)) -> a -> RecordMap k b -> (a, RecordMap k c)
recordMapAccum f z r = (a, r{ recordMap = m' })
  where
  (a, m') = Map.mapAccum f z (recordMap r)

-- | Zip together the fields of two records using the provided action.
--   If some field is present in one record, but not the other,
--   an @Either a a@ will be returned, indicating which record is missing
--   the field, and returning the name of the missing field.
--
--   The @displayOrder@ of the resulting record will be taken from the first
--   argument (rather arbitrarily).
zipRecordsM :: forall t a b c d. (Ord a, Monad t) =>
  (a -> b -> c -> t d) -> RecordMap a b -> RecordMap a c -> t (Either (Either a a) (RecordMap a d))
zipRecordsM f r1 r2 = runExceptT (mkRec <$> zipMap (recordMap r1) (recordMap r2))
  where
  mkRec m = RecordMap m (displayOrder r1)

  zipMap :: Map a b -> Map a c -> ExceptT (Either a a) t (Map a d)
  zipMap = Map.mergeA missingLeft missingRight matched
  missingLeft  = Map.traverseMissing (\a _b -> throwError (Left a))
  missingRight = Map.traverseMissing (\a _c -> throwError (Right a))
  matched = Map.zipWithAMatched (\a b c -> lift (f a b c))

-- | Pure version of `zipRecordsM`
zipRecords :: forall a b c d. Ord a =>
  (a -> b -> c -> d) -> RecordMap a b -> RecordMap a c -> Either (Either a a) (RecordMap a d)
zipRecords f r1 r2 = runIdentity (zipRecordsM (\a b c -> Identity (f a b c)) r1 r2)