{-# LINE 1 "System/Linux/Netlink/C.hsc" #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-|
Module      : System.Linux.Netlink.C
Description : A module to bridge the haskell code to underlying C code
Maintainer  : ongy
Stability   : testing
Portability : Linux

I consider this module internal.
The documentation may be a bit sparse.
-}
module System.Linux.Netlink.C
    ( makeSocket
    , makeSocketGeneric
    , closeSocket
    , sendmsg
    , recvmsg
    , joinMulticastGroup
    , leaveMulticastGroup
    )
where



{-# LINE 25 "System/Linux/Netlink/C.hsc" #-}

{-# LINE 28 "System/Linux/Netlink/C.hsc" #-}

import Control.Monad (when)
import Data.ByteString (ByteString)
import Data.ByteString.Internal (createAndTrim, toForeignPtr)
import Data.Word (Word32)
import Foreign.C.Error (throwErrnoIf, throwErrnoIfMinus1, throwErrnoIfMinus1_)
import Foreign.C.Types
import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import Foreign.Marshal.Array (withArrayLen)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (Storable(..))

import System.Linux.Netlink.Constants (eAF_NETLINK, eNETLINK_ADD_MEMBERSHIP, eNETLINK_DROP_MEMBERSHIP)







-- FFI declarations for clib syscall wrappers
-- So if we are not blocking long or calling back into haskell it should be ok to do unsafe imports?
-- These should be done fast, and we know the will never call back into haskell
foreign import ccall unsafe "socket" c_socket :: CInt -> CInt -> CInt -> IO CInt
foreign import ccall unsafe "bind" c_bind :: CInt -> Ptr SockAddrNetlink -> Int -> IO CInt
foreign import ccall unsafe "close" c_close :: CInt -> IO CInt
foreign import ccall unsafe "setsockopt" c_setsockopt :: CInt -> CInt -> CInt -> Ptr a -> CInt -> IO CInt
foreign import ccall unsafe "memset" c_memset :: Ptr a -> CInt -> CInt -> IO ()

-- those two may block for a while, so we'll not do unsafe for them
foreign import ccall "sendmsg" c_sendmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt
foreign import ccall "recvmsg" c_recvmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt

data SockAddrNetlink = SockAddrNetlink Word32

instance Storable SockAddrNetlink where
    sizeOf    _ = (12)
{-# LINE 67 "System/Linux/Netlink/C.hsc" #-}
    alignment _ = 4
    peek p = do
        family <- (\hsc_ptr -> peekByteOff hsc_ptr 0) p
{-# LINE 70 "System/Linux/Netlink/C.hsc" #-}
        when ((family :: CShort) /= eAF_NETLINK) $ fail "Bad address family"
        SockAddrNetlink . fromIntegral <$> ((\hsc_ptr -> peekByteOff hsc_ptr 4) p :: IO CUInt)
{-# LINE 72 "System/Linux/Netlink/C.hsc" #-}
    poke p (SockAddrNetlink pid) = do
        zero p
        (\hsc_ptr -> pokeByteOff hsc_ptr 0) p (eAF_NETLINK :: CShort)
{-# LINE 75 "System/Linux/Netlink/C.hsc" #-}
        (\hsc_ptr -> pokeByteOff hsc_ptr 4) p (fromIntegral pid :: CUInt)
{-# LINE 76 "System/Linux/Netlink/C.hsc" #-}

data IoVec = IoVec (Ptr (), Int)

instance Storable IoVec where
    sizeOf    _ = (16)
{-# LINE 81 "System/Linux/Netlink/C.hsc" #-}
    alignment _ = 4
    peek p = do
        addr <- (\hsc_ptr -> peekByteOff hsc_ptr 0) p
{-# LINE 84 "System/Linux/Netlink/C.hsc" #-}
        len  <- (\hsc_ptr -> peekByteOff hsc_ptr 8)  p :: IO CSize
{-# LINE 85 "System/Linux/Netlink/C.hsc" #-}
        return $ IoVec (addr, fromIntegral len)
    poke p (IoVec (addr, len)) = do
        zero p
        (\hsc_ptr -> pokeByteOff hsc_ptr 0) p addr
{-# LINE 89 "System/Linux/Netlink/C.hsc" #-}
        (\hsc_ptr -> pokeByteOff hsc_ptr 8) p (fromIntegral len :: CSize)
{-# LINE 90 "System/Linux/Netlink/C.hsc" #-}

data MsgHdr = MsgHdr (Ptr (), Int)

instance Storable MsgHdr where
    sizeOf    _ = (56)
{-# LINE 95 "System/Linux/Netlink/C.hsc" #-}
    alignment _ = 4
    peek p = do
        iov     <- (\hsc_ptr -> peekByteOff hsc_ptr 16) p
{-# LINE 98 "System/Linux/Netlink/C.hsc" #-}
        iovlen  <- (\hsc_ptr -> peekByteOff hsc_ptr 24) p :: IO CSize
{-# LINE 99 "System/Linux/Netlink/C.hsc" #-}
        return $ MsgHdr (iov, fromIntegral iovlen)
    poke p (MsgHdr (iov, iovlen)) = do
        zero p
        (\hsc_ptr -> pokeByteOff hsc_ptr 16) p iov
{-# LINE 103 "System/Linux/Netlink/C.hsc" #-}
        (\hsc_ptr -> pokeByteOff hsc_ptr 24) p (fromIntegral iovlen :: CSize)
{-# LINE 104 "System/Linux/Netlink/C.hsc" #-}


-- |Create a netlink socket, for legacy reasons this will be of the route family
makeSocket :: IO CInt
makeSocket = makeSocketGeneric 0
{-# LINE 109 "System/Linux/Netlink/C.hsc" #-}

-- TODO maybe readd the unique thingy (look at git log)
-- |Create any netlink socket
makeSocketGeneric 
  :: Int -- ^The netlink family to use
  -> IO CInt
makeSocketGeneric prot = do
  fd <- throwErrnoIfMinus1 "makeSocket.socket" $
          c_socket eAF_NETLINK 3 (fromIntegral prot)
{-# LINE 118 "System/Linux/Netlink/C.hsc" #-}
  -- we need to bind or joining multicast groups will be useless
  with (SockAddrNetlink 0) $ \addr ->
    throwErrnoIfMinus1_ "makeSocket.bind" $
      c_bind fd (castPtr addr) (12)
{-# LINE 122 "System/Linux/Netlink/C.hsc" #-}
  return fd


-- |Close a socket when it is not needed anymore
closeSocket :: CInt -> IO ()
closeSocket fd = throwErrnoIfMinus1_ "closeSocket" $ c_close fd

-- |Send a message over a socket.
sendmsg :: CInt -> [ByteString] -> IO ()
sendmsg fd bs =
    useManyAsPtrLen bs $ \ptrs ->
    withArrayLen (map IoVec ptrs) $ \iovlen iov ->
    with (MsgHdr (castPtr iov, iovlen)) $ \msg ->
    throwErrnoIfMinus1_ "sendmsg" $c_sendmsg fd (castPtr msg) (0 :: CInt)

-- |Receive a message over a socket.
recvmsg :: CInt -> Int -> IO ByteString
recvmsg fd len =
    createAndTrim len $ \ptr ->
    with (IoVec (castPtr ptr, len)) $ \vec ->
    with (MsgHdr (castPtr vec, 1)) $ \msg ->
    fmap fromIntegral . throwErrnoIf (<= 0) "recvmsg" $
        c_recvmsg fd (castPtr msg) (0 :: CInt)

useManyAsPtrLen :: [ByteString] -> ([(Ptr (), Int)] -> IO a) -> IO a
useManyAsPtrLen bs act =
    let makePtrLen (fptr, off, len) =
            let ptr = plusPtr (unsafeForeignPtrToPtr fptr) off
            in (ptr, len)
        touchByteStringPtr (fptr, _, _) = touchForeignPtr fptr
        foreigns = map toForeignPtr bs
    in act (map makePtrLen foreigns) <* mapM_ touchByteStringPtr foreigns

sizeOfPtr :: (Storable a, Integral b) => Ptr a -> b
sizeOfPtr = fromIntegral . sizeOf . (undefined :: Ptr a -> a)

zero :: Storable a => Ptr a -> IO ()
zero p = void $ c_memset (castPtr p) 0 (sizeOfPtr p)

void :: Monad m => m a -> m ()
void act = act >> return ()


-- |Set membership to netlink multicast group
joinOrLeaveMulticastGroup :: Bool -> CInt -> Word32 -> IO ()
joinOrLeaveMulticastGroup beMember fd fid = do
  _ <- throwErrnoIfMinus1 "joinMulticast" $ with fid (\ptr ->
    c_setsockopt fd sol_netlink value (castPtr ptr) size)
  return ()
  where
    size = fromIntegral $sizeOf (undefined :: CInt)
    sol_netlink = 270 :: CInt
    value = if beMember
      then eNETLINK_ADD_MEMBERSHIP
      else eNETLINK_DROP_MEMBERSHIP

-- |Join a netlink multicast group
joinMulticastGroup :: CInt -> Word32 -> IO ()
joinMulticastGroup = joinOrLeaveMulticastGroup True

-- |Leave a netlink multicast group
leaveMulticastGroup :: CInt -> Word32 -> IO ()
leaveMulticastGroup = joinOrLeaveMulticastGroup False