-- | UniformIO functions for TCP connections
module System.IO.Uniform.Network (
  SocketIO,
  BoundedPort,
  connectTo,
  connectToHost,
  bindPort,
  accept,
  closePort,
  getPeer
  ) where

import System.IO.Uniform
import System.IO.Uniform.External

import Foreign
import Foreign.C.Types
import Foreign.C.String
import Foreign.C.Error
import qualified Data.IP as IP
import qualified Data.ByteString as BS
import qualified Data.List as L
import Control.Exception
import Control.Monad
import qualified Network.Socket as Soc
import System.IO.Error

import System.Posix.Types (Fd(..))

-- | UniformIO TCP connections.
instance UniformIO SocketIO where
  uRead (SocketIO s) n = allocaArray n (
    \b -> do
      count <- c_recv s b (fromIntegral n)
      if count < 0
        then throwErrno "could not read"
        else BS.packCStringLen (b, fromIntegral count)
    )
  uRead (TlsSocketIO s) n = allocaArray n (
    \b -> do
      count <- c_recvTls s b $ fromIntegral n
      if count < 0
        then throwErrno "could not read"
        else BS.packCStringLen (b, fromIntegral count)
    )
  uPut (SocketIO s) t = BS.useAsCStringLen t (
    \(str, n) -> do
      count <- c_send s str $ fromIntegral n
      when (count < 0) $ throwErrno "could not write"
    )
  uPut (TlsSocketIO s) t = BS.useAsCStringLen t (
    \(str, n) -> do
      count <- c_sendTls s str $ fromIntegral n
      when (count < 0) $ throwErrno "could not write"
    )
  uClose (SocketIO s) = do
    f <- Fd <$> c_prepareToClose s
    closeFd f
  uClose (TlsSocketIO s) = do
    d <- c_closeTls s
    f <- Fd <$> c_prepareToClose d
    closeFd f
  startTls st (SocketIO s) = withCString (tlsCertificateChainFile st) (
    \cert -> withCString (tlsPrivateKeyFile st) (
      \key -> withCString (tlsDHParametersFile st) (
        \para -> do
          r <- c_startSockTls s cert key para
          if r == nullPtr
            then throwErrno "could not start TLS"
            else return . TlsSocketIO $ r
        )
      )
    )
  startTls _ s@(TlsSocketIO _) = return s
  isSecure (SocketIO _) = False
  isSecure (TlsSocketIO _) = True


-- | connectToHost hostName port
--
--  Connects to the given host and port.
connectToHost :: String -> Int -> IO SocketIO
connectToHost host port = do
  ip <- getAddr
  connectTo ip port
  where
    getAddr :: IO IP.IP
    getAddr = do
      add <- Soc.getAddrInfo Nothing (Just host) Nothing
      case add of
        [] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
        (a:_) -> case Soc.addrAddress a of
          Soc.SockAddrInet _ a'  -> return . IP.IPv4 . IP.fromHostAddress $ a'
          Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a'
          _ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing


-- | ConnecctTo ipAddress port
--
--  Connects to the given port of the host at the given IP address.
connectTo :: IP.IP -> Int -> IO SocketIO
connectTo host port = do
  r <- case host of
    IP.IPv4 host' -> SocketIO <$> c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port)
    IP.IPv6 host' -> SocketIO <$> withArray (ipToArray host') (
      \add -> c_connect6 add (fromIntegral port)
      )
  if sock r == nullPtr
    then throwErrno "could not connect to host"
    else return r
  where
    ipToArray :: IP.IPv6 -> [CUChar]
    ipToArray ip = let
      (w0, w1, w2, w3) = IP.toHostAddress6 ip
      in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3]
    wtoc :: Word32 -> [CUChar]
    wtoc w = let
      c0 = fromIntegral $ mod w 256
      w1 = div w 256
      c1 = fromIntegral $ mod w1 256
      w2 = div w1 256
      c2 = fromIntegral $ mod w2 256
      c3 = fromIntegral $ div w2 256
      in [c3, c2, c1, c0]
  
-- | bindPort port
--  Binds to the given IP port, becoming ready to accept connections on it.
--  Binding to port numbers under 1024 will fail unless performed by the superuser,
--  once bounded, a process can reduce its privileges and still accept clients on that port.
bindPort :: Int -> IO BoundedPort
bindPort port = do
  r <- fmap BoundedPort $ c_getPort $ fromIntegral port
  if lis r == nullPtr
    then throwErrno "could not bind to port"
    else return r
  
-- | accept port
--
--  Accept clients on a port previously bound with bindPort.
accept :: BoundedPort -> IO SocketIO
accept port = do
  r <- SocketIO <$> c_accept (lis port)
  if sock r == nullPtr
    then throwErrno "could not accept connection"
    else return r

-- | Gets the address of the peer socket of a internet connection.
getPeer :: SocketIO -> IO (IP.IP, Int)
getPeer s = allocaArray 16 (
  \p6 -> alloca (
    \p4 -> alloca (
      \iptype -> do
        p <- c_getPeer (sock s) p4 p6 iptype
        if p == -1
          then throwErrno "could not get peer address"
          else do
          iptp <- peek iptype
          if iptp == 1
            then do --IPv6
            add <- peekArray 16 p6
            return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p)
            else do --IPv4
            add <- peek p4
            return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p)
      )
    )
  )