{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo       #-}
{-# LANGUAGE TupleSections     #-}
{-# LANGUAGE ViewPatterns      #-}

-- | Generate VHDL for assorted Netlist datatypes
module CLaSH.Netlist.VHDL
  ( genVHDL
  , mkTyPackage
  , vhdlType
  , vhdlTypeDefault
  , vhdlTypeMark
  , inst
  , expr
  )
where

import qualified Control.Applicative                  as A
import           Control.Lens                         hiding (Indexed)
import           Control.Monad                        (liftM,zipWithM)
import           Control.Monad.State                  (State)
import           Data.Graph.Inductive                 (Gr, mkGraph, topsort')
import qualified Data.HashMap.Lazy                    as HashMap
import qualified Data.HashSet                         as HashSet
import           Data.List                            (nub)
import           Data.Maybe                           (catMaybes,mapMaybe)
import           Data.Text.Lazy                       (unpack)
import qualified Data.Text.Lazy                       as T
import           Text.PrettyPrint.Leijen.Text.Monadic

import           CLaSH.Netlist.Types
import           CLaSH.Netlist.Util
import           CLaSH.Util                           (makeCached, (<:>))

type VHDLM a = State VHDLState a

-- | Generate VHDL for a Netlist component
genVHDL :: Component -> VHDLM (String,Doc)
genVHDL c = do
    _1 %= (\s -> foldr HashSet.insert s needsDec)
    (unpack cName,) A.<$> vhdl
  where
    cName   = componentName c
    vhdl    = tyImports (not $ null needsDec) <$$> linebreak <>
              entity c <$$> linebreak <>
              architecture c

    tys     =  snd (output c)
            :  map snd (inputs c)
            ++ concatMap (\d -> case d of {(NetDecl _ ty _) -> [ty]; _ -> []}) (declarations c)
    needsDec = nub $ concatMap needsTyDec tys

-- | Generate a VHDL package containing type definitions for the given HWTypes
mkTyPackage :: [HWType]
            -> VHDLM Doc
mkTyPackage hwtys =
   "library IEEE;" <$>
   "use IEEE.STD_LOGIC_1164.ALL;" <$>
   "use IEEE.NUMERIC_STD.ALL;" <$$> linebreak <>
   "package" <+> "types" <+> "is" <$>
      packageDec <$>
   "end" <> semi <> packageBodyDec
  where
    hwTysSorted = topSortHWTys hwtys
    packageDec  = indent 2 (vcat $ mapM tyDec hwTysSorted)

    packageBodyDec = do
      funDecs <- catMaybes A.<$> mapM funDec hwTysSorted
      case funDecs of
        [] -> empty
        _  -> linebreak <$>
              "package" <+> "body" <+> "types" <+> "is" <$>
                indent 2 (vcat $ return funDecs) <$>
              "end" <> semi

topSortHWTys :: [HWType]
             -> [HWType]
topSortHWTys hwtys = sorted
  where
    nodes  = zip [0..] hwtys
    nodesI = HashMap.fromList (zip hwtys [0..])
    edges  = concatMap edge hwtys
    graph  = mkGraph nodes edges :: Gr HWType ()
    sorted = reverse $ topsort' graph

    edge t@(Vector _ elTy) = maybe [] ((:[]) . (nodesI HashMap.! t,,())) (HashMap.lookup elTy nodesI)
    edge t@(Product _ tys) = let ti = nodesI HashMap.! t
                             in mapMaybe (\ty -> liftM (ti,,()) (HashMap.lookup ty nodesI)) tys
    edge t@(SP _ ctys)     = let ti = nodesI HashMap.! t
                             in concatMap (\(_,tys) -> mapMaybe (\ty -> liftM (ti,,()) (HashMap.lookup ty nodesI)) tys) ctys
    edge _                 = []

needsTyDec :: HWType -> [HWType]
needsTyDec (Vector _ Bit)     = []
needsTyDec (Vector _ elTy)    = needsTyDec elTy ++ [Vector 0 elTy]
needsTyDec ty@(Product _ tys) = concatMap needsTyDec tys ++ [ty]
needsTyDec (SP _ tys)         = concatMap (concatMap needsTyDec . snd) tys
needsTyDec Bool               = [Bool]
needsTyDec Integer            = [Integer]
needsTyDec _                  = []

tyDec :: HWType -> VHDLM Doc
tyDec Bool = "function" <+> "toSLV" <+> parens ("b" <+> colon <+> "in" <+> "boolean") <+> "return" <+> "std_logic_vector" <> semi
tyDec Integer = "function" <+> "to_integer" <+> parens ("i" <+> colon <+> "in" <+> "integer") <+> "return" <+> "integer" <> semi

tyDec (Vector _ elTy) = "type" <+> "array_of_" <> tyName elTy <+> "is array (natural range <>) of" <+> vhdlType elTy <> semi

tyDec ty@(Product _ tys) = prodDec
  where
    prodDec = "type" <+> tName <+> "is record" <$>
                indent 2 (vcat $ zipWithM (\x y -> x <+> colon <+> y <> semi) selNames selTys) <$>
              "end record" <> semi

    tName    = tyName ty
    selNames = map (\i -> tName <> "_sel" <> int i) [0..]
    selTys   = map vhdlType tys

tyDec _ = empty

funDec :: HWType -> VHDLM (Maybe Doc)
funDec Bool = fmap Just $
  "function" <+> "toSLV" <+> parens ("b" <+> colon <+> "in" <+> "boolean") <+> "return" <+> "std_logic_vector" <+> "is" <$>
  "begin" <$>
    indent 2 (vcat $ sequence ["if" <+> "b" <+> "then"
                              ,  indent 2 ("return" <+> dquotes (int 1) <> semi)
                              ,"else"
                              ,  indent 2 ("return" <+> dquotes (int 0) <> semi)
                              ,"end" <+> "if" <> semi
                              ]) <$>
  "end" <> semi

funDec Integer = fmap Just $
  "function" <+> "to_integer" <+> parens ("i" <+> colon <+> "in" <+> "integer") <+> "return" <+> "integer" <+> "is" <$>
  "begin" <$>
    indent 2 ("return" <+> "i" <> semi) <$>
  "end" <> semi

funDec _ = return Nothing

tyName :: HWType -> VHDLM Doc
tyName Integer           = "integer"
tyName Bit               = "std_logic"
tyName (Vector n Bit)    = "std_logic_vector_" <> int n
tyName (Vector n elTy)   = "array_of_" <> int n <> "_" <> tyName elTy
tyName (Signed n)        = "signed_" <> int n
tyName (Unsigned n)      = "unsigned_" <> int n
tyName t@(Sum _ _)       = "unsigned_" <> int (typeSize t)
tyName t@(Product _ _)   = makeCached t _3 prodName
  where
    prodName = do i <- _2 <<%= (+1)
                  "product" <> int i

tyName _ = empty

tyImports :: Bool -> VHDLM Doc
tyImports needsDec =
  punctuate' semi $ sequence $ concat
    [ [ "library IEEE"
      , "use IEEE.STD_LOGIC_1164.ALL"
      , "use IEEE.NUMERIC_STD.ALL"
      , "use work.all" ]
    , if needsDec then ["use work.types.all"] else []
    ]


entity :: Component -> VHDLM Doc
entity c = do
    rec (p,ls) <- fmap unzip (ports (maximum ls))
    "entity" <+> text (componentName c) <+> "is" <$>
      (case p of
         [] -> empty
         _  -> indent 2 ("port" <>
                         parens (align $ vcat $ punctuate semi (A.pure p)) <>
                         semi)
      ) <$>
      "end" <> semi
  where
    ports l = sequence
            $ [ (,fromIntegral $ T.length i) A.<$> (fill l (text i) <+> colon <+> "in" <+> vhdlType ty <+> ":=" <+> vhdlTypeDefault ty)
              | (i,ty) <- inputs c ] ++
              [ (,fromIntegral $ T.length i) A.<$> (fill l (text i) <+> colon <+> "in" <+> vhdlType ty <+> ":=" <+> vhdlTypeDefault ty)
              | (i,ty) <- hiddenPorts c ] ++
              [ (,fromIntegral $ T.length (fst $ output c)) A.<$> (fill l (text (fst $ output c)) <+> colon <+> "out" <+> vhdlType (snd $ output c) <+> ":=" <+> vhdlTypeDefault (snd $ output c))
              ]

architecture :: Component -> VHDLM Doc
architecture c =
  nest 2
    ("architecture structural of" <+> text (componentName c) <+> "is" <$$>
     decls (declarations c)) <$$>
  nest 2
    ("begin" <$$>
     insts (declarations c)) <$$>
    "end" <> semi

-- | Convert a Netlist HWType to a VHDL type
vhdlType :: HWType -> VHDLM Doc
vhdlType Bit        = "std_logic"
vhdlType Bool       = "boolean"
vhdlType (Clock _)  = "std_logic"
vhdlType (Reset _)  = "std_logic"
vhdlType Integer    = "integer"
vhdlType (Signed n) = "signed" <>
                      parens ( int (n-1) <+> "downto 0")
vhdlType (Unsigned n) = "unsigned" <>
                        parens ( int (n-1) <+> "downto 0")
vhdlType (Vector n Bit) = "std_logic_vector" <> parens ( int (n-1) <+> "downto 0")
vhdlType (Vector n elTy) = "array_of_" <> tyName elTy <> parens ( int (n-1) <+> "downto 0")
vhdlType t@(SP _ _) = "std_logic_vector" <>
                      parens ( int (typeSize t - 1) <+>
                               "downto 0" )
vhdlType t@(Sum _ _) = "unsigned" <>
                        parens ( int (typeSize t -1) <+>
                                 "downto 0")
vhdlType t@(Product _ _) = tyName t
vhdlType t          = error $ "vhdlType: " ++ show t

-- | Convert a Netlist HWType to the root of a VHDL type
vhdlTypeMark :: HWType -> VHDLM Doc
vhdlTypeMark Bit             = "std_logic"
vhdlTypeMark Bool            = "boolean"
vhdlTypeMark (Clock _)       = "std_logic"
vhdlTypeMark (Reset _)       = "std_logic"
vhdlTypeMark Integer         = "integer"
vhdlTypeMark (Signed _)      = "signed"
vhdlTypeMark (Unsigned _)    = "unsigned"
vhdlTypeMark (Vector _ Bit)  = "std_logic_vector"
vhdlTypeMark (Vector _ elTy) = "array_of_" <> tyName elTy
vhdlTypeMark (SP _ _)        = "std_logic_vector"
vhdlTypeMark (Sum _ _)       = "unsigned"
vhdlTypeMark t@(Product _ _) = tyName t
vhdlTypeMark t               = error $ "vhdlTypeMark: " ++ show t

-- | Convert a Netlist HWType to a default VHDL value for that type
vhdlTypeDefault :: HWType -> VHDLM Doc
vhdlTypeDefault Bit                 = "'0'"
vhdlTypeDefault Bool                = "false"
vhdlTypeDefault Integer             = "0"
vhdlTypeDefault (Signed _)          = "(others => '0')"
vhdlTypeDefault (Unsigned _)        = "(others => '0')"
vhdlTypeDefault (Vector _ elTy)     = parens ("others" <+> rarrow <+> vhdlTypeDefault elTy)
vhdlTypeDefault (SP _ _)            = "(others => '0')"
vhdlTypeDefault (Sum _ _)           = "(others => '0')"
vhdlTypeDefault (Product _ elTys)   = tupled $ mapM vhdlTypeDefault elTys
vhdlTypeDefault (Reset _)           = "'0'"
vhdlTypeDefault (Clock _)           = "'0'"
vhdlTypeDefault t                   = error $ "vhdlTypeDefault: " ++ show t

decls :: [Declaration] -> VHDLM Doc
decls [] = empty
decls ds = do
    rec (dsDoc,ls) <- fmap (unzip . catMaybes) $ mapM (decl (maximum ls)) ds
    case dsDoc of
      [] -> empty
      _  -> vcat (punctuate semi (A.pure dsDoc)) <> semi

decl :: Int ->  Declaration -> VHDLM (Maybe (Doc,Int))
decl l (NetDecl id_ ty netInit) = Just A.<$> (,fromIntegral (T.length id_)) A.<$>
  "signal" <+> fill l (text id_) <+> colon <+> vhdlType ty <+> ":=" <+> maybe (vhdlTypeDefault ty) (expr False) netInit

decl _ _ = return Nothing

insts :: [Declaration] -> VHDLM Doc
insts [] = empty
insts is = vcat . punctuate linebreak . fmap catMaybes $ mapM inst is

-- | Turn a Netlist Declaration to a VHDL concurrent block
inst :: Declaration -> VHDLM (Maybe Doc)
inst (Assignment id_ e) = fmap Just $
  text id_ <+> larrow <+> expr False e <> semi

inst (CondAssignment id_ scrut es) = fmap Just $
  text id_ <+> larrow <+> align (vcat (mapM cond es)) <> semi
    where
      cond :: (Maybe Expr,Expr) -> VHDLM Doc
      cond (Nothing,e) = expr False e
      cond (Just c ,e) = expr False e <+> "when" <+> parens (expr True scrut <+> "=" <+> expr True c) <+> "else"

inst (InstDecl nm lbl pms) = fmap Just $
    nest 2 $ text lbl <> "_comp_inst" <+> colon <+> "entity"
              <+> text nm <$$> pms' <> semi
  where
    pms' = do
      rec (p,ls) <- fmap unzip $ sequence [ (,fromIntegral (T.length i)) A.<$> fill (maximum ls) (text i) <+> "=>" <+> expr False e | (i,e) <- pms]
      nest 2 $ "port map" <$$> tupled (A.pure p)

inst (BlackBoxD bs) = fmap Just $ string bs

inst _ = return Nothing

-- | Turn a Netlist expression into a VHDL expression
expr :: Bool -- ^ Enclose in parenthesis?
     -> Expr -- ^ Expr to convert
     -> VHDLM Doc
expr _ (Literal sizeM lit)                           = exprLit sizeM lit
expr _ (Identifier id_ Nothing)                      = text id_
expr _ (Identifier id_ (Just (Indexed (ty@(SP _ args),dcI,fI)))) = fromSLV argTy selected
  where
    argTys   = snd $ args !! dcI
    argTy    = argTys !! fI
    argSize  = typeSize argTy
    other    = otherSize argTys (fI-1)
    start    = typeSize ty - 1 - conSize ty - other
    end      = start - argSize + 1
    selected = text id_ <> parens (int start <+> "downto" <+> int end)

expr _ (Identifier id_ (Just (Indexed (ty@(Product _ _),_,fI)))) = text id_ <> dot <> tyName ty <> "_sel" <> int fI
expr _ (Identifier id_ (Just (DC (ty@(SP _ _),_)))) = text id_ <> parens (int start <+> "downto" <+> int end)
  where
    start = typeSize ty - 1
    end   = typeSize ty - conSize ty

expr _ (Identifier id_ (Just _)) = text id_
expr _ (vectorChain -> Just es)                  = tupled (mapM (expr False) es)
expr _ (DataCon (Vector 1 _) _ [e])              = parens ("others" <+> rarrow <+> expr False e)
expr _ (DataCon (Vector _ _) _ [e1,e2])          = expr False e1 <+> "&" <+> expr False e2
expr _ (DataCon ty@(SP _ args) (Just (DC (_,i))) es) = assignExpr
  where
    argTys     = snd $ args !! i
    dcSize     = conSize ty + sum (map typeSize argTys)
    dcExpr     = expr False (dcToExpr ty i)
    argExprs   = zipWith toSLV argTys $ map (expr False) es
    extraArg   = case typeSize ty - dcSize of
                   0 -> []
                   n -> [exprLit (Just n) (NumLit 0)]
    assignExpr = hcat $ punctuate " & " $ sequence (dcExpr:argExprs ++ extraArg)

expr _ (DataCon ty@(Sum _ _) (Just (DC (_,i))) []) = "to_unsigned" <> tupled (sequence [int i,int (typeSize ty)])
expr _ (DataCon ty@(Product _ _) _ es)             = tupled $ zipWithM (\i e -> tName <> "_sel" <> int i <+> rarrow <+> expr False e) [0..] es
  where
    tName = tyName ty

expr b (BlackBoxE bs (Just (DC (ty@(SP _ _),_)))) = parenIf b $ parens (string bs) <> parens (int start <+> "downto" <+> int end)
  where
    start = typeSize ty - 1
    end   = typeSize ty - conSize ty
expr b (BlackBoxE bs _) = parenIf b $ string bs

expr _ _ = empty

otherSize :: [HWType] -> Int -> Int
otherSize _ n | n < 0 = 0
otherSize []     _    = 0
otherSize (a:as) n    = typeSize a + otherSize as (n-1)

vectorChain :: Expr -> Maybe [Expr]
vectorChain (DataCon (Vector _ _) Nothing _)        = Just []
vectorChain (DataCon (Vector 1 _) (Just _) [e])     = Just [e]
vectorChain (DataCon (Vector _ _) (Just _) [e1,e2]) = Just e1 <:> vectorChain e2
vectorChain _                                       = Nothing

exprLit :: Maybe Size -> Literal -> VHDLM Doc
exprLit Nothing   (NumLit i) = int i
exprLit (Just sz) (NumLit i) = bits (toBits sz i)
exprLit _         (BoolLit t) = if t then "true" else "false"
exprLit _         (BitLit b) = squotes $ bit_char b
exprLit _         _          = error "exprLit"

toBits :: Integral a => Int -> a -> [Bit]
toBits size val = map (\x -> if odd x then H else L)
                $ reverse
                $ take size
                $ map (`mod` 2)
                $ iterate (`div` 2) val

bits :: [Bit] -> VHDLM Doc
bits = dquotes . hcat . mapM bit_char

bit_char :: Bit -> VHDLM Doc
bit_char H = char '1'
bit_char L = char '0'
bit_char U = char 'U'
bit_char Z = char 'Z'

toSLV :: HWType -> VHDLM Doc -> VHDLM Doc
toSLV Bit        d   = parens (int 0 <+> rarrow <+> d)
toSLV Bool       d   = "toSLV" <> parens d
toSLV Integer    d   = toSLV (Signed 32) ("to_signed" <> tupled (sequence [d,int 32]))
toSLV (Signed _) d   = "std_logic_vector" <> parens d
toSLV (Unsigned _) d = "std_logic_vector" <> parens d
toSLV (Sum _ _) d    = "std_logic_vector" <> parens d
toSLV hty          _ = error $ "toSLV: " ++ show hty

fromSLV :: HWType -> VHDLM Doc -> VHDLM Doc
fromSLV Bit d          = d <> parens (int 0)
fromSLV Bool d         = "fromSLV" <> parens d
fromSLV Integer d      = "to_integer" <> parens (fromSLV (Signed 32) d)
fromSLV (Signed _) d   = "signed" <> parens d
fromSLV (Unsigned _) d = "unsigned" <> parens d
fromSLV (SP _ _) d     = d
fromSLV (Sum _ _) d    = "unsigned" <> parens d
fromSLV hty _          = error $ "fromSLV: " ++ show hty

dcToExpr :: HWType -> Int -> Expr
dcToExpr ty i = Literal (Just $ conSize ty) (NumLit i)

larrow :: VHDLM Doc
larrow = "<="

rarrow :: VHDLM Doc
rarrow = "=>"

parenIf :: Monad m => Bool -> m Doc -> m Doc
parenIf True  = parens
parenIf False = id

punctuate' :: Monad m => m Doc -> m [Doc] -> m Doc
punctuate' s d = vcat (punctuate s d) <> s