{-# LANGUAGE TemplateHaskellQuotes #-}

module Clash.Core.TermLiteral.TH
  (  deriveTermToData
  ) where

import           Data.Either
import qualified Data.Text                       as Text
import           Language.Haskell.TH.Syntax

import           Clash.Core.DataCon
import           Clash.Core.Term                 (collectArgs, Term(Data))
import           Clash.Core.Name                 (nameOcc)

-- Workaround for a strange GHC bug, where it complains about Subst only
-- existing as a boot file:
--
-- module Clash.Core.Subst cannot be linked; it is only available as a boot module
import Clash.Core.Subst ()

dcName' :: DataCon -> String
dcName' :: DataCon -> String
dcName' = Text -> String
Text.unpack (Text -> String) -> (DataCon -> Text) -> DataCon -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DataCon -> Text
forall a. Name a -> Text
nameOcc (Name DataCon -> Text)
-> (DataCon -> Name DataCon) -> DataCon -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataCon -> Name DataCon
dcName

termToDataName :: Name
termToDataName :: Name
termToDataName = String -> Name
mkName "Clash.Core.TermLiteral.termToData"

deriveTermToData :: Name -> Q Exp
deriveTermToData :: Name -> Q Exp
deriveTermToData typName :: Name
typName = do
  TyConI (DataD _ _ _ _ constrs :: [Con]
constrs _) <- Name -> Q Info
reify Name
typName
  Exp -> Q Exp
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([(Name, Int)] -> Exp
deriveTermToData1 ((Con -> (Name, Int)) -> [Con] -> [(Name, Int)]
forall a b. (a -> b) -> [a] -> [b]
map Con -> (Name, Int)
toConstr' [Con]
constrs))
 where
  toConstr' :: Con -> (Name, Int)
toConstr' (NormalC cName :: Name
cName fields :: [BangType]
fields) = (Name
cName, [BangType] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [BangType]
fields)
  toConstr' c :: Con
c = String -> (Name, Int)
forall a. HasCallStack => String -> a
error (String -> (Name, Int)) -> String -> (Name, Int)
forall a b. (a -> b) -> a -> b
$ "Unexpected constructor: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
c

deriveTermToData1 :: [(Name, Int)] -> Exp
deriveTermToData1 :: [(Name, Int)] -> Exp
deriveTermToData1 constrs :: [(Name, Int)]
constrs =
  [Pat] -> Exp -> Exp
LamE
    [Pat
pat]
    (if [Dec] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Dec]
args then Exp
theCase else [Dec] -> Exp -> Exp
LetE [Dec]
args Exp
theCase)
 where
  nArgs :: Int
nArgs = [Int] -> Int
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum (((Name, Int) -> Int) -> [(Name, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Int) -> Int
forall a b. (a, b) -> b
snd [(Name, Int)]
constrs)

  args :: [Dec]
  args :: [Dec]
args = (Integer -> Name -> Dec) -> [Integer] -> [Name] -> [Dec]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\n :: Integer
n nm :: Name
nm -> Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
nm) (Exp -> Body
NormalB (Integer -> Exp
arg Integer
n)) []) [0..] [Name]
argNames
  arg :: Integer -> Exp
arg n :: Integer
n = Exp -> Exp -> Exp -> Exp
UInfixE (Name -> Exp
VarE Name
argsName) (Name -> Exp
VarE '(!!)) (Lit -> Exp
LitE (Integer -> Lit
IntegerL Integer
n))

  -- case nm of {"ConstrOne" -> ConstOne <$> termToData arg0; "ConstrTwo" -> ...}
  theCase :: Exp
  theCase :: Exp
theCase =
    Exp -> [Match] -> Exp
CaseE
      (Name -> Exp
VarE Name
nameName)
      (((Name, Int) -> Match) -> [(Name, Int)] -> [Match]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Int) -> Match
match [(Name, Int)]
constrs [Match] -> [Match] -> [Match]
forall a. [a] -> [a] -> [a]
++ [Match
emptyMatch])

  emptyMatch :: Match
emptyMatch = Pat -> Body -> [Dec] -> Match
Match Pat
WildP (Exp -> Body
NormalB (Name -> Exp
ConE 'Left Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
termName)) []

  match :: (Name, Int) -> Match
  match :: (Name, Int) -> Match
match (cName :: Name
cName, nFields :: Int
nFields) =
    Pat -> Body -> [Dec] -> Match
Match (Lit -> Pat
LitP (String -> Lit
StringL (Name -> String
forall a. Show a => a -> String
show Name
cName))) (Exp -> Body
NormalB (Name -> Int -> Exp
mkCall Name
cName Int
nFields)) []

  mkCall :: Name -> Int -> Exp
  mkCall :: Name -> Int -> Exp
mkCall cName :: Name
cName 0  = Name -> Exp
ConE 'Right Exp -> Exp -> Exp
`AppE` Name -> Exp
ConE Name
cName
  mkCall cName :: Name
cName 1 =
    Exp -> Exp -> Exp -> Exp
UInfixE
      (Name -> Exp
ConE Name
cName)
      (Name -> Exp
VarE '(<$>))
      (Name -> Exp
VarE Name
termToDataName Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE ([Name] -> Name
forall a. [a] -> a
head [Name]
argNames))
  mkCall cName :: Name
cName nFields :: Int
nFields =
    (Exp -> Name -> Exp) -> Exp -> [Name] -> Exp
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
      (\e :: Exp
e aName :: Name
aName ->
        Exp -> Exp -> Exp -> Exp
UInfixE
          Exp
e
          (Name -> Exp
VarE '(<*>))
          (Name -> Exp
VarE Name
termToDataName Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
aName))
      (Name -> Int -> Exp
mkCall Name
cName 1)
      (Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take (Int
nFieldsInt -> Int -> Int
forall a. Num a => a -> a -> a
-1) ([Name] -> [Name]
forall a. [a] -> [a]
tail [Name]
argNames))

  -- term@(collectArgs -> (Data (dcName' -> nm), args))
  pat :: Pat
  pat :: Pat
pat =
    Name -> Pat -> Pat
AsP
      Name
termName
      (Exp -> Pat -> Pat
ViewP
        (Name -> Exp
VarE 'collectArgs)
        ([Pat] -> Pat
TupP [ Name -> [Pat] -> Pat
ConP 'Data [Exp -> Pat -> Pat
ViewP (Name -> Exp
VarE 'dcName') (Name -> Pat
VarP Name
nameName)]
              , Exp -> Pat -> Pat
ViewP
                 (Name -> Exp
VarE 'lefts)
                 (if Int
nArgs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then Pat
WildP else Name -> Pat
VarP Name
argsName)]))

  termName :: Name
termName = String -> Name
mkName "term"
  argsName :: Name
argsName = String -> Name
mkName "args"
  argNames :: [Name]
argNames = [String -> Name
mkName ("arg" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n) | Int
n <- [0..Int
nArgsInt -> Int -> Int
forall a. Num a => a -> a -> a
-1]]
  nameName :: Name
nameName = String -> Name
mkName "nm"