{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE InterruptibleFFI #-} {-# LANGUAGE EmptyDataDecls #-} module System.IO.Uniform.Targets (TlsSettings(..), UniformIO(..), SocketIO, FileIO, StdIO, TlsStream, BoundedPort, SomeIO(..), connectTo, connectToHost, bindPort, accept, openFile, getPeer, closePort) where import Foreign import Foreign.C.Types import Foreign.C.String import Foreign.C.Error import qualified Data.IP as IP import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.List as L import Control.Exception import Control.Applicative ((<$>)) import qualified Network.Socket as Soc import System.IO.Error import Data.Default.Class import System.Posix.Types (Fd(..)) -- | Settings for starttls functions. data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String, tlsDHParametersFile :: String} deriving (Read, Show) instance Default TlsSettings where def = TlsSettings "" "" "" -- | -- Typeclass for uniform IO targets. class UniformIO a where -- | uRead fd n -- -- Reads a block of at most n bytes of data from the IO target. -- Reading will block if there's no data available, but will return immediately -- if any amount of data is availble. uRead :: a -> Int -> IO ByteString -- | uPut fd text -- -- Writes all the bytes of text into the IO target. Takes care of retrying if needed. uPut :: a -> ByteString -> IO () -- | fClose fd -- -- Closes the IO target, releasing any allocated resource. Resources may leak if not called -- for every oppened fd. uClose :: a -> IO () -- | startTLS fd -- -- Starts a TLS connection over the IO target. startTls :: TlsSettings -> a -> IO TlsStream -- | isSecure fd -- -- Indicates whether the data written or read from fd is secure at transport. isSecure :: a -> Bool -- | A type that wraps any type in the UniformIO class. data SomeIO = forall a. (UniformIO a) => SomeIO a instance UniformIO SomeIO where uRead (SomeIO s) n = uRead s n uPut (SomeIO s) t = uPut s t uClose (SomeIO s) = uClose s startTls set (SomeIO s) = startTls set s isSecure (SomeIO s) = isSecure s data Nethandler -- | A bounded IP port from where to accept SocketIO connections. newtype BoundedPort = BoundedPort {lis :: (Ptr Nethandler)} data Ds newtype SocketIO = SocketIO {sock :: (Ptr Ds)} newtype FileIO = FileIO {fd :: (Ptr Ds)} data TlsDs newtype TlsStream = TlsStream {tls :: (Ptr TlsDs)} data StdIO -- | UniformIO IP connections. instance UniformIO SocketIO where uRead s n = do allocaArray n ( \b -> do count <- c_recv (sock s) b (fromIntegral n) if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut s t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_send (sock s) str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose s = do f <- Fd <$> c_prepareToClose (sock s) closeFd f startTls st s = withCString (tlsCertificateChainFile st) ( \cert -> withCString (tlsPrivateKeyFile st) ( \key -> withCString (tlsDHParametersFile st) ( \para -> do r <- c_startSockTls (sock s) cert key para if r == nullPtr then throwErrno "could not start TLS" else return . TlsStream $ r ) ) ) isSecure _ = False -- | UniformIO IP connections. instance UniformIO StdIO where uRead _ n = do allocaArray n ( \b -> do count <- c_recvStd b (fromIntegral n) if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut _ t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_sendStd str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose _ = return () startTls _ _ = return . TlsStream $ nullPtr isSecure _ = False -- | UniformIO type for file IO. instance UniformIO FileIO where uRead s n = do allocaArray n ( \b -> do count <- c_recv (fd s) b $ fromIntegral n if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut s t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_send (fd s) str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose s = do f <- Fd <$> c_prepareToClose (fd s) closeFd f -- Not implemented yet. startTls _ _ = return . TlsStream $ nullPtr isSecure _ = False -- | UniformIO wrapper that applies TLS to communication on IO target. -- This type is constructed by calling startTls on other targets. instance UniformIO TlsStream where uRead s n = do allocaArray n ( \b -> do count <- c_recvTls (tls s) b $ fromIntegral n if count < 0 then throwErrno "could not read" else BS.packCStringLen (b, fromIntegral count) ) uPut s t = do BS.useAsCStringLen t ( \(str, n) -> do count <- c_sendTls (tls s) str $ fromIntegral n if count < 0 then throwErrno "could not write" else return () ) uClose s = do d <- c_closeTls (tls s) f <- Fd <$> c_prepareToClose d closeFd f startTls _ s = return s isSecure _ = True -- | connectToHost hostName port -- -- Connects to the given host and port. connectToHost :: String -> Int -> IO SocketIO connectToHost host port = do ip <- getAddr connectTo ip port where getAddr :: IO IP.IP getAddr = do add <- Soc.getAddrInfo Nothing (Just host) Nothing case add of [] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing (a:_) -> case Soc.addrAddress a of Soc.SockAddrInet _ a' -> return . IP.IPv4 . IP.fromHostAddress $ a' Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a' _ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing -- | ConnecctTo ipAddress port -- -- Connects to the given port of the host at the given IP address. connectTo :: IP.IP -> Int -> IO SocketIO connectTo host port = do r <- case host of IP.IPv4 host' -> fmap SocketIO $ c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port) IP.IPv6 host' -> fmap SocketIO $ withArray (ipToArray host') ( \add -> c_connect6 add (fromIntegral port) ) if sock r == nullPtr then throwErrno "could not connect to host" else return r where ipToArray :: IP.IPv6 -> [CUChar] ipToArray ip = let (w0, w1, w2, w3) = IP.toHostAddress6 ip in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3] wtoc :: Word32 -> [CUChar] wtoc w = let c0 = fromIntegral $ mod w 256 w1 = div w 256 c1 = fromIntegral $ mod w1 256 w2 = div w1 256 c2 = fromIntegral $ mod w2 256 c3 = fromIntegral $ div w2 256 in [c3, c2, c1, c0] -- | bindPort port -- Binds to the given IP port, becoming ready to accept connections on it. -- Binding to port numbers under 1024 will fail unless performed by the superuser, -- once bounded, a process can reduce its privileges and still accept clients on that port. bindPort :: Int -> IO BoundedPort bindPort port = do r <- fmap BoundedPort $ c_getPort $ fromIntegral port if lis r == nullPtr then throwErrno "could not bind to port" else return r -- | accept port -- -- Accept clients on a port previously bound with bindPort. accept :: BoundedPort -> IO SocketIO accept port = do r <- fmap SocketIO $ c_accept (lis port) if sock r == nullPtr then throwErrno "could not accept connection" else return r -- | Open a file for bidirectional IO. openFile :: String -> IO FileIO openFile fileName = do r <- withCString fileName ( \f -> fmap FileIO $ c_createFile f ) if fd r == nullPtr then throwErrno "could not open file" else return r -- | Gets the address of the peer socket of a internet connection. getPeer :: SocketIO -> IO (IP.IP, Int) getPeer s = allocaArray 16 ( \p6 -> alloca ( \p4 -> alloca ( \iptype -> do p <- c_getPeer (sock s) p4 p6 iptype if p == -1 then throwErrno "could not get peer address" else do iptp <- peek iptype if iptp == 1 then do --IPv6 add <- peekArray 16 p6 return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p) else do --IPv4 add <- peek p4 return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p) ) ) ) closeFd :: Fd -> IO () closeFd (Fd f) = c_closeFd f -- | Closes a BoundedPort, and releases any resource used by it. closePort :: BoundedPort -> IO () closePort p = c_closePort (lis p) foreign import ccall interruptible "getPort" c_getPort :: CInt -> IO (Ptr Nethandler) foreign import ccall interruptible "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr Ds) foreign import ccall safe "createFromFileName" c_createFile :: CString -> IO (Ptr Ds) foreign import ccall interruptible "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr Ds) foreign import ccall interruptible "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr Ds) foreign import ccall interruptible "startSockTls" c_startSockTls :: Ptr Ds -> CString -> CString -> CString -> IO (Ptr TlsDs) foreign import ccall safe "getPeer" c_getPeer :: Ptr Ds -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt) --foreign import ccall safe "getFd" c_getFd :: Ptr Ds -> IO CInt --foreign import ccall safe "getTlsFd" c_getTlsFd :: Ptr TlsDs -> IO CInt foreign import ccall safe "closeFd" c_closeFd :: CInt -> IO () foreign import ccall safe "prepareToClose" c_prepareToClose :: Ptr Ds -> IO CInt foreign import ccall safe "closeHandler" c_closePort :: Ptr Nethandler -> IO () foreign import ccall safe "closeTls" c_closeTls :: Ptr TlsDs -> IO (Ptr Ds) foreign import ccall interruptible "sendDs" c_send :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "stdDsSend" c_sendStd :: Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "recvDs" c_recv :: Ptr Ds -> Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "stdDsRecv" c_recvStd :: Ptr CChar -> CInt -> IO CInt foreign import ccall interruptible "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt