-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
module Retrie.Substitution
  ( Substitution
  , HoleVal(..)
  , emptySubst
  , extendSubst
  , lookupSubst
  , deleteSubst
  , foldSubst
  ) where

import Retrie.ExactPrint
import Retrie.GHC

-- | A 'Substitution' is essentially a map from variable name to 'HoleVal'.
#if __GLASGOW_HASKELL__ < 900
newtype Substitution = Substitution (UniqFM (FastString, HoleVal))
#else
newtype Substitution = Substitution (UniqFM FastString (FastString, HoleVal))
#endif
-- See Note [Why not RdrNames?] for explanation of use of FastString

instance Show Substitution where
  show :: Substitution -> String
show (Substitution UniqFM (FastString, HoleVal)
m) = [(FastString, HoleVal)] -> String
forall a. Show a => a -> String
show (UniqFM (FastString, HoleVal) -> [(FastString, HoleVal)]
forall elt. UniqFM elt -> [elt]
eltsUFM UniqFM (FastString, HoleVal)
m)

-- | Sum type of possible substitution values.
data HoleVal
  = HoleExpr AnnotatedHsExpr -- ^ 'HsExpr'
  | HolePat AnnotatedPat -- ^ 'Pat'
  | HoleType AnnotatedHsType -- ^ 'HsType'
  | HoleRdr RdrName -- ^ Alpha-renamed binder.

instance Show HoleVal where
  show :: HoleVal -> String
show (HoleExpr AnnotatedHsExpr
e) = String
"HoleExpr " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnnotatedHsExpr -> String
forall ast. Annotate ast => Annotated (Located ast) -> String
printA AnnotatedHsExpr
e
  show (HolePat AnnotatedPat
p) = String
"HolePat " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnnotatedPat -> String
forall ast. Annotate ast => Annotated (Located ast) -> String
printA AnnotatedPat
p
  show (HoleType AnnotatedHsType
t) = String
"HoleType " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnnotatedHsType -> String
forall ast. Annotate ast => Annotated (Located ast) -> String
printA AnnotatedHsType
t
  show (HoleRdr RdrName
r) = String
"HoleRdr " String -> ShowS
forall a. [a] -> [a] -> [a]
++ FastString -> String
unpackFS (RdrName -> FastString
rdrFS RdrName
r)

-- | The empty substitution.
emptySubst :: Substitution
emptySubst :: Substitution
emptySubst = UniqFM (FastString, HoleVal) -> Substitution
Substitution UniqFM (FastString, HoleVal)
forall elt. UniqFM elt
emptyUFM

-- | Lookup a value in the substitution.
lookupSubst :: FastString -> Substitution -> Maybe HoleVal
lookupSubst :: FastString -> Substitution -> Maybe HoleVal
lookupSubst FastString
k (Substitution UniqFM (FastString, HoleVal)
m) = (FastString, HoleVal) -> HoleVal
forall a b. (a, b) -> b
snd ((FastString, HoleVal) -> HoleVal)
-> Maybe (FastString, HoleVal) -> Maybe HoleVal
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UniqFM (FastString, HoleVal)
-> FastString -> Maybe (FastString, HoleVal)
forall key elt. Uniquable key => UniqFM elt -> key -> Maybe elt
lookupUFM UniqFM (FastString, HoleVal)
m FastString
k

-- | Extend the substitution. If the key already exists, its value is replaced.
extendSubst :: Substitution -> FastString -> HoleVal -> Substitution
extendSubst :: Substitution -> FastString -> HoleVal -> Substitution
extendSubst (Substitution UniqFM (FastString, HoleVal)
m) FastString
k HoleVal
v = UniqFM (FastString, HoleVal) -> Substitution
Substitution (UniqFM (FastString, HoleVal)
-> FastString
-> (FastString, HoleVal)
-> UniqFM (FastString, HoleVal)
forall key elt.
Uniquable key =>
UniqFM elt -> key -> elt -> UniqFM elt
addToUFM UniqFM (FastString, HoleVal)
m FastString
k (FastString
k,HoleVal
v))

-- | Delete from the substitution.
deleteSubst :: Substitution -> [FastString] -> Substitution
deleteSubst :: Substitution -> [FastString] -> Substitution
deleteSubst (Substitution UniqFM (FastString, HoleVal)
m) [FastString]
ks = UniqFM (FastString, HoleVal) -> Substitution
Substitution (UniqFM (FastString, HoleVal)
-> [FastString] -> UniqFM (FastString, HoleVal)
forall key elt. Uniquable key => UniqFM elt -> [key] -> UniqFM elt
delListFromUFM UniqFM (FastString, HoleVal)
m [FastString]
ks)

-- | Fold over the substitution.
foldSubst :: ((FastString, HoleVal) -> a -> a) -> a -> Substitution -> a
foldSubst :: ((FastString, HoleVal) -> a -> a) -> a -> Substitution -> a
foldSubst (FastString, HoleVal) -> a -> a
f a
x (Substitution UniqFM (FastString, HoleVal)
m) = ((FastString, HoleVal) -> a -> a)
-> a -> UniqFM (FastString, HoleVal) -> a
forall elt a. (elt -> a -> a) -> a -> UniqFM elt -> a
foldUFM (FastString, HoleVal) -> a -> a
f a
x UniqFM (FastString, HoleVal)
m