{-# OPTIONS_GHC -Wunused-imports #-}

module Agda.Compiler.MAlonzo.Coerce (addCoercions, erasedArity) where

import Agda.Syntax.Common (Nat)
import Agda.Syntax.Treeless

import Agda.TypeChecking.Monad
  ( HasConstInfo
  , getErasedConArgs
  , getTreeless
  )

-- | Insert unsafeCoerce (in the form of 'TCoerce') everywhere it's needed in
--   the right-hand side of a definition.
addCoercions :: HasConstInfo m => TTerm -> m TTerm
addCoercions :: forall (m :: * -> *). HasConstInfo m => TTerm -> m TTerm
addCoercions = TTerm -> m TTerm
forall (m :: * -> *). HasConstInfo m => TTerm -> m TTerm
coerceTop
  where
    -- Don't coerce top-level lambdas.
    coerceTop :: TTerm -> f TTerm
coerceTop (TLam TTerm
b) = TTerm -> TTerm
TLam (TTerm -> TTerm) -> f TTerm -> f TTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> f TTerm
coerceTop TTerm
b
    coerceTop TTerm
t        = TTerm -> f TTerm
forall (m :: * -> *). HasConstInfo m => TTerm -> m TTerm
coerce TTerm
t

    -- Coerce a term `t`. The result (when translated to Haskell) has type
    -- `forall a. a`.
    coerce :: TTerm -> m TTerm
coerce TTerm
t =
      case TTerm
t of
        TVar{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TPrim{}   -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TDef{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TCon{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TLit{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TUnit{}   -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TSort{}   -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (TTerm -> m TTerm) -> TTerm -> m TTerm
forall a b. (a -> b) -> a -> b
$ TTerm -> TTerm
TCoerce TTerm
t
        TErased{} -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TCoerce{} -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TError{}  -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TApp TTerm
f Args
vs -> do
          ar <- TTerm -> m Int
forall (m :: * -> *). HasConstInfo m => TTerm -> m Int
funArity TTerm
f
          if length vs > ar
            then TApp (TCoerce f) <$> mapM softCoerce vs
            else TCoerce . TApp f <$> mapM coerce vs
        TLam TTerm
b         -> TTerm -> TTerm
TCoerce (TTerm -> TTerm) -> (TTerm -> TTerm) -> TTerm -> TTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TTerm -> TTerm
TLam (TTerm -> TTerm) -> m TTerm -> m TTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
softCoerce TTerm
b
        TLet TTerm
e TTerm
b       -> TTerm -> TTerm -> TTerm
TLet (TTerm -> TTerm -> TTerm) -> m TTerm -> m (TTerm -> TTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
softCoerce TTerm
e m (TTerm -> TTerm) -> m TTerm -> m TTerm
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TTerm -> m TTerm
coerce TTerm
b
        TCase Int
x CaseInfo
t TTerm
d [TAlt]
bs -> Int -> CaseInfo -> TTerm -> [TAlt] -> TTerm
TCase Int
x CaseInfo
t (TTerm -> [TAlt] -> TTerm) -> m TTerm -> m ([TAlt] -> TTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
coerce TTerm
d m ([TAlt] -> TTerm) -> m [TAlt] -> m TTerm
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (TAlt -> m TAlt) -> [TAlt] -> m [TAlt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TAlt -> m TAlt
coerceAlt [TAlt]
bs

    coerceAlt :: TAlt -> m TAlt
coerceAlt (TACon QName
c Int
a TTerm
b) = QName -> Int -> TTerm -> TAlt
TACon QName
c Int
a (TTerm -> TAlt) -> m TTerm -> m TAlt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
coerce TTerm
b
    coerceAlt (TAGuard TTerm
g TTerm
b) = TTerm -> TTerm -> TAlt
TAGuard   (TTerm -> TTerm -> TAlt) -> m TTerm -> m (TTerm -> TAlt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
coerce TTerm
g m (TTerm -> TAlt) -> m TTerm -> m TAlt
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TTerm -> m TTerm
coerce TTerm
b
    coerceAlt (TALit Literal
l TTerm
b)   = Literal -> TTerm -> TAlt
TALit Literal
l   (TTerm -> TAlt) -> m TTerm -> m TAlt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
coerce TTerm
b

    -- Insert TCoerce in subterms. When translated to Haskell, the resulting
    -- term is well-typed with some type arbitrary type.
    softCoerce :: TTerm -> m TTerm
softCoerce TTerm
t =
      case TTerm
t of
        TVar{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TPrim{}   -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TDef{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TCon{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TLit{}    -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TUnit{}   -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TSort{}   -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TErased{} -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TCoerce{} -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TError{}  -> TTerm -> m TTerm
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return TTerm
t
        TApp TTerm
f Args
vs -> do
          ar <- TTerm -> m Int
forall (m :: * -> *). HasConstInfo m => TTerm -> m Int
funArity TTerm
f
          if length vs > ar
            then TApp (TCoerce f) <$> mapM softCoerce vs
            else TApp f <$> mapM coerce vs
        TLam TTerm
b         -> TTerm -> TTerm
TLam (TTerm -> TTerm) -> m TTerm -> m TTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
softCoerce TTerm
b
        TLet TTerm
e TTerm
b       -> TTerm -> TTerm -> TTerm
TLet (TTerm -> TTerm -> TTerm) -> m TTerm -> m (TTerm -> TTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
softCoerce TTerm
e m (TTerm -> TTerm) -> m TTerm -> m TTerm
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TTerm -> m TTerm
softCoerce TTerm
b
        TCase Int
x CaseInfo
t TTerm
d [TAlt]
bs -> Int -> CaseInfo -> TTerm -> [TAlt] -> TTerm
TCase Int
x CaseInfo
t (TTerm -> [TAlt] -> TTerm) -> m TTerm -> m ([TAlt] -> TTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TTerm -> m TTerm
coerce TTerm
d m ([TAlt] -> TTerm) -> m [TAlt] -> m TTerm
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (TAlt -> m TAlt) -> [TAlt] -> m [TAlt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TAlt -> m TAlt
coerceAlt [TAlt]
bs

funArity :: HasConstInfo m => TTerm -> m Nat
funArity :: forall (m :: * -> *). HasConstInfo m => TTerm -> m Int
funArity (TDef QName
q)  = Int -> (TTerm -> Int) -> Maybe TTerm -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 ((Int, TTerm) -> Int
forall a b. (a, b) -> a
fst ((Int, TTerm) -> Int) -> (TTerm -> (Int, TTerm)) -> TTerm -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TTerm -> (Int, TTerm)
tLamView) (Maybe TTerm -> Int) -> m (Maybe TTerm) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> QName -> m (Maybe TTerm)
forall (m :: * -> *). HasConstInfo m => QName -> m (Maybe TTerm)
getTreeless QName
q
funArity (TCon QName
q)  = QName -> m Int
forall (m :: * -> *). HasConstInfo m => QName -> m Int
erasedArity QName
q
funArity (TPrim TPrim
_) = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
3 -- max arity of any primitive
funArity TTerm
_         = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0

-- | The number of retained arguments after erasure.
erasedArity :: HasConstInfo m => QName -> m Nat
erasedArity :: forall (m :: * -> *). HasConstInfo m => QName -> m Int
erasedArity QName
q = [Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Bool] -> Int) -> ([Bool] -> [Bool]) -> [Bool] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Bool) -> [Bool] -> [Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter Bool -> Bool
not ([Bool] -> Int) -> m [Bool] -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> QName -> m [Bool]
forall (m :: * -> *). HasConstInfo m => QName -> m [Bool]
getErasedConArgs QName
q