{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Use like this:
--
-- @
-- mkHasGradInstances
--   defaultBVarOptions
--   [d|
--     instance HasGrad MyRecord where
--       type MScalar MyRecord = Float
--     |]
-- @
--
-- Instance declaration passed to @mkHasGradInstances@ gives two important bits of information:
--
--   * Type variables for @MyRecord@, which can be concrete types (such as @instance HasGrad (MyRecord Float)@)
--     or regular type variables (@instance HasGrad (MyRecord a)@)
--
--   * Scalar type.
--
module Downhill.TH
  (
    mkHasGradInstances,
    AffineSpaceOptions (..),
    RecordNamer (..),
    BVarOptions (..),
    defaultBVarOptions,
  )
where

import Control.Monad
import Data.AdditiveGroup ((^+^), (^-^))
import Data.AffineSpace (AffineSpace (Diff, (.+^), (.-.)))
import Data.Foldable (traverse_)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import Data.VectorSpace (AdditiveGroup (negateV, zeroV), VectorSpace (Scalar, (*^)))
import Downhill.BVar (BVar (BVar))
import Downhill.Grad
  ( Dual (evalGrad),
    HasGrad (Grad, MScalar, Metric, Tang),
    MetricTensor (MtCovector, MtVector, evalMetric, sqrNorm),
  )
import Downhill.Linear.Expr (BasicVector (VecBuilder, sumBuilder))
import Downhill.Linear.Lift (lift1_sparse)
import GHC.Records (HasField (getField))
import Language.Haskell.TH
  ( Bang (Bang),
    Con (NormalC, RecC),
    Cxt,
    Dec (DataD, InstanceD, NewtypeD, SigD),
    Exp (AppE, ConE, InfixE, VarE),
    Name,
    Pat (VarP),
    Q,
    SourceStrictness (NoSourceStrictness),
    SourceUnpackedness (NoSourceUnpackedness),
    Type (AppT, ConT, VarT),
    nameBase,
    newName,
  )
import Language.Haskell.TH.Datatype (ConstructorInfo (constructorFields, constructorName, constructorVariant), ConstructorVariant (InfixConstructor, NormalConstructor, RecordConstructor), DatatypeInfo (datatypeCons, datatypeInstTypes, datatypeName, datatypeVariant, datatypeVars), DatatypeVariant (Newtype), TypeSubstitution (applySubstitution), reifyDatatype)
import Language.Haskell.TH.Datatype.TyVarBndr (TyVarBndrUnit)
import Language.Haskell.TH.Syntax
  ( BangType,
    Body (NormalB),
    Clause (Clause),
    Dec (FunD, TySynInstD, ValD),
    Exp (AppTypeE),
    TyLit (StrTyLit),
    TySynEqn (TySynEqn),
    Type (ArrowT, EqualityT, LitT, SigT),
    VarBangType,
    mkNameS,
  )
import qualified  Language.Haskell.TH

data DatatypeFields
  = NormalFields [Type]
  | RecordFields [(String, Type)]
  deriving (Int -> DatatypeFields -> ShowS
[DatatypeFields] -> ShowS
DatatypeFields -> String
(Int -> DatatypeFields -> ShowS)
-> (DatatypeFields -> String)
-> ([DatatypeFields] -> ShowS)
-> Show DatatypeFields
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DatatypeFields] -> ShowS
$cshowList :: [DatatypeFields] -> ShowS
show :: DatatypeFields -> String
$cshow :: DatatypeFields -> String
showsPrec :: Int -> DatatypeFields -> ShowS
$cshowsPrec :: Int -> DatatypeFields -> ShowS
Show)

data DownhillRecord = DownhillRecord
  { DownhillRecord -> Name
ddtTypeConName :: Name,
    DownhillRecord -> Name
ddtDataConName :: Name,
    DownhillRecord -> [Type]
ddtFieldTypes :: [Type],
    DownhillRecord -> Maybe [String]
ddtFieldNames :: Maybe [String],
    DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars :: [TyVarBndrUnit],
    DownhillRecord -> Int
ddtFieldCount :: Int,
    DownhillRecord -> DatatypeVariant
ddtVariant :: DatatypeVariant
  }
  deriving (Int -> DownhillRecord -> ShowS
[DownhillRecord] -> ShowS
DownhillRecord -> String
(Int -> DownhillRecord -> ShowS)
-> (DownhillRecord -> String)
-> ([DownhillRecord] -> ShowS)
-> Show DownhillRecord
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DownhillRecord] -> ShowS
$cshowList :: [DownhillRecord] -> ShowS
show :: DownhillRecord -> String
$cshow :: DownhillRecord -> String
showsPrec :: Int -> DownhillRecord -> ShowS
$cshowsPrec :: Int -> DownhillRecord -> ShowS
Show)

data RecordNamer = RecordNamer
  { RecordNamer -> ShowS
typeConNamer :: String -> String,
    RecordNamer -> ShowS
dataConNamer :: String -> String,
    RecordNamer -> ShowS
fieldNamer :: String -> String
  }

data RecordTranstorm = RecordTranstorm RecordNamer (Type -> Type)

data AffineSpaceOptions
  = -- | Generate AffineSpace instance
    MakeAffineSpace
  | -- | Don't generate AffineSpace instance
    NoAffineSpace
  | -- | Generate AffineSpace instance if @optExcludeFields@ is empty
    AutoAffineSpace

data BVarOptions = BVarOptions
  { BVarOptions -> RecordNamer
optTangNamer :: RecordNamer,
    BVarOptions -> RecordNamer
optGradNamer :: RecordNamer,
    BVarOptions -> RecordNamer
optMetricNamer :: RecordNamer,
    BVarOptions -> RecordNamer
optBuilderNamer :: RecordNamer,
    BVarOptions -> AffineSpaceOptions
optAffineSpace :: AffineSpaceOptions,
     -- | List of fields that take no part in differentiation
    BVarOptions -> [String]
optExcludeFields :: [String]
  }

pattern ConP :: Name -> [Pat] -> Pat
#if MIN_VERSION_template_haskell(2,18,0)
pattern ConP x y = Language.Haskell.TH.ConP x [] y
#else
pattern $bConP :: Name -> [Pat] -> Pat
$mConP :: forall r. Pat -> (Name -> [Pat] -> r) -> (Void# -> r) -> r
ConP x y = Language.Haskell.TH.ConP x y
#endif

defaultTangRecordNamer :: RecordNamer
defaultTangRecordNamer :: RecordNamer
defaultTangRecordNamer =
  RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
    { typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Tang"),
      dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Tang"),
      fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
    }

defaultGradRecordNamer :: RecordNamer
defaultGradRecordNamer :: RecordNamer
defaultGradRecordNamer =
  RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
    { typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Grad"),
      dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Grad"),
      fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
    }

defaultMetricRecordNamer :: RecordNamer
defaultMetricRecordNamer :: RecordNamer
defaultMetricRecordNamer =
  RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
    { typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Metric"),
      dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Metric"),
      fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
    }

defaultBuilderRecordNamer :: RecordNamer
defaultBuilderRecordNamer :: RecordNamer
defaultBuilderRecordNamer =
  RecordNamer :: ShowS -> ShowS -> ShowS -> RecordNamer
RecordNamer
    { typeConNamer :: ShowS
typeConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Builder"),
      dataConNamer :: ShowS
dataConNamer = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"Builder"),
      fieldNamer :: ShowS
fieldNamer = ShowS
forall a. a -> a
id
    }

defaultBVarOptions :: BVarOptions
defaultBVarOptions :: BVarOptions
defaultBVarOptions =
  BVarOptions :: RecordNamer
-> RecordNamer
-> RecordNamer
-> RecordNamer
-> AffineSpaceOptions
-> [String]
-> BVarOptions
BVarOptions
    { optTangNamer :: RecordNamer
optTangNamer = RecordNamer
defaultTangRecordNamer,
      optGradNamer :: RecordNamer
optGradNamer = RecordNamer
defaultGradRecordNamer,
      optMetricNamer :: RecordNamer
optMetricNamer = RecordNamer
defaultMetricRecordNamer,
      optBuilderNamer :: RecordNamer
optBuilderNamer = RecordNamer
defaultBuilderRecordNamer,
      optAffineSpace :: AffineSpaceOptions
optAffineSpace = AffineSpaceOptions
AutoAffineSpace,
      optExcludeFields :: [String]
optExcludeFields = []
    }

mkConstructor :: DownhillRecord -> Con
mkConstructor :: DownhillRecord -> Con
mkConstructor DownhillRecord
record =
  case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
    Maybe [String]
Nothing ->
      Name -> [BangType] -> Con
NormalC Name
newConstrName ((Type -> BangType) -> [Type] -> [BangType]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BangType
mkType (DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record))
    Just [String]
names ->
      Name -> [VarBangType] -> Con
RecC Name
newConstrName ((String -> Type -> VarBangType)
-> [String] -> [Type] -> [VarBangType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith String -> Type -> VarBangType
mkRecType [String]
names (DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record))
  where
    newConstrName :: Name
    newConstrName :: Name
newConstrName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
    mkRecType :: String -> Type -> VarBangType
    mkRecType :: String -> Type -> VarBangType
mkRecType String
name Type
type_ =
      ( String -> Name
mkNameS String
name,
        SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness,
        Type
type_
      )
    mkType :: Type -> BangType
    mkType :: Type -> BangType
mkType Type
type_ =
      ( SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness,
        Type
type_
      )

parseGradConstructor :: Name -> DatatypeInfo -> ConstructorInfo -> [TyVarBndrUnit] -> Q DownhillRecord
parseGradConstructor :: Name
-> DatatypeInfo
-> ConstructorInfo
-> [TyVarBndrUnit]
-> Q DownhillRecord
parseGradConstructor Name
tyName DatatypeInfo
dinfo ConstructorInfo
cinfo [TyVarBndrUnit]
typevars = do
  let types :: [Type]
types = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
cinfo
      n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
types
  ([Type]
fieldTypes, Maybe [String]
fieldNames) <- case ConstructorInfo -> ConstructorVariant
constructorVariant ConstructorInfo
cinfo of
    ConstructorVariant
NormalConstructor -> ([Type], Maybe [String]) -> Q ([Type], Maybe [String])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
types, Maybe [String]
forall a. Maybe a
Nothing)
    ConstructorVariant
InfixConstructor -> ([Type], Maybe [String]) -> Q ([Type], Maybe [String])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
types, Maybe [String]
forall a. Maybe a
Nothing)
    RecordConstructor [Name]
fieldNames -> do
      ([Type], Maybe [String]) -> Q ([Type], Maybe [String])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
types, [String] -> Maybe [String]
forall a. a -> Maybe a
Just (Name -> String
nameBase (Name -> String) -> [Name] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
fieldNames))
  DownhillRecord -> Q DownhillRecord
forall (m :: * -> *) a. Monad m => a -> m a
return
    DownhillRecord :: Name
-> Name
-> [Type]
-> Maybe [String]
-> [TyVarBndrUnit]
-> Int
-> DatatypeVariant
-> DownhillRecord
DownhillRecord
      { ddtTypeConName :: Name
ddtTypeConName = Name
tyName,
        ddtDataConName :: Name
ddtDataConName = ConstructorInfo -> Name
constructorName ConstructorInfo
cinfo,
        ddtTypeVars :: [TyVarBndrUnit]
ddtTypeVars = [TyVarBndrUnit]
typevars,
        ddtFieldCount :: Int
ddtFieldCount = Int
n,
        ddtFieldTypes :: [Type]
ddtFieldTypes = [Type]
fieldTypes,
        ddtFieldNames :: Maybe [String]
ddtFieldNames = Maybe [String]
fieldNames,
        ddtVariant :: DatatypeVariant
ddtVariant = DatatypeInfo -> DatatypeVariant
datatypeVariant DatatypeInfo
dinfo
      }

parseDownhillRecord :: Name -> DatatypeInfo -> Q (DownhillRecord, ConstructorInfo)
parseDownhillRecord :: Name -> DatatypeInfo -> Q (DownhillRecord, ConstructorInfo)
parseDownhillRecord Name
recordName DatatypeInfo
record' = do
  let name :: Name
name = DatatypeInfo -> Name
datatypeName DatatypeInfo
record'
  let typevars :: [TyVarBndrUnit]
typevars = DatatypeInfo -> [TyVarBndrUnit]
datatypeVars DatatypeInfo
record'
      constructors' :: [ConstructorInfo]
constructors' = DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
record'
  ConstructorInfo
constr' <- case [ConstructorInfo]
constructors' of
    [] -> String -> Q ConstructorInfo
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Name -> String
forall a. Show a => a -> String
show Name
recordName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has no data constructors")
    [ConstructorInfo
constr''] -> ConstructorInfo -> Q ConstructorInfo
forall (m :: * -> *) a. Monad m => a -> m a
return ConstructorInfo
constr''
    [ConstructorInfo]
_ -> String -> Q ConstructorInfo
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Name -> String
forall a. Show a => a -> String
show Name
recordName String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has multiple data constructors")

  DownhillRecord
r <- Name
-> DatatypeInfo
-> ConstructorInfo
-> [TyVarBndrUnit]
-> Q DownhillRecord
parseGradConstructor Name
name DatatypeInfo
record' ConstructorInfo
constr' [TyVarBndrUnit]
typevars
  (DownhillRecord, ConstructorInfo)
-> Q (DownhillRecord, ConstructorInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (DownhillRecord
r, ConstructorInfo
constr')

elementwiseOp :: DownhillRecord -> Name -> Q Dec
elementwiseOp :: DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record = DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
record DownhillRecord
record DownhillRecord
record

elementwiseOp' :: DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' :: DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
leftRecord DownhillRecord
rightRecord DownhillRecord
resRecord Name
func = do
  let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
resRecord
  --dataConName :: Name
  --dataConName = ddtDataConName record
  [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
  [Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"y")
  let fieldOp :: Name -> Name -> Exp
      fieldOp :: Name -> Name -> Exp
fieldOp Name
x Name
y = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
x)) (Name -> Exp
VarE Name
func) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
y))
      resultFields :: [Exp]
      resultFields :: [Exp]
resultFields = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
fieldOp [Name]
xs [Name]
ys
      leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
leftRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
      rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
rightRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
      rhs :: Exp
      rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
resRecord)) [Exp]
resultFields
      dec :: Dec
dec =
        Name -> [Clause] -> Dec
FunD
          Name
func
          [ [Pat] -> Body -> [Dec] -> Clause
Clause
              [Pat
leftPat, Pat
rightPat]
              (Exp -> Body
NormalB Exp
rhs)
              []
          ]
  Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
dec

elementwiseValue :: DownhillRecord -> Name -> Q Dec
elementwiseValue :: DownhillRecord -> Name -> Q Dec
elementwiseValue DownhillRecord
record Name
func = do
  let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record
      dataConName :: Name
      dataConName :: Name
dataConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
      rhs :: Exp
      rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
dataConName) (Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate Int
n (Name -> Exp
VarE 'zeroV))
      dec :: Dec
dec = Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
func) (Exp -> Body
NormalB Exp
rhs) []
  Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
dec

elementwiseFunc :: DownhillRecord -> Name -> Q Dec
elementwiseFunc :: DownhillRecord -> Name -> Q Dec
elementwiseFunc DownhillRecord
record Name
func = do
  let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record
      dataConName :: Name
      dataConName :: Name
dataConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
      rhsConName :: Name
rhsConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
  [Name]
xs <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
    Maybe [String]
Nothing -> Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
    Just [String]
names -> (String -> Q Name) -> [String] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse String -> Q Name
newName [String]
names
  let fieldOp :: Name -> Exp
      fieldOp :: Name -> Exp
fieldOp = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
func) (Exp -> Exp) -> (Name -> Exp) -> Name -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE
      resultFields :: [Exp]
      resultFields :: [Exp]
resultFields = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
fieldOp [Name]
xs
      leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP Name
dataConName ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
      rhs :: Exp
      rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
rhsConName) [Exp]
resultFields
      dec :: Dec
dec =
        Name -> [Clause] -> Dec
FunD
          Name
func
          [ [Pat] -> Body -> [Dec] -> Clause
Clause
              [Pat
leftPat]
              (Exp -> Body
NormalB Exp
rhs)
              []
          ]
  Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return Dec
dec

mkClassInstance :: Name -> Cxt -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance :: Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance Name
className [Type]
cxt DownhillRecord
record [Type]
instVars [Dec]
decs = do
  let recordType :: Type
recordType = Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record)
      ihead :: Type
ihead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
recordType [Type]
instVars)
  [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
cxt Type
ihead [Dec]
decs]

mkSemigroupInstance :: Cxt -> DownhillRecord -> [Type] -> Q [Dec]
mkSemigroupInstance :: [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkSemigroupInstance [Type]
cxt DownhillRecord
record [Type]
instVars = do
  Dec
dec <- DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record '(<>)
  Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''Semigroup [Type]
cxt DownhillRecord
record [Type]
instVars [Dec
dec]

mkAdditiveGroupInstance :: Cxt -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance :: [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance [Type]
cxt DownhillRecord
record [Type]
instVars = do
  Dec
zeroVDec <- DownhillRecord -> Name -> Q Dec
elementwiseValue DownhillRecord
record 'zeroV
  Dec
negateDec <- DownhillRecord -> Name -> Q Dec
elementwiseFunc DownhillRecord
record 'negateV
  Dec
plusDec <- DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record '(^+^)
  Dec
minusDec <- DownhillRecord -> Name -> Q Dec
elementwiseOp DownhillRecord
record '(^-^)
  let decs :: [Dec]
decs =
        [ Dec
zeroVDec,
          Dec
negateDec,
          Dec
plusDec,
          Dec
minusDec
        ]
  Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''AdditiveGroup [Type]
cxt DownhillRecord
record [Type]
instVars [Dec]
decs

mkVectorSpaceInstance :: DownhillRecord -> Type -> Cxt -> [Type] -> Q [Dec]
mkVectorSpaceInstance :: DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkVectorSpaceInstance DownhillRecord
record Type
scalarType [Type]
cxt [Type]
instVars = do
  let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record
      dataConName :: Name
      dataConName :: Name
dataConName = DownhillRecord -> Name
ddtDataConName DownhillRecord
record
  [Name]
xs <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
    Maybe [String]
Nothing -> Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
    Just [String]
names -> (String -> Q Name) -> [String] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse String -> Q Name
newName [String]
names

  Name
lhsName <- String -> Q Name
newName String
"s"
  let rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
record) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
      recordType :: Type
recordType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record)) [Type]
instVars
      mulField :: Name -> Exp
      mulField :: Name -> Exp
mulField Name
y = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
lhsName)) (Name -> Exp
VarE '(*^)) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Name -> Exp
VarE Name
y))
      rhsMulV :: Exp
      rhsMulV :: Exp
rhsMulV = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
dataConName) ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
mulField [Name]
xs)
  let vmulDec :: Dec
vmulDec =
        Name -> [Clause] -> Dec
FunD
          '(*^)
          [ [Pat] -> Body -> [Dec] -> Clause
Clause
              [Name -> Pat
VarP Name
lhsName, Pat
rightPat]
              (Exp -> Body
NormalB Exp
rhsMulV)
              []
          ]
      scalarTypeDec :: Dec
scalarTypeDec =
        TySynEqn -> Dec
TySynInstD
          ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
              Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
              (Type -> Type -> Type
AppT (Name -> Type
ConT ''Scalar) Type
recordType)
              Type
scalarType
          )
      decs :: [Dec]
decs = [Dec
scalarTypeDec, Dec
vmulDec]
  Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''VectorSpace [Type]
cxt DownhillRecord
record [Type]
instVars [Dec]
decs

mkBasicVectorInstance :: DownhillRecord -> BVarOptions -> Cxt -> [Type] -> Q [Dec]
mkBasicVectorInstance :: DownhillRecord -> BVarOptions -> [Type] -> [Type] -> Q [Dec]
mkBasicVectorInstance DownhillRecord
vectorRecord BVarOptions
options [Type]
cxt [Type]
instVars = do
  Dec
sumBuilderDec <- Q Dec
mkSumBuilder
  Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''BasicVector [Type]
cxt DownhillRecord
vectorRecord [Type]
instVars [Dec
vecbuilderDec, Dec
sumBuilderDec]
  where
    n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
vectorRecord
    builderRecord :: DownhillRecord
builderRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options) DownhillRecord
vectorRecord

    -- not an elementiseOp, because right hand side is wrapped in Maybe
    mkSumBuilder :: Q Dec
    mkSumBuilder :: Q Dec
mkSumBuilder = do
      [Name]
builders <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
      let pat :: Pat
          pat :: Pat
pat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
builderRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
builders)
          rhs :: Exp
          rhs :: Exp
rhs =
            (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              Exp -> Exp -> Exp
AppE
              (Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
vectorRecord))
              [Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'sumBuilder) (Name -> Exp
VarE Name
x) | Name
x <- [Name]
builders]
      Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
        Name -> [Clause] -> Dec
FunD
          'sumBuilder
          [ [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> [Pat] -> Pat
ConP 'Nothing []] (Exp -> Body
NormalB (Name -> Exp
VarE 'zeroV)) [],
            [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> [Pat] -> Pat
ConP 'Just [Pat
pat]] (Exp -> Body
NormalB Exp
rhs) []
          ]

    vecbuilderDec :: Dec
vecbuilderDec =
      TySynEqn -> Dec
TySynInstD
        ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
            Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
            (Type -> Type -> Type
AppT (Name -> Type
ConT ''VecBuilder) Type
vectorType)
            (Type -> Type -> Type
AppT (Name -> Type
ConT ''Maybe) Type
builderType)
        )
      where
        vectorType :: Type
vectorType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
vectorRecord)) [Type]
instVars
        builderType :: Type
builderType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
builderRecord)) [Type]
instVars

sumVExpr :: [Exp] -> Exp
sumVExpr :: [Exp] -> Exp
sumVExpr = \case
  [] -> Name -> Exp
VarE 'zeroV
  [Exp]
exps -> (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (Name -> Exp -> Exp -> Exp
zipExpInfix '(^+^)) [Exp]
exps
  where
    zipExpInfix :: Name -> Exp -> Exp -> Exp
    zipExpInfix :: Name -> Exp -> Exp -> Exp
zipExpInfix Name
f Exp
x Exp
y = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
x) (Name -> Exp
VarE Name
f) (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
y)

mkDualInstance ::
  DownhillRecord ->
  DownhillRecord ->
  Type ->
  Cxt ->
  [Type] ->
  Q [Dec]
mkDualInstance :: DownhillRecord
-> DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkDualInstance DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars = do
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DownhillRecord -> Int
ddtFieldCount DownhillRecord
tangRecord Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= DownhillRecord -> Int
ddtFieldCount DownhillRecord
gradRecord) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"mkDualInstance: ddtFieldCount tangRecord /= ddtFieldCount gradRecord"
  Name
scalarTypeName <- String -> Q Name
newName String
"s"
  Type -> Q [Dec]
mkClassDec (Name -> Type
VarT Name
scalarTypeName)
  where
    n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
tangRecord

    -- instance (cxt, AdditiveGroup s, s ~ scalarType) => AdditiveGroup (Record a1 … an) where
    --   …
    mkClassDec :: Type -> Q [Dec]
    mkClassDec :: Type -> Q [Dec]
mkClassDec Type
scalarVar = do
      Dec
evalGradDec <- Q Dec
mkEvalGradDec
      [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing ([Type]
cxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
newConstraints) Type
ihead [Dec
evalGradDec]]
      where
        -- Dual s (RecordTang a1 … an) (RecordGrad a1 … an)
        ihead :: Type
        ihead :: Type
ihead = Name -> Type
ConT ''Dual Type -> Type -> Type
`AppT` Type
scalarVar Type -> Type -> Type
`AppT` Type
vecType Type -> Type -> Type
`AppT` Type
gradType
          where
            vecType :: Type
vecType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
tangRecord) [Type]
instVars
            gradType :: Type
gradType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
gradRecord) [Type]
instVars
        newConstraints :: Cxt
        newConstraints :: [Type]
newConstraints =
          [ -- AdditiveGroup s
            Type -> Type -> Type
AppT (Name -> Type
ConT ''AdditiveGroup) Type
scalarVar,
            -- s ~ scalarType
            Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
EqualityT Type
scalarVar) Type
scalarType
          ]

        -- evalGrad (RecordGrad x1 … xn) (RecordTang y1 … yn) = evalGrad x1 y1 ^+^ … ^+^ evalGrad xn yn
        mkEvalGradDec :: Q Dec
        mkEvalGradDec :: Q Dec
mkEvalGradDec = do
          [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"x")
          [Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"y")
          let leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
              rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
tangRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
              -- terms = [evalGrad x1 y1, …, evalGrad xn yn]
              terms :: [Exp]
              terms :: [Exp]
terms = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
evalGradExp [Name]
xs [Name]
ys
                where
                  evalGradExp :: Name -> Name -> Exp
                  evalGradExp :: Name -> Name -> Exp
evalGradExp Name
x Name
y = Name -> Exp
VarE 'evalGrad Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y
              rhs :: Exp
rhs = [Exp] -> Exp
sumVExpr [Exp]
terms
          Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
            Name -> [Clause] -> Dec
FunD
              'evalGrad
              [ [Pat] -> Body -> [Dec] -> Clause
Clause
                  [Pat
leftPat, Pat
rightPat]
                  (Exp -> Body
NormalB Exp
rhs)
                  []
              ]

mkMetricInstance ::
  DownhillRecord ->
  DownhillRecord ->
  DownhillRecord ->
  Type ->
  Cxt ->
  [Type] ->
  Q [Dec]
mkMetricInstance :: DownhillRecord
-> DownhillRecord
-> DownhillRecord
-> Type
-> [Type]
-> [Type]
-> Q [Dec]
mkMetricInstance DownhillRecord
metricRecord DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars = do
  Name
scalarTypeName <- String -> Q Name
newName String
"s"
  Type -> Q [Dec]
mkClassDec (Name -> Type
VarT Name
scalarTypeName)
  where
    -- instance (ctx, s ~ scalarType) => MetricTensor s (RecordMetric a1 … an) where
    --   …
    mkClassDec :: Type -> Q [Dec]
    mkClassDec :: Type -> Q [Dec]
mkClassDec Type
scalarVar = do
      let newConstraints :: [Type]
newConstraints =
            [ -- s ~ scalarType
              Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
EqualityT Type
scalarVar) Type
scalarType
            ]
          -- MetricTensor s (RecordMetric a1 … an)
          ihead :: Type
ihead = Name -> Type
ConT ''MetricTensor Type -> Type -> Type
`AppT` Type
metricType
      Dec
evalMetricDec <- Q Dec
mkEvalMetric
      Dec
sqrNormDec <- Q Dec
mkSqrNorm
      [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
        [ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD
            Maybe Overlap
forall a. Maybe a
Nothing
            ([Type]
cxt [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
newConstraints)
            Type
ihead
            [Dec
vectypeDec, Dec
covectorTypeDec, Dec
evalMetricDec, Dec
sqrNormDec]
        ]
      where
        vectorType :: Type
        vectorType :: Type
vectorType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
tangRecord) [Type]
instVars
        covectorType :: Type
        covectorType :: Type
covectorType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
gradRecord) [Type]
instVars
        metricType :: Type
        metricType :: Type
metricType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
metricRecord) [Type]
instVars
        -- type MtVector (RecordMetric a1 … an) = RecordTang a1 … an
        vectypeDec :: Dec
vectypeDec =
          TySynEqn -> Dec
TySynInstD
            ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
                Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
                (Type -> Type -> Type
AppT (Name -> Type
ConT ''MtVector) Type
metricType)
                Type
vectorType
            )
        -- type MtCovector (RecordMetric a1 … an) = RecordGrad a1 … an
        covectorTypeDec :: Dec
covectorTypeDec =
          TySynEqn -> Dec
TySynInstD
            ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
                Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
                (Type -> Type -> Type
AppT (Name -> Type
ConT ''MtCovector) Type
metricType)
                Type
covectorType
            )

        mkEvalMetric :: Q Dec
        mkEvalMetric :: Q Dec
mkEvalMetric = do
          let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
metricRecord
          [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"m")
          [Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"dv")
          let leftPat, rightPat :: Pat
              leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
metricRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
              rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
              terms :: [Exp]
              terms :: [Exp]
terms = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
evalGradExp [Name]
xs [Name]
ys
                where
                  evalGradExp :: Name -> Name -> Exp
                  evalGradExp :: Name -> Name -> Exp
evalGradExp Name
x Name
y = Name -> Exp
VarE 'evalMetric Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y
              rhs :: Exp
rhs =
                (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                  Exp -> Exp -> Exp
AppE
                  (Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
tangRecord))
                  [Exp]
terms
          Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
            Name -> [Clause] -> Dec
FunD
              'evalMetric
              [ [Pat] -> Body -> [Dec] -> Clause
Clause
                  [Pat
leftPat, Pat
rightPat]
                  (Exp -> Body
NormalB Exp
rhs)
                  []
              ]

        mkSqrNorm :: Q Dec
        mkSqrNorm :: Q Dec
mkSqrNorm = do
          let n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
metricRecord
          [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"m")
          [Name]
ys <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"dv")
          let leftPat, rightPat :: Pat
              leftPat :: Pat
leftPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
metricRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs)
              rightPat :: Pat
rightPat = Name -> [Pat] -> Pat
ConP (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradRecord) ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
ys)
              terms :: [Exp]
              terms :: [Exp]
terms = (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Exp
evalSqrtNorm [Name]
xs [Name]
ys
                where
                  evalSqrtNorm :: Name -> Name -> Exp
                  evalSqrtNorm :: Name -> Name -> Exp
evalSqrtNorm Name
x Name
y = Name -> Exp
VarE 'sqrNorm Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y
              rhs :: Exp
rhs = [Exp] -> Exp
sumVExpr [Exp]
terms
          Dec -> Q Dec
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
            Name -> [Clause] -> Dec
FunD
              'sqrNorm
              [ [Pat] -> Body -> [Dec] -> Clause
Clause
                  [Pat
leftPat, Pat
rightPat]
                  (Exp -> Body
NormalB Exp
rhs)
                  []
              ]

mkRecord :: DownhillRecord -> Q [Dec]
mkRecord :: DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
record = do
  let newConstr :: Con
newConstr = DownhillRecord -> Con
mkConstructor DownhillRecord
record
  let newRecordName :: Name
newRecordName = DownhillRecord -> Name
ddtTypeConName DownhillRecord
record
  let dataType :: Dec
dataType = case DownhillRecord -> DatatypeVariant
ddtVariant DownhillRecord
record of
        DatatypeVariant
Newtype -> [Type]
-> Name
-> [TyVarBndrUnit]
-> Maybe Type
-> Con
-> [DerivClause]
-> Dec
NewtypeD [] Name
newRecordName (DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars DownhillRecord
record) Maybe Type
forall a. Maybe a
Nothing Con
newConstr []
        DatatypeVariant
_ -> [Type]
-> Name
-> [TyVarBndrUnit]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
newRecordName (DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars DownhillRecord
record) Maybe Type
forall a. Maybe a
Nothing [Con
newConstr] []
  [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Dec
dataType]

renameTypeS :: (String -> String) -> Name -> Name
renameTypeS :: ShowS -> Name -> Name
renameTypeS ShowS
f = String -> Name
mkNameS (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
f ShowS -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase

data FieldInfo = FieldInfo
  { FieldInfo -> String
fiName :: String,
    FieldInfo -> Int
fiIndex :: Int,
    FieldInfo -> Type
fiType :: Type
  }

mkGetField ::
  DownhillRecord ->
  DownhillRecord ->
  Cxt ->
  [Type] ->
  FieldInfo ->
  Q [Dec]
mkGetField :: DownhillRecord
-> DownhillRecord -> [Type] -> [Type] -> FieldInfo -> Q [Dec]
mkGetField DownhillRecord
pointRecord DownhillRecord
gradBuilderRecord [Type]
cxt [Type]
instVars FieldInfo
field = do
  Name
rName <- String -> Q Name
newName String
"r"
  Name
xName <- String -> Q Name
newName String
"x"
  Name
dxName <- String -> Q Name
newName String
"dx"
  Name
goName <- String -> Q Name
newName String
"go"
  Name
dxdaName <- String -> Q Name
newName String
"dx_da"
  let rhsFieldList :: [Exp]
      rhsFieldList :: [Exp]
rhsFieldList =
        Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate (FieldInfo -> Int
fiIndex FieldInfo
field) (Name -> Exp
VarE 'mempty)
          [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Name -> Exp
VarE Name
dxdaName]
          [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- FieldInfo -> Int
fiIndex FieldInfo
field Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Name -> Exp
VarE 'mempty)
      -- rhs = MyRecordGradBuilder mempty … mempty dx_da_a6SX mempty … mempty
      rhs :: Exp
      rhs :: Exp
rhs = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE (DownhillRecord -> Name
ddtDataConName DownhillRecord
gradBuilderRecord)) [Exp]
rhsFieldList
  [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD
        Maybe Overlap
forall a. Maybe a
Nothing
        [Type]
cxt
        ( Type -> Type -> Type
AppT
            ( Type -> Type -> Type
AppT
                (Type -> Type -> Type
AppT (Name -> Type
ConT ''HasField) (TyLit -> Type
LitT (String -> TyLit
StrTyLit (FieldInfo -> String
fiName FieldInfo
field))))
                (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''BVar) (Name -> Type
VarT Name
rName)) Type
pointType)
            )
            (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''BVar) (Name -> Type
VarT Name
rName)) (FieldInfo -> Type
fiType FieldInfo
field))
        )
        [ Name -> [Clause] -> Dec
FunD
            'getField
            [ [Pat] -> Body -> [Dec] -> Clause
Clause
                [Name -> [Pat] -> Pat
ConP 'BVar [Name -> Pat
VarP Name
xName, Name -> Pat
VarP Name
dxName]]
                ( Exp -> Body
NormalB
                    ( Exp -> Exp -> Exp
AppE
                        ( Exp -> Exp -> Exp
AppE
                            (Name -> Exp
ConE 'BVar)
                            (Exp -> Exp -> Exp
AppE (Exp -> Type -> Exp
AppTypeE (Name -> Exp
VarE 'getField) (TyLit -> Type
LitT (String -> TyLit
StrTyLit (FieldInfo -> String
fiName FieldInfo
field)))) (Name -> Exp
VarE Name
xName))
                        )
                        (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'lift1_sparse) (Name -> Exp
VarE Name
goName)) (Name -> Exp
VarE Name
dxName))
                    )
                )
                [ Name -> Type -> Dec
SigD
                    Name
goName
                    ( Type -> Type -> Type
AppT
                        ( Type -> Type -> Type
AppT
                            Type
ArrowT
                            ( Name -> Type
ConT ''VecBuilder
                                Type -> Type -> Type
`AppT` Type -> Type -> Type
AppT (Name -> Type
ConT ''Grad) (FieldInfo -> Type
fiType FieldInfo
field)
                            )
                        )
                        (Name -> Type
ConT ''Maybe Type -> Type -> Type
`AppT` Type
gradBuilderType)
                    ),
                  Name -> [Clause] -> Dec
FunD
                    Name
goName
                    [ [Pat] -> Body -> [Dec] -> Clause
Clause
                        [Name -> Pat
VarP Name
dxdaName]
                        ( Exp -> Body
NormalB
                            ( Exp -> Exp -> Exp
AppE
                                (Name -> Exp
ConE 'Just)
                                Exp
rhs
                            )
                        )
                        []
                    ]
                ]
            ]
        ]
    ]
  where
    n :: Int
n = DownhillRecord -> Int
ddtFieldCount DownhillRecord
pointRecord
    applyVars :: Type -> Type
    applyVars :: Type -> Type
applyVars Type
x = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
x [Type]
instVars
    pointType :: Type
    pointType :: Type
pointType = Type -> Type
applyVars (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
pointRecord)
    gradBuilderType :: Type
gradBuilderType = Type -> Type
applyVars (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DownhillRecord -> Name
ddtTypeConName DownhillRecord
gradBuilderRecord)

renameDownhillRecord :: RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord :: RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (RecordTranstorm RecordNamer
namer Type -> Type
typeFun) DownhillRecord
record =
  DownhillRecord :: Name
-> Name
-> [Type]
-> Maybe [String]
-> [TyVarBndrUnit]
-> Int
-> DatatypeVariant
-> DownhillRecord
DownhillRecord
    { ddtTypeConName :: Name
ddtTypeConName = ShowS -> Name -> Name
renameTypeS (RecordNamer -> ShowS
typeConNamer RecordNamer
namer) (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record),
      ddtDataConName :: Name
ddtDataConName = ShowS -> Name -> Name
renameTypeS (RecordNamer -> ShowS
dataConNamer RecordNamer
namer) (DownhillRecord -> Name
ddtDataConName DownhillRecord
record),
      ddtTypeVars :: [TyVarBndrUnit]
ddtTypeVars = DownhillRecord -> [TyVarBndrUnit]
ddtTypeVars DownhillRecord
record,
      ddtFieldCount :: Int
ddtFieldCount = DownhillRecord -> Int
ddtFieldCount DownhillRecord
record,
      ddtFieldTypes :: [Type]
ddtFieldTypes = Type -> Type
typeFun (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record,
      ddtFieldNames :: Maybe [String]
ddtFieldNames = ([String] -> [String]) -> Maybe [String] -> Maybe [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ShowS -> [String] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (RecordNamer -> ShowS
fieldNamer RecordNamer
namer)) (DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record),
      ddtVariant :: DatatypeVariant
ddtVariant = DownhillRecord -> DatatypeVariant
ddtVariant DownhillRecord
record
    }

builderTransform :: BVarOptions -> RecordTranstorm
builderTransform :: BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optBuilderNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''VecBuilder))

tangTransform :: BVarOptions -> RecordTranstorm
tangTransform :: BVarOptions -> RecordTranstorm
tangTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optTangNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''Tang))

gradTransform :: BVarOptions -> RecordTranstorm
gradTransform :: BVarOptions -> RecordTranstorm
gradTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optGradNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''Grad))

metricTransform :: BVarOptions -> RecordTranstorm
metricTransform :: BVarOptions -> RecordTranstorm
metricTransform BVarOptions
options = RecordNamer -> (Type -> Type) -> RecordTranstorm
RecordTranstorm (BVarOptions -> RecordNamer
optMetricNamer BVarOptions
options) (Type -> Type -> Type
AppT (Name -> Type
ConT ''Metric))

mkVec :: Cxt -> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec :: [Type]
-> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec [Type]
cxt [Type]
instVars Type
scalarType DownhillRecord
vectorType BVarOptions
options = do
  let builderType :: DownhillRecord
builderType = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options) DownhillRecord
vectorType
  [Dec]
tangDec <- DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
vectorType
  [Dec]
tangBuilderDec <- DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
builderType
  [Dec]
tangSemigroup <- [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkSemigroupInstance [Type]
cxt DownhillRecord
builderType [Type]
instVars
  [Dec]
tangInst <- DownhillRecord -> BVarOptions -> [Type] -> [Type] -> Q [Dec]
mkBasicVectorInstance DownhillRecord
vectorType BVarOptions
options [Type]
cxt [Type]
instVars
  [Dec]
additiveTang <- [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance [Type]
cxt DownhillRecord
vectorType [Type]
instVars
  [Dec]
vspaceTang <- DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkVectorSpaceInstance DownhillRecord
vectorType Type
scalarType [Type]
cxt [Type]
instVars
  [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
        [ [Dec]
tangDec,
          [Dec]
tangBuilderDec,
          [Dec]
tangInst,
          [Dec]
tangSemigroup,
          [Dec]
additiveTang,
          [Dec]
vspaceTang
        ]
    )

mkDVar'' ::
  Cxt ->
  DownhillRecord ->
  BVarOptions ->
  Type ->
  [Type] ->
  ConstructorInfo ->
  Q [Dec]
mkDVar'' :: [Type]
-> DownhillRecord
-> BVarOptions
-> Type
-> [Type]
-> ConstructorInfo
-> Q [Dec]
mkDVar'' [Type]
cxt DownhillRecord
pointRecord BVarOptions
options Type
scalarType [Type]
instVars ConstructorInfo
substitutedCInfo = do
  let tangRecord :: DownhillRecord
tangRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
tangTransform BVarOptions
options) DownhillRecord
pointRecord
      gradRecord :: DownhillRecord
gradRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
gradTransform BVarOptions
options) DownhillRecord
pointRecord
      metricRecord :: DownhillRecord
metricRecord = RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
metricTransform BVarOptions
options) DownhillRecord
pointRecord

  [Dec]
tangDecs <- [Type]
-> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec [Type]
cxt [Type]
instVars Type
scalarType DownhillRecord
tangRecord BVarOptions
options
  [Dec]
gradDecs <- [Type]
-> [Type] -> Type -> DownhillRecord -> BVarOptions -> Q [Dec]
mkVec [Type]
cxt [Type]
instVars Type
scalarType DownhillRecord
gradRecord BVarOptions
options

  [Dec]
metricDec <- DownhillRecord -> Q [Dec]
mkRecord DownhillRecord
metricRecord
  [Dec]
additiveMetric <- [Type] -> DownhillRecord -> [Type] -> Q [Dec]
mkAdditiveGroupInstance [Type]
cxt DownhillRecord
metricRecord [Type]
instVars
  [Dec]
vspaceMetric <- DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkVectorSpaceInstance DownhillRecord
metricRecord Type
scalarType [Type]
cxt [Type]
instVars
  [Dec]
dualInstance <- DownhillRecord
-> DownhillRecord -> Type -> [Type] -> [Type] -> Q [Dec]
mkDualInstance DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars
  [Dec]
metricInstance <- DownhillRecord
-> DownhillRecord
-> DownhillRecord
-> Type
-> [Type]
-> [Type]
-> Q [Dec]
mkMetricInstance DownhillRecord
metricRecord DownhillRecord
tangRecord DownhillRecord
gradRecord Type
scalarType [Type]
cxt [Type]
instVars
  let needAffineSpace :: Bool
needAffineSpace = case BVarOptions -> AffineSpaceOptions
optAffineSpace BVarOptions
options of
        AffineSpaceOptions
MakeAffineSpace -> Bool
True
        AffineSpaceOptions
NoAffineSpace -> Bool
False
        AffineSpaceOptions
AutoAffineSpace -> [String] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (BVarOptions -> [String]
optExcludeFields BVarOptions
options)

  [Dec]
affineSpaceInstance <-
    if Bool
needAffineSpace
      then [Type] -> DownhillRecord -> DownhillRecord -> [Type] -> Q [Dec]
mkAffineSpaceInstance [Type]
cxt DownhillRecord
pointRecord DownhillRecord
tangRecord [Type]
instVars
      else [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []

  [Dec]
hasFieldInstance <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
pointRecord of
    Maybe [String]
Nothing -> [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    Just [String]
names ->
      let info :: Int -> String -> Type -> FieldInfo
          info :: Int -> String -> Type -> FieldInfo
info Int
index String
name = String -> Int -> Type -> FieldInfo
FieldInfo String
name Int
index
          substitutedFields :: [Type]
substitutedFields = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
substitutedCInfo
          fields :: [FieldInfo]
          fields :: [FieldInfo]
fields = (Int -> String -> Type -> FieldInfo)
-> [Int] -> [String] -> [Type] -> [FieldInfo]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Int -> String -> Type -> FieldInfo
info [Int
0 ..] [String]
names [Type]
substitutedFields
       in [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FieldInfo -> Q [Dec]) -> [FieldInfo] -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
              ( DownhillRecord
-> DownhillRecord -> [Type] -> [Type] -> FieldInfo -> Q [Dec]
mkGetField
                  DownhillRecord
pointRecord
                  ( RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
builderTransform BVarOptions
options) DownhillRecord
gradRecord
                  )
                  [Type]
cxt
                  [Type]
instVars
              )
              [FieldInfo]
fields

  let decs :: [[Dec]]
decs =
        [ [Dec]
tangDecs,
          [Dec]
gradDecs,
          [Dec]
additiveMetric,
          [Dec]
vspaceMetric,
          [Dec]
dualInstance,
          [Dec]
metricDec,
          [Dec]
metricInstance,
          [Dec]
hasFieldInstance,
          [Dec]
affineSpaceInstance
        ]
  [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Dec]]
decs)

parseRecordType :: Type -> [Type] -> Q (Name, [Type])
parseRecordType :: Type -> [Type] -> Q (Name, [Type])
parseRecordType Type
type_ [Type]
vars = case Type
type_ of
  AppT Type
inner Type
typeVar -> Type -> [Type] -> Q (Name, [Type])
parseRecordType Type
inner (Type
typeVar Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
vars)
  ConT Name
recordName -> (Name, [Type]) -> Q (Name, [Type])
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
recordName, [Type]
vars)
  Type
_ -> String -> Q (Name, [Type])
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected (T a1 ... an) in constraint"

mkAffineSpaceInstance :: Cxt -> DownhillRecord -> DownhillRecord -> [Type] -> Q [Dec]
mkAffineSpaceInstance :: [Type] -> DownhillRecord -> DownhillRecord -> [Type] -> Q [Dec]
mkAffineSpaceInstance [Type]
cxt DownhillRecord
recordPoint DownhillRecord
recordTang [Type]
instVars = do
  Dec
plusDec <- DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
recordPoint DownhillRecord
recordTang DownhillRecord
recordPoint '(.+^)
  Dec
minusDec <- DownhillRecord -> DownhillRecord -> DownhillRecord -> Name -> Q Dec
elementwiseOp' DownhillRecord
recordPoint DownhillRecord
recordPoint DownhillRecord
recordTang '(.-.)
  let recordTypePoint :: Type
recordTypePoint = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
recordPoint)) [Type]
instVars
      recordTypeTang :: Type
recordTypeTang = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DownhillRecord -> Name
ddtTypeConName DownhillRecord
recordTang)) [Type]
instVars
      diffTypeDec :: Dec
diffTypeDec =
        TySynEqn -> Dec
TySynInstD
          ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
              Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
              (Type -> Type -> Type
AppT (Name -> Type
ConT ''Diff) Type
recordTypePoint)
              Type
recordTypeTang
          )
  let decs :: [Dec]
decs =
        [ Dec
plusDec,
          Dec
minusDec,
          Dec
diffTypeDec
        ]
  Name -> [Type] -> DownhillRecord -> [Type] -> [Dec] -> Q [Dec]
mkClassInstance ''AffineSpace [Type]
cxt DownhillRecord
recordPoint [Type]
instVars [Dec]
decs

filterFields :: forall m. MonadFail m => BVarOptions -> DownhillRecord -> m DownhillRecord
filterFields :: BVarOptions -> DownhillRecord -> m DownhillRecord
filterFields BVarOptions
options DownhillRecord
record =
  case BVarOptions -> [String]
optExcludeFields BVarOptions
options of
    [] -> DownhillRecord -> m DownhillRecord
forall (m :: * -> *) a. Monad m => a -> m a
return DownhillRecord
record
    [String]
_ -> do
      [String]
fieldList <- case DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record of
        Just [String]
fields -> [String] -> m [String]
forall (m :: * -> *) a. Monad m => a -> m a
return [String]
fields
        Maybe [String]
Nothing -> String -> m [String]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Name -> String
nameBase (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a records, can't exclude fields")
      [String] -> m DownhillRecord
doFilterFields [String]
fieldList
  where
    doFilterFields :: [String] -> m DownhillRecord
doFilterFields [String]
fieldList = do
      (String -> m ()) -> [String] -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ String -> m ()
check (BVarOptions -> [String]
optExcludeFields BVarOptions
options)
      DownhillRecord -> m DownhillRecord
forall (m :: * -> *) a. Monad m => a -> m a
return
        DownhillRecord
record
          { ddtFieldTypes :: [Type]
ddtFieldTypes = [Type] -> [Type]
forall a. [a] -> [a]
go (DownhillRecord -> [Type]
ddtFieldTypes DownhillRecord
record),
            ddtFieldNames :: Maybe [String]
ddtFieldNames = [String] -> [String]
forall a. [a] -> [a]
go ([String] -> [String]) -> Maybe [String] -> Maybe [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DownhillRecord -> Maybe [String]
ddtFieldNames DownhillRecord
record,
            ddtFieldCount :: Int
ddtFieldCount = Int -> Int
goN (DownhillRecord -> Int
ddtFieldCount DownhillRecord
record)
          }
      where
        check :: String -> m ()
        check :: String -> m ()
check String
name
          | String
name String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
fieldList = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          | Bool
otherwise = String -> m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Field " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a member of " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase (DownhillRecord -> Name
ddtTypeConName DownhillRecord
record))
        excludeZipList :: [x -> Maybe x]
        excludeZipList :: [x -> Maybe x]
excludeZipList = String -> x -> Maybe x
forall x. String -> x -> Maybe x
filterField (String -> x -> Maybe x) -> [String] -> [x -> Maybe x]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
fieldList
          where
            filterField :: String -> x -> Maybe x
            filterField :: String -> x -> Maybe x
filterField String
fieldName x
x
              | String
fieldName String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` BVarOptions -> [String]
optExcludeFields BVarOptions
options = Maybe x
forall a. Maybe a
Nothing
              | Bool
otherwise = x -> Maybe x
forall a. a -> Maybe a
Just x
x
        go :: [a] -> [a]
        go :: [a] -> [a]
go = [Maybe a] -> [a]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe a] -> [a]) -> ([a] -> [Maybe a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a -> Maybe a) -> a -> Maybe a)
-> [a -> Maybe a] -> [a] -> [Maybe a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
($) [a -> Maybe a]
forall x. [x -> Maybe x]
excludeZipList
        goN :: Int -> Int
        goN :: Int -> Int
goN Int
n = [()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([()] -> Int) -> ([()] -> [()]) -> [()] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [()] -> [()]
forall a. [a] -> [a]
go ([()] -> Int) -> [()] -> Int
forall a b. (a -> b) -> a -> b
$ Int -> () -> [()]
forall a. Int -> a -> [a]
replicate Int
n ()

mkDVarC1 :: BVarOptions -> Dec -> Q [Dec]
mkDVarC1 :: BVarOptions -> Dec -> Q [Dec]
mkDVarC1 BVarOptions
options = \case
  InstanceD Maybe Overlap
mayOverlap [Type]
cxt Type
type_ [Dec]
decs -> do
    case Maybe Overlap
mayOverlap of
      Just Overlap
_ -> String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Overlapping instances not implemented"
      Maybe Overlap
_ -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    case Type
type_ of
      AppT (ConT Name
hasgradCtx) Type
recordInConstraintType -> do
        Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
hasgradCtx Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= ''HasGrad) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
          String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"Constraint must be `HasGrad`, got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
hasgradCtx
        (Name
recordName, [Type]
instVars) <- Type -> [Type] -> Q (Name, [Type])
parseRecordType Type
recordInConstraintType []
        DatatypeInfo
record' <- Name -> Q DatatypeInfo
reifyDatatype Name
recordName

        (DownhillRecord
fullParsedRecord, ConstructorInfo
cinfo) <- Name -> DatatypeInfo -> Q (DownhillRecord, ConstructorInfo)
parseDownhillRecord Name
recordName DatatypeInfo
record'
        DownhillRecord
parsedRecord <- BVarOptions -> DownhillRecord -> Q DownhillRecord
forall (m :: * -> *).
MonadFail m =>
BVarOptions -> DownhillRecord -> m DownhillRecord
filterFields BVarOptions
options DownhillRecord
fullParsedRecord
        [Name]
recordTypeVarNames <- do
          let getName :: Type -> m Name
getName Type
x = case Type
x of
                SigT (VarT Name
y) Type
_ -> Name -> m Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
y
                Type
_ -> String -> m Name
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Type variable is not VarT"
          (Type -> Q Name) -> [Type] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Type -> Q Name
forall (m :: * -> *). MonadFail m => Type -> m Name
getName (DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
record')
        -- We have two sets of type variables: one in record definition (as in `data MyRecord a b c = ...`)
        -- and another one in instance head (`instance HasGrad (MyRecord a' b' c')). We need
        -- those from instance head for HasField instances.
        let substPairs :: [(Name, Type)]
substPairs = [Name] -> [Type] -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
recordTypeVarNames [Type]
instVars
            substitutedRecord :: ConstructorInfo
substitutedRecord = Map Name Type -> ConstructorInfo -> ConstructorInfo
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution ([(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)]
substPairs) ConstructorInfo
cinfo

        Type
scalarType <- case [Dec]
decs of
          [] -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"`HasGrad` instance has no declarations"
          [Dec
dec1] -> case Dec
dec1 of
            TySynInstD (TySynEqn Maybe [TyVarBndrUnit]
_ (AppT (ConT Name
scalarName) Type
_) Type
scalarType) -> do
              Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
scalarName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= ''MScalar) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
                String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Expected `Scalar` equation, got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
scalarName)
              Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
scalarType
            Dec
_ -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"HasGrad instance must contain `Scalar ... = ...` declaration"
          [Dec]
_ -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"`HasGrad` has multiple declarations"

        [Dec]
dvar <- [Type]
-> DownhillRecord
-> BVarOptions
-> Type
-> [Type]
-> ConstructorInfo
-> Q [Dec]
mkDVar'' [Type]
cxt DownhillRecord
parsedRecord BVarOptions
options Type
scalarType [Type]
instVars ConstructorInfo
substitutedRecord

        let tangName :: Name
tangName = DownhillRecord -> Name
ddtTypeConName (RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
tangTransform BVarOptions
options) DownhillRecord
parsedRecord)
            gradName :: Name
gradName = DownhillRecord -> Name
ddtTypeConName (RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
gradTransform BVarOptions
options) DownhillRecord
parsedRecord)
            metricName :: Name
metricName = DownhillRecord -> Name
ddtTypeConName (RecordTranstorm -> DownhillRecord -> DownhillRecord
renameDownhillRecord (BVarOptions -> RecordTranstorm
metricTransform BVarOptions
options) DownhillRecord
parsedRecord)
            tangTypeDec :: Dec
tangTypeDec =
              TySynEqn -> Dec
TySynInstD
                ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
                    Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
                    (Type -> Type -> Type
AppT (Name -> Type
ConT ''Tang) Type
recordInConstraintType)
                    ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
tangName) [Type]
instVars)
                )
            gradTypeDec :: Dec
gradTypeDec =
              TySynEqn -> Dec
TySynInstD
                ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
                    Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
                    (Type -> Type -> Type
AppT (Name -> Type
ConT ''Grad) Type
recordInConstraintType)
                    ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
gradName) [Type]
instVars)
                )
            metricTypeDec :: Dec
metricTypeDec =
              TySynEqn -> Dec
TySynInstD
                ( Maybe [TyVarBndrUnit] -> Type -> Type -> TySynEqn
TySynEqn
                    Maybe [TyVarBndrUnit]
forall a. Maybe a
Nothing
                    (Type -> Type -> Type
AppT (Name -> Type
ConT ''Metric) Type
recordInConstraintType)
                    ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
metricName) [Type]
instVars)
                )

            hasgradInstance :: Dec
hasgradInstance =
              Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD
                Maybe Overlap
forall a. Maybe a
Nothing
                [Type]
cxt
                Type
type_
                ( [Dec]
decs
                    [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [ Dec
tangTypeDec,
                         Dec
gradTypeDec,
                         Dec
metricTypeDec
                       ]
                )
        [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [Dec]
dvar [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec
hasgradInstance]
      Type
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Instance head is not a constraint"
  Dec
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected instance declaration"

-- | Generates @HasGrad@ instance, along with @Tang@ and @Grad@ types,
-- @VecBuilder@ types and all other instances needed for @HasGrad@.
mkHasGradInstances :: BVarOptions -> Q [Dec] -> Q [Dec]
mkHasGradInstances :: BVarOptions -> Q [Dec] -> Q [Dec]
mkHasGradInstances BVarOptions
options Q [Dec]
decs = [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Dec -> Q [Dec]) -> [Dec] -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (BVarOptions -> Dec -> Q [Dec]
mkDVarC1 BVarOptions
options) ([Dec] -> Q [[Dec]]) -> Q [Dec] -> Q [[Dec]]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Q [Dec]
decs)