{-# LANGUAGE CPP #-}
#ifdef TRUSTWORTHY
{-# LANGUAGE Trustworthy #-}
#endif
#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(x,y,z) 1
#endif
module Control.Lens.Internal.PrismTH
  ( makePrisms
  , makeClassyPrisms
  , makeDecPrisms
  ) where
import Control.Applicative
import Control.Lens.Fold
import Control.Lens.Getter
import Control.Lens.Internal.TH
import Control.Lens.Lens
import Control.Lens.Setter
import Control.Lens.Tuple
import Control.Monad
import Data.Char (isUpper)
import Data.List
import Data.Monoid
import Data.Set.Lens
import Data.Traversable
import Language.Haskell.TH
import Language.Haskell.TH.Lens
import qualified Data.Map as Map
import qualified Data.Set as Set
import Prelude
makePrisms :: Name  -> DecsQ
makePrisms = makePrisms' True
makeClassyPrisms :: Name  -> DecsQ
makeClassyPrisms = makePrisms' False
makePrisms' :: Bool -> Name -> DecsQ
makePrisms' normal typeName =
  do info <- reify typeName
     case info of
       TyConI dec -> makeDecPrisms normal dec
       _          -> fail "makePrisms: expected type constructor name"
makeDecPrisms :: Bool  -> Dec -> DecsQ
makeDecPrisms normal dec = case dec of
#if MIN_VERSION_template_haskell(2,11,0)
  DataD        _ ty vars _ cons _ -> next ty (convertTVBs vars) cons
  NewtypeD     _ ty vars _ con  _ -> next ty (convertTVBs vars) [con]
  DataInstD    _ ty tys  _ cons _ -> next ty tys                cons
  NewtypeInstD _ ty tys  _ con  _ -> next ty tys                [con]
#else
  DataD        _ ty vars   cons _ -> next ty (convertTVBs vars) cons
  NewtypeD     _ ty vars   con  _ -> next ty (convertTVBs vars) [con]
  DataInstD    _ ty tys    cons _ -> next ty tys                cons
  NewtypeInstD _ ty tys    con  _ -> next ty tys                [con]
#endif
  _                             -> fail "makePrisms: expected type constructor dec"
  where
  convertTVBs = map (VarT . bndrName)
  next ty args cons =
    makeConsPrisms (conAppsT ty args) (concatMap normalizeCon cons) cls
    where
    cls | normal    = Nothing
        | otherwise = Just ty
makeConsPrisms :: Type -> [NCon] -> Maybe Name -> DecsQ
makeConsPrisms t [con@(NCon _ Nothing _)] Nothing = makeConIso t con
makeConsPrisms t cons Nothing =
  fmap concat $ for cons $ \con ->
    do let conName = view nconName con
       stab <- computeOpticType t cons con
       let n = prismName conName
       sequenceA
         [ sigD n (close (stabToType stab))
         , valD (varP n) (normalB (makeConOpticExp stab cons con)) []
         ]
makeConsPrisms t cons (Just typeName) =
  sequenceA
    [ makeClassyPrismClass t className methodName cons
    , makeClassyPrismInstance t className methodName cons
    ]
  where
  className = mkName ("As" ++ nameBase typeName)
  methodName = prismName typeName
data OpticType = PrismType | ReviewType
data Stab  = Stab Cxt OpticType Type Type Type Type
simplifyStab :: Stab -> Stab
simplifyStab (Stab cx ty _ t _ b) = Stab cx ty t t b b
  
  
stabSimple :: Stab -> Bool
stabSimple (Stab _ _ s t a b) = s == t && a == b
stabToType :: Stab -> Type
stabToType stab@(Stab cx ty s t a b) = ForallT vs cx $
  case ty of
    PrismType  | stabSimple stab -> prism'TypeName  `conAppsT` [t,b]
               | otherwise       -> prismTypeName   `conAppsT` [s,t,a,b]
    ReviewType                   -> reviewTypeName  `conAppsT` [t,b]
  where
  vs = map PlainTV
     $ nub 
     $ toListOf typeVars cx
stabType :: Stab -> OpticType
stabType (Stab _ o _ _ _ _) = o
computeOpticType :: Type -> [NCon] -> NCon -> Q Stab
computeOpticType t cons con =
  do let cons' = delete con cons
     case view nconCxt con of
       Just xs -> computeReviewType t xs (view nconTypes con)
       Nothing -> computePrismType t cons' con
computeReviewType :: Type -> Cxt -> [Type] -> Q Stab
computeReviewType s' cx tys =
  do let t = s'
     s <- fmap VarT (newName "s")
     a <- fmap VarT (newName "a")
     b <- toTupleT (map return tys)
     return (Stab cx ReviewType s t a b)
computePrismType :: Type -> [NCon] -> NCon -> Q Stab
computePrismType t cons con =
  do let ts      = view nconTypes con
         unbound = setOf typeVars t Set.\\ setOf typeVars cons
     sub <- sequenceA (fromSet (newName . nameBase) unbound)
     b   <- toTupleT (map return ts)
     a   <- toTupleT (map return (substTypeVars sub ts))
     let s = substTypeVars sub t
     return (Stab [] PrismType s t a b)
computeIsoType :: Type -> [Type] -> TypeQ
computeIsoType t' fields =
  do sub <- sequenceA (fromSet (newName . nameBase) (setOf typeVars t'))
     let t = return                    t'
         s = return (substTypeVars sub t')
         b = toTupleT (map return                    fields)
         a = toTupleT (map return (substTypeVars sub fields))
#ifndef HLINT
         ty | Map.null sub = appsT (conT iso'TypeName) [t,b]
            | otherwise    = appsT (conT isoTypeName) [s,t,a,b]
#endif
     close =<< ty
makeConOpticExp :: Stab -> [NCon] -> NCon -> ExpQ
makeConOpticExp stab cons con =
  case stabType stab of
    PrismType  -> makeConPrismExp stab cons con
    ReviewType -> makeConReviewExp con
makeConIso :: Type -> NCon -> DecsQ
makeConIso s con =
  do let ty      = computeIsoType s (view nconTypes con)
         defName = prismName (view nconName con)
     sequenceA
       [ sigD       defName  ty
       , valD (varP defName) (normalB (makeConIsoExp con)) []
       ]
makeConPrismExp ::
  Stab ->
  [NCon]  ->
  NCon    ->
  ExpQ
makeConPrismExp stab cons con = appsE [varE prismValName, reviewer, remitter]
  where
  ts = view nconTypes con
  fields  = length ts
  conName = view nconName con
  reviewer                   = makeReviewer       conName fields
  remitter | stabSimple stab = makeSimpleRemitter conName fields
           | otherwise       = makeFullRemitter cons conName
makeConIsoExp :: NCon -> ExpQ
makeConIsoExp con = appsE [varE isoValName, remitter, reviewer]
  where
  conName = view nconName con
  fields  = length (view nconTypes con)
  reviewer = makeReviewer    conName fields
  remitter = makeIsoRemitter conName fields
makeConReviewExp :: NCon -> ExpQ
makeConReviewExp con = appE (varE untoValName) reviewer
  where
  conName = view nconName con
  fields  = length (view nconTypes con)
  reviewer = makeReviewer conName fields
makeReviewer :: Name -> Int -> ExpQ
makeReviewer conName fields =
  do xs <- replicateM fields (newName "x")
     lam1E (toTupleP (map varP xs))
           (conE conName `appsE1` map varE xs)
makeSimpleRemitter :: Name -> Int -> ExpQ
makeSimpleRemitter conName fields =
  do x  <- newName "x"
     xs <- newNames "y" fields
     let matches =
           [ match (conP conName (map varP xs))
                   (normalB (appE (conE rightDataName) (toTupleE (map varE xs))))
                   []
           , match wildP (normalB (appE (conE leftDataName) (varE x))) []
           ]
     lam1E (varP x) (caseE (varE x) matches)
makeFullRemitter :: [NCon] -> Name -> ExpQ
makeFullRemitter cons target =
  do x <- newName "x"
     lam1E (varP x) (caseE (varE x) (map mkMatch cons))
  where
  mkMatch (NCon conName _ n) =
    do xs <- newNames "y" (length n)
       match (conP conName (map varP xs))
             (normalB
               (if conName == target
                  then appE (conE rightDataName) (toTupleE (map varE xs))
                  else appE (conE leftDataName) (conE conName `appsE1` map varE xs)))
             []
makeIsoRemitter :: Name -> Int -> ExpQ
makeIsoRemitter conName fields =
  do xs <- newNames "x" fields
     lam1E (conP conName (map varP xs))
           (toTupleE (map varE xs))
makeClassyPrismClass ::
  Type    ->
  Name    ->
  Name    ->
  [NCon]  ->
  DecQ
makeClassyPrismClass t className methodName cons =
  do r <- newName "r"
#ifndef HLINT
     let methodType = appsT (conT prism'TypeName) [varT r,return t]
#endif
     methodss <- traverse (mkMethod (VarT r)) cons'
     classD (cxt[]) className (map PlainTV (r : vs)) (fds r)
       ( sigD methodName methodType
       : map return (concat methodss)
       )
  where
  mkMethod r con =
    do Stab cx o _ _ _ b <- computeOpticType t cons con
       let stab' = Stab cx o r r b b
           defName = view nconName con
           body    = appsE [varE composeValName, varE methodName, varE defName]
       sequenceA
         [ sigD defName        (return (stabToType stab'))
         , valD (varP defName) (normalB body) []
         ]
  cons'         = map (over nconName prismName) cons
  vs            = Set.toList (setOf typeVars t)
  fds r
    | null vs   = []
    | otherwise = [FunDep [r] vs]
makeClassyPrismInstance ::
  Type ->
  Name      ->
  Name      ->
  [NCon]  ->
  DecQ
makeClassyPrismInstance s className methodName cons =
  do let vs = Set.toList (setOf typeVars s)
         cls = className `conAppsT` (s : map VarT vs)
     instanceD (cxt[]) (return cls)
       (   valD (varP methodName)
                (normalB (varE idValName)) []
       : [ do stab <- computeOpticType s cons con
              let stab' = simplifyStab stab
              valD (varP (prismName conName))
                (normalB (makeConOpticExp stab' cons con)) []
           | con <- cons
           , let conName = view nconName con
           ]
       )
data NCon = NCon
  { _nconName :: Name
  , _nconCxt  :: Maybe Cxt
  , _nconTypes :: [Type]
  }
  deriving (Eq)
instance HasTypeVars NCon where
  typeVarsEx s f (NCon x y z) = NCon x <$> typeVarsEx s f y <*> typeVarsEx s f z
nconName :: Lens' NCon Name
nconName f x = fmap (\y -> x {_nconName = y}) (f (_nconName x))
nconCxt :: Lens' NCon (Maybe Cxt)
nconCxt f x = fmap (\y -> x {_nconCxt = y}) (f (_nconCxt x))
nconTypes :: Lens' NCon [Type]
nconTypes f x = fmap (\y -> x {_nconTypes = y}) (f (_nconTypes x))
normalizeCon :: Con -> [NCon]
normalizeCon (RecC    conName xs) = [NCon conName Nothing (map (view _3) xs)]
normalizeCon (NormalC conName xs) = [NCon conName Nothing (map (view _2) xs)]
normalizeCon (InfixC (_,x) conName (_,y)) = [NCon conName Nothing [x,y]]
normalizeCon (ForallC [] [] con) = normalizeCon con 
normalizeCon (ForallC _ cx1 con) =
  [NCon n (Just cx1 <> cx2) tys
     | NCon n cx2 tys <- normalizeCon con ]
#if MIN_VERSION_template_haskell(2,11,0)
normalizeCon (GadtC conNames xs _)    =
  [ NCon conName Nothing (map (view _2) xs) | conName <- conNames ]
normalizeCon (RecGadtC conNames xs _) =
  [ NCon conName Nothing (map (view _3) xs) | conName <- conNames ]
#endif
prismName :: Name -> Name
prismName n = case nameBase n of
                [] -> error "prismName: empty name base?"
                x:xs | isUpper x -> mkName ('_':x:xs)
                     | otherwise -> mkName ('.':x:xs) 
close :: Type -> TypeQ
close t = forallT (map PlainTV (Set.toList vs)) (cxt[]) (return t)
  where
  vs = setOf typeVars t