module System.IO.Uniform.Network (
SocketIO,
BoundedPort,
connectTo,
connectToHost,
bindPort,
accept,
closePort,
getPeer
) where
import System.IO.Uniform
import System.IO.Uniform.External
import Foreign
import Foreign.C.Types
import Foreign.C.String
import Foreign.C.Error
import qualified Data.IP as IP
import qualified Data.ByteString as BS
import qualified Data.List as L
import Control.Exception
import Control.Monad
import qualified Network.Socket as Soc
import System.IO.Error
import System.Posix.Types (Fd(..))
instance UniformIO SocketIO where
uRead (SocketIO s) n = allocaArray n (
\b -> do
count <- c_recv s b (fromIntegral n)
if count < 0
then throwErrno "could not read"
else BS.packCStringLen (b, fromIntegral count)
)
uRead (TlsSocketIO s) n = allocaArray n (
\b -> do
count <- c_recvTls s b $ fromIntegral n
if count < 0
then throwErrno "could not read"
else BS.packCStringLen (b, fromIntegral count)
)
uPut (SocketIO s) t = BS.useAsCStringLen t (
\(str, n) -> do
count <- c_send s str $ fromIntegral n
when (count < 0) $ throwErrno "could not write"
)
uPut (TlsSocketIO s) t = BS.useAsCStringLen t (
\(str, n) -> do
count <- c_sendTls s str $ fromIntegral n
when (count < 0) $ throwErrno "could not write"
)
uClose (SocketIO s) = do
f <- Fd <$> c_prepareToClose s
closeFd f
uClose (TlsSocketIO s) = do
d <- c_closeTls s
f <- Fd <$> c_prepareToClose d
closeFd f
startTls st (SocketIO s) = withCString (tlsCertificateChainFile st) (
\cert -> withCString (tlsPrivateKeyFile st) (
\key -> withCString (tlsDHParametersFile st) (
\para -> do
r <- c_startSockTls s cert key para
if r == nullPtr
then throwErrno "could not start TLS"
else return . TlsSocketIO $ r
)
)
)
startTls _ s@(TlsSocketIO _) = return s
isSecure (SocketIO _) = False
isSecure (TlsSocketIO _) = True
connectToHost :: String -> Int -> IO SocketIO
connectToHost host port = do
ip <- getAddr
connectTo ip port
where
getAddr :: IO IP.IP
getAddr = do
add <- Soc.getAddrInfo Nothing (Just host) Nothing
case add of
[] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
(a:_) -> case Soc.addrAddress a of
Soc.SockAddrInet _ a' -> return . IP.IPv4 . IP.fromHostAddress $ a'
Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a'
_ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
connectTo :: IP.IP -> Int -> IO SocketIO
connectTo host port = do
r <- case host of
IP.IPv4 host' -> SocketIO <$> c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port)
IP.IPv6 host' -> SocketIO <$> withArray (ipToArray host') (
\add -> c_connect6 add (fromIntegral port)
)
if sock r == nullPtr
then throwErrno "could not connect to host"
else return r
where
ipToArray :: IP.IPv6 -> [CUChar]
ipToArray ip = let
(w0, w1, w2, w3) = IP.toHostAddress6 ip
in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3]
wtoc :: Word32 -> [CUChar]
wtoc w = let
c0 = fromIntegral $ mod w 256
w1 = div w 256
c1 = fromIntegral $ mod w1 256
w2 = div w1 256
c2 = fromIntegral $ mod w2 256
c3 = fromIntegral $ div w2 256
in [c3, c2, c1, c0]
bindPort :: Int -> IO BoundedPort
bindPort port = do
r <- fmap BoundedPort $ c_getPort $ fromIntegral port
if lis r == nullPtr
then throwErrno "could not bind to port"
else return r
accept :: BoundedPort -> IO SocketIO
accept port = do
r <- SocketIO <$> c_accept (lis port)
if sock r == nullPtr
then throwErrno "could not accept connection"
else return r
getPeer :: SocketIO -> IO (IP.IP, Int)
getPeer s = allocaArray 16 (
\p6 -> alloca (
\p4 -> alloca (
\iptype -> do
p <- c_getPeer (sock s) p4 p6 iptype
if p == 1
then throwErrno "could not get peer address"
else do
iptp <- peek iptype
if iptp == 1
then do --IPv6
add <- peekArray 16 p6
return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p)
else do --IPv4
add <- peek p4
return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p)
)
)
)