-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Singletons.Deriving.Bounded
-- Copyright   :  (C) 2015 Richard Eisenberg
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Implements deriving of Bounded instances
--
----------------------------------------------------------------------------

module Data.Singletons.Deriving.Bounded where

import Language.Haskell.TH.Ppr
import Language.Haskell.TH.Desugar
import Data.Singletons.Names
import Data.Singletons.Util
import Data.Singletons.Syntax
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Control.Monad

-- monadic only for failure and parallelism with other functions
-- that make instances
mkBoundedInstance :: DsMonad q => DerivDesc q
mkBoundedInstance mb_ctxt ty (DataDecl _ _ cons) = do
  -- We can derive instance of Bounded if datatype is an enumeration (all
  -- constructors must be nullary) or has only one constructor. See Section 11
  -- of Haskell 2010 Language Report.
  -- Note that order of conditions below is important.
  when (null cons
       || (any (\(DCon _ _ _ f _) -> not . null . tysOfConFields $ f) cons
            && (not . null . tail $ cons))) $
       fail ("Can't derive Bounded instance for "
             ++ pprint (typeToTH ty) ++ ".")
  -- at this point we know that either we have a datatype that has only one
  -- constructor or a datatype where each constructor is nullary
  let (DCon _ _ minName fields _) = head cons
      (DCon _ _ maxName _ _)      = last cons
      fieldsCount   = length $ tysOfConFields fields
      (minRHS, maxRHS) = case fieldsCount of
        0 -> (DConE minName, DConE maxName)
        _ ->
          let minEqnRHS = foldExp (DConE minName)
                                  (replicate fieldsCount (DVarE minBoundName))
              maxEqnRHS = foldExp (DConE maxName)
                                  (replicate fieldsCount (DVarE maxBoundName))
          in (minEqnRHS, maxEqnRHS)

      mk_rhs rhs = UFunction [DClause [] rhs]
  constraints <- inferConstraintsDef mb_ctxt (DConPr boundedName) ty cons
  return $ InstDecl { id_cxt = constraints
                    , id_name = boundedName
                    , id_arg_tys = [ty]
                    , id_sigs  = mempty
                    , id_meths = [ (minBoundName, mk_rhs minRHS)
                                 , (maxBoundName, mk_rhs maxRHS) ] }