{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
module Futhark.Analysis.HORepresentation.MapNest
  ( Nesting (..)
  , MapNest (..)
  , typeOf
  , params
  , inputs
  , setInputs
  , fromSOAC
  , toSOAC
  )
where

import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import qualified Futhark.Analysis.HORepresentation.SOAC as SOAC
import Futhark.Analysis.HORepresentation.SOAC (SOAC)

import qualified Futhark.Representation.SOACS.SOAC as Futhark
import Futhark.Transform.Substitute
import Futhark.Representation.AST hiding (typeOf)
import Futhark.MonadFreshNames
import Futhark.Construct

data Nesting lore = Nesting {
    nestingParamNames   :: [VName]
  , nestingResult       :: [VName]
  , nestingReturnType   :: [Type]
  , nestingWidth        :: SubExp
  } deriving (Eq, Ord, Show)

data MapNest lore = MapNest SubExp (Lambda lore) [Nesting lore] [SOAC.Input]
                  deriving (Show)

typeOf :: MapNest lore -> [Type]
typeOf (MapNest w lam [] _) =
  map (`arrayOfRow` w) $ lambdaReturnType lam
typeOf (MapNest w _ (nest:_) _) =
  map (`arrayOfRow` w) $ nestingReturnType nest

params :: MapNest lore -> [VName]
params (MapNest _ lam [] _)       =
  map paramName $ lambdaParams lam
params (MapNest _ _ (nest:_) _) =
  nestingParamNames nest

inputs :: MapNest lore -> [SOAC.Input]
inputs (MapNest _ _ _ inps) = inps

setInputs :: [SOAC.Input] -> MapNest lore -> MapNest lore
setInputs [] (MapNest w body ns _) = MapNest w body ns []
setInputs (inp:inps) (MapNest _ body ns _) = MapNest w body ns' (inp:inps)
  where w = arraySize 0 $ SOAC.inputType inp
        ws = drop 1 $ arrayDims $ SOAC.inputType inp
        ns' = zipWith setDepth ns ws
        setDepth n nw = n { nestingWidth = nw }

fromSOAC :: (Bindable lore, MonadFreshNames m,
             LocalScope lore m,
             Op lore ~ Futhark.SOAC lore) =>
            SOAC lore -> m (Maybe (MapNest lore))
fromSOAC = fromSOAC' mempty

fromSOAC' :: (Bindable lore, MonadFreshNames m,
              LocalScope lore m,
              Op lore ~ Futhark.SOAC lore) =>
             [Ident]
          -> SOAC lore
          -> m (Maybe (MapNest lore))

fromSOAC' bound (SOAC.Screma w (SOAC.ScremaForm (_, []) (_, _, []) lam) inps) = do
  maybenest <- case (stmsToList $ bodyStms $ lambdaBody lam,
                     bodyResult $ lambdaBody lam) of
    ([Let pat _ e], res) | res == map Var (patternNames pat) ->
      localScope (scopeOfLParams $ lambdaParams lam) $
      SOAC.fromExp e >>=
      either (return . Left) (fmap (Right . fmap (pat,)) . fromSOAC' bound')
    _ ->
      return $ Right Nothing

  case maybenest of
    -- Do we have a nested MapNest?
    Right (Just (pat, mn@(MapNest inner_w body' ns' inps'))) -> do
      (ps, inps'') <-
        unzip <$>
        fixInputs w (zip (map paramName $ lambdaParams lam) inps)
        (zip (params mn) inps')
      let n' = Nesting {
            nestingParamNames   = ps
            , nestingResult     = patternNames pat
            , nestingReturnType = typeOf mn
            , nestingWidth      = inner_w
            }
      return $ Just $ MapNest w body' (n':ns') inps''
    -- No nested MapNest it seems.
    _ -> do
      let isBound name
            | Just param <- find ((name==) . identName) bound =
              Just param
            | otherwise =
              Nothing
          boundUsedInBody =
            mapMaybe isBound $ S.toList $ freeInLambda lam
      newParams <- mapM (newIdent' (++"_wasfree")) boundUsedInBody
      let subst = M.fromList $
                  zip (map identName boundUsedInBody) (map identName newParams)
          inps' = map (substituteNames subst) inps ++
                  map (SOAC.addTransform (SOAC.Replicate mempty $ Shape [w]) . SOAC.identInput)
                  boundUsedInBody
          lam' =
            lam { lambdaBody =
                    substituteNames subst $ lambdaBody lam
                , lambdaParams =
                    lambdaParams lam ++ [ Param name t
                                        | Ident name t <- newParams ]
                }
      return $ Just $ MapNest w lam' [] inps'
  where bound' = bound <> map paramIdent (lambdaParams lam)

fromSOAC' _ _ = return Nothing

toSOAC :: (MonadFreshNames m, HasScope lore m,
           Bindable lore, BinderOps lore, Op lore ~ Futhark.SOAC lore) =>
          MapNest lore -> m (SOAC lore)
toSOAC (MapNest w lam [] inps) =
  return $ SOAC.Screma w (Futhark.mapSOAC lam) inps
toSOAC (MapNest w lam (Nesting npnames nres nrettype nw:ns) inps) = do
  let nparams = zipWith Param npnames $ map SOAC.inputRowType inps
  (e,bnds) <- runBinder $ localScope (scopeOfLParams nparams) $ SOAC.toExp =<<
    toSOAC (MapNest nw lam ns $ map (SOAC.identInput . paramIdent) nparams)
  bnd <- mkLetNames nres e
  let outerlam = Lambda { lambdaParams = nparams
                        , lambdaBody = mkBody (bnds<>oneStm bnd) $ map Var nres
                        , lambdaReturnType = nrettype
                        }
  return $ SOAC.Screma w (Futhark.mapSOAC outerlam) inps

fixInputs :: MonadFreshNames m =>
             SubExp -> [(VName, SOAC.Input)] -> [(VName, SOAC.Input)]
          -> m [(VName, SOAC.Input)]
fixInputs w ourInps childInps =
  reverse . snd <$> foldM inspect (ourInps, []) childInps
  where
    isParam x (y, _) = x == y

    findParam :: [(VName, SOAC.Input)]
              -> VName
              -> Maybe ((VName, SOAC.Input), [(VName, SOAC.Input)])
    findParam remPs v
      | ([ourP], remPs') <- partition (isParam v) remPs = Just (ourP, remPs')
      | otherwise                                       = Nothing

    inspect :: MonadFreshNames m =>
               ([(VName, SOAC.Input)], [(VName, SOAC.Input)])
            -> (VName, SOAC.Input)
            -> m ([(VName, SOAC.Input)], [(VName, SOAC.Input)])
    inspect (remPs, newInps) (_, SOAC.Input ts v _)
      | Just ((p,pInp), remPs') <- findParam remPs v =
          let pInp' = SOAC.transformRows ts pInp
          in return (remPs',
                     (p, pInp') : newInps)

      | Just ((p,pInp), _) <- findParam newInps v = do
          -- The input corresponds to a variable that has already
          -- been used.
          p' <- newNameFromString $ baseString p
          return (remPs, (p', pInp) : newInps)

    inspect (remPs, newInps) (param, SOAC.Input ts a t) = do
      param' <- newNameFromString (baseString param ++ "_rep")
      return (remPs, (param',
                      SOAC.Input (ts SOAC.|> SOAC.Replicate mempty (Shape [w])) a t) : newInps)