{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
module Downhill.TH
(
mkHasGradInstances,
AffineSpaceOptions (..),
RecordNamer (..),
BVarOptions (..),
defaultBVarOptions,
)
where
import Control.Monad
import Data.AdditiveGroup ((^+^), (^-^))
import Data.AffineSpace (AffineSpace (Diff, (.+^), (.-.)))
import Data.Foldable (traverse_)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import Data.VectorSpace (AdditiveGroup (negateV, zeroV), VectorSpace (Scalar, (*^)))
import Downhill.BVar (BVar (BVar))
import Downhill.Grad
( Dual (evalGrad),
HasGrad (Grad, MScalar, Metric, Tang),
MetricTensor (MtCovector, MtVector, evalMetric, sqrNorm),
)
import Downhill.Linear.Expr (BasicVector (VecBuilder, sumBuilder))
import Downhill.Linear.Lift (lift1_sparse)
import GHC.Records (HasField (getField))
import Language.Haskell.TH
( Bang (Bang),
Con (NormalC, RecC),
Cxt,
Dec (DataD, InstanceD, NewtypeD, SigD),
Exp (AppE, ConE, InfixE, VarE),
Name,
Pat (VarP),
Q,
SourceStrictness (NoSourceStrictness),
SourceUnpackedness (NoSourceUnpackedness),
Type (AppT, ConT, VarT),
nameBase,
newName,
)
import Language.Haskell.TH.Datatype (ConstructorInfo (constructorFields, constructorName, constructorVariant), ConstructorVariant (InfixConstructor, NormalConstructor, RecordConstructor), DatatypeInfo (datatypeCons, datatypeInstTypes, datatypeName, datatypeVariant, datatypeVars), DatatypeVariant (Newtype), TypeSubstitution (applySubstitution), reifyDatatype)
import Language.Haskell.TH.Datatype.TyVarBndr (TyVarBndrUnit)
import Language.Haskell.TH.Syntax
( BangType,
Body (NormalB),
Clause (Clause),
Dec (FunD, TySynInstD, ValD),
Exp (AppTypeE),
TyLit (StrTyLit),
TySynEqn (TySynEqn),
Type (ArrowT, EqualityT, LitT, SigT),
VarBangType,
mkNameS,
)
import qualified Language.Haskell.TH
data DatatypeFields
= NormalFields [Type]
| RecordFields [(String, Type)]
deriving (Int -> DatatypeFields -> ShowS
[DatatypeFields] -> ShowS
DatatypeFields -> String
(Int -> DatatypeFields -> ShowS)
-> (DatatypeFields -> String)
-> ([DatatypeFields] -> ShowS)
-> Show DatatypeFields
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DatatypeFields] -> ShowS
$cshowList :: [DatatypeFields] -> ShowS
show :: DatatypeFields -> String
$cshow :: DatatypeFields -> String
showsPrec :: Int -> DatatypeFields -> ShowS
$cshowsPrec :: Int -> DatatypeFields -> ShowS
Show)
data DownhillRecord = DownhillRecord
{ DownhillRecord -> Name
ddtTypeConName :: Name,
DownhillRecord -> Name
ddtDataConName :: Name,
DownhillRecord -> [Type]
ddtFieldTypes :: [Type],
DownhillRecord -> Maybe [String]
ddtFieldNames :: Maybe [String],
DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars :: [TyVarBndrUnit],
DownhillRecord -> Int
ddtFieldCount :: Int,
DownhillRecord -> DatatypeVariant
ddtVariant :: DatatypeVariant
}
deriving (Int -> DownhillRecord -> ShowS
[DownhillRecord] -> ShowS
DownhillRecord -> String
(Int -> DownhillRecord -> ShowS)
-> (DownhillRecord -> String)
-> ([DownhillRecord] -> ShowS)
-> Show DownhillRecord
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DownhillRecord] -> ShowS
$cshowList :: [DownhillRecord] -> ShowS
show :: DownhillRecord -> String
$cshow :: DownhillRecord -> String
showsPrec :: Int -> DownhillRecord -> ShowS
$cshowsPrec :: Int -> DownhillRecord -> ShowS
Show)
data RecordNamer = RecordNamer
{ RecordNamer -> ShowS
typeConNamer :: String -> String,
RecordNamer -> ShowS
dataConNamer :: String -> String,
RecordNamer -> ShowS
fieldNamer :: String -> String
}
data RecordTranstorm = RecordTranstorm RecordNamer (Type -> Type)
data AffineSpaceOptions
=
MakeAffineSpace
|
NoAffineSpace
|
AutoAffineSpace
data BVarOptions = BVarOptions
{ BVarOptions -> RecordNamer
optTangNamer :: RecordNamer,
BVarOptions -> RecordNamer
optGradNamer :: RecordNamer,
BVarOptions -> RecordNamer
optMetricNamer :: RecordNamer,
BVarOptions -> RecordNamer
optBuilderNamer :: RecordNamer,
BVarOptions -> AffineSpaceOptions
optAffineSpace :: AffineSpaceOptions,
BVarOptions -> [String]
optExcludeFields :: [String]
}
pattern ConP :: Name -> [Pat] -> Pat
#if MIN_VERSION_template_haskell(2,18,0)
pattern ConP x y = Language.Haskell.TH.ConP x [] y
#else
pattern $bConP :: Name -> [Pat] -> Pat
$mConP :: forall r. Pat -> (Name -> [Pat] -> r) -> (Void# -> r) -> r
ConP x y = Language.Haskell.TH.ConP x y
#endif
defaultTangRecordNamer :: RecordNamer
defaultTangRecordNamer :: RecordNamer
defaultTangRecordNamer =
RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
{ typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Tang"),
dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Tang"),
fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
}
defaultGradRecordNamer :: RecordNamer
defaultGradRecordNamer :: RecordNamer
defaultGradRecordNamer =
RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
{ typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Grad"),
dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Grad"),
fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
}
defaultMetricRecordNamer :: RecordNamer
defaultMetricRecordNamer :: RecordNamer
defaultMetricRecordNamer =
RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
{ typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Metric"),
dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Metric"),
fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
}
defaultBuilderRecordNamer :: RecordNamer
defaultBuilderRecordNamer :: RecordNamer
defaultBuilderRecordNamer =
RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
{ typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Builder"),
dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Builder"),
fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
}
defaultBVarOptions :: BVarOptions
defaultBVarOptions :: BVarOptions
defaultBVarOptions =
BVarOptions :: RecordNamer
-> RecordNamer
-> RecordNamer
-> RecordNamer
-> AffineSpaceOptions
-> [String]
-> BVarOptions
BVarOptions
{ optTangNamer :: RecordNamer
optTangNamer = RecordNamer
defaultTangRecordNamer,
optGradNamer :: RecordNamer
optGradNamer = RecordNamer
defaultGradRecordNamer,
optMetricNamer :: RecordNamer
optMetricNamer = RecordNamer
defaultMetricRecordNamer,
optBuilderNamer :: RecordNamer
optBuilderNamer = RecordNamer
defaultBuilderRecordNamer,
optAffineSpace :: AffineSpaceOptions
optAffineSpace = AffineSpaceOptions
AutoAffineSpace,
optExcludeFields :: [String]
optExcludeFields = []
}
mkConstructor :: DownhillRecord -> Con
mkConstructor :: DownhillRecord -> Con
mkConstructor DownhillRecord
record =
case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
Maybe [String]
Nothing ->
Name -> [BangType] -> Con
NormalC Name
newConstrName ((Type -> BangType) -> [Type] -> [BangType]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BangType
mkType (DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record))
Just [String]
names ->
Name -> [VarBangType] -> Con
RecC Name
newConstrName ((String -> Type -> VarBangType)
-> [String] -> [Type] -> [VarBangType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith String -> Type -> VarBangType
mkRecType [String]
names (DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record))
where
newConstrName :: Name
newConstrName :: Name
newConstrName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
mkRecType :: String -> Type -> VarBangType
mkRecType :: String -> Type -> VarBangType
mkRecType String
name Type
type_ =
( String -> Name
mkNameS String
name,
SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness,
Type
type_
)
mkType :: Type -> BangType
mkType :: Type -> BangType
mkType Type
type_ =
( SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness,
Type
type_
)
parseGradConstructor :: Name -> DatatypeInfo -> ConstructorInfo -> [TyVarBndrUnit] -> Q DownhillRecord
parseGradConstructor :: Name
-> DatatypeInfo
-> ConstructorInfo
-> [TyVarBndrUnit]
-> Q DownhillRecord
parseGradConstructor Name
tyName DatatypeInfo
dinfo ConstructorInfo
cinfo [TyVarBndrUnit]
typevars = do
let types :: [Type]
types = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
cinfo
n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
types
([Type]
fieldTypes, Maybe [String]
fieldNames) <- case ConstructorInfo -> ConstructorVariant
constructorVariant ConstructorInfo
cinfo of
ConstructorVariant
NormalConstructor -> ([Type], Maybe [String]) -> Q ([Type], Maybe [String])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
types, Maybe [String]
forall a. Maybe a
Nothing)
ConstructorVariant
InfixConstructor -> ([Type], Maybe [String]) -> Q ([Type], Maybe [String])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
types, Maybe [String]
forall a. Maybe a
Nothing)
RecordConstructor [Name]
fieldNames -> do
([Type], Maybe [String]) -> Q ([Type], Maybe [String])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
types, [String] -> Maybe [String]
forall a. a -> Maybe a
Just (Name -> String
nameBase (Name -> String) -> [Name] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
fieldNames))
DownhillRecord -> Q DownhillRecord
forall (m :: * -> *) a. Monad m => a -> m a
return
DownhillRecord :: Name
-> Name
-> [Type]
-> Maybe [String]
-> [TyVarBndrUnit]
-> Int
-> DatatypeVariant
-> DownhillRecord
DownhillRecord
{ ddtTypeConName :: Name
ddtTypeConName = Name
tyName,
ddtDataConName :: Name
ddtDataConName = ConstructorInfo -> Name
constructorName ConstructorInfo
cinfo,
ddtTypeVars :: [TyVarBndrUnit]
ddtTypeVars = [TyVarBndrUnit]
typevars,
ddtFieldCount :: Int
ddtFieldCount = Int
n,
ddtFieldTypes :: [Type]
ddtFieldTypes = [Type]
fieldTypes,
ddtFieldNames :: Maybe [String]
ddtFieldNames = Maybe [String]
fieldNames,
ddtVariant :: DatatypeVariant
ddtVariant = DatatypeInfo -> DatatypeVariant
datatypeVariant DatatypeInfo
dinfo
}
parseDownhillRecord :: Name -> DatatypeInfo -> Q (DownhillRecord, ConstructorInfo)
parseDownhillRecord :: Name -> DatatypeInfo -> Q (DownhillRecord, ConstructorInfo)
parseDownhillRecord Name
recordName DatatypeInfo
record' = do
let name :: Name
name = DatatypeInfo -> Name
datatypeName DatatypeInfo
record'
let typevars :: [TyVarBndrUnit]
typevars = DatatypeInfo -> [TyVarBndrUnit]
datatypeVars DatatypeInfo
record'
constructors' :: [ConstructorInfo]
constructors' = DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
record'
ConstructorInfo
constr' <- case [ConstructorInfo]
constructors' of
[] -> String -> Q ConstructorInfo
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Name -> String
forall a. Show a => a -> String
show Name
recordName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has no data constructors")
[ConstructorInfo
constr''] -> ConstructorInfo -> Q ConstructorInfo
forall (m :: * -> *) a. Monad m => a -> m a
return ConstructorInfo
constr''
[ConstructorInfo]
_ -> String -> Q ConstructorInfo
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Name -> String
forall a. Show a => a -> String
show Name
recordName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has multiple data constructors")
DownhillRecord
r <- Name
-> DatatypeInfo
-> ConstructorInfo
-> [TyVarBndrUnit]
-> Q DownhillRecord
parseGradConstructor Name
name DatatypeInfo
record' ConstructorInfo
constr' [TyVarBndrUnit]
typevars
(DownhillRecord, ConstructorInfo)
-> Q (DownhillRecord, ConstructorInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (DownhillRecord
r, ConstructorInfo
constr')
elementwiseOp :: DownhillRecord -> Name -> Q Dec
elementwiseOp :: DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record = DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
record DownhillRecord
record DownhillRecord
record
elementwiseOp' :: DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' :: DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
leftRecord DownhillRecord
rightRecord DownhillRecord
resRecord Name
func = do
let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
resRecord
[Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
[Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"y")
let fieldOp :: Name -> Name -> Exp
fieldOp :: Name -> Name -> Exp
fieldOp Name
x Name
y = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
x)) (Name -> Exp
VarE Name
func) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
y))
resultFields :: [Exp]
resultFields :: [Exp]
resultFields = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
fieldOp [Name]
xs [Name]
ys
leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
leftRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
rightRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
rhs :: Exp
rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
resRecord)) [Exp]
resultFields
dec :: Dec
dec =
Name -> [Clause] -> Dec
FunD
Name
func
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Pat
leftPat, Pat
rightPat]
(Exp -> Body
NormalB Exp
rhs)
[]
]
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
dec
elementwiseValue :: DownhillRecord -> Name -> Q Dec
elementwiseValue :: DownhillRecord -> Name -> Q Dec
elementwiseValue DownhillRecord
record Name
func = do
let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record
dataConName :: Name
dataConName :: Name
dataConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
rhs :: Exp
rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
dataConName) (Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate Int
n (Name -> Exp
VarE 'zeroV))
dec :: Dec
dec = Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
func) (Exp -> Body
NormalB Exp
rhs) []
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
dec
elementwiseFunc :: DownhillRecord -> Name -> Q Dec
elementwiseFunc :: DownhillRecord -> Name -> Q Dec
elementwiseFunc DownhillRecord
record Name
func = do
let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record
dataConName :: Name
dataConName :: Name
dataConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
rhsConName :: Name
rhsConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
[Name]
xs <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
Maybe [String]
Nothing -> Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
Just [String]
names -> (String -> Q Name) -> [String] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse String -> Q Name
newName [String]
names
let fieldOp :: Name -> Exp
fieldOp :: Name -> Exp
fieldOp = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
func) (Exp -> Exp) -> (Name -> Exp) -> Name -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE
resultFields :: [Exp]
resultFields :: [Exp]
resultFields = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
fieldOp [Name]
xs
leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP Name
dataConName ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
rhs :: Exp
rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
rhsConName) [Exp]
resultFields
dec :: Dec
dec =
Name -> [Clause] -> Dec
FunD
Name
func
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Pat
leftPat]
(Exp -> Body
NormalB Exp
rhs)
[]
]
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
dec
mkClassInstance :: Name -> Cxt -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance :: Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance Name
className [Type]
cxt DownhillRecord
record [Type]
instVars [Dec]
decs = do
let recordType :: Type
recordType = Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record)
ihead :: Type
ihead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
recordType [Type]
instVars)
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
cxt Type
ihead [Dec]
decs]
mkSemigroupInstance :: Cxt -> DownhillRecord -> [Type] -> Q [Dec]
mkSemigroupInstance :: [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkSemigroupInstance [Type]
cxt DownhillRecord
record [Type]
instVars = do
Dec
dec <- DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record '(<>)
Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''Semigroup [Type]
cxt DownhillRecord
record [Type]
instVars [Dec
dec]
mkAdditiveGroupInstance :: Cxt -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance :: [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance [Type]
cxt DownhillRecord
record [Type]
instVars = do
Dec
zeroVDec <- DownhillRecord -> Name -> Q Dec
elementwiseValue DownhillRecord
record 'zeroV
Dec
negateDec <- DownhillRecord -> Name -> Q Dec
elementwiseFunc DownhillRecord
record 'negateV
Dec
plusDec <- DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record '(^+^)
Dec
minusDec <- DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record '(^-^)
let decs :: [Dec]
decs =
[ Dec
zeroVDec,
Dec
negateDec,
Dec
plusDec,
Dec
minusDec
]
Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''AdditiveGroup [Type]
cxt DownhillRecord
record [Type]
instVars [Dec]
decs
mkVectorSpaceInstance :: DownhillRecord -> Type -> Cxt -> [Type] -> Q [Dec]
mkVectorSpaceInstance :: DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkVectorSpaceInstance DownhillRecord
record Type
scalarType [Type]
cxt [Type]
instVars = do
let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record
dataConName :: Name
dataConName :: Name
dataConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
[Name]
xs <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
Maybe [String]
Nothing -> Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
Just [String]
names -> (String -> Q Name) -> [String] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse String -> Q Name
newName [String]
names
Name
lhsName <- String -> Q Name
newName String
"s"
let rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
record) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
recordType :: Type
recordType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record)) [Type]
instVars
mulField :: Name -> Exp
mulField :: Name -> Exp
mulField Name
y = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
lhsName)) (Name -> Exp
VarE '(*^)) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
y))
rhsMulV :: Exp
rhsMulV :: Exp
rhsMulV = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
dataConName) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
mulField [Name]
xs)
let vmulDec :: Dec
vmulDec =
Name -> [Clause] -> Dec
FunD
'(*^)
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Name -> Pat
VarP Name
lhsName, Pat
rightPat]
(Exp -> Body
NormalB Exp
rhsMulV)
[]
]
scalarTypeDec :: Dec
scalarTypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''Scalar) Type
recordType)
Type
scalarType
)
decs :: [Dec]
decs = [Dec
scalarTypeDec, Dec
vmulDec]
Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''VectorSpace [Type]
cxt DownhillRecord
record [Type]
instVars [Dec]
decs
mkBasicVectorInstance :: DownhillRecord -> BVarOptions -> Cxt -> [Type] -> Q [Dec]
mkBasicVectorInstance :: DownhillRecord -> BVarOptions -> [Type] -> [Type] -> Q [Dec]
mkBasicVectorInstance DownhillRecord
vectorRecord BVarOptions
options [Type]
cxt [Type]
instVars = do
Dec
sumBuilderDec <- Q Dec
mkSumBuilder
Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''BasicVector [Type]
cxt DownhillRecord
vectorRecord [Type]
instVars [Dec
vecbuilderDec, Dec
sumBuilderDec]
where
n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
vectorRecord
builderRecord :: DownhillRecord
builderRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options) DownhillRecord
vectorRecord
mkSumBuilder :: Q Dec
mkSumBuilder :: Q Dec
mkSumBuilder = do
[Name]
builders <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
let pat :: Pat
pat :: Pat
pat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
builderRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
builders)
rhs :: Exp
rhs :: Exp
rhs =
(Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
Exp -> Exp -> Exp
AppE
(Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
vectorRecord))
[Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'sumBuilder) (Name -> Exp
VarE Name
x) | Name
x <- [Name]
builders]
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
Name -> [Clause] -> Dec
FunD
'sumBuilder
[ [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> [Pat] -> Pat
ConP 'Nothing []] (Exp -> Body
NormalB (Name -> Exp
VarE 'zeroV)) [],
[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> [Pat] -> Pat
ConP 'Just [Pat
pat]] (Exp -> Body
NormalB Exp
rhs) []
]
vecbuilderDec :: Dec
vecbuilderDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''VecBuilder) Type
vectorType)
(Type -> Type -> Type
AppT (Name -> Type
ConT ''Maybe) Type
builderType)
)
where
vectorType :: Type
vectorType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
vectorRecord)) [Type]
instVars
builderType :: Type
builderType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
builderRecord)) [Type]
instVars
sumVExpr :: [Exp] -> Exp
sumVExpr :: [Exp] -> Exp
sumVExpr = \case
[] -> Name -> Exp
VarE 'zeroV
[Exp]
exps -> (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (Name -> Exp -> Exp -> Exp
zipExpInfix '(^+^)) [Exp]
exps
where
zipExpInfix :: Name -> Exp -> Exp -> Exp
zipExpInfix :: Name -> Exp -> Exp -> Exp
zipExpInfix Name
f Exp
x Exp
y = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
x) (Name -> Exp
VarE Name
f) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
y)
mkDualInstance ::
DownhillRecord ->
DownhillRecord ->
Type ->
Cxt ->
[Type] ->
Q [Dec]
mkDualInstance :: DownhillRecord
-> DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkDualInstance DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars = do
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DownhillRecord -> Int
ddtFieldCount DownhillRecord
tangRecord Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= DownhillRecord -> Int
ddtFieldCount DownhillRecord
gradRecord) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"mkDualInstance: ddtFieldCount tangRecord /= ddtFieldCount gradRecord"
Name
scalarTypeName <- String -> Q Name
newName String
"s"
Type -> Q [Dec]
mkClassDec (Name -> Type
VarT Name
scalarTypeName)
where
n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
tangRecord
mkClassDec :: Type -> Q [Dec]
mkClassDec :: Type -> Q [Dec]
mkClassDec Type
scalarVar = do
Dec
evalGradDec <- Q Dec
mkEvalGradDec
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing ([Type]
cxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
newConstraints) Type
ihead [Dec
evalGradDec]]
where
ihead :: Type
ihead :: Type
ihead = Name -> Type
ConT ''Dual Type -> Type -> Type
`AppT` Type
scalarVar Type -> Type -> Type
`AppT` Type
vecType Type -> Type -> Type
`AppT` Type
gradType
where
vecType :: Type
vecType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
tangRecord) [Type]
instVars
gradType :: Type
gradType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
gradRecord) [Type]
instVars
newConstraints :: Cxt
newConstraints :: [Type]
newConstraints =
[
Type -> Type -> Type
AppT (Name -> Type
ConT ''AdditiveGroup) Type
scalarVar,
Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
EqualityT Type
scalarVar) Type
scalarType
]
mkEvalGradDec :: Q Dec
mkEvalGradDec :: Q Dec
mkEvalGradDec = do
[Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
[Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"y")
let leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
tangRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
terms :: [Exp]
terms :: [Exp]
terms = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
evalGradExp [Name]
xs [Name]
ys
where
evalGradExp :: Name -> Name -> Exp
evalGradExp :: Name -> Name -> Exp
evalGradExp Name
x Name
y = Name -> Exp
VarE 'evalGrad Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y
rhs :: Exp
rhs = [Exp] -> Exp
sumVExpr [Exp]
terms
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
Name -> [Clause] -> Dec
FunD
'evalGrad
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Pat
leftPat, Pat
rightPat]
(Exp -> Body
NormalB Exp
rhs)
[]
]
mkMetricInstance ::
DownhillRecord ->
DownhillRecord ->
DownhillRecord ->
Type ->
Cxt ->
[Type] ->
Q [Dec]
mkMetricInstance :: DownhillRecord
-> DownhillRecord
-> DownhillRecord
-> Type
-> [Type]
-> [Type]
-> Q [Dec]
mkMetricInstance DownhillRecord
metricRecord DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars = do
Name
scalarTypeName <- String -> Q Name
newName String
"s"
Type -> Q [Dec]
mkClassDec (Name -> Type
VarT Name
scalarTypeName)
where
mkClassDec :: Type -> Q [Dec]
mkClassDec :: Type -> Q [Dec]
mkClassDec Type
scalarVar = do
let newConstraints :: [Type]
newConstraints =
[
Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
EqualityT Type
scalarVar) Type
scalarType
]
ihead :: Type
ihead = Name -> Type
ConT ''MetricTensor Type -> Type -> Type
`AppT` Type
metricType
Dec
evalMetricDec <- Q Dec
mkEvalMetric
Dec
sqrNormDec <- Q Dec
mkSqrNorm
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD
Maybe Overlap
forall a. Maybe a
Nothing
([Type]
cxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
newConstraints)
Type
ihead
[Dec
vectypeDec, Dec
covectorTypeDec, Dec
evalMetricDec, Dec
sqrNormDec]
]
where
vectorType :: Type
vectorType :: Type
vectorType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
tangRecord) [Type]
instVars
covectorType :: Type
covectorType :: Type
covectorType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
gradRecord) [Type]
instVars
metricType :: Type
metricType :: Type
metricType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
metricRecord) [Type]
instVars
vectypeDec :: Dec
vectypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''MtVector) Type
metricType)
Type
vectorType
)
covectorTypeDec :: Dec
covectorTypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''MtCovector) Type
metricType)
Type
covectorType
)
mkEvalMetric :: Q Dec
mkEvalMetric :: Q Dec
mkEvalMetric = do
let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
metricRecord
[Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"m")
[Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"dv")
let leftPat, rightPat :: Pat
leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
metricRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
terms :: [Exp]
terms :: [Exp]
terms = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
evalGradExp [Name]
xs [Name]
ys
where
evalGradExp :: Name -> Name -> Exp
evalGradExp :: Name -> Name -> Exp
evalGradExp Name
x Name
y = Name -> Exp
VarE 'evalMetric Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y
rhs :: Exp
rhs =
(Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
Exp -> Exp -> Exp
AppE
(Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
tangRecord))
[Exp]
terms
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
Name -> [Clause] -> Dec
FunD
'evalMetric
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Pat
leftPat, Pat
rightPat]
(Exp -> Body
NormalB Exp
rhs)
[]
]
mkSqrNorm :: Q Dec
mkSqrNorm :: Q Dec
mkSqrNorm = do
let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
metricRecord
[Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"m")
[Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"dv")
let leftPat, rightPat :: Pat
leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
metricRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
terms :: [Exp]
terms :: [Exp]
terms = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
evalSqrtNorm [Name]
xs [Name]
ys
where
evalSqrtNorm :: Name -> Name -> Exp
evalSqrtNorm :: Name -> Name -> Exp
evalSqrtNorm Name
x Name
y = Name -> Exp
VarE 'sqrNorm Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y
rhs :: Exp
rhs = [Exp] -> Exp
sumVExpr [Exp]
terms
Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
Name -> [Clause] -> Dec
FunD
'sqrNorm
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Pat
leftPat, Pat
rightPat]
(Exp -> Body
NormalB Exp
rhs)
[]
]
mkRecord :: DownhillRecord -> Q [Dec]
mkRecord :: DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
record = do
let newConstr :: Con
newConstr = DownhillRecord -> Con
mkConstructor DownhillRecord
record
let newRecordName :: Name
newRecordName = DownhillRecord -> Name
ddtTypeConName DownhillRecord
record
let dataType :: Dec
dataType = case DownhillRecord -> DatatypeVariant
ddtVariant DownhillRecord
record of
DatatypeVariant
Newtype -> [Type]
-> Name
-> [TyVarBndrUnit]
-> Maybe Type
-> Con
-> [DerivClause]
-> Dec
NewtypeD [] Name
newRecordName (DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars DownhillRecord
record) Maybe Type
forall a. Maybe a
Nothing Con
newConstr []
DatatypeVariant
_ -> [Type]
-> Name
-> [TyVarBndrUnit]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
newRecordName (DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars DownhillRecord
record) Maybe Type
forall a. Maybe a
Nothing [Con
newConstr] []
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Dec
dataType]
renameTypeS :: (String -> String) -> Name -> Name
renameTypeS :: ShowS -> Name -> Name
renameTypeS ShowS
f = String -> Name
mkNameS (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
f ShowS -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
data FieldInfo = FieldInfo
{ FieldInfo -> String
fiName :: String,
FieldInfo -> Int
fiIndex :: Int,
FieldInfo -> Type
fiType :: Type
}
mkGetField ::
DownhillRecord ->
DownhillRecord ->
Cxt ->
[Type] ->
FieldInfo ->
Q [Dec]
mkGetField :: DownhillRecord
-> DownhillRecord -> [Type] -> [Type] -> FieldInfo -> Q [Dec]
mkGetField DownhillRecord
pointRecord DownhillRecord
gradBuilderRecord [Type]
cxt [Type]
instVars FieldInfo
field = do
Name
rName <- String -> Q Name
newName String
"r"
Name
xName <- String -> Q Name
newName String
"x"
Name
dxName <- String -> Q Name
newName String
"dx"
Name
goName <- String -> Q Name
newName String
"go"
Name
dxdaName <- String -> Q Name
newName String
"dx_da"
let rhsFieldList :: [Exp]
rhsFieldList :: [Exp]
rhsFieldList =
Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate (FieldInfo -> Int
fiIndex FieldInfo
field) (Name -> Exp
VarE 'mempty)
[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Name -> Exp
VarE Name
dxdaName]
[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- FieldInfo -> Int
fiIndex FieldInfo
field Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Name -> Exp
VarE 'mempty)
rhs :: Exp
rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradBuilderRecord)) [Exp]
rhsFieldList
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD
Maybe Overlap
forall a. Maybe a
Nothing
[Type]
cxt
( Type -> Type -> Type
AppT
( Type -> Type -> Type
AppT
(Type -> Type -> Type
AppT (Name -> Type
ConT ''HasField) (TyLit -> Type
LitT (String -> TyLit
StrTyLit (FieldInfo -> String
fiName FieldInfo
field))))
(Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''BVar) (Name -> Type
VarT Name
rName)) Type
pointType)
)
(Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''BVar) (Name -> Type
VarT Name
rName)) (FieldInfo -> Type
fiType FieldInfo
field))
)
[ Name -> [Clause] -> Dec
FunD
'getField
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Name -> [Pat] -> Pat
ConP 'BVar [Name -> Pat
VarP Name
xName, Name -> Pat
VarP Name
dxName]]
( Exp -> Body
NormalB
( Exp -> Exp -> Exp
AppE
( Exp -> Exp -> Exp
AppE
(Name -> Exp
ConE 'BVar)
(Exp -> Exp -> Exp
AppE (Exp -> Type -> Exp
AppTypeE (Name -> Exp
VarE 'getField) (TyLit -> Type
LitT (String -> TyLit
StrTyLit (FieldInfo -> String
fiName FieldInfo
field)))) (Name -> Exp
VarE Name
xName))
)
(Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'lift1_sparse) (Name -> Exp
VarE Name
goName)) (Name -> Exp
VarE Name
dxName))
)
)
[ Name -> Type -> Dec
SigD
Name
goName
( Type -> Type -> Type
AppT
( Type -> Type -> Type
AppT
Type
ArrowT
( Name -> Type
ConT ''VecBuilder
Type -> Type -> Type
`AppT` Type -> Type -> Type
AppT (Name -> Type
ConT ''Grad) (FieldInfo -> Type
fiType FieldInfo
field)
)
)
(Name -> Type
ConT ''Maybe Type -> Type -> Type
`AppT` Type
gradBuilderType)
),
Name -> [Clause] -> Dec
FunD
Name
goName
[ [Pat] -> Body -> [Dec] -> Clause
Clause
[Name -> Pat
VarP Name
dxdaName]
( Exp -> Body
NormalB
( Exp -> Exp -> Exp
AppE
(Name -> Exp
ConE 'Just)
Exp
rhs
)
)
[]
]
]
]
]
]
where
n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
pointRecord
applyVars :: Type -> Type
applyVars :: Type -> Type
applyVars Type
x = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
x [Type]
instVars
pointType :: Type
pointType :: Type
pointType = Type -> Type
applyVars (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
pointRecord)
gradBuilderType :: Type
gradBuilderType = Type -> Type
applyVars (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
gradBuilderRecord)
renameDownhillRecord :: RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord :: RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (RecordTranstorm RecordNamer
namer Type -> Type
typeFun) DownhillRecord
record =
DownhillRecord :: Name
-> Name
-> [Type]
-> Maybe [String]
-> [TyVarBndrUnit]
-> Int
-> DatatypeVariant
-> DownhillRecord
DownhillRecord
{ ddtTypeConName :: Name
ddtTypeConName = ShowS -> Name -> Name
renameTypeS (RecordNamer -> ShowS
typeConNamer RecordNamer
namer) (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record),
ddtDataConName :: Name
ddtDataConName = ShowS -> Name -> Name
renameTypeS (RecordNamer -> ShowS
dataConNamer RecordNamer
namer) (DownhillRecord -> Name
ddtDataConName DownhillRecord
record),
ddtTypeVars :: [TyVarBndrUnit]
ddtTypeVars = DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars DownhillRecord
record,
ddtFieldCount :: Int
ddtFieldCount = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record,
ddtFieldTypes :: [Type]
ddtFieldTypes = Type -> Type
typeFun (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record,
ddtFieldNames :: Maybe [String]
ddtFieldNames = ([String] -> [String]) -> Maybe [String] -> Maybe [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ShowS -> [String] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (RecordNamer -> ShowS
fieldNamer RecordNamer
namer)) (DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record),
ddtVariant :: DatatypeVariant
ddtVariant = DownhillRecord -> DatatypeVariant
ddtVariant DownhillRecord
record
}
builderTransform :: BVarOptions -> RecordTranstorm
builderTransform :: BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optBuilderNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''VecBuilder))
tangTransform :: BVarOptions -> RecordTranstorm
tangTransform :: BVarOptions -> RecordTranstorm
tangTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optTangNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''Tang))
gradTransform :: BVarOptions -> RecordTranstorm
gradTransform :: BVarOptions -> RecordTranstorm
gradTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optGradNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''Grad))
metricTransform :: BVarOptions -> RecordTranstorm
metricTransform :: BVarOptions -> RecordTranstorm
metricTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optMetricNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''Metric))
mkVec :: Cxt -> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec :: [Type]
-> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec [Type]
cxt [Type]
instVars Type
scalarType DownhillRecord
vectorType BVarOptions
options = do
let builderType :: DownhillRecord
builderType = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options) DownhillRecord
vectorType
[Dec]
tangDec <- DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
vectorType
[Dec]
tangBuilderDec <- DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
builderType
[Dec]
tangSemigroup <- [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkSemigroupInstance [Type]
cxt DownhillRecord
builderType [Type]
instVars
[Dec]
tangInst <- DownhillRecord -> BVarOptions -> [Type] -> [Type] -> Q [Dec]
mkBasicVectorInstance DownhillRecord
vectorType BVarOptions
options [Type]
cxt [Type]
instVars
[Dec]
additiveTang <- [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance [Type]
cxt DownhillRecord
vectorType [Type]
instVars
[Dec]
vspaceTang <- DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkVectorSpaceInstance DownhillRecord
vectorType Type
scalarType [Type]
cxt [Type]
instVars
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
( [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ [Dec]
tangDec,
[Dec]
tangBuilderDec,
[Dec]
tangInst,
[Dec]
tangSemigroup,
[Dec]
additiveTang,
[Dec]
vspaceTang
]
)
mkDVar'' ::
Cxt ->
DownhillRecord ->
BVarOptions ->
Type ->
[Type] ->
ConstructorInfo ->
Q [Dec]
mkDVar'' :: [Type]
-> DownhillRecord
-> BVarOptions
-> Type
-> [Type]
-> ConstructorInfo
-> Q [Dec]
mkDVar'' [Type]
cxt DownhillRecord
pointRecord BVarOptions
options Type
scalarType [Type]
instVars ConstructorInfo
substitutedCInfo = do
let tangRecord :: DownhillRecord
tangRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
tangTransform BVarOptions
options) DownhillRecord
pointRecord
gradRecord :: DownhillRecord
gradRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
gradTransform BVarOptions
options) DownhillRecord
pointRecord
metricRecord :: DownhillRecord
metricRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
metricTransform BVarOptions
options) DownhillRecord
pointRecord
[Dec]
tangDecs <- [Type]
-> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec [Type]
cxt [Type]
instVars Type
scalarType DownhillRecord
tangRecord BVarOptions
options
[Dec]
gradDecs <- [Type]
-> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec [Type]
cxt [Type]
instVars Type
scalarType DownhillRecord
gradRecord BVarOptions
options
[Dec]
metricDec <- DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
metricRecord
[Dec]
additiveMetric <- [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance [Type]
cxt DownhillRecord
metricRecord [Type]
instVars
[Dec]
vspaceMetric <- DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkVectorSpaceInstance DownhillRecord
metricRecord Type
scalarType [Type]
cxt [Type]
instVars
[Dec]
dualInstance <- DownhillRecord
-> DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkDualInstance DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars
[Dec]
metricInstance <- DownhillRecord
-> DownhillRecord
-> DownhillRecord
-> Type
-> [Type]
-> [Type]
-> Q [Dec]
mkMetricInstance DownhillRecord
metricRecord DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars
let needAffineSpace :: Bool
needAffineSpace = case BVarOptions -> AffineSpaceOptions
optAffineSpace BVarOptions
options of
AffineSpaceOptions
MakeAffineSpace -> Bool
True
AffineSpaceOptions
NoAffineSpace -> Bool
False
AffineSpaceOptions
AutoAffineSpace -> [String] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (BVarOptions -> [String]
optExcludeFields BVarOptions
options)
[Dec]
affineSpaceInstance <-
if Bool
needAffineSpace
then [Type] -> DownhillRecord -> DownhillRecord -> [Type] -> Q [Dec]
mkAffineSpaceInstance [Type]
cxt DownhillRecord
pointRecord DownhillRecord
tangRecord [Type]
instVars
else [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []
[Dec]
hasFieldInstance <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
pointRecord of
Maybe [String]
Nothing -> [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []
Just [String]
names ->
let info :: Int -> String -> Type -> FieldInfo
info :: Int -> String -> Type -> FieldInfo
info Int
index String
name = String -> Int -> Type -> FieldInfo
FieldInfo String
name Int
index
substitutedFields :: [Type]
substitutedFields = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
substitutedCInfo
fields :: [FieldInfo]
fields :: [FieldInfo]
fields = (Int -> String -> Type -> FieldInfo)
-> [Int] -> [String] -> [Type] -> [FieldInfo]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Int -> String -> Type -> FieldInfo
info [Int
0 ..] [String]
names [Type]
substitutedFields
in [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FieldInfo -> Q [Dec]) -> [FieldInfo] -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
( DownhillRecord
-> DownhillRecord -> [Type] -> [Type] -> FieldInfo -> Q [Dec]
mkGetField
DownhillRecord
pointRecord
( RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options) DownhillRecord
gradRecord
)
[Type]
cxt
[Type]
instVars
)
[FieldInfo]
fields
let decs :: [[Dec]]
decs =
[ [Dec]
tangDecs,
[Dec]
gradDecs,
[Dec]
additiveMetric,
[Dec]
vspaceMetric,
[Dec]
dualInstance,
[Dec]
metricDec,
[Dec]
metricInstance,
[Dec]
hasFieldInstance,
[Dec]
affineSpaceInstance
]
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Dec]]
decs)
parseRecordType :: Type -> [Type] -> Q (Name, [Type])
parseRecordType :: Type -> [Type] -> Q (Name, [Type])
parseRecordType Type
type_ [Type]
vars = case Type
type_ of
AppT Type
inner Type
typeVar -> Type -> [Type] -> Q (Name, [Type])
parseRecordType Type
inner (Type
typeVar Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
vars)
ConT Name
recordName -> (Name, [Type]) -> Q (Name, [Type])
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
recordName, [Type]
vars)
Type
_ -> String -> Q (Name, [Type])
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected (T a1 ... an) in constraint"
mkAffineSpaceInstance :: Cxt -> DownhillRecord -> DownhillRecord -> [Type] -> Q [Dec]
mkAffineSpaceInstance :: [Type] -> DownhillRecord -> DownhillRecord -> [Type] -> Q [Dec]
mkAffineSpaceInstance [Type]
cxt DownhillRecord
recordPoint DownhillRecord
recordTang [Type]
instVars = do
Dec
plusDec <- DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
recordPoint DownhillRecord
recordTang DownhillRecord
recordPoint '(.+^)
Dec
minusDec <- DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
recordPoint DownhillRecord
recordPoint DownhillRecord
recordTang '(.-.)
let recordTypePoint :: Type
recordTypePoint = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
recordPoint)) [Type]
instVars
recordTypeTang :: Type
recordTypeTang = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
recordTang)) [Type]
instVars
diffTypeDec :: Dec
diffTypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''Diff) Type
recordTypePoint)
Type
recordTypeTang
)
let decs :: [Dec]
decs =
[ Dec
plusDec,
Dec
minusDec,
Dec
diffTypeDec
]
Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''AffineSpace [Type]
cxt DownhillRecord
recordPoint [Type]
instVars [Dec]
decs
filterFields :: forall m. MonadFail m => BVarOptions -> DownhillRecord -> m DownhillRecord
filterFields :: BVarOptions -> DownhillRecord -> m DownhillRecord
filterFields BVarOptions
options DownhillRecord
record =
case BVarOptions -> [String]
optExcludeFields BVarOptions
options of
[] -> DownhillRecord -> m DownhillRecord
forall (m :: * -> *) a. Monad m => a -> m a
return DownhillRecord
record
[String]
_ -> do
[String]
fieldList <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
Just [String]
fields -> [String] -> m [String]
forall (m :: * -> *) a. Monad m => a -> m a
return [String]
fields
Maybe [String]
Nothing -> String -> m [String]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Name -> String
nameBase (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a records, can't exclude fields")
[String] -> m DownhillRecord
doFilterFields [String]
fieldList
where
doFilterFields :: [String] -> m DownhillRecord
doFilterFields [String]
fieldList = do
(String -> m ()) -> [String] -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ String -> m ()
check (BVarOptions -> [String]
optExcludeFields BVarOptions
options)
DownhillRecord -> m DownhillRecord
forall (m :: * -> *) a. Monad m => a -> m a
return
DownhillRecord
record
{ ddtFieldTypes :: [Type]
ddtFieldTypes = [Type] -> [Type]
forall a. [a] -> [a]
go (DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record),
ddtFieldNames :: Maybe [String]
ddtFieldNames = [String] -> [String]
forall a. [a] -> [a]
go ([String] -> [String]) -> Maybe [String] -> Maybe [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record,
ddtFieldCount :: Int
ddtFieldCount = Int -> Int
goN (DownhillRecord -> Int
ddtFieldCount DownhillRecord
record)
}
where
check :: String -> m ()
check :: String -> m ()
check String
name
| String
name String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
fieldList = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = String -> m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Field " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a member of " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record))
excludeZipList :: [x -> Maybe x]
excludeZipList :: [x -> Maybe x]
excludeZipList = String -> x -> Maybe x
forall x. String -> x -> Maybe x
filterField (String -> x -> Maybe x) -> [String] -> [x -> Maybe x]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
fieldList
where
filterField :: String -> x -> Maybe x
filterField :: String -> x -> Maybe x
filterField String
fieldName x
x
| String
fieldName String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` BVarOptions -> [String]
optExcludeFields BVarOptions
options = Maybe x
forall a. Maybe a
Nothing
| Bool
otherwise = x -> Maybe x
forall a. a -> Maybe a
Just x
x
go :: [a] -> [a]
go :: [a] -> [a]
go = [Maybe a] -> [a]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe a] -> [a]) -> ([a] -> [Maybe a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a -> Maybe a) -> a -> Maybe a)
-> [a -> Maybe a] -> [a] -> [Maybe a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
($) [a -> Maybe a]
forall x. [x -> Maybe x]
excludeZipList
goN :: Int -> Int
goN :: Int -> Int
goN Int
n = [()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([()] -> Int) -> ([()] -> [()]) -> [()] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [()] -> [()]
forall a. [a] -> [a]
go ([()] -> Int) -> [()] -> Int
forall a b. (a -> b) -> a -> b
$ Int -> () -> [()]
forall a. Int -> a -> [a]
replicate Int
n ()
mkDVarC1 :: BVarOptions -> Dec -> Q [Dec]
mkDVarC1 :: BVarOptions -> Dec -> Q [Dec]
mkDVarC1 BVarOptions
options = \case
InstanceD Maybe Overlap
mayOverlap [Type]
cxt Type
type_ [Dec]
decs -> do
case Maybe Overlap
mayOverlap of
Just Overlap
_ -> String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Overlapping instances not implemented"
Maybe Overlap
_ -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
case Type
type_ of
AppT (ConT Name
hasgradCtx) Type
recordInConstraintType -> do
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
hasgradCtx Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= ''HasGrad) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"Constraint must be `HasGrad`, got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
hasgradCtx
(Name
recordName, [Type]
instVars) <- Type -> [Type] -> Q (Name, [Type])
parseRecordType Type
recordInConstraintType []
DatatypeInfo
record' <- Name -> Q DatatypeInfo
reifyDatatype Name
recordName
(DownhillRecord
fullParsedRecord, ConstructorInfo
cinfo) <- Name -> DatatypeInfo -> Q (DownhillRecord, ConstructorInfo)
parseDownhillRecord Name
recordName DatatypeInfo
record'
DownhillRecord
parsedRecord <- BVarOptions -> DownhillRecord -> Q DownhillRecord
forall (m :: * -> *).
MonadFail m =>
BVarOptions -> DownhillRecord -> m DownhillRecord
filterFields BVarOptions
options DownhillRecord
fullParsedRecord
[Name]
recordTypeVarNames <- do
let getName :: Type -> m Name
getName Type
x = case Type
x of
SigT (VarT Name
y) Type
_ -> Name -> m Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
y
Type
_ -> String -> m Name
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Type variable is not VarT"
(Type -> Q Name) -> [Type] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Type -> Q Name
forall (m :: * -> *). MonadFail m => Type -> m Name
getName (DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
record')
let substPairs :: [(Name, Type)]
substPairs = [Name] -> [Type] -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
recordTypeVarNames [Type]
instVars
substitutedRecord :: ConstructorInfo
substitutedRecord = Map Name Type -> ConstructorInfo -> ConstructorInfo
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution ([(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)]
substPairs) ConstructorInfo
cinfo
Type
scalarType <- case [Dec]
decs of
[] -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"`HasGrad` instance has no declarations"
[Dec
dec1] -> case Dec
dec1 of
TySynInstD (TySynEqn Maybe [TyVarBndrUnit]
_ (AppT (ConT Name
scalarName) Type
_) Type
scalarType) -> do
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
scalarName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= ''MScalar) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Expected `Scalar` equation, got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
scalarName)
Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
scalarType
Dec
_ -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"HasGrad instance must contain `Scalar ... = ...` declaration"
[Dec]
_ -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"`HasGrad` has multiple declarations"
[Dec]
dvar <- [Type]
-> DownhillRecord
-> BVarOptions
-> Type
-> [Type]
-> ConstructorInfo
-> Q [Dec]
mkDVar'' [Type]
cxt DownhillRecord
parsedRecord BVarOptions
options Type
scalarType [Type]
instVars ConstructorInfo
substitutedRecord
let tangName :: Name
tangName = DownhillRecord -> Name
ddtTypeConName (RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
tangTransform BVarOptions
options) DownhillRecord
parsedRecord)
gradName :: Name
gradName = DownhillRecord -> Name
ddtTypeConName (RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
gradTransform BVarOptions
options) DownhillRecord
parsedRecord)
metricName :: Name
metricName = DownhillRecord -> Name
ddtTypeConName (RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
metricTransform BVarOptions
options) DownhillRecord
parsedRecord)
tangTypeDec :: Dec
tangTypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''Tang) Type
recordInConstraintType)
((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
tangName) [Type]
instVars)
)
gradTypeDec :: Dec
gradTypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''Grad) Type
recordInConstraintType)
((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
gradName) [Type]
instVars)
)
metricTypeDec :: Dec
metricTypeDec =
TySynEqn -> Dec
TySynInstD
( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
(Type -> Type -> Type
AppT (Name -> Type
ConT ''Metric) Type
recordInConstraintType)
((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
metricName) [Type]
instVars)
)
hasgradInstance :: Dec
hasgradInstance =
Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD
Maybe Overlap
forall a. Maybe a
Nothing
[Type]
cxt
Type
type_
( [Dec]
decs
[Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [ Dec
tangTypeDec,
Dec
gradTypeDec,
Dec
metricTypeDec
]
)
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [Dec]
dvar [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec
hasgradInstance]
Type
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Instance head is not a constraint"
Dec
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected instance declaration"
mkHasGradInstances :: BVarOptions -> Q [Dec] -> Q [Dec]
mkHasGradInstances :: BVarOptions -> Q [Dec] -> Q [Dec]
mkHasGradInstances BVarOptions
options Q [Dec]
decs = [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Dec -> Q [Dec]) -> [Dec] -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (BVarOptions -> Dec -> Q [Dec]
mkDVarC1 BVarOptions
options) ([Dec] -> Q [[Dec]]) -> Q [Dec] -> Q [[Dec]]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Q [Dec]
decs)