{-# LANGUAGE TemplateHaskell #-}
module Pinchot.SyntaxTree.Optics where

import Data.Coerce (coerce)
import Data.Foldable (toList)
import Data.Maybe (catMaybes)
import Data.Sequence (Seq)
import qualified Control.Lens as Lens
import qualified Language.Haskell.TH as T
import qualified Language.Haskell.TH.Syntax as Syntax

import Pinchot.Rules
import Pinchot.Types
import Pinchot.Intervals

-- | Creates optics declarations for a 'Rule', if optics can
-- be made for the 'Rule':
--
-- * 'Pinchot.terminal' gets a single 'Lens.Prism'
--
-- * 'Pinchot.nonTerminal' gets a 'Lens.Prism' for each constructor
--
-- * 'Pinchot.record' gets a single 'Lens.Lens'
--
-- * 'Pinchot.wrap', 'Pinchot.opt', 'Pinchot.star',
-- and 'Pinchot.plus' do not get optics.
--
-- Each rule in the sequence of 'Rule', as well as all ancestors of
-- those 'Rule's, will be handled.
--
-- Example: "Pinchot.Examples.RulesToOptics".
rulesToOptics
  :: Syntax.Lift t
  => Qualifier
  -- ^ Qualifier for module containing the data types that will get
  -- optics
  -> T.Name
  -- ^ Type name for the terminal
  -> Seq (Rule t)
  -> T.Q [T.Dec]
rulesToOptics qual termName
  = fmap concat
  . traverse (ruleToOptics qual termName)
  . families

-- | Creates optics declarations for a single 'Rule', if optics can
-- be made for the 'Rule':
--
-- * 'Terminal' gets a single 'Lens.Prism'
--
-- * 'NonTerminal' gets a 'Lens.Prism' for each constructor
--
-- * 'Terminals' gets a single 'Lens.Prism'
--
-- * 'Record' gets a single 'Lens.Lens'
--
-- * 'Wrap', 'Opt', 'Star', and 'Plus' do not get optics.
ruleToOptics
  :: Syntax.Lift t
  => Qualifier
  -- ^ Qualifier for module containing the data type that will get
  -- optics
  -> T.Name
  -- ^ Type name for the terminal
  -> Rule t
  -> T.Q [T.Dec]
ruleToOptics qual termName (Rule nm _ ty) = case ty of
  Terminal ivls -> terminalToOptics qual termName nm ivls
  NonTerminal b1 bs -> sequence $ nonTerminalToOptics qual nm b1 bs
  Record sq -> sequence $ recordsToOptics qual nm sq
  _ -> return []
  

-- | Creates a prism for a terminal type.  Although a newtype wraps
-- each terminal, do not make a Wrapped or an Iso, because the
-- relationship between the outer type and the type that it wraps
-- typically is not isometric.  Thus, use a Prism instead, which
-- captures this relationship properly.
terminalToOptics
  :: Syntax.Lift t
  => Qualifier
  -- ^ Qualifier for module containing the data type that will get
  -- optics
  -> T.Name
  -- ^ Terminal type name
  -> String
  -- ^ Rule name
  -> Intervals t
  -> T.Q [T.Dec]
terminalToOptics qual termName nm ivls = do
  e1 <- T.sigD (T.mkName ('_':nm))
    $ T.forallT [ T.PlainTV (T.mkName "a")] (return [])
    [t| Lens.Prism' ( $(T.conT termName), $(anyType) )
                    ($(T.conT (quald qual nm)) $(T.conT termName) $(anyType))
    |]
  
  e2 <- T.valD prismName (T.normalB expn) []
  return [e1, e2]
  where
    anyType = T.varT (T.mkName "a")
    charType = T.varT (T.mkName "t")
    prismName = T.varP (T.mkName ('_' : nm))
    fetchPat = T.conP (quald qual nm) [T.varP (T.mkName "_x")]
    fetchName = T.varE (T.mkName "_x")
    ctor = T.conE (quald qual nm)
    expn = [| let fetch $fetchPat = $fetchName
                  store (term, a)
                    | inIntervals ivls term = Just ($(ctor) (term, a))
                    | otherwise = Nothing
              in Lens.prism' fetch store
           |]

-- | Creates prisms for each 'Branch'.
nonTerminalToOptics
  :: Qualifier
  -- ^ Qualifier for module containing the data type that will get
  -- optics
  -> String
  -- ^ Rule name
  -> Branch t
  -> Seq (Branch t)
  -> [T.Q T.Dec]
nonTerminalToOptics qual nm b1 bsSeq
  = concat $ makePrism b1 : fmap makePrism bs
  where
    bs = toList bsSeq
    makePrism (Branch inner rulesSeq) = [ signature, binding ]
      where
        charType = T.varT (T.mkName "t")
        anyType = T.varT (T.mkName "a")
        rules = toList rulesSeq
        prismName = T.mkName ('_' : inner)
        signature = T.sigD prismName
          (forallA [t| Lens.Prism'
              ($(T.conT (quald qual nm)) $(charType) $(anyType))
              $(fieldsType) |])
          where
            fieldsType = case rules of
              [] -> T.tupleT 0
              Rule r1 _ _ : [] -> [t| $(T.conT (quald qual r1))
                $(charType) $(anyType) |]
              rs -> foldl addType (T.tupleT (length rs)) rs
                where
                  addType soFar (Rule r _ _) = soFar `T.appT`
                    [t| $(T.conT (quald qual r)) $(charType) $(anyType) |]
        binding = T.valD (T.varP prismName) body []
          where
            body = T.normalB
              $ (T.varE 'Lens.prism)
              `T.appE` setter
              `T.appE` getter
              where
                setter = T.lamE [pat] expn
                  where
                    (pat, expn) = case rules of
                      [] -> (T.tupP [], T.conE (quald qual inner))
                      _ : [] -> (T.varP local,
                        T.conE (quald qual inner)
                        `T.appE` T.varE local)
                        where
                          local = T.mkName "_x"
                      ls -> (T.tupP pats, set)
                        where
                          pats = fmap (\i -> T.varP (T.mkName ("_x" ++ show i)))
                            . take (length ls) $ [(0 :: Int) ..]
                          set = foldl addVar start . take (length ls)
                            $ [(0 :: Int) ..]
                            where
                              addVar acc i = acc `T.appE`
                                (T.varE (T.mkName ("_x" ++ show i)))
                              start = T.conE (quald qual inner)

                getter = T.lamE [pat] expn
                  where
                    local = T.mkName "_x"
                    pat = T.varP local
                    expn = T.caseE (T.varE (T.mkName "_x")) $
                      T.match patCtor bodyCtor []
                      : rest
                      where
                        patCtor = T.conP (quald qual inner)
                          . fmap (\i -> T.varP (T.mkName $ "_y" ++ show i))
                          . take (length rules)
                          $ [(0 :: Int) ..]
                        bodyCtor = T.normalB . (T.conE 'Right `T.appE`)
                          $ case rules of
                          [] -> T.tupE []
                          _:[] -> T.varE (T.mkName "_y0")
                          _ -> T.tupE
                            . fmap (\i -> T.varE (T.mkName $ "_y" ++ show i))
                            . take (length rules)
                            $ [(0 :: Int) ..]
                        rest = case bs of
                          [] -> []
                          _ -> [T.match patBlank bodyBlank []]
                          where
                            patBlank = T.varP (T.mkName "_z")
                            bodyBlank = T.normalB
                              $ T.conE ('Left)
                              `T.appE` T.varE (T.mkName "_z")

recordsToOptics
  :: Qualifier
  -- ^ Qualifier for module containing the data type that will get
  -- optics
  -> String
  -- ^ Rule name
  -> Seq (Rule t)
  -> [T.Q T.Dec]
recordsToOptics qual nm
  = concat . zipWith makeLens [(0 :: Int) ..] . toList
  where
    makeLens index (Rule inner _ _) = [ signature, function ]
      where
        charType = T.varT (T.mkName "t")
        anyType = T.varT (T.mkName "a")
        fieldNm = recordFieldName index nm inner
        lensName = T.mkName fieldNm
        signature = T.sigD lensName (forallA
          [t| Lens.Lens' ($(T.conT (quald qual nm)) $(charType) $(anyType))
                         ($(T.conT (quald qual inner)) $(charType) $(anyType))
          |])

        function = T.funD lensName [T.clause [] (T.normalB body) []]
          where
            namedRec = T.mkName "_namedRec"
            namedNewVal = T.mkName "_namedNewVal"
            body = (T.varE 'Lens.lens) `T.appE` getter `T.appE` setter
              where
                getter = T.lamE [pat] expn
                  where
                    pat = T.varP namedRec
                    expn = (T.varE (quald qual ('_' : fieldNm)))
                      `T.appE` (T.varE namedRec)

                setter = T.lamE [patRec, patNewVal] expn
                  where
                    patRec = T.varP namedRec
                    patNewVal = T.varP namedNewVal
                    expn = T.recUpdE (T.varE namedRec)
                      [ return ( quald qual ('_' : fieldNm)
                               , T.VarE namedNewVal) ]

forallA :: T.TypeQ -> T.TypeQ
forallA = T.forallT [ T.PlainTV (T.mkName "t")
                    , T.PlainTV (T.mkName "a")] (return [])