module Data.DeriveTraversal(
TraveralType(..), defaultTraversalType,
traversalDerivation1,
traversalInstance, traversalInstance1,
deriveTraversal
) where
import Language.Haskell.TH.All
import Data.List
import qualified Data.Set as S
import Control.Monad.Writer
import Control.Applicative
instance Monoid w => Applicative (Writer w) where
pure = return
(<*>) = ap
type Trav = Exp
data TraveralType = TraveralType
{ traversalArg :: Int
, traversalCo :: Bool
, traversalName :: String
, traversalId :: Trav
, traversalDirect :: Trav
, traversalFunc :: String -> Trav -> Trav
, traversalPlus :: Trav -> Trav -> Trav
, traverseArrow :: Trav -> Trav -> Trav
, traverseTuple :: [Exp] -> Exp
, traverseCtor :: String -> [Exp] -> Exp
, traverseFunc :: Pat -> Exp -> Clause
}
defaultTraversalType = TraveralType
{ traversalArg = 1
, traversalCo = False
, traversalName = undefined
, traversalId = id'
, traversalDirect = l0 "f"
, traversalFunc = l1
, traversalPlus = (.:)
, traverseArrow = fail "Cannot derive traversal over function types"
, traverseTuple = TupE
, traverseCtor = lK
, traverseFunc = undefined
}
data RequiredInstance = RequiredInstance
{ requiredDataArg :: Name
, requiredPosition :: Int
}
deriving (Eq, Ord)
type WithInstances a = Writer (S.Set RequiredInstance) a
traversalDerivation1 :: TraveralType -> String -> Derivation
traversalDerivation1 tt nm = derivation (traversalInstance1 tt nm) (className (traversalArg tt))
where className n = nm ++ (if n > 1 then show n else "")
traversalInstance1 :: TraveralType -> String -> DataDef -> [Dec]
traversalInstance1 tt nm dat = traversalInstance tt nm dat [deriveTraversal tt dat]
traversalInstance :: TraveralType -> String -> DataDef -> [WithInstances Dec] -> [Dec]
traversalInstance tt nameBase dat bodyM
| dataArity dat == 0 = []
| otherwise = [InstanceD ctx head body]
where
(body, required) = runWriter (sequence bodyM)
ctx = [ lK (className p) (VarT n : vars 's' (p 1))
| RequiredInstance n p <- S.toList required
]
vrs = vars 't' (dataArity dat)
(vrsBefore,(_:vrsAfter)) = splitAt (length vrs traversalArg tt) vrs
className n = nameBase ++ (if n > 1 then show n else "")
head = lK (className (traversalArg tt)) (lK (dataName dat) vrsBefore : vrsAfter)
deriveTraversal :: TraveralType -> DataDef -> WithInstances Dec
deriveTraversal tt dat = fun
where
fun = funN (traversalNameN tt (traversalArg tt)) <$> body
args = argPositions dat
body = mapM (deriveTraversalCtor tt args) (dataCtors dat)
deriveTraversalCtor :: TraveralType -> ArgPositions -> CtorDef -> WithInstances Clause
deriveTraversalCtor tt ap ctor = do
tTypes <- mapM (deriveTraversalType tt ap) (ctorTypes ctor)
return $ traverseFunc tt (ctp ctor 'a')
$ traverseCtor tt (ctorName ctor) (zipWith AppE tTypes (ctv ctor 'a'))
deriveTraversalType :: TraveralType -> ArgPositions -> Type -> WithInstances Trav
deriveTraversalType tt ap (ForallT _ _ _) = fail "forall not supported in traversal deriving"
deriveTraversalType tt ap (AppT (AppT ArrowT a) b)
= traverseArrow tt <$> deriveTraversalType tt{traversalCo = not $ traversalCo tt} ap a
<*> deriveTraversalType tt ap b
deriveTraversalType tt ap (AppT a b) = deriveTraversalApp tt ap a [b]
deriveTraversalType tt ap ListT = return $ traversalId tt
deriveTraversalType tt ap (ConT n) = return $ traversalId tt
deriveTraversalType tt ap (VarT n)
| ap n /= traversalArg tt = return $ traversalId tt
| traversalCo tt = fail "tyvar used in covariant position"
| otherwise = return $ traversalDirect tt
deriveTraversalApp :: TraveralType -> ArgPositions -> Type -> [Type] -> WithInstances Trav
deriveTraversalApp tt ap (AppT a b) args = deriveTraversalApp tt ap a (b : args)
deriveTraversalApp tt ap tycon args
| isTupleT tycon = do
tArgs <- mapM (deriveTraversalType tt ap) args
return $
if (all (== traversalId tt) tArgs) then
traversalId tt
else
LamE [TupP (vars 't' (length args))]
(traverseTuple tt $ zipWith AppE tArgs (vars 't' (length args)))
| otherwise = do
tCon <- deriveTraversalType tt ap tycon
tArgs <- mapM (deriveTraversalType tt ap) args
case tycon of
VarT n | ap n == traversalArg tt -> fail "kind error: type used type constructor"
| otherwise -> tell $ S.fromList
[ RequiredInstance n i
| (t,i) <- zip (reverse tArgs) [1..]
, t /= traversalId tt
]
_ -> return ()
let nonId = [ traverseArg tt i t
| (t,i) <- zip (reverse tArgs) [1..]
, t /= traversalId tt
]
return $ case nonId of
[] -> traversalId tt
_ -> foldl1 (traversalPlus tt) nonId
traverseArg :: TraveralType -> Int -> Trav -> Trav
traverseArg tt n e = traversalFunc tt (traversalNameN tt n) e
traversalNameN :: TraveralType -> Int -> String
traversalNameN tt n = traversalName tt ++ (if n > 1 then show n else "")
type ArgPositions = Name -> Int
argPositions :: DataDef -> Name -> Int
argPositions dat = \nm -> case elemIndex nm args of
Nothing -> error "impossible: tyvar not in scope"
Just k -> length args k
where args = ex_args dat