{-# LANGUAGE UndecidableInstances, LambdaCase, ParallelListComp, ViewPatterns #-}
module Data.Syntax where

import Definitive
import qualified Prelude as P
import Language.Syntax.Regex

type Env f = Map String (ThunkT f)
type ThunkT f = f (SyntaxT f)
data SyntaxT f = ValList [ThunkT f]
              | Dictionary (Env f)
              | Text String
              | Quote (SyntaxT f)
              | Function (ThunkT f -> ThunkT f)
instance Show (ThunkT f) => Show (SyntaxT f) where
  show (ValList l) = show l
  show (Dictionary d) = "{"+show (toList (map show d^.keyed))+"}"
  show (Text t) = show t
  show (Quote s) = "'"+show s
  show (Function _) = "<fun>"

dict :: Traversal' (SyntaxT f) (Env f)
dict = prism f g
  where f (Dictionary d) = Right d
        f c = Left c
        g (Dictionary _) d = Dictionary d
        g x _ = x

nil :: SyntaxT f
nil = ValList zero
variable :: Unit f => String -> SyntaxT f -> SyntaxT f
variable n v = Dictionary (fromList [("name",pure $ Text n),("value",pure v)])
funcall :: ThunkT f -> ThunkT f -> SyntaxT f
funcall f x = ValList [f,x]
builtin :: Unit m => (ThunkT m -> ThunkT m) -> ThunkT m
builtin = pure . Function
builtin2 :: Unit m => (ThunkT m -> ThunkT m -> ThunkT m) -> ThunkT m
builtin2 = builtin . map builtin
builtin3 :: Unit m => (ThunkT m -> ThunkT m -> ThunkT m -> ThunkT m) -> ThunkT m
builtin3 = builtin . map builtin2

shape :: SyntaxT f -> String
shape (ValList []) = "Nil"
shape (ValList _) = "ValList"
shape (Text _) = "Text"
shape (Dictionary _) = "Dictionary"
shape (Quote _) = "Quote"
shape (Function _) = "Function"

reduce :: MonadReader (Env m) m => SyntaxT m -> ThunkT m
reduce (ValList (map (>>= reduce) -> (fun:args))) = fun >>= \f -> foldlM call f args
  where call (Function f) x = f x
        call _ _ = error "Invalid function call"
reduce (Dictionary d) = pure $ Dictionary $ fix (\d' -> map (local (d'+) . (>>= reduce)) d)
reduce (Quote s) = pure s
reduce a = pure a

lambda :: MonadReader (Env m) m => SyntaxT Id -> SyntaxT m -> (ThunkT m -> ThunkT m)
lambda pat e = tryAlt
  where tryAlt x = x >>= match >>= maybe (pure nil) bind
          where bind vars = local (compose (_insert<$>c'list vars)) (reduce e)
                  where _insert (s,v) = insert s (pure v)
                match = matchPat pat
matchPat :: Monad f => SyntaxT Id -> (SyntaxT f -> f (Maybe [(String,SyntaxT f)]))
matchPat (Text re) = pure.matchText
  where matchText (Text t) | ((a,wh):_) <- match t = pure $ map2 Text (("&",wh):a)
        matchText _ = zero
        match = runRegex re
matchPat (Dictionary d) = matchDict
  where matchDict (Dictionary d') = 
          traverse (matches d') (toList pats) <&> map concat.sequence
        matchDict _ = pure zero
        pats = (matchPat.yb i'Id<$>d)^.keyed
        matches d' (k,m) = maybe (pure zero) (>>= m) (d'^.at k)
matchPat (ValList l) = matchList
  where n = length l
        matchList (ValList l') | length (take n l') == n =
          sequence [matchPat p =<< e | p <- yb i'Id<$>l | e <- l'] <&> map concat.sequence
        matchList _ = pure zero
matchPat _ = pure (pure zero)

lambdaSum :: Monad m => [ThunkT m -> ThunkT m] -> ThunkT m -> ThunkT m
lambdaSum = foldr combine (const (pure nil))
  where combine f g = \v -> f v >>= \case
          ValList [] -> g v
          x -> pure x