{-# LANGUAGE FlexibleContexts #-}
module Futhark.Internalise.Bindings
  (
  -- * Internalising bindings
    bindingParams
  , bindingLambdaParams
  , stmPattern
  , MatchPattern
  )
  where

import Control.Monad.State  hiding (mapM)
import Control.Monad.Reader hiding (mapM)
import Control.Monad.Writer hiding (mapM)

import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Loc
import Data.Traversable (mapM)

import Language.Futhark as E
import qualified Futhark.Representation.SOACS as I
import Futhark.MonadFreshNames

import Futhark.Internalise.Monad
import Futhark.Internalise.TypesValues
import Futhark.Internalise.AccurateSizes
import Futhark.Util

bindingParams :: [E.TypeParam] -> [E.Pattern]
              -> (ConstParams -> [I.FParam] -> [[I.FParam]] -> InternaliseM a)
              -> InternaliseM a
bindingParams tparams params m = do
  flattened_params <- mapM flattenPattern params
  let (params_idents, params_types) = unzip $ concat flattened_params
      bound = boundInTypes tparams
      param_names = M.fromList [ (E.identName x, y) | (x,y) <- params_idents ]
  (params_ts, cm) <- internaliseParamTypes bound param_names params_types
  let num_param_idents = map length flattened_params
      num_param_ts = map (sum . map length) $ chunks num_param_idents params_ts

  (params_ts', unnamed_shape_params) <-
    fmap unzip $ forM params_ts $ \param_ts -> do
      (param_ts', param_unnamed_dims) <- instantiateShapesWithDecls mempty param_ts

      return (param_ts',
              param_unnamed_dims)

  let named_shape_params = [ I.Param v $ I.Prim I.int32 | E.TypeParamDim v _ <- tparams ]
      shape_params = named_shape_params ++ concat unnamed_shape_params
      shape_subst = M.fromList [ (I.paramName p, [I.Var $ I.paramName p]) | p <- shape_params ]
  bindingFlatPattern params_idents (concat params_ts') $ \valueparams ->
    I.localScope (I.scopeOfFParams $ shape_params++concat valueparams) $
    substitutingVars shape_subst $ m cm shape_params $ chunks num_param_ts (concat valueparams)

bindingLambdaParams :: [E.TypeParam] -> [E.Pattern] -> [I.Type]
                    -> (ConstParams -> [I.LParam] -> InternaliseM a)
                    -> InternaliseM a
bindingLambdaParams tparams params ts m = do
  (params_idents, params_types) <-
    unzip . concat <$> mapM flattenPattern params
  let bound = boundInTypes tparams
      param_names = M.fromList [ (E.identName x, y) | (x,y) <- params_idents ]
  (params_ts, cm) <- internaliseParamTypes bound param_names params_types

  let ascript_substs = lambdaShapeSubstitutions (concat params_ts) ts

  bindingFlatPattern params_idents ts $ \params' ->
    local (\env -> env { envSubsts = ascript_substs `M.union` envSubsts env }) $
    I.localScope (I.scopeOfLParams $ concat params') $ m cm $ concat params'

processFlatPattern :: Show t => [(E.Ident,VName)] -> [t]
                   -> InternaliseM ([[I.Param t]], VarSubstitutions)
processFlatPattern x y = processFlatPattern' [] x y
  where
    processFlatPattern' pat []       _  = do
      let (vs, substs) = unzip pat
          substs' = M.fromList substs
          idents = reverse vs
      return (idents, substs')

    processFlatPattern' pat ((p,name):rest) ts = do
      (ps, subst, rest_ts) <- handleMapping ts <$> internaliseBindee (p, name)
      processFlatPattern' ((ps, (E.identName p, map (I.Var . I.paramName) subst)) : pat) rest rest_ts

    handleMapping ts [] =
      ([], [], ts)
    handleMapping ts (r:rs) =
        let (ps, reps, ts')    = handleMapping' ts r
            (pss, repss, ts'') = handleMapping ts' rs
        in (ps++pss, reps:repss, ts'')

    handleMapping' (t:ts) (vname,_) =
      let v' = I.Param vname t
      in ([v'], v', ts)
    handleMapping' [] _ =
      error $ "processFlatPattern: insufficient identifiers in pattern." ++ show (x, y)

    internaliseBindee :: (E.Ident, VName) -> InternaliseM [(VName, I.DeclExtType)]
    internaliseBindee (bindee, name) = do
      -- XXX: we gotta be screwing up somehow by ignoring the extra
      -- return values.  If not, why not?
      (tss, _) <- internaliseParamTypes nothing_bound mempty
                  [flip E.setAliases () $ E.vacuousShapeAnnotations $
                   E.unInfo $ E.identType bindee]
      case concat tss of
        [t] -> return [(name, t)]
        tss' -> forM tss' $ \t -> do
          name' <- newVName $ baseString name
          return (name', t)

    -- Fixed up later.
    nothing_bound = boundInTypes []

bindingFlatPattern :: Show t => [(E.Ident, VName)] -> [t]
                   -> ([[I.Param t]] -> InternaliseM a)
                   -> InternaliseM a
bindingFlatPattern idents ts m = do
  (ps, substs) <- processFlatPattern idents ts
  local (\env -> env { envSubsts = substs `M.union` envSubsts env}) $
    m ps

-- | Flatten a pattern.  Returns a list of identifiers.  The
-- structural type of each identifier is returned separately.
flattenPattern :: MonadFreshNames m => E.Pattern -> m [((E.Ident, VName), E.StructType)]
flattenPattern = flattenPattern'
  where flattenPattern' (E.PatternParens p _) =
          flattenPattern' p
        flattenPattern' (E.Wildcard t loc) = do
          name <- newVName "nameless"
          flattenPattern' $ E.Id name t loc
        flattenPattern' (E.Id v (Info t) loc) = do
          new_name <- newVName $ baseString v
          return [((E.Ident v (Info (E.removeShapeAnnotations t)) loc,
                    new_name),
                   t `E.setAliases` ())]
        flattenPattern' (E.TuplePattern pats _) =
          concat <$> mapM flattenPattern' pats
        flattenPattern' (E.RecordPattern fs loc) =
          flattenPattern' $ E.TuplePattern (map snd $ sortFields $ M.fromList fs) loc
        flattenPattern' (E.PatternAscription p _ _) =
          flattenPattern' p
        flattenPattern' (E.PatternLit _ t loc) =
          flattenPattern' $ E.Wildcard t loc

type MatchPattern = SrcLoc -> [I.SubExp] -> InternaliseM [I.SubExp]

stmPattern :: [E.TypeParam] -> E.Pattern -> [I.ExtType]
           -> (ConstParams -> [VName] -> MatchPattern -> InternaliseM a)
           -> InternaliseM a
stmPattern tparams pat ts m = do
  (pat', pat_types) <- unzip <$> flattenPattern pat
  (ts',_) <- instantiateShapes' ts
  (pat_types', cm) <- internaliseParamTypes (boundInTypes tparams) mempty pat_types
  let pat_types'' = map I.fromDecl $ concat pat_types'
      tparam_names = S.fromList $ map E.typeParamName tparams
  let addShapeStms l =
        m cm (map I.paramName $ concat l) (matchPattern tparam_names pat_types'')
  bindingFlatPattern pat' ts' addShapeStms

matchPattern :: S.Set VName -> [I.ExtType] -> MatchPattern
matchPattern tparam_names exts loc ses =
  forM (zip exts ses) $ \(et, se) -> do
  se_t <- I.subExpType se
  et' <- unExistentialise tparam_names et se_t
  ensureExtShape asserting (I.ErrorMsg [I.ErrorString "value cannot match pattern"])
    loc et' "correct_shape" se

unExistentialise :: S.Set VName -> I.ExtType -> I.Type -> InternaliseM I.ExtType
unExistentialise tparam_names et t = do
  new_dims <- zipWithM inspectDim (I.shapeDims $ I.arrayShape et) (I.arrayDims t)
  return $ t `I.setArrayShape` I.Shape new_dims
  where inspectDim (I.Free (I.Var v)) d
          | v `S.member` tparam_names = do
              letBindNames_ [v] $ I.BasicOp $ I.SubExp d
              return $ I.Free $ I.Var v
        inspectDim ed _ = return ed

instantiateShapesWithDecls :: MonadFreshNames m =>
                              M.Map Int I.Ident
                           -> [I.DeclExtType]
                           -> m ([I.DeclType], [I.FParam])
instantiateShapesWithDecls ctx ts =
  runWriterT $ instantiateShapes instantiate ts
  where instantiate x
          | Just v <- M.lookup x ctx =
            return $ I.Var $ I.identName v

          | otherwise = do
            v <- lift $ nonuniqueParamFromIdent <$> newIdent "size" (I.Prim I.int32)
            tell [v]
            return $ I.Var $ I.paramName v

lambdaShapeSubstitutions :: [I.TypeBase I.ExtShape Uniqueness]
                         -> [I.Type]
                         -> VarSubstitutions
lambdaShapeSubstitutions param_ts ts =
  mconcat $ zipWith matchTypes param_ts ts
  where matchTypes pt t =
          mconcat $ zipWith matchDims (I.shapeDims $ I.arrayShape pt) (I.arrayDims t)
        matchDims (I.Free (I.Var v)) d = M.singleton v [d]
        matchDims _ _ = mempty

nonuniqueParamFromIdent :: I.Ident -> I.FParam
nonuniqueParamFromIdent (I.Ident name t) =
  I.Param name $ I.toDecl t Nonunique