module Agda.Compiler.Treeless.Identity
( detectIdentityFunctions ) where
import Prelude hiding ((!!))
import Control.Applicative ( Alternative((<|>), empty) )
import Data.Semigroup
import qualified Data.List as List
import Agda.Syntax.Treeless
import Agda.TypeChecking.Monad
import Agda.Utils.List
import Agda.Utils.List1 (pattern (:|))
import Agda.Utils.Impossible
detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions QName
q TTerm
t =
case QName -> TTerm -> Maybe (Int, Int)
isIdentity QName
q TTerm
t of
Maybe (Int, Int)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
Just (Int
n, Int
k) -> do
Bool -> QName -> TCM ()
markInline Bool
True QName
q
Defn
def <- Definition -> Defn
theDef forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). HasConstInfo m => QName -> m Definition
getConstInfo QName
q
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Int -> TTerm -> TTerm
mkTLam Int
n forall a b. (a -> b) -> a -> b
$ Int -> TTerm
TVar Int
k
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity QName
q TTerm
t =
QName -> TTerm -> Maybe (Int, Int)
trivialIdentity QName
q TTerm
t forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity QName
q TTerm
t
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity QName
q TTerm
t =
case TTerm
b of
TCase Int
x CaseInfo
_ (TError TError
TUnreachable) [TAlt]
bs
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> TAlt -> Bool
identityBranch Int
x) [TAlt]
bs -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
x)
TTerm
_ -> forall (f :: * -> *) a. Alternative f => f a
empty
where
(Int
n, TTerm
b) = TTerm -> (Int, TTerm)
tLamView TTerm
t
identityBranch :: Int -> TAlt -> Bool
identityBranch Int
_ TALit{} = Bool
False
identityBranch Int
_ TAGuard{} = Bool
False
identityBranch Int
x (TACon QName
c Int
a TTerm
b) =
case TTerm
b of
TApp (TCon QName
c') Args
args -> QName
c forall a. Eq a => a -> a -> Bool
== QName
c' Bool -> Bool -> Bool
&& Int -> Args -> Bool
identityArgs Int
a Args
args
TVar Int
y -> Int
y forall a. Eq a => a -> a -> Bool
== Int
x forall a. Num a => a -> a -> a
+ Int
a
TTerm
_ -> Bool
False
where
identityArgs :: Int -> Args -> Bool
identityArgs Int
a Args
args =
forall (t :: * -> *) a. Foldable t => t a -> Int
length Args
args forall a. Eq a => a -> a -> Bool
== Int
a Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TTerm -> Int -> Bool
match (forall a. [a] -> [a]
reverse Args
args) [Int
0..])
proj :: Int -> [a] -> a
proj Int
x [a]
args = forall a. a -> [a] -> Int -> a
indexWithDefault forall a. HasCallStack => a
__IMPOSSIBLE__ (forall a. [a] -> [a]
reverse [a]
args) Int
x
match :: TTerm -> Int -> Bool
match TTerm
TErased Int
_ = Bool
True
match (TVar Int
z) Int
y = Int
z forall a. Eq a => a -> a -> Bool
== Int
y
match (TApp (TDef QName
f) Args
args) Int
y = QName
f forall a. Eq a => a -> a -> Bool
== QName
q Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length Args
args forall a. Eq a => a -> a -> Bool
== Int
n Bool -> Bool -> Bool
&& TTerm -> Int -> Bool
match (forall {a}. Int -> [a] -> a
proj Int
x Args
args) Int
y
match TTerm
_ Int
_ = Bool
False
data IdentityIn = IdIn [Int]
notId :: IdentityIn
notId :: IdentityIn
notId = [Int] -> IdentityIn
IdIn []
instance Semigroup IdentityIn where
IdIn [Int]
xs <> :: IdentityIn -> IdentityIn -> IdentityIn
<> IdIn [Int]
ys = [Int] -> IdentityIn
IdIn forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> [a] -> [a]
List.intersect [Int]
xs [Int]
ys
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity QName
q TTerm
t =
case Int -> TTerm -> IdentityIn
go Int
0 TTerm
b of
IdIn [Int
x] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
n, Int
x)
IdIn [] -> forall a. Maybe a
Nothing
IdIn (Int
_:Int
_:[Int]
_) -> forall a. Maybe a
Nothing
where
(Int
n, TTerm
b) = TTerm -> (Int, TTerm)
tLamView TTerm
t
go :: Int -> TTerm -> IdentityIn
go :: Int -> TTerm -> IdentityIn
go Int
k TTerm
t =
case TTerm
t of
TVar Int
x | Int
x forall a. Ord a => a -> a -> Bool
>= Int
k -> [Int] -> IdentityIn
IdIn [Int
x forall a. Num a => a -> a -> a
- Int
k]
| Bool
otherwise -> IdentityIn
notId
TLet TTerm
_ TTerm
b -> Int -> TTerm -> IdentityIn
go (Int
k forall a. Num a => a -> a -> a
+ Int
1) TTerm
b
TCase Int
_ CaseInfo
_ TTerm
d [TAlt]
bs -> forall a. Semigroup a => NonEmpty a -> a
sconcat (Int -> TTerm -> IdentityIn
go Int
k TTerm
d forall a. a -> [a] -> NonEmpty a
:| forall a b. (a -> b) -> [a] -> [b]
map (Int -> TAlt -> IdentityIn
goAlt Int
k) [TAlt]
bs)
TApp (TDef QName
f) Args
args
| QName
f forall a. Eq a => a -> a -> Bool
== QName
q -> [Int] -> IdentityIn
IdIn [ Int
y | (TVar Int
x, Int
y) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [a]
reverse Args
args) [Int
0..], Int
y forall a. Num a => a -> a -> a
+ Int
k forall a. Eq a => a -> a -> Bool
== Int
x ]
TCoerce TTerm
v -> Int -> TTerm -> IdentityIn
go Int
k TTerm
v
TApp{} -> IdentityIn
notId
TLam{} -> IdentityIn
notId
TLit{} -> IdentityIn
notId
TDef{} -> IdentityIn
notId
TCon{} -> IdentityIn
notId
TPrim{} -> IdentityIn
notId
TUnit{} -> IdentityIn
notId
TSort{} -> IdentityIn
notId
TErased{} -> IdentityIn
notId
TError{} -> IdentityIn
notId
goAlt :: Int -> TAlt -> IdentityIn
goAlt :: Int -> TAlt -> IdentityIn
goAlt Int
k (TALit Literal
_ TTerm
b) = Int -> TTerm -> IdentityIn
go Int
k TTerm
b
goAlt Int
k (TAGuard TTerm
_ TTerm
b) = Int -> TTerm -> IdentityIn
go Int
k TTerm
b
goAlt Int
k (TACon QName
_ Int
n TTerm
b) = Int -> TTerm -> IdentityIn
go (Int
k forall a. Num a => a -> a -> a
+ Int
n) TTerm
b