{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE FlexibleContexts #-}
module Lambdabot.Plugin.Haskell.Pl.Transform (
    transform,
  ) where

import Lambdabot.Plugin.Haskell.Pl.Common

import qualified Data.Map as M

import Data.Graph (stronglyConnComp, flattenSCC, flattenSCCs)
import Control.Monad.State

-- | Does a name occur in a pattern?
occursP :: String -> Pattern -> Bool
occursP :: String -> Pattern -> Bool
occursP String
v (PVar String
v') = String
v String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v'
occursP String
v (PTuple Pattern
p1 Pattern
p2) = String
v String -> Pattern -> Bool
`occursP` Pattern
p1 Bool -> Bool -> Bool
|| String
v String -> Pattern -> Bool
`occursP` Pattern
p2
occursP String
v (PCons  Pattern
p1 Pattern
p2) = String
v String -> Pattern -> Bool
`occursP` Pattern
p1 Bool -> Bool -> Bool
|| String
v String -> Pattern -> Bool
`occursP` Pattern
p2

-- | How often does the given name occur free in an expression?
freeIn :: String -> Expr -> Int
freeIn :: String -> Expr -> Int
freeIn String
v (Var Fixity
_ String
v') = Bool -> Int
forall a. Enum a => a -> Int
fromEnum (Bool -> Int) -> Bool -> Int
forall a b. (a -> b) -> a -> b
$ String
v String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v'
freeIn String
v (Lambda Pattern
pat Expr
e) = if String
v String -> Pattern -> Bool
`occursP` Pattern
pat then Int
0 else String -> Expr -> Int
freeIn String
v Expr
e
freeIn String
v (App Expr
e1 Expr
e2) = String -> Expr -> Int
freeIn String
v Expr
e1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ String -> Expr -> Int
freeIn String
v Expr
e2
freeIn String
v (Let [Decl]
ds Expr
e') = if String
v String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Decl -> String) -> [Decl] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Decl -> String
declName [Decl]
ds then Int
0
  else String -> Expr -> Int
freeIn String
v Expr
e' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [String -> Expr -> Int
freeIn String
v Expr
e | Define String
_ Expr
e <- [Decl]
ds]

-- | Does a name occur free in an expression?
isFreeIn :: String -> Expr -> Bool
isFreeIn :: String -> Expr -> Bool
isFreeIn String
v Expr
e = String -> Expr -> Int
freeIn String
v Expr
e Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0

tuple :: [Expr] -> Expr
tuple :: [Expr] -> Expr
tuple [Expr]
es  = (Expr -> Expr -> Expr) -> [Expr] -> Expr
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (\Expr
x Expr
y -> Fixity -> String -> Expr
Var Fixity
Inf String
"," Expr -> Expr -> Expr
`App` Expr
x Expr -> Expr -> Expr
`App` Expr
y) [Expr]
es

tupleP :: [String] -> Pattern
tupleP :: [String] -> Pattern
tupleP [String]
vs = (Pattern -> Pattern -> Pattern) -> [Pattern] -> Pattern
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Pattern -> Pattern -> Pattern
PTuple ([Pattern] -> Pattern) -> [Pattern] -> Pattern
forall a b. (a -> b) -> a -> b
$ String -> Pattern
PVar (String -> Pattern) -> [String] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
`map` [String]
vs

-- | The subset of ds that d depends on
dependsOn :: [Decl] -> Decl -> [Decl]
dependsOn :: [Decl] -> Decl -> [Decl]
dependsOn [Decl]
ds Decl
d = [Decl
d' | Decl
d' <- [Decl]
ds, Decl -> String
declName Decl
d' String -> Expr -> Bool
`isFreeIn` Decl -> Expr
declExpr Decl
d]

-- | Convert recursive lets to lambdas with tuple patterns and fix calls
unLet :: Expr -> Expr
unLet :: Expr -> Expr
unLet (App Expr
e1 Expr
e2) = Expr -> Expr -> Expr
App (Expr -> Expr
unLet Expr
e1) (Expr -> Expr
unLet Expr
e2)
unLet (Let [] Expr
e) = Expr -> Expr
unLet Expr
e
unLet (Let [Decl]
ds Expr
e) = Expr -> Expr
unLet (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
  (Pattern -> Expr -> Expr
Lambda ([String] -> Pattern
tupleP ([String] -> Pattern) -> [String] -> Pattern
forall a b. (a -> b) -> a -> b
$ Decl -> String
declName (Decl -> String) -> [Decl] -> [String]
forall a b. (a -> b) -> [a] -> [b]
`map` [Decl]
dsYes) ([Decl] -> Expr -> Expr
Let [Decl]
dsNo Expr
e)) Expr -> Expr -> Expr
`App`
    (Expr
fix' Expr -> Expr -> Expr
`App` (Pattern -> Expr -> Expr
Lambda ([String] -> Pattern
tupleP ([String] -> Pattern) -> [String] -> Pattern
forall a b. (a -> b) -> a -> b
$ Decl -> String
declName (Decl -> String) -> [Decl] -> [String]
forall a b. (a -> b) -> [a] -> [b]
`map` [Decl]
dsYes)
                        ([Expr] -> Expr
tuple  ([Expr] -> Expr) -> [Expr] -> Expr
forall a b. (a -> b) -> a -> b
$ Decl -> Expr
declExpr (Decl -> Expr) -> [Decl] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
`map` [Decl]
dsYes)))
    where
  comps :: [SCC Decl]
comps = [(Decl, Decl, [Decl])] -> [SCC Decl]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp [(Decl
d',Decl
d',[Decl] -> Decl -> [Decl]
dependsOn [Decl]
ds Decl
d') | Decl
d' <- [Decl]
ds]
  dsYes :: [Decl]
dsYes = SCC Decl -> [Decl]
forall vertex. SCC vertex -> [vertex]
flattenSCC (SCC Decl -> [Decl]) -> SCC Decl -> [Decl]
forall a b. (a -> b) -> a -> b
$ [SCC Decl] -> SCC Decl
forall a. [a] -> a
head [SCC Decl]
comps
  dsNo :: [Decl]
dsNo = [SCC Decl] -> [Decl]
forall a. [SCC a] -> [a]
flattenSCCs ([SCC Decl] -> [Decl]) -> [SCC Decl] -> [Decl]
forall a b. (a -> b) -> a -> b
$ [SCC Decl] -> [SCC Decl]
forall a. [a] -> [a]
tail [SCC Decl]
comps

unLet (Lambda Pattern
v Expr
e) = Pattern -> Expr -> Expr
Lambda Pattern
v (Expr -> Expr
unLet Expr
e)
unLet (Var Fixity
f String
x) = Fixity -> String -> Expr
Var Fixity
f String
x

type Env = (M.Map String String, Int)
-- note: The second component is the environment size, counting duplicate
-- variables.

-- | Rename all variables to (locally) unique fresh ones
--
-- It's a pity we still need that for the pointless transformation.
-- Otherwise a newly created id/const/... could be bound by a lambda
-- e.g. transform' (\id x -> x) ==> transform' (\id -> id) ==> id
alphaRename :: Expr -> Expr
alphaRename :: Expr -> Expr
alphaRename Expr
e = Expr -> State Env Expr
alpha Expr
e State Env Expr -> Env -> Expr
forall s a. State s a -> s -> a
`evalState` (Map String String
forall k a. Map k a
M.empty, Int
0) where
  alpha :: Expr -> State Env Expr
  alpha :: Expr -> State Env Expr
alpha (Var Fixity
f String
v)     = do (Map String String
fm, Int
_) <- StateT Env Identity Env
forall s (m :: * -> *). MonadState s m => m s
get; Expr -> State Env Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> State Env Expr) -> Expr -> State Env Expr
forall a b. (a -> b) -> a -> b
$ Fixity -> String -> Expr
Var Fixity
f (String -> Expr) -> String -> Expr
forall a b. (a -> b) -> a -> b
$ String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
v String -> String
forall a. a -> a
id (String -> Map String String -> Maybe String
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
v Map String String
fm)
  alpha (App Expr
e1 Expr
e2)   = (Expr -> Expr -> Expr)
-> State Env Expr -> State Env Expr -> State Env Expr
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Expr -> Expr -> Expr
App (Expr -> State Env Expr
alpha Expr
e1) (Expr -> State Env Expr
alpha Expr
e2)
  alpha (Let [Decl]
_ Expr
_)     = Bool -> State Env Expr -> State Env Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False State Env Expr
forall a. (?callStack::CallStack) => a
undefined
  alpha (Lambda Pattern
v Expr
e') = State Env Expr -> State Env Expr
forall s a. State s a -> State s a
inEnv (State Env Expr -> State Env Expr)
-> State Env Expr -> State Env Expr
forall a b. (a -> b) -> a -> b
$ (Pattern -> Expr -> Expr)
-> StateT Env Identity Pattern -> State Env Expr -> State Env Expr
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Pattern -> Expr -> Expr
Lambda (Pattern -> StateT Env Identity Pattern
forall (m :: * -> *) b.
(MonadState (Map String String, b) m, Num b, Show b) =>
Pattern -> m Pattern
alphaPat Pattern
v) (Expr -> State Env Expr
alpha Expr
e')

  -- act like a reader monad
  inEnv :: State s a -> State s a
  inEnv :: State s a -> State s a
inEnv State s a
f = (s -> (a, s)) -> State s a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((s -> (a, s)) -> State s a) -> (s -> (a, s)) -> State s a
forall a b. (a -> b) -> a -> b
$ \s
s -> ((a, s) -> a
forall a b. (a, b) -> a
fst ((a, s) -> a) -> (a, s) -> a
forall a b. (a -> b) -> a -> b
$ State s a -> s -> (a, s)
forall s a. State s a -> s -> (a, s)
runState State s a
f s
s, s
s)

  alphaPat :: Pattern -> m Pattern
alphaPat (PVar String
v) = do
    (Map String String
fm, b
i) <- m (Map String String, b)
forall s (m :: * -> *). MonadState s m => m s
get
    let v' :: String
v' = String
"$" String -> String -> String
forall a. [a] -> [a] -> [a]
++ b -> String
forall a. Show a => a -> String
show b
i
    (Map String String, b) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (String -> String -> Map String String -> Map String String
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert String
v String
v' Map String String
fm, b
ib -> b -> b
forall a. Num a => a -> a -> a
+b
1)
    Pattern -> m Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> m Pattern) -> Pattern -> m Pattern
forall a b. (a -> b) -> a -> b
$ String -> Pattern
PVar String
v'
  alphaPat (PTuple Pattern
p1 Pattern
p2) = (Pattern -> Pattern -> Pattern)
-> m Pattern -> m Pattern -> m Pattern
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Pattern -> Pattern -> Pattern
PTuple (Pattern -> m Pattern
alphaPat Pattern
p1) (Pattern -> m Pattern
alphaPat Pattern
p2)
  alphaPat (PCons  Pattern
p1 Pattern
p2) = (Pattern -> Pattern -> Pattern)
-> m Pattern -> m Pattern -> m Pattern
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Pattern -> Pattern -> Pattern
PCons  (Pattern -> m Pattern
alphaPat Pattern
p1) (Pattern -> m Pattern
alphaPat Pattern
p2)

-- | Make an expression points free
transform :: Expr -> Expr
transform :: Expr -> Expr
transform = Expr -> Expr
transform' (Expr -> Expr) -> (Expr -> Expr) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
alphaRename (Expr -> Expr) -> (Expr -> Expr) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
unLet

-- | Transform patterns to:
--     fst/snd for tuple patterns
--     head/tail for cons patterns
--     id/const/flip/. for variable paterns
transform' :: Expr -> Expr
transform' :: Expr -> Expr
transform' (Let {}) = Bool -> Expr -> Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False Expr
forall a. (?callStack::CallStack) => a
undefined
transform' (Var Fixity
f String
v) = Fixity -> String -> Expr
Var Fixity
f String
v
transform' (App Expr
e1 Expr
e2) = Expr -> Expr -> Expr
App (Expr -> Expr
transform' Expr
e1) (Expr -> Expr
transform' Expr
e2)
transform' (Lambda (PTuple Pattern
p1 Pattern
p2) Expr
e)
  = Expr -> Expr
transform' (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Pattern -> Expr -> Expr
Lambda (String -> Pattern
PVar String
"z") (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
      (Pattern -> Expr -> Expr
Lambda Pattern
p1 (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Pattern -> Expr -> Expr
Lambda Pattern
p2 (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr
e) Expr -> Expr -> Expr
`App` Expr
f Expr -> Expr -> Expr
`App` Expr
s where
    f :: Expr
f = Fixity -> String -> Expr
Var Fixity
Pref String
"fst" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
"z"
    s :: Expr
s = Fixity -> String -> Expr
Var Fixity
Pref String
"snd" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
"z"
transform' (Lambda (PCons Pattern
p1 Pattern
p2) Expr
e)
  = Expr -> Expr
transform' (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Pattern -> Expr -> Expr
Lambda (String -> Pattern
PVar String
"z") (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
      (Pattern -> Expr -> Expr
Lambda Pattern
p1 (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Pattern -> Expr -> Expr
Lambda Pattern
p2 (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr
e) Expr -> Expr -> Expr
`App` Expr
f Expr -> Expr -> Expr
`App` Expr
s where
    f :: Expr
f = Fixity -> String -> Expr
Var Fixity
Pref String
"head" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
"z"
    s :: Expr
s = Fixity -> String -> Expr
Var Fixity
Pref String
"tail" Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
Pref String
"z"
transform' (Lambda (PVar String
v) Expr
e) = Expr -> Expr
transform' (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
getRidOfV Expr
e where
  getRidOfV :: Expr -> Expr
getRidOfV (Var Fixity
f String
v') | String
v String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v'   = Expr
id'
                       | Bool
otherwise = Expr
const' Expr -> Expr -> Expr
`App` Fixity -> String -> Expr
Var Fixity
f String
v'
  getRidOfV l :: Expr
l@(Lambda Pattern
pat Expr
_) = Bool -> Expr -> Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ String
v String -> Pattern -> Bool
`occursP` Pattern
pat) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
    Expr -> Expr
getRidOfV (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
transform' Expr
l
  getRidOfV (Let {}) = Bool -> Expr -> Expr
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False Expr
forall a. a
bt
  getRidOfV e' :: Expr
e'@(App Expr
e1 Expr
e2)
    | Bool
fr1 Bool -> Bool -> Bool
&& Bool
fr2 = Expr
scomb Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e1 Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e2
    | Bool
fr1 = Expr
flip' Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e1 Expr -> Expr -> Expr
`App` Expr
e2
    | Var Fixity
_ String
v' <- Expr
e2, String
v' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v = Expr
e1
    | Bool
fr2 = Expr
comp Expr -> Expr -> Expr
`App` Expr
e1 Expr -> Expr -> Expr
`App` Expr -> Expr
getRidOfV Expr
e2
    | Bool
True = Expr
const' Expr -> Expr -> Expr
`App` Expr
e'
    where
      fr1 :: Bool
fr1 = String
v String -> Expr -> Bool
`isFreeIn` Expr
e1
      fr2 :: Bool
fr2 = String
v String -> Expr -> Bool
`isFreeIn` Expr
e2