{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.UnpackTuples
-- Description : unpacks and flattens tuples. / タプルを展開し平坦にします。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.UnpackTuple
  ( run,

    -- * internal rules
    rule,
  )
where

import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
  let return' :: a -> m (Maybe a)
return' = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> (a -> Maybe a) -> a -> m (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just
   in ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
_ -> \case
        App (Lam VarName
x (TupleTy [Type]
ts) Expr
body) Expr
e -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
          (Tuple' [Type]
ts', [Expr]
es) -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Type]
ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Type]
ts') (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError String
"the types of tuple don't match"
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Expr] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
es) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError String
"the sizes of tuple don't match"
            [VarName]
xs <- Int -> m VarName -> m [VarName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts) (VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x)
            Expr
body' <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((VarName -> Expr) -> [VarName] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map VarName -> Expr
Var [VarName]
xs)) Expr
body
            Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp ([(VarName, Type)] -> Expr -> Expr
curryLam ([VarName] -> [Type] -> [(VarName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VarName]
xs [Type]
ts) Expr
body') [Expr]
es
          (Expr, [Expr])
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
        App (Tuple' [Type
_]) (Proj' [Type
_] Int
0 Expr
e) -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' Expr
e
        Proj' [Type]
ts Int
i Expr
e -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
          (Tuple' [Type]
_, [Expr]
es) -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ [Expr]
es [Expr] -> Int -> Expr
forall a. [a] -> Int -> a
!! Int
i
          (Lit (LitBuiltin (If Type
_)), [Expr
e1, Expr
e2, Expr
e3]) -> Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr -> Expr
If' ([Type]
ts [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
i) Expr
e1 ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
i Expr
e2) ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
i Expr
e3)
          (Expr, [Expr])
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
        Foldl' Type
t2 (TupleTy [Type
t1]) (Lam VarName
x1 (TupleTy [Type
_]) (Lam VarName
x2 Type
_ Expr
body)) Expr
e Expr
es -> do
          Expr
body' <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x1 (Expr -> Expr -> Expr
App ([Type] -> Expr
Tuple' [Type
t1]) (VarName -> Expr
Var VarName
x1)) ([Type] -> Int -> Expr -> Expr
Proj' [Type
t1] Int
0 Expr
body)
          Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
App ([Type] -> Expr
Tuple' [Type
t1]) (Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
t2 Type
t1 (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x1 Type
t1 VarName
x2 Type
t2 Expr
body') ([Type] -> Int -> Expr -> Expr
Proj' [Type
t1] Int
0 Expr
e) Expr
es)
        Scanl' Type
t2 (TupleTy [Type
t1]) (Lam VarName
x1 Type
_ (Lam VarName
x2 (TupleTy [Type
_]) Expr
body)) Expr
e Expr
es -> do
          Expr
body' <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x1 (Expr -> Expr -> Expr
App ([Type] -> Expr
Tuple' [Type
t1]) (VarName -> Expr
Var VarName
x1)) ([Type] -> Int -> Expr -> Expr
Proj' [Type
t1] Int
0 Expr
body)
          let e' :: Expr
e' = Type -> Type -> Expr -> Expr -> Expr -> Expr
Scanl' Type
t2 Type
t1 (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x1 Type
t1 VarName
x2 Type
t2 Expr
body') ([Type] -> Int -> Expr -> Expr
Proj' [Type
t1] Int
0 Expr
e) Expr
es
          VarName
y <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
          let f :: Expr -> Expr
f = Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 ([Type] -> Type
TupleTy [Type
t1]) (VarName -> Type -> Expr -> Expr
Lam VarName
y Type
t1 (Expr -> Expr -> Expr
App ([Type] -> Expr
Tuple' [Type
t1]) (VarName -> Expr
Var VarName
y)))
          Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
f Expr
e'
        Iterate' (TupleTy [Type
t]) Expr
n (Lam VarName
x (TupleTy [Type
_]) Expr
body) Expr
base -> do
          Expr
body' <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (Expr -> Expr -> Expr
App ([Type] -> Expr
Tuple' [Type
t]) (VarName -> Expr
Var VarName
x)) ([Type] -> Int -> Expr -> Expr
Proj' [Type
t] Int
0 Expr
body)
          Expr -> m (Maybe Expr)
forall a. a -> m (Maybe a)
return' (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type
t]) [Type -> Expr -> Expr -> Expr -> Expr
Iterate' Type
t Expr
n (VarName -> Type -> Expr -> Expr
Lam VarName
x Type
t Expr
body') ([Type] -> Int -> Expr -> Expr
Proj' [Type
t] Int
0 Expr
base)]
        Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` removes unnecessary introductions and eliminations of tuples.
-- For example, this converts the following:
--
-- > (fun xs -> (proj0 xs) + (proj1 xs)) (tuple 2 1)
--
-- to the follwoing:
--
-- > (fun x0 x1 -> x0 + x1) 2 1
--
-- This can remove 1-tuples over higher-order functions.
-- For example, this converts the following:
--
-- > foldl (fun xs y -> tuple (proj0 xs + y) (tuple 0) [1, 2, 3]
--
-- to the follwoing:
--
-- > tuple (foldl (fun x y -> x + y) 0 [1, 2, 3])
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.UnpackTuple" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
Alpha.run Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog