module Language.Hakaru.Parser.AST where
import Prelude hiding (replicate, unlines)
import Control.Arrow ((***))
import qualified Data.Foldable as F
import qualified Data.Vector as V
import qualified Data.Number.Nat as N
import qualified Data.Number.Natural as N
import qualified Data.List.NonEmpty as L
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Syntax.ABT hiding (Var, Bind)
import Language.Hakaru.Syntax.AST
(Literal(..), MeasureOp(..), LCs(), UnLCs ())
import Language.Hakaru.Syntax.IClasses
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
#endif
import Data.Text
import Text.Printf
import Text.Parsec (SourcePos)
import Text.Parsec.Pos
data Name = Name !N.Nat !Text
deriving (Read, Show, Eq, Ord)
nameID :: Name -> N.Nat
nameID (Name i _) = i
hintID :: Name -> Text
hintID (Name _ t) = t
type Name' = Text
data Branch' a
= Branch' (Pattern' Name') (AST' a)
| Branch'' (Pattern' Name) (AST' a)
deriving (Eq, Show)
data Pattern' a
= PVar' a
| PWild'
| PData' (PDatum a)
deriving (Eq, Show)
data PDatum a = DV Name' [Pattern' a]
deriving (Eq, Show)
data SourceSpan = SourceSpan !SourcePos !SourcePos
deriving (Eq, Show)
numberLine :: Text -> Int -> Text
numberLine s n = append (pack (printf "%5d| " n)) s
printSourceSpan :: SourceSpan -> V.Vector Text -> Text
printSourceSpan (SourceSpan start stop) input
| sourceLine start == sourceLine stop =
unlines [ numberLine endLine (sourceLine stop)
, append " " textSpan
]
| otherwise =
unlines $ flip fmap [sourceLine start .. sourceLine stop] $ \i ->
numberLine (input V.! (i 1)) i
where endLine = input V.! (sourceLine stop 1)
spanLen = sourceColumn stop sourceColumn start
textSpan =
append (replicate (sourceColumn start 1) " ")
(replicate spanLen "^")
data Literal'
= Nat Integer
| Int Integer
| Prob Rational
| Real Rational
deriving (Eq, Show)
data NaryOp
= And | Or | Xor
| Iff | Min | Max
| Sum | Prod
deriving (Eq, Show)
data ArrayOp = Index_ | Size | Reduce
data TypeAST'
= TypeVar Name'
| TypeApp Name' [TypeAST']
| TypeFun TypeAST' TypeAST'
deriving (Eq, Show)
data AST' a
= Var a
| Lam a TypeAST' (AST' a)
| App (AST' a) (AST' a)
| Let a (AST' a) (AST' a)
| If (AST' a) (AST' a) (AST' a)
| Ann (AST' a) TypeAST'
| Infinity'
| ULiteral Literal'
| NaryOp NaryOp [AST' a]
| Unit
| Empty
| Pair (AST' a) (AST' a)
| Array a (AST' a) (AST' a)
| Index (AST' a) (AST' a)
| Case (AST' a) [(Branch' a)]
| Dirac (AST' a)
| Bind a (AST' a) (AST' a)
| Plate a (AST' a) (AST' a)
| Chain a (AST' a) (AST' a) (AST' a)
| Integrate a (AST' a) (AST' a) (AST' a)
| Summate a (AST' a) (AST' a) (AST' a)
| Product a (AST' a) (AST' a) (AST' a)
| Expect a (AST' a) (AST' a)
| Observe (AST' a) (AST' a)
| Msum [AST' a]
| Data a [TypeAST']
| WithMeta (AST' a) SourceSpan
deriving (Show)
instance Eq a => Eq (AST' a) where
(Var t) == (Var t') = t == t'
(Lam n e1 e2) == (Lam n' e1' e2') = n == n' &&
e1 == e1' &&
e2 == e2'
(App e1 e2) == (App e1' e2') = e1 == e1' &&
e2 == e2'
(Let n e1 e2) == (Let n' e1' e2') = n == n' &&
e1 == e1' &&
e2 == e2'
(If c e1 e2) == (If c' e1' e2') = c == c' &&
e1 == e1' &&
e2 == e2'
(Ann e typ) == (Ann e' typ') = e == e' &&
typ == typ'
Infinity' == Infinity' = True
(ULiteral v) == (ULiteral v') = v == v'
(NaryOp op args) == (NaryOp op' args') = op == op' &&
args == args'
Unit == Unit = True
Empty == Empty = True
(Pair e1 e2) == (Pair e1' e2') = e1 == e1' &&
e2 == e2'
(Array e1 e2 e3) == (Array e1' e2' e3') = e1 == e1' &&
e2 == e2' &&
e3 == e3'
(Index e1 e2) == (Index e1' e2') = e1 == e1' &&
e2 == e2'
(Case e1 bs) == (Case e1' bs') = e1 == e1' &&
bs == bs'
(Dirac e1) == (Dirac e1') = e1 == e1'
(Bind e1 e2 e3) == (Bind e1' e2' e3') = e1 == e1' &&
e2 == e2' &&
e3 == e3'
(Plate e1 e2 e3) == (Plate e1' e2' e3') = e1 == e1' &&
e2 == e2' &&
e3 == e3'
(Chain e1 e2 e3 e4) == (Chain e1' e2' e3' e4') = e1 == e1' &&
e2 == e2' &&
e3 == e3' &&
e4 == e4'
(Integrate a b c d) == (Integrate a' b' c' d') = a == a' &&
b == b' &&
c == c' &&
d == d'
(Summate a b c d) == (Summate a' b' c' d') = a == a' &&
b == b' &&
c == c' &&
d == d'
(Product a b c d) == (Product a' b' c' d') = a == a' &&
b == b' &&
c == c' &&
d == d'
(Expect e1 e2 e3) == (Expect e1' e2' e3') = e1 == e1' &&
e2 == e2' &&
e3 == e3'
(Observe e1 e2) == (Observe e1' e2') = e1 == e1' &&
e2 == e2'
(Msum es) == (Msum es') = es == es'
(Data e1 ts) == (Data e1' ts') = e1 == e1' &&
ts == ts'
(WithMeta e1 _ ) == e2 = e1 == e2
e1 == (WithMeta e2 _) = e1 == e2
_ == _ = False
data Import a = Import a
deriving (Eq, Show)
data ASTWithImport' a = ASTWithImport' [Import a] (AST' a)
deriving (Eq, Show)
val :: Literal' -> Some1 Literal
val (Nat n) = Some1 $ LNat (N.unsafeNatural n)
val (Int n) = Some1 $ LInt n
val (Prob n) = Some1 $ LProb (N.unsafeNonNegativeRational n)
val (Real n) = Some1 $ LReal n
data PrimOp
= Not | Impl | Diff | Nand | Nor
| Pi
| Sin | Cos | Tan
| Asin | Acos | Atan
| Sinh | Cosh | Tanh
| Asinh | Acosh | Atanh
| RealPow | NatPow
| Exp | Log | Infinity
| GammaFunc | BetaFunc
| Equal | Less
| Negate | Recip
| Abs | Signum | NatRoot | Erf
deriving (Eq, Show)
data SomeOp op where
SomeOp
:: (typs ~ UnLCs args, args ~ LCs typs)
=> !(op typs a)
-> SomeOp op
data SSing =
forall (a :: Hakaru). SSing !(Sing a)
data Branch_ abt = Branch_ Pattern (abt '[] 'U)
fmapBranch
:: (f '[] 'U -> g '[] 'U)
-> Branch_ f
-> Branch_ g
fmapBranch f (Branch_ pat e) = Branch_ pat (f e)
foldBranch
:: (abt '[] 'U -> m)
-> Branch_ abt
-> m
foldBranch f (Branch_ _ e) = f e
data Pattern
= PWild
| PVar Name
| PDatum Text PCode
data PFun
= PKonst Pattern
| PIdent Pattern
data PStruct
= PEt PFun PStruct
| PDone
data PCode
= PInr PCode
| PInl PStruct
infixr 7 `Et`, `PEt`
data DFun abt
= Konst (abt '[] 'U)
| Ident (abt '[] 'U)
data DStruct abt
= Et (DFun abt) (DStruct abt)
| Done
data DCode abt
= Inr (DCode abt)
| Inl (DStruct abt)
data Datum abt = Datum Text (DCode abt)
fmapDatum
:: (f '[] 'U -> g '[] 'U)
-> Datum f
-> Datum g
fmapDatum f (Datum t dcode) = Datum t (fdcode f dcode)
where fdcode f' (Inr d) = Inr (fdcode f' d)
fdcode f' (Inl d) = Inl (fdstruct f' d)
fdstruct f' (Et df ds) = Et (fdfun f' df)
(fdstruct f' ds)
fdstruct _ Done = Done
fdfun f' (Konst e) = Konst (f' e)
fdfun f' (Ident e) = Ident (f' e)
foldDatum
:: (Monoid m)
=> (abt '[] 'U -> m)
-> Datum abt
-> m
foldDatum f (Datum _ dcode) = fdcode f dcode
where fdcode f' (Inr d) = fdcode f' d
fdcode f' (Inl d) = fdstruct f' d
fdstruct f' (Et df ds) = fdfun f' df `mappend`
fdstruct f' ds
fdstruct _ Done = mempty
fdfun f' (Konst e) = f' e
fdfun f' (Ident e) = f' e
data Untyped = U
deriving (Read, Show)
data instance Sing (a :: Untyped) where
SU :: Sing 'U
instance SingI 'U where
sing = SU
instance Show (Sing (a :: Untyped)) where
showsPrec = showsPrec1
show = show1
instance Show1 (Sing :: Untyped -> *) where
showsPrec1 _ SU = ("SU" ++)
instance Eq (Sing (a :: Untyped)) where
SU == SU = True
instance Eq1 (Sing :: Untyped -> *) where
eq1 SU SU = True
instance JmEq1 (Sing :: Untyped -> *) where
jmEq1 SU SU = Just Refl
nameToVar :: Name -> Variable 'U
nameToVar (Name i h) = Variable h i SU
data Term :: ([Untyped] -> Untyped -> *) -> Untyped -> * where
Lam_ :: SSing -> abt '[ 'U ] 'U -> Term abt 'U
App_ :: abt '[] 'U -> abt '[] 'U -> Term abt 'U
Let_ :: abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Ann_ :: SSing -> abt '[] 'U -> Term abt 'U
CoerceTo_ :: Some2 Coercion -> abt '[] 'U -> Term abt 'U
UnsafeTo_ :: Some2 Coercion -> abt '[] 'U -> Term abt 'U
PrimOp_ :: PrimOp -> [abt '[] 'U] -> Term abt 'U
ArrayOp_ :: ArrayOp -> [abt '[] 'U] -> Term abt 'U
MeasureOp_ :: SomeOp MeasureOp -> [abt '[] 'U] -> Term abt 'U
NaryOp_ :: NaryOp -> [abt '[] 'U] -> Term abt 'U
Literal_ :: Some1 Literal -> Term abt 'U
Empty_ :: Term abt 'U
Pair_ :: abt '[] 'U -> abt '[] 'U -> Term abt 'U
Array_ :: abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Datum_ :: Datum abt -> Term abt 'U
Case_ :: abt '[] 'U -> [Branch_ abt] -> Term abt 'U
Dirac_ :: abt '[] 'U -> Term abt 'U
MBind_ :: abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Plate_ :: abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Chain_ :: abt '[] 'U -> abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Integrate_ :: abt '[] 'U -> abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Summate_ :: abt '[] 'U -> abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Product_ :: abt '[] 'U -> abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Expect_ :: abt '[] 'U -> abt '[ 'U ] 'U -> Term abt 'U
Observe_ :: abt '[] 'U -> abt '[] 'U -> Term abt 'U
Superpose_ :: L.NonEmpty (abt '[] 'U, abt '[] 'U) -> Term abt 'U
Reject_ :: Term abt 'U
instance Functor21 Term where
fmap21 f (Lam_ typ e1) = Lam_ typ (f e1)
fmap21 f (App_ e1 e2) = App_ (f e1) (f e2)
fmap21 f (Let_ e1 e2) = Let_ (f e1) (f e2)
fmap21 f (Ann_ typ e1) = Ann_ typ (f e1)
fmap21 f (CoerceTo_ c e1) = CoerceTo_ c (f e1)
fmap21 f (UnsafeTo_ c e1) = UnsafeTo_ c (f e1)
fmap21 f (PrimOp_ op es) = PrimOp_ op (fmap f es)
fmap21 f (ArrayOp_ op es) = ArrayOp_ op (fmap f es)
fmap21 f (MeasureOp_ op es) = MeasureOp_ op (fmap f es)
fmap21 f (NaryOp_ op es) = NaryOp_ op (fmap f es)
fmap21 _ (Literal_ v) = Literal_ v
fmap21 _ Empty_ = Empty_
fmap21 f (Pair_ e1 e2) = Pair_ (f e1) (f e2)
fmap21 f (Array_ e1 e2) = Array_ (f e1) (f e2)
fmap21 f (Datum_ d) = Datum_ (fmapDatum f d)
fmap21 f (Case_ e1 bs) = Case_ (f e1) (fmap (fmapBranch f) bs)
fmap21 f (Dirac_ e1) = Dirac_ (f e1)
fmap21 f (MBind_ e1 e2) = MBind_ (f e1) (f e2)
fmap21 f (Plate_ e1 e2) = Plate_ (f e1) (f e2)
fmap21 f (Chain_ e1 e2 e3) = Chain_ (f e1) (f e2) (f e3)
fmap21 f (Integrate_ e1 e2 e3) = Integrate_ (f e1) (f e2) (f e3)
fmap21 f (Summate_ e1 e2 e3) = Summate_ (f e1) (f e2) (f e3)
fmap21 f (Product_ e1 e2 e3) = Product_ (f e1) (f e2) (f e3)
fmap21 f (Expect_ e1 e2) = Expect_ (f e1) (f e2)
fmap21 f (Observe_ e1 e2) = Observe_ (f e1) (f e2)
fmap21 f (Superpose_ es) = Superpose_ (L.map (f *** f) es)
fmap21 _ Reject_ = Reject_
instance Foldable21 Term where
foldMap21 f (Lam_ _ e1) = f e1
foldMap21 f (App_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Let_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Ann_ _ e1) = f e1
foldMap21 f (CoerceTo_ _ e1) = f e1
foldMap21 f (UnsafeTo_ _ e1) = f e1
foldMap21 f (PrimOp_ _ es) = F.foldMap f es
foldMap21 f (ArrayOp_ _ es) = F.foldMap f es
foldMap21 f (MeasureOp_ _ es) = F.foldMap f es
foldMap21 f (NaryOp_ _ es) = F.foldMap f es
foldMap21 _ (Literal_ _) = mempty
foldMap21 _ Empty_ = mempty
foldMap21 f (Pair_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Array_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Datum_ d) = foldDatum f d
foldMap21 f (Case_ e1 bs) = f e1 `mappend` F.foldMap (foldBranch f) bs
foldMap21 f (Dirac_ e1) = f e1
foldMap21 f (MBind_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Plate_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Chain_ e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
foldMap21 f (Integrate_ e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
foldMap21 f (Summate_ e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
foldMap21 f (Product_ e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
foldMap21 f (Expect_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Observe_ e1 e2) = f e1 `mappend` f e2
foldMap21 f (Superpose_ es) = F.foldMap (\(e1,e2) -> f e1 `mappend` f e2) es
foldMap21 _ Reject_ = mempty
type AST = MetaABT SourceSpan Term '[] 'U
type MetaTerm = Term (MetaABT SourceSpan Term) 'U
type Branch = Branch_ (MetaABT SourceSpan Term)
type DFun_ = DFun (MetaABT SourceSpan Term)
type DStruct_ = DStruct (MetaABT SourceSpan Term)
type DCode_ = DCode (MetaABT SourceSpan Term)