module Control.Lens.Internal.PrismTH
  ( makePrisms
  , makeClassyPrisms
  , makeDecPrisms
  ) where
import Control.Applicative
import Control.Lens.Getter
import Control.Lens.Internal.TH
import Control.Lens.Lens
import Control.Lens.Setter
import Control.Lens.Tuple
import Control.Monad
import Data.Char (isUpper)
import Data.List
import Data.Monoid
import Data.Set.Lens
import Data.Traversable (for,sequenceA,traverse)
import Language.Haskell.TH
import Language.Haskell.TH.Lens
import qualified Data.Map as Map
import qualified Data.Set as Set
makePrisms :: Name  -> DecsQ
makePrisms = makePrisms' True
makeClassyPrisms :: Name  -> DecsQ
makeClassyPrisms = makePrisms' False
makePrisms' :: Bool -> Name -> DecsQ
makePrisms' normal typeName =
  do info <- reify typeName
     case info of
       TyConI dec -> makeDecPrisms normal dec
       _          -> fail "makePrisms: expected type constructor name"
makeDecPrisms :: Bool  -> Dec -> DecsQ
makeDecPrisms normal dec = case dec of
  DataD        _ ty vars cons _ -> next ty (convertTVBs vars) cons
  NewtypeD     _ ty vars con  _ -> next ty (convertTVBs vars) [con]
  DataInstD    _ ty tys  cons _ -> next ty tys                cons
  NewtypeInstD _ ty tys  con  _ -> next ty tys                [con]
  _                             -> fail "makePrisms: expected type constructor dec"
  where
  convertTVBs = map (VarT . bndrName)
  next ty args cons =
    makeConsPrisms (conAppsT ty args) (map normalizeCon cons) cls
    where
    cls | normal    = Nothing
        | otherwise = Just ty
makeConsPrisms :: Type -> [NCon] -> Maybe Name -> DecsQ
makeConsPrisms t [con@(NCon _ Nothing _)] Nothing = makeConIso t con
makeConsPrisms t cons Nothing =
  fmap concat $ for cons $ \con ->
    do let conName = view nconName con
       stab <- computeOpticType t cons con
       let n = prismName conName
       sequence
         [ sigD n (close (stabToType stab))
         , valD (varP n) (normalB (makeConOpticExp stab cons con)) []
         ]
makeConsPrisms t cons (Just typeName) =
  sequence
    [ makeClassyPrismClass t className methodName cons
    , makeClassyPrismInstance t className methodName cons
    ]
  where
  className = mkName ("As" ++ nameBase typeName)
  methodName = prismName typeName
data OpticType = PrismType | ReviewType
data Stab  = Stab Cxt OpticType Type Type Type Type
simplifyStab :: Stab -> Stab
simplifyStab (Stab cx ty _ t _ b) = Stab cx ty t t b b
  
  
stabSimple :: Stab -> Bool
stabSimple (Stab _ _ s t a b) = s == t && a == b
stabToType :: Stab -> Type
stabToType stab@(Stab cx ty s t a b) = ForallT vs cx $
  case ty of
    PrismType  | stabSimple stab -> prism'TypeName  `conAppsT` [t,b]
               | otherwise       -> prismTypeName   `conAppsT` [s,t,a,b]
    ReviewType | stabSimple stab -> review'TypeName `conAppsT` [t,b]
               | otherwise       -> reviewTypeName  `conAppsT` [s,t,a,b]
  where
  vs = map PlainTV (Set.toList (setOf typeVars cx))
stabType :: Stab -> OpticType
stabType (Stab _ o _ _ _ _) = o
computeOpticType :: Type -> [NCon] -> NCon -> Q Stab
computeOpticType t cons con =
  do let cons' = delete con cons
     case view nconCxt con of
       Just xs -> computeReviewType t xs (view nconTypes con)
       Nothing -> computePrismType t cons' con
computeReviewType :: Type -> Cxt -> [Type] -> Q Stab
computeReviewType s' cx tys =
  do let t = s'
     s <- fmap VarT (newName "s")
     a <- fmap VarT (newName "a")
     b <- toTupleT (map return tys)
     return (Stab cx ReviewType s t a b)
computePrismType :: Type -> [NCon] -> NCon -> Q Stab
computePrismType t cons con =
  do let ts      = view nconTypes con
         unbound = setOf typeVars t Set.\\ setOf typeVars cons
     sub <- sequenceA (fromSet (newName . nameBase) unbound)
     b   <- toTupleT (map return ts)
     a   <- toTupleT (map return (substTypeVars sub ts))
     let s = substTypeVars sub t
     return (Stab [] PrismType s t a b)
computeIsoType :: Type -> [Type] -> TypeQ
computeIsoType t' fields =
  do sub <- sequenceA (fromSet (newName . nameBase) (setOf typeVars t'))
     let t = return                    t'
         s = return (substTypeVars sub t')
         b = toTupleT (map return                    fields)
         a = toTupleT (map return (substTypeVars sub fields))
#ifndef HLINT
         ty | Map.null sub = appsT (conT iso'TypeName) [t,b]
            | otherwise    = appsT (conT isoTypeName) [s,t,a,b]
#endif
     close =<< ty
makeConOpticExp :: Stab -> [NCon] -> NCon -> ExpQ
makeConOpticExp stab cons con =
  case stabType stab of
    PrismType  -> makeConPrismExp stab cons con
    ReviewType -> makeConReviewExp con
makeConIso :: Type -> NCon -> DecsQ
makeConIso s con =
  do let ty      = computeIsoType s (view nconTypes con)
         defName = prismName (view nconName con)
     sequence
       [ sigD       defName  ty
       , valD (varP defName) (normalB (makeConIsoExp con)) []
       ]
makeConPrismExp ::
  Stab ->
  [NCon]  ->
  NCon    ->
  ExpQ
makeConPrismExp stab cons con = appsE [varE prismValName, reviewer, remitter]
  where
  ts = view nconTypes con
  fields  = length ts
  conName = view nconName con
  reviewer                   = makeReviewer       conName fields
  remitter | stabSimple stab = makeSimpleRemitter conName fields
           | otherwise       = makeFullRemitter cons conName
makeConIsoExp :: NCon -> ExpQ
makeConIsoExp con = appsE [varE isoValName, remitter, reviewer]
  where
  conName = view nconName con
  fields  = length (view nconTypes con)
  reviewer = makeReviewer    conName fields
  remitter = makeIsoRemitter conName fields
makeConReviewExp :: NCon -> ExpQ
makeConReviewExp con = appE (varE untoValName) reviewer
  where
  conName = view nconName con
  fields  = length (view nconTypes con)
  reviewer = makeReviewer conName fields
makeReviewer :: Name -> Int -> ExpQ
makeReviewer conName fields =
  do xs <- replicateM fields (newName "x")
     lam1E (toTupleP (map varP xs))
           (conE conName `appsE1` map varE xs)
makeSimpleRemitter :: Name -> Int -> ExpQ
makeSimpleRemitter conName fields =
  do x  <- newName "x"
     xs <- replicateM fields (newName "y")
     let matches =
           [ match (conP conName (map varP xs))
                   (normalB (appE (conE rightDataName) (toTupleE (map varE xs))))
                   []
           , match wildP (normalB (appE (conE leftDataName) (varE x))) []
           ]
     lam1E (varP x) (caseE (varE x) matches)
makeFullRemitter :: [NCon] -> Name -> ExpQ
makeFullRemitter cons target =
  do x <- newName "x"
     lam1E (varP x) (caseE (varE x) (map mkMatch cons))
  where
  mkMatch (NCon conName _ n) =
    do xs <- replicateM (length n) (newName "y")
       match (conP conName (map varP xs))
             (normalB
               (if conName == target
                  then appE (conE rightDataName) (toTupleE (map varE xs))
                  else appE (conE leftDataName) (conE conName `appsE1` map varE xs)))
             []
makeIsoRemitter :: Name -> Int -> ExpQ
makeIsoRemitter conName fields =
  do xs <- replicateM fields (newName "x")
     lam1E (conP conName (map varP xs))
           (toTupleE (map varE xs))
makeClassyPrismClass ::
  Type    ->
  Name    ->
  Name    ->
  [NCon]  ->
  DecQ
makeClassyPrismClass t className methodName cons =
  do r <- newName "r"
#ifndef HLINT
     let methodType = appsT (conT prism'TypeName) [varT r,return t]
#endif
     methodss <- traverse (mkMethod (VarT r)) cons'
     classD (cxt[]) className (map PlainTV (r : vs)) (fds r)
       ( sigD methodName methodType
       : map return (concat methodss)
       )
  where
  mkMethod r con =
    do Stab cx o _ _ _ b <- computeOpticType t cons con
       let stab' = Stab cx o r r b b
           defName = view nconName con
           body    = appsE [varE composeValName, varE methodName, varE defName]
       sequence
         [ sigD defName        (return (stabToType stab'))
         , valD (varP defName) (normalB body) []
         ]
  cons'         = map (over nconName prismName) cons
  vs            = Set.toList (setOf typeVars t)
  fds r
    | null vs   = []
    | otherwise = [FunDep [r] vs]
makeClassyPrismInstance ::
  Type ->
  Name      ->
  Name      ->
  [NCon]  ->
  DecQ
makeClassyPrismInstance s className methodName cons =
  do let vs = Set.toList (setOf typeVars s)
         cls = className `conAppsT` (s : map VarT vs)
     instanceD (cxt[]) (return cls)
       (   valD (varP methodName)
                (normalB (varE idValName)) []
       : [ do stab <- computeOpticType s cons con
              let stab' = simplifyStab stab
              valD (varP (prismName conName))
                (normalB (makeConOpticExp stab' cons con)) []
           | con <- cons
           , let conName = view nconName con
           ]
       )
data NCon = NCon
  { _nconName :: Name
  , _nconCxt  :: Maybe Cxt
  , _nconTypes :: [Type]
  }
  deriving (Eq)
instance HasTypeVars NCon where
  typeVarsEx s f (NCon x y z) = NCon x <$> typeVarsEx s f y <*> typeVarsEx s f z
nconName :: Lens' NCon Name
nconName f x = fmap (\y -> x {_nconName = y}) (f (_nconName x))
nconCxt :: Lens' NCon (Maybe Cxt)
nconCxt f x = fmap (\y -> x {_nconCxt = y}) (f (_nconCxt x))
nconTypes :: Lens' NCon [Type]
nconTypes f x = fmap (\y -> x {_nconTypes = y}) (f (_nconTypes x))
normalizeCon :: Con -> NCon
normalizeCon (RecC    conName xs) = NCon conName Nothing (map (view _3) xs)
normalizeCon (NormalC conName xs) = NCon conName Nothing (map (view _2) xs)
normalizeCon (InfixC (_,x) conName (_,y)) = NCon conName Nothing [x,y]
normalizeCon (ForallC [] [] con) = normalizeCon con 
normalizeCon (ForallC _ cx con) = NCon n (cx1 <> cx2) tys
  where
  cx1 = Just cx
  NCon n cx2 tys = normalizeCon con
prismName :: Name -> Name
prismName n = case nameBase n of
                [] -> error "prismName: empty name base?"
                x:xs | isUpper x -> mkName ('_':x:xs)
                     | otherwise -> mkName ('.':x:xs) 
close :: Type -> TypeQ
close t = forallT (map PlainTV (Set.toList vs)) (cxt[]) (return t)
  where
  vs = setOf typeVars t