module Agda.Compiler.Treeless.Identity
  ( detectIdentityFunctions ) where

import Prelude hiding ((!!))  -- don't use partial functions

import Control.Applicative ( Alternative((<|>), empty) )
import Data.Semigroup
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List as List

import Agda.Syntax.Treeless
import Agda.TypeChecking.Monad

import Agda.Utils.List

import Agda.Utils.Impossible

detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions QName
q TTerm
t =
  case QName -> TTerm -> Maybe (Int, Int)
isIdentity QName
q TTerm
t of
    Maybe (Int, Int)
Nothing     -> TTerm -> TCM TTerm
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
    Just (Int
n, Int
k) -> do
      Bool -> QName -> TCM ()
markInline Bool
True QName
q
      Defn
def <- Definition -> Defn
theDef (Definition -> Defn) -> TCMT IO Definition -> TCMT IO Defn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> QName -> TCMT IO Definition
forall (m :: * -> *). HasConstInfo m => QName -> m Definition
getConstInfo QName
q
      TTerm -> TCM TTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> TCM TTerm) -> TTerm -> TCM TTerm
forall a b. (a -> b) -> a -> b
$ Int -> TTerm -> TTerm
mkTLam Int
n (TTerm -> TTerm) -> TTerm -> TTerm
forall a b. (a -> b) -> a -> b
$ Int -> TTerm
TVar Int
k

-- If isIdentity f t = Just (n, k) then
-- f = t is equivalent to f = λ xn₋₁ .. x₀ → xk
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity QName
q TTerm
t =
  QName -> TTerm -> Maybe (Int, Int)
trivialIdentity QName
q TTerm
t Maybe (Int, Int) -> Maybe (Int, Int) -> Maybe (Int, Int)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity QName
q TTerm
t

-- Does the function recurse on an argument, rebuilding the same value again.
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity QName
q TTerm
t =
  case TTerm
b of
    TCase Int
x CaseInfo
_ (TError TError
TUnreachable) [TAlt]
bs
      | (TAlt -> Bool) -> [TAlt] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> TAlt -> Bool
identityBranch Int
x) [TAlt]
bs -> (Int, Int) -> Maybe (Int, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
x)
    TTerm
_ -> Maybe (Int, Int)
forall (f :: * -> *) a. Alternative f => f a
empty -- TODO: lets?
  where
    (Int
n, TTerm
b) = TTerm -> (Int, TTerm)
tLamView TTerm
t

    identityBranch :: Int -> TAlt -> Bool
identityBranch Int
_ TALit{}   = Bool
False
    identityBranch Int
_ TAGuard{} = Bool
False
    identityBranch Int
x (TACon QName
c Int
a TTerm
b) =
      case TTerm
b of
        TApp (TCon QName
c') Args
args -> QName
c QName -> QName -> Bool
forall a. Eq a => a -> a -> Bool
== QName
c' Bool -> Bool -> Bool
&& Int -> Args -> Bool
identityArgs Int
a Args
args
        TVar Int
y              -> Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a  -- from @-pattern recovery
        TTerm
_                   -> Bool
False -- TODO: nested cases
      where
        identityArgs :: Int -> Args -> Bool
identityArgs Int
a Args
args =
          Args -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Args
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((TTerm -> Int -> Bool) -> Args -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TTerm -> Int -> Bool
match (Args -> Args
forall a. [a] -> [a]
reverse Args
args) [Int
0..])

        proj :: Int -> [a] -> a
proj Int
x [a]
args = a -> [a] -> Int -> a
forall a. a -> [a] -> Int -> a
indexWithDefault a
forall a. HasCallStack => a
__IMPOSSIBLE__ ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
args) Int
x

        match :: TTerm -> Int -> Bool
match TTerm
TErased              Int
_  = Bool
True
        match (TVar Int
z)             Int
y = Int
z Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y
        match (TApp (TDef QName
f) Args
args) Int
y = QName
f QName -> QName -> Bool
forall a. Eq a => a -> a -> Bool
== QName
q Bool -> Bool -> Bool
&& Args -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Args
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Bool -> Bool -> Bool
&& TTerm -> Int -> Bool
match (Int -> Args -> TTerm
forall {a}. Int -> [a] -> a
proj Int
x Args
args) Int
y
        match TTerm
_                    Int
_ = Bool
False

data IdentityIn = IdIn [Int]

notId :: IdentityIn
notId :: IdentityIn
notId = [Int] -> IdentityIn
IdIn []

instance Semigroup IdentityIn where
  IdIn [Int]
xs <> :: IdentityIn -> IdentityIn -> IdentityIn
<> IdIn [Int]
ys = [Int] -> IdentityIn
IdIn ([Int] -> IdentityIn) -> [Int] -> IdentityIn
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
List.intersect [Int]
xs [Int]
ys

-- Does the function always return one of its arguments unchanged (possibly
-- through recursive calls).
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity QName
q TTerm
t =
  case Int -> TTerm -> IdentityIn
go Int
0 TTerm
b of
    IdIn [Int
x]     -> (Int, Int) -> Maybe (Int, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
x)
    IdIn []      -> Maybe (Int, Int)
forall a. Maybe a
Nothing
    IdIn (Int
_:Int
_:[Int]
_) -> Maybe (Int, Int)
forall a. Maybe a
Nothing  -- only happens for empty functions (which will never be called)
  where
    (Int
n, TTerm
b) = TTerm -> (Int, TTerm)
tLamView TTerm
t

    go :: Int -> TTerm -> IdentityIn
    go :: Int -> TTerm -> IdentityIn
go Int
k TTerm
t =
      case TTerm
t of
        TVar Int
x | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
k    -> [Int] -> IdentityIn
IdIn [Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k]
               | Bool
otherwise -> IdentityIn
notId
        TLet TTerm
_ TTerm
b           -> Int -> TTerm -> IdentityIn
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) TTerm
b
        TCase Int
_ CaseInfo
_ TTerm
d [TAlt]
bs     -> NonEmpty IdentityIn -> IdentityIn
forall a. Semigroup a => NonEmpty a -> a
sconcat (Int -> TTerm -> IdentityIn
go Int
k TTerm
d IdentityIn -> [IdentityIn] -> NonEmpty IdentityIn
forall a. a -> [a] -> NonEmpty a
:| (TAlt -> IdentityIn) -> [TAlt] -> [IdentityIn]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> TAlt -> IdentityIn
goAlt Int
k) [TAlt]
bs)
        TApp (TDef QName
f) Args
args
          | QName
f QName -> QName -> Bool
forall a. Eq a => a -> a -> Bool
== QName
q         -> [Int] -> IdentityIn
IdIn [ Int
y | (TVar Int
x, Int
y) <- Args -> [Int] -> [(TTerm, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Args -> Args
forall a. [a] -> [a]
reverse Args
args) [Int
0..], Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x ]
        TCoerce TTerm
v          -> Int -> TTerm -> IdentityIn
go Int
k TTerm
v
        TApp{}             -> IdentityIn
notId
        TLam{}             -> IdentityIn
notId
        TLit{}             -> IdentityIn
notId
        TDef{}             -> IdentityIn
notId
        TCon{}             -> IdentityIn
notId
        TPrim{}            -> IdentityIn
notId
        TUnit{}            -> IdentityIn
notId
        TSort{}            -> IdentityIn
notId
        TErased{}          -> IdentityIn
notId
        TError{}           -> IdentityIn
notId

    goAlt :: Int -> TAlt -> IdentityIn
    goAlt :: Int -> TAlt -> IdentityIn
goAlt Int
k (TALit Literal
_ TTerm
b)   = Int -> TTerm -> IdentityIn
go Int
k TTerm
b
    goAlt Int
k (TAGuard TTerm
_ TTerm
b) = Int -> TTerm -> IdentityIn
go Int
k TTerm
b
    goAlt Int
k (TACon QName
_ Int
n TTerm
b) = Int -> TTerm -> IdentityIn
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) TTerm
b