module Agda.Compiler.Treeless.Identity
( detectIdentityFunctions ) where
import Prelude hiding ((!!))
import Control.Applicative ( Alternative((<|>), empty) )
import Data.Semigroup
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List as List
import Agda.Syntax.Treeless
import Agda.TypeChecking.Monad
import Agda.Utils.List
import Agda.Utils.Impossible
detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions QName
q TTerm
t =
case QName -> TTerm -> Maybe (Int, Int)
isIdentity QName
q TTerm
t of
Maybe (Int, Int)
Nothing -> TTerm -> TCM TTerm
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
Just (Int
n, Int
k) -> do
Bool -> QName -> TCM ()
markInline Bool
True QName
q
Defn
def <- Definition -> Defn
theDef (Definition -> Defn) -> TCMT IO Definition -> TCMT IO Defn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> QName -> TCMT IO Definition
forall (m :: * -> *). HasConstInfo m => QName -> m Definition
getConstInfo QName
q
TTerm -> TCM TTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> TCM TTerm) -> TTerm -> TCM TTerm
forall a b. (a -> b) -> a -> b
$ Int -> TTerm -> TTerm
mkTLam Int
n (TTerm -> TTerm) -> TTerm -> TTerm
forall a b. (a -> b) -> a -> b
$ Int -> TTerm
TVar Int
k
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity QName
q TTerm
t =
QName -> TTerm -> Maybe (Int, Int)
trivialIdentity QName
q TTerm
t Maybe (Int, Int) -> Maybe (Int, Int) -> Maybe (Int, Int)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity QName
q TTerm
t
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity QName
q TTerm
t =
case TTerm
b of
TCase Int
x CaseInfo
_ (TError TError
TUnreachable) [TAlt]
bs
| (TAlt -> Bool) -> [TAlt] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> TAlt -> Bool
identityBranch Int
x) [TAlt]
bs -> (Int, Int) -> Maybe (Int, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
x)
TTerm
_ -> Maybe (Int, Int)
forall (f :: * -> *) a. Alternative f => f a
empty
where
(Int
n, TTerm
b) = TTerm -> (Int, TTerm)
tLamView TTerm
t
identityBranch :: Int -> TAlt -> Bool
identityBranch Int
_ TALit{} = Bool
False
identityBranch Int
_ TAGuard{} = Bool
False
identityBranch Int
x (TACon QName
c Int
a TTerm
b) =
case TTerm
b of
TApp (TCon QName
c') Args
args -> QName
c QName -> QName -> Bool
forall a. Eq a => a -> a -> Bool
== QName
c' Bool -> Bool -> Bool
&& Int -> Args -> Bool
identityArgs Int
a Args
args
TVar Int
y -> Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a
TTerm
_ -> Bool
False
where
identityArgs :: Int -> Args -> Bool
identityArgs Int
a Args
args =
Args -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Args
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((TTerm -> Int -> Bool) -> Args -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TTerm -> Int -> Bool
match (Args -> Args
forall a. [a] -> [a]
reverse Args
args) [Int
0..])
proj :: Int -> [a] -> a
proj Int
x [a]
args = a -> [a] -> Int -> a
forall a. a -> [a] -> Int -> a
indexWithDefault a
forall a. HasCallStack => a
__IMPOSSIBLE__ ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
args) Int
x
match :: TTerm -> Int -> Bool
match TTerm
TErased Int
_ = Bool
True
match (TVar Int
z) Int
y = Int
z Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y
match (TApp (TDef QName
f) Args
args) Int
y = QName
f QName -> QName -> Bool
forall a. Eq a => a -> a -> Bool
== QName
q Bool -> Bool -> Bool
&& Args -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Args
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Bool -> Bool -> Bool
&& TTerm -> Int -> Bool
match (Int -> Args -> TTerm
forall {a}. Int -> [a] -> a
proj Int
x Args
args) Int
y
match TTerm
_ Int
_ = Bool
False
data IdentityIn = IdIn [Int]
notId :: IdentityIn
notId :: IdentityIn
notId = [Int] -> IdentityIn
IdIn []
instance Semigroup IdentityIn where
IdIn [Int]
xs <> :: IdentityIn -> IdentityIn -> IdentityIn
<> IdIn [Int]
ys = [Int] -> IdentityIn
IdIn ([Int] -> IdentityIn) -> [Int] -> IdentityIn
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
List.intersect [Int]
xs [Int]
ys
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity QName
q TTerm
t =
case Int -> TTerm -> IdentityIn
go Int
0 TTerm
b of
IdIn [Int
x] -> (Int, Int) -> Maybe (Int, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
x)
IdIn [] -> Maybe (Int, Int)
forall a. Maybe a
Nothing
IdIn (Int
_:Int
_:[Int]
_) -> Maybe (Int, Int)
forall a. Maybe a
Nothing
where
(Int
n, TTerm
b) = TTerm -> (Int, TTerm)
tLamView TTerm
t
go :: Int -> TTerm -> IdentityIn
go :: Int -> TTerm -> IdentityIn
go Int
k TTerm
t =
case TTerm
t of
TVar Int
x | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
k -> [Int] -> IdentityIn
IdIn [Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k]
| Bool
otherwise -> IdentityIn
notId
TLet TTerm
_ TTerm
b -> Int -> TTerm -> IdentityIn
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) TTerm
b
TCase Int
_ CaseInfo
_ TTerm
d [TAlt]
bs -> NonEmpty IdentityIn -> IdentityIn
forall a. Semigroup a => NonEmpty a -> a
sconcat (Int -> TTerm -> IdentityIn
go Int
k TTerm
d IdentityIn -> [IdentityIn] -> NonEmpty IdentityIn
forall a. a -> [a] -> NonEmpty a
:| (TAlt -> IdentityIn) -> [TAlt] -> [IdentityIn]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> TAlt -> IdentityIn
goAlt Int
k) [TAlt]
bs)
TApp (TDef QName
f) Args
args
| QName
f QName -> QName -> Bool
forall a. Eq a => a -> a -> Bool
== QName
q -> [Int] -> IdentityIn
IdIn [ Int
y | (TVar Int
x, Int
y) <- Args -> [Int] -> [(TTerm, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Args -> Args
forall a. [a] -> [a]
reverse Args
args) [Int
0..], Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x ]
TCoerce TTerm
v -> Int -> TTerm -> IdentityIn
go Int
k TTerm
v
TApp{} -> IdentityIn
notId
TLam{} -> IdentityIn
notId
TLit{} -> IdentityIn
notId
TDef{} -> IdentityIn
notId
TCon{} -> IdentityIn
notId
TPrim{} -> IdentityIn
notId
TUnit{} -> IdentityIn
notId
TSort{} -> IdentityIn
notId
TErased{} -> IdentityIn
notId
TError{} -> IdentityIn
notId
goAlt :: Int -> TAlt -> IdentityIn
goAlt :: Int -> TAlt -> IdentityIn
goAlt Int
k (TALit Literal
_ TTerm
b) = Int -> TTerm -> IdentityIn
go Int
k TTerm
b
goAlt Int
k (TAGuard TTerm
_ TTerm
b) = Int -> TTerm -> IdentityIn
go Int
k TTerm
b
goAlt Int
k (TACon QName
_ Int
n TTerm
b) = Int -> TTerm -> IdentityIn
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) TTerm
b