{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeSynonymInstances   #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE BangPatterns           #-}
module Data.SBV.Core.Splittable (Splittable(..), FromBits(..), checkAndConvert) where
import Data.Bits (Bits(..))
import Data.Word (Word8, Word16, Word32, Word64)
import Data.SBV.Core.Operations
import Data.SBV.Core.Data
import Data.SBV.Core.Model
infixr 5 #
class Splittable a b | b -> a where
  split  :: a -> (b, b)
  (#)    :: b -> b -> a
  extend :: b -> a
genSplit :: (Integral a, Num b) => Int -> a -> (b, b)
genSplit ss x = (fromIntegral ((ix `shiftR` ss) .&. mask), fromIntegral (ix .&. mask))
  where ix = toInteger x
        mask = 2 ^ ss - 1
genJoin :: (Integral b, Num a) => Int -> b -> b -> a
genJoin ss x y = fromIntegral ((ix `shiftL` ss) .|. iy)
  where ix = toInteger x
        iy = toInteger y
instance Splittable Word64 Word32 where
  split = genSplit 32
  (#)   = genJoin  32
  extend b = 0 # b
instance Splittable Word32 Word16 where
  split = genSplit 16
  (#)   = genJoin  16
  extend b = 0 # b
instance Splittable Word16 Word8 where
  split = genSplit 8
  (#)   = genJoin  8
  extend b = 0 # b
instance Splittable SWord64 SWord32 where
  split (SBV x) = (SBV (svExtract 63 32 x), SBV (svExtract 31 0 x))
  SBV a # SBV b = SBV (svJoin a b)
  extend b = 0 # b
instance Splittable SWord32 SWord16 where
  split (SBV x) = (SBV (svExtract 31 16 x), SBV (svExtract 15 0 x))
  SBV a # SBV b = SBV (svJoin a b)
  extend b = 0 # b
instance Splittable SWord16 SWord8 where
  split (SBV x) = (SBV (svExtract 15 8 x), SBV (svExtract 7 0 x))
  SBV a # SBV b = SBV (svJoin a b)
  extend b = 0 # b
class FromBits a where
 fromBitsLE, fromBitsBE :: [SBool] -> a
 fromBitsBE = fromBitsLE . reverse
fromBinLE :: (Num a, Bits a, SymWord a) => [SBool] -> SBV a
fromBinLE = go 0 0
  where go !acc _  []     = acc
        go !acc !i (x:xs) = go (ite x (setBit acc i) acc) (i+1) xs
checkAndConvert :: (Num a, Bits a, SymWord a) => Int -> [SBool] -> SBV a
checkAndConvert sz xs
  | sz /= l
  = error $ "SBV.fromBits.SWord" ++ ssz ++ ": Expected " ++ ssz ++ " elements, got: " ++ show l
  | True
  = fromBinLE xs
  where l   = length xs
        ssz = show sz
instance FromBits SBool where
 fromBitsLE [x] = x
 fromBitsLE xs  = error $ "SBV.fromBits.SBool: Expected 1 element, got: " ++ show (length xs)
instance FromBits SWord8  where fromBitsLE = checkAndConvert  8
instance FromBits SInt8   where fromBitsLE = checkAndConvert  8
instance FromBits SWord16 where fromBitsLE = checkAndConvert 16
instance FromBits SInt16  where fromBitsLE = checkAndConvert 16
instance FromBits SWord32 where fromBitsLE = checkAndConvert 32
instance FromBits SInt32  where fromBitsLE = checkAndConvert 32
instance FromBits SWord64 where fromBitsLE = checkAndConvert 64
instance FromBits SInt64  where fromBitsLE = checkAndConvert 64