```-- |
-- Module      :  Cryptol.Transform.MonoValues
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
--
-- This module implements a transformation, which tries to avoid exponential
-- slow down in some cases.  What's the problem?  Consider the following (common)
-- patterns:
--
-- >    fibs = [0,1] # [ x + y | x <- fibs, y <- drop`{1} fibs ]
--
-- The type of @fibs@ is:
--
-- >    {a} (a >= 1, fin a) => [inf][a]
--
-- Here @a@ is the number of bits to be used in the values computed by @fibs@.
-- When we evaluate @fibs@, @a@ becomes a parameter to @fibs@, which works
-- except that now @fibs@ is a function, and we don't get any of the memoization
-- we might expect!  What looked like an efficient implementation has all
-- of a sudden become exponential!
--
-- Note that this is only a problem for polymorphic values: if @fibs@ was
-- already a function, it would not be that surprising that it does not
-- get cached.
--
-- So, to avoid this, we try to spot recursive polymorphic values,
-- where the recursive occurrences have the exact same type parameters
-- as the definition.  For example, this is the case in @fibs@: each
-- recursive call to @fibs@ is instantiated with exactly the same
-- type parameter (i.e., @a@).  The rewrite we do is as follows:
--
-- >    fibs : {a} (a >= 1, fin a) => [inf][a]
-- >    fibs = \{a} (a >= 1, fin a) -> fibs'
-- >      where fibs' : [inf][a]
-- >            fibs' = [0,1] # [ x + y | x <- fibs', y <- drop`{1} fibs' ]
--
-- After the rewrite, the recursion is monomorphic (i.e., we are always using
-- the same type).  As a result, @fibs'@ is an ordinary recursive value,
-- where we get the benefit of caching.
--
-- The rewrite is a bit more complex, when there are multiple mutually
-- recursive functions.  Here is an example:
--
-- >    zig : {a} (a >= 2, fin a) => [inf][a]
-- >    zig = [1] # zag
-- >
-- >    zag : {a} (a >= 2, fin a) => [inf][a]
-- >    zag = [2] # zig
--
-- This gets rewritten to:
--
-- >    newName : {a} (a >= 2, fin a) => ([inf][a], [inf][a])
-- >    newName = \{a} (a >= 2, fin a) -> (zig', zag')
-- >      where
-- >      zig' : [inf][a]
-- >      zig' = [1] # zag'
-- >
-- >      zag' : [inf][a]
-- >      zag' = [2] # zig'
-- >
-- >    zig : {a} (a >= 2, fin a) => [inf][a]
-- >    zig = \{a} (a >= 2, fin a) -> (newName a <> <> ).1
-- >
-- >    zag : {a} (a >= 2, fin a) => [inf][a]
-- >    zag = \{a} (a >= 2, fin a) -> (newName a <> <> ).2
--
-- NOTE:  We are assuming that no capture would occur with binders.
-- For values, this is because we replaces things with freshly chosen variables.
-- For types, this should be because there should be no shadowing in the types.
-- XXX: Make sure that this really is the case for types!!

{-# LANGUAGE PatternGuards, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE CPP #-}
module Cryptol.Transform.MonoValues (rewModule) where

import Cryptol.ModuleSystem.Name
(SupplyT,liftSupply,Supply,mkDeclared,NameSource(..))
import Cryptol.Parser.Position (emptyRange)
import Cryptol.TypeCheck.AST hiding (splitTApp) -- XXX: just use this one
import Cryptol.TypeCheck.TypeMap
import Cryptol.Utils.Ident (ModName)
import Data.List(sortBy,groupBy)
import Data.Either(partitionEithers)
import Data.Map (Map)

import Prelude ()
import Prelude.Compat

{- (f,t,n) |--> x  means that when we spot instantiations of @f@ with @ts@ and
@n@ proof argument, we should replace them with @Var x@ -}
newtype RewMap' a = RM (Map Name (TypesMap (Map Int a)))
type RewMap = RewMap' Name

instance TrieMap RewMap' (Name,[Type],Int) where
emptyTM  = RM emptyTM

nullTM (RM m) = nullTM m

lookupTM (x,ts,n) (RM m) = do tM <- lookupTM x m
tP <- lookupTM ts tM
lookupTM n tP

alterTM (x,ts,n) f (RM m) = RM (alterTM x f1 m)
where
f1 Nothing   = do a <- f Nothing
return (insertTM ts (insertTM n a emptyTM) emptyTM)

f1 (Just tM) = Just (alterTM ts f2 tM)

f2 Nothing   = do a <- f Nothing
return (insertTM n a emptyTM)

f2 (Just pM) = Just (alterTM n f pM)

unionTM f (RM a) (RM b) = RM (unionTM (unionTM (unionTM f)) a b)

toListTM (RM m) = [ ((x,ts,n),y) | (x,tM)  <- toListTM m
, (ts,pM) <- toListTM tM
, (n,y)   <- toListTM pM ]

mapMaybeWithKeyTM f (RM m) =
RM (mapWithKeyTM      (\qn  tm ->
mapWithKeyTM      (\tys is ->
mapMaybeWithKeyTM (\i   a  -> f (qn,tys,i) a) is) tm) m)

-- | Note that this assumes that this pass will be run only once for each
-- module, otherwise we will get name collisions.
rewModule :: Supply -> Module -> (Module,Supply)
rewModule s m = runM body (mName m) s
where
body = do ds <- mapM (rewDeclGroup emptyTM) (mDecls m)
return m { mDecls = ds }

--------------------------------------------------------------------------------

type M  = ReaderT RO (SupplyT Id)
type RO = ModName

-- | Produce a fresh top-level name.
newName :: M Name
newName  =
liftSupply (mkDeclared ns SystemName "\$mono" Nothing emptyRange)

newTopOrLocalName :: M Name
newTopOrLocalName  = newName

-- | Not really any distinction between global and local, all names get the
-- module prefix added, and a unique id.
inLocal :: M a -> M a
inLocal  = id

--------------------------------------------------------------------------------
rewE :: RewMap -> Expr -> M Expr   -- XXX: not IO
rewE rews = go

where
tryRewrite (EVar x,tps,n) =
do y <- lookupTM (x,tps,n) rews
return (EVar y)
tryRewrite _ = Nothing

go expr =
case expr of

-- Interesting cases
ETApp e t      -> case tryRewrite (splitTApp expr 0) of
Nothing  -> ETApp <\$> go e <*> return t
Just yes -> return yes
EProofApp e    -> case tryRewrite (splitTApp e 1) of
Nothing  -> EProofApp <\$> go e
Just yes -> return yes

EList es t      -> EList   <\$> mapM go es <*> return t
ETuple es       -> ETuple  <\$> mapM go es
ERec fs         -> ERec    <\$> traverse go fs
ESel e s        -> ESel    <\$> go e  <*> return s
ESet e s v      -> ESet    <\$> go e  <*> return s <*> go v
EIf e1 e2 e3    -> EIf     <\$> go e1 <*> go e2 <*> go e3

EComp len t e mss -> EComp len t <\$> go e  <*> mapM (mapM (rewM rews)) mss
EVar _          -> return expr

ETAbs x e       -> ETAbs x  <\$> go e

EApp e1 e2      -> EApp     <\$> go e1 <*> go e2
EAbs x t e      -> EAbs x t <\$> go e

EProofAbs x e   -> EProofAbs x <\$> go e

EWhere e dgs    -> EWhere      <\$> go e <*> inLocal
(mapM (rewDeclGroup rews) dgs)

rewM :: RewMap -> Match -> M Match
rewM rews ma =
case ma of
From x len t e -> From x len t <\$> rewE rews e

-- These are not recursive.
Let d      -> Let <\$> rewD rews d

rewD :: RewMap -> Decl -> M Decl
rewD rews d = do e <- rewDef rews (dDefinition d)
return d { dDefinition = e }

rewDef :: RewMap -> DeclDef -> M DeclDef
rewDef rews (DExpr e) = DExpr <\$> rewE rews e
rewDef _    DPrim     = return DPrim

rewDeclGroup :: RewMap -> DeclGroup -> M DeclGroup
rewDeclGroup rews dg =
case dg of
NonRecursive d -> NonRecursive <\$> rewD rews d
Recursive ds ->
do let (leave,rew) = partitionEithers (map consider ds)
rewGroups   = groupBy sameTParams
\$ sortBy compareTParams rew
ds1 <- mapM (rewD rews) leave
ds2 <- mapM rewSame rewGroups
return \$ Recursive (ds1 ++ concat ds2)

where
sameTParams    (_,tps1,x,_) (_,tps2,y,_) = tps1 == tps2 && x == y
compareTParams (_,tps1,x,_) (_,tps2,y,_) = compare (x,tps1) (y,tps2)

consider d   =
case dDefinition d of
DPrim   -> Left d
DExpr e -> let (tps,props,e') = splitTParams e
in if not (null tps) && notFun e'
then Right (d, tps, props, e')
else Left d

rewSame ds =
do new <- forM ds \$ \(d,_,_,e) ->
do x <- newName
return (d, x, e)
let (_,tps,props,_) : _ = ds
tys            = map (TVar . tpVar) tps
proofNum       = length props
addRew (d,x,_) = insertTM (dName d,tys,proofNum) x
newRews        = foldr addRew rews new

newDs <- forM new \$ \(d,newN,e) ->
do e1 <- rewE newRews e
return ( d
, d { dName        = newN
, dSignature   = (dSignature d)
{ sVars = [], sProps = [] }
, dDefinition  = DExpr e1
}
)

case newDs of
[(f,f')] ->
return  [ f { dDefinition =
let newBody = EVar (dName f')
newE = EWhere newBody
[ Recursive [f'] ]
in DExpr \$ foldr ETAbs
(foldr EProofAbs newE props) tps
}
]

_ -> do tupName <- newTopOrLocalName
let (polyDs,monoDs) = unzip newDs

tupAr  = length monoDs
addTPs = flip (foldr ETAbs)     tps
. flip (foldr EProofAbs) props

-- tuple = \{a} p -> (f',g')
--                where f' = ...
--                      g' = ...
tupD = Decl
{ dName       = tupName
, dSignature  =
Forall tps props \$
TCon (TC (TCTuple tupAr))
(map (sType . dSignature) monoDs)

, dDefinition =
DExpr  \$
EWhere (ETuple [ EVar (dName d) | d <- monoDs ])
[ Recursive monoDs ]

, dPragmas    = [] -- ?

, dInfix = False
, dFixity = Nothing
, dDoc = Nothing
}

mkProof e _ = EProofApp e

-- f = \{a} (p) -> (tuple @a p). n

mkFunDef n f =
f { dDefinition =
DExpr  \$
addTPs \$ ESel ( flip (foldl mkProof) props
\$ flip (foldl ETApp) tys
\$ EVar tupName
) (TupleSel n (Just tupAr))
}

return (tupD : zipWith mkFunDef [ 0 .. ] polyDs)

--------------------------------------------------------------------------------
splitTParams :: Expr -> ([TParam], [Prop], Expr)
splitTParams e = let (tps, e1)   = splitWhile splitTAbs e
(props, e2) = splitWhile splitProofAbs e1
in (tps,props,e2)

-- returns type instantitaion and how many "proofs" were there
splitTApp :: Expr -> Int -> (Expr, [Type], Int)
splitTApp (EProofApp e) n = splitTApp e \$! (n + 1)
splitTApp e0 n            = let (e1,ts) = splitTy e0 []
in (e1, ts, n)
where
splitTy (ETApp e t) ts = splitTy e (t:ts)
splitTy e ts           = (e,ts)

notFun :: Expr -> Bool
notFun (EAbs {})       = False
notFun (EProofAbs _ e) = notFun e
notFun _               = True

```