module Network.Fancy
    (
     HostName, Address(..),
     
     withStream, connectStream,
     
     connectDgram, withDgram, StringLike, recv,send, closeSocket,
     
     ServerSpec(..), serverSpec, 
     Threading(..), Reverse(..),
     streamServer, dgramServer, sleepForever,
     
     getCurrentHost,
     Socket
    ) where
import Control.Concurrent
import Control.Exception as E(bracket, finally, try, SomeException)
import Control.Monad(when, forM)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Lazy  as L
import Data.List(intercalate)
import Data.Typeable(Typeable)
import Foreign
import Foreign.C
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import Numeric(showHex)
import System.IO(Handle, hClose, IOMode(ReadWriteMode))
import System.IO.Unsafe(unsafeInterleaveIO)
import System.Posix.Internals hiding(c_close)
import GHC.IO.Device
import GHC.IO.Handle.FD(fdToHandle')
setNonBlockingFD' :: FD -> IO ()
setNonBlockingFD' =
    flip System.Posix.Internals.setNonBlockingFD True
type HostName = String
data Address = IP   HostName Int 
             | IPv4 HostName Int 
             | IPv6 HostName Int 
             | Unix FilePath     
               deriving(Eq,Ord,Show,Typeable)
class StringLike string where
    toBS   :: string -> B.ByteString
    fromBS :: B.ByteString -> string
instance StringLike String where
    toBS   = B.pack
    fromBS = B.unpack
instance StringLike L.ByteString where
    toBS   = B.concat . L.toChunks
    fromBS = \x -> L.fromChunks [x]
instance StringLike B.ByteString where
    toBS   = id
    fromBS = id
send :: StringLike string => Socket -> string -> IO ()
send (Socket s) bs = B.unsafeUseAsCStringLen (toBS bs) $ \(ptr,len) -> do
                     r <- throwErrnoIfMinus1RetryMayBlock "send" (c_send s (castPtr ptr) (fromIntegral len) 0) (threadWaitWrite (fromIntegral s))
                     when (r/=fromIntegral len) $ fail "send: partial packet sent!"
recv :: StringLike string => Socket -> Int -> IO string
recv (Socket s) len= fmap fromBS (
                     B.createAndTrim len $ \ptr -> do
                     r <- throwErrnoIfMinus1RetryMayBlock "recv" (c_recv s (castPtr ptr) (fromIntegral len) 0) (threadWaitRead (fromIntegral s))
                     return $ fromIntegral r)
recvFrom :: StringLike string => Socket -> Int -> SocketAddress -> IO (string,SocketAddress)
recvFrom (Socket s) buflen (SA _ salen) = do
  sa <- mallocForeignPtrBytes salen
  withForeignPtr sa $ \sa_ptr -> do
  str<- B.createAndTrim buflen $ \ptr -> do
    with (fromIntegral salen) $ \salen_ptr -> do
    fmap fromIntegral $ throwErrnoIfMinus1RetryMayBlock "recvfrom"
                                                        (c_recvfrom s ptr (fromIntegral buflen) 0 sa_ptr salen_ptr)
                                                        (threadWaitRead (fromIntegral s))
  return (fromBS str, SA sa salen)
sendTo :: StringLike string => SocketAddress -> Socket -> string -> IO ()
sendTo (SA sa salen) (Socket s) str = do
  withForeignPtr sa $ \sa_ptr -> do
  B.unsafeUseAsCStringLen (toBS str) $ \(ptr,len) -> do
  r <- throwErrnoIfMinus1RetryMayBlock "sendTo" 
                                       (c_sendto s (castPtr ptr) (fromIntegral len) 0 sa_ptr (fromIntegral salen))
                                       (threadWaitWrite (fromIntegral s))
  when (r/=fromIntegral len) $ fail "sendTo: partial packet sent!"
foreign import CALLCONV SAFE_ON_WIN "recv" c_recv :: CInt -> Ptr Word8 -> CSize -> CInt -> IO (Int64)
foreign import CALLCONV SAFE_ON_WIN "send" c_send :: CInt -> Ptr Word8 -> CSize -> CInt -> IO (Int64)
foreign import CALLCONV SAFE_ON_WIN "recvfrom" c_recvfrom :: CInt -> Ptr Word8 -> CSize -> CInt -> Ptr () -> Ptr SLen -> IO (Int64)
foreign import CALLCONV SAFE_ON_WIN "sendto" c_sendto :: CInt -> Ptr Word8 -> CSize -> CInt -> Ptr () -> SLen -> IO (Int64)
closeSocket :: Socket -> IO ()
closeSocket (Socket fd) = throwErrnoIfMinus1_ "close" $ c_close fd
foreign import CALLCONV unsafe "bind"    c_bind    :: CInt -> Ptr () -> (SLen) -> IO CInt
foreign import CALLCONV unsafe "listen"  c_listen  :: CInt -> CInt -> IO CInt
foreign import CALLCONV unsafe "socket"  c_socket  :: CFamily -> CType -> CInt -> IO CInt
foreign import CALLCONV SAFE_ON_WIN "connect" c_connect :: CInt -> Ptr () -> (SLen) -> IO CInt
foreign import CALLCONV unsafe "close" c_close :: CInt -> IO CInt
newtype Socket = Socket CInt
withDgram :: Address -> (Socket -> IO a) -> IO a
withDgram a = bracket (connectDgram a) closeSocket
withStream :: Address -> (Handle -> IO a) -> IO a
withStream a = bracket (connectStream a) hClose
connectStream :: Address -> IO Handle
connectStream addr = a2sas sockStream aiNumericserv addr >>= csas (connect sockStream) >>= socketToHandle
connectDgram  :: Address -> IO Socket
connectDgram addr = a2sas sockDgram aiNumericserv addr >>= csas (connect sockDgram)
socketToHandle :: Socket -> IO Handle
socketToHandle (Socket fd) = fdToHandle' (fromIntegral fd) (Just GHC.IO.Device.Stream) True (show fd) ReadWriteMode True
connect :: CType -> SocketAddress -> IO Socket
connect stype (SA sa len) = do
  fam <- getFamily (SA sa len)
  s   <- throwErrnoIfMinus1 "socket" $ c_socket fam stype 0
  setNonBlockingFD' s
  let loop = do r   <- withForeignPtr sa $ \ptr -> c_connect s ptr (fromIntegral len)
	       	err <- getErrno
       	        case r of
                  1 | err == eINTR       -> do loop
		     | err == eINPROGRESS -> do threadWaitWrite (fromIntegral s)
                                                soe <- getsockopt_error s
                                                if soe==0 then return (Socket s) else fail "connect"
                     |  otherwise         -> do fail "connect"
                  _                       -> do return $ Socket s
  loop
foreign import ccall unsafe getsockopt_error :: CInt -> IO CInt
getFamily :: SocketAddress -> IO CFamily
getFamily (SA sa _) = worker >>= return . fromIntegral
    where worker :: IO Word16
	  worker = withForeignPtr sa ((\hsc_ptr -> peekByteOff hsc_ptr 0)) 
csas :: (SocketAddress -> IO a) -> [SocketAddress] -> IO a
csas _ []       = fail "No such host"
csas c [sa]     = c sa
csas c (sa:sas) = do x <- try' (c sa)
                     case x of
                      (Left _)  -> csas c sas
                      (Right v) -> return v
try' :: IO a -> IO (Either SomeException a)
try' = E.try
withResolverLock :: IO a -> IO a
withResolverLock x = x
data SocketAddress = SA !(ForeignPtr ()) !Int deriving(Show)
type AddrInfoT     = Word8
type CFamily      = Int
type CType        = Int
afInet :: CFamily
afInet =  2
afInet6 :: CFamily
afInet6 =  10
afUnspec :: CFamily
afUnspec =  0
afLocal :: CFamily
afLocal =  1
sockStream :: CType
sockStream =  1
sockDgram :: CType
sockDgram =  2
a2sas :: CType -> CInt -> Address -> IO [SocketAddress]
a2sas t f (IP   hn p)        = getAddrInfo hn (show p) f afUnspec t
a2sas t f (IPv4 hn p)        = getAddrInfo hn (show p) f afInet t
a2sas t f (IPv6 hn p)        = getAddrInfo hn (show p) f afInet6 t
a2sas _ _ (Unix fp)          = do let maxSize = (((110))((2)))
                                  when (length fp >= maxSize) $ fail "Too long address for Unix socket"
                                  sa <- mallocForeignPtrBytes $ fromIntegral salLocal
                                  withForeignPtr sa $ \sa_ptr -> do
                                  ((\hsc_ptr -> pokeByteOff hsc_ptr 0)) sa_ptr afLocal
                                  let tw :: Char -> Word8
                                      tw = toEnum . fromEnum
                                  pokeArray0 0 (((\hsc_ptr -> hsc_ptr `plusPtr` 2)) sa_ptr) $ map tw fp
                                  return [SA sa salLocal]
salLocal :: Int
salLocal   =  (110)
aiPassive, aiNumericserv :: CInt
aiNumericserv = 1024
aiPassive = 1
getAddrInfo :: String     
            -> String    
            -> CInt          
            -> CFamily 
            -> CType 
            -> IO [SocketAddress]
getAddrInfo host serv flags fam typ = withResolverLock $ do
  let unai :: Ptr AddrInfoT -> IO [SocketAddress]
      unai ai | ai == nullPtr = return []
              | otherwise     = uwork ai
      uwork ai = do sal'<- ((\hsc_ptr -> peekByteOff hsc_ptr 16))   ai :: IO SLen
                    sa' <- ((\hsc_ptr -> peekByteOff hsc_ptr 24))      ai
                    let sal = fromIntegral sal'
                    sa  <- mallocForeignPtrBytes sal
                    copyBytes (unsafeForeignPtrToPtr sa) sa' sal
                    next<- ((\hsc_ptr -> peekByteOff hsc_ptr 40))      ai
                    rest<- unai next
                    return ((SA sa sal):rest)
      getAI :: IO (Ptr AddrInfoT)
      getAI = allocaBytes ((48)) $ \hints -> do
              _ <- B.memset hints 0 ((48))
              ((\hsc_ptr -> pokeByteOff hsc_ptr 0))    hints flags
              ((\hsc_ptr -> pokeByteOff hsc_ptr 4))   hints fam
              ((\hsc_ptr -> pokeByteOff hsc_ptr 8)) hints typ
              withStr host $ \host_buf -> do
              withStr serv $ \serv_buf -> do
              with nullPtr $ \result   -> do
              throwGAIErrorIf $ c_getaddrinfo host_buf serv_buf hints result
              peek result
      withStr :: String -> (CString -> IO a) -> IO a
      withStr "" fun = fun nullPtr
      withStr s  fun = withCString s fun
  bracket getAI c_freeaddrinfo unai
foreign import CALLCONV unsafe "freeaddrinfo" c_freeaddrinfo :: Ptr AddrInfoT -> IO ()
foreign import CALLCONV   safe "getaddrinfo"  c_getaddrinfo  :: Ptr CChar -> Ptr CChar -> 
							     Ptr AddrInfoT -> Ptr (Ptr AddrInfoT) ->
							     IO CInt
throwGAIErrorIf :: IO CInt -> IO ()
throwGAIErrorIf comp = do 
  err <- comp
  when (err /= 0) (gaiError err >>= fail)
gaiError :: CInt -> IO String
gaiError err = c_gai_strerror err >>= peekCString
foreign import CALLCONV unsafe "gai_strerror" c_gai_strerror :: CInt -> IO (Ptr CChar)
getCurrentHost :: IO HostName
getCurrentHost = do
  allocaArray 256 $ \buffer -> do
    throwErrnoIfMinus1_ "gethostname" $ c_gethostname buffer 256
    peekCString buffer
foreign import CALLCONV unsafe "gethostname" c_gethostname :: Ptr CChar -> CSize -> IO CInt
data Threading  = Threaded  
                | Inline    
data Reverse    = ReverseNumeric 
                | ReverseName    
data ServerSpec = ServerSpec
    { address         :: Address    
    , reverseAddress  :: Reverse    
    , threading       :: Threading  
    , closeConnection :: Bool       
    , recvSize        :: Int        
    }
serverSpec :: ServerSpec
serverSpec = ServerSpec { address   = IP "" 0
                        , reverseAddress   = ReverseNumeric
                        , threading = Threaded
                        , closeConnection = True
                        , recvSize        = 4096
                        }
streamServer :: ServerSpec -> (Handle -> Address -> IO ()) -> IO [ThreadId]
streamServer ss sfun = do
  sas <- a2sas sockStream (aiNumericserv .|. aiPassive) (address ss)
  when (null sas) $ fail "No address for server!"
  let sf ha psa = case threading ss of
                    Threaded -> forkIO (clo ha $ sfun ha psa) >> return ()
                    Inline   -> clo ha $ sfun ha psa
      clo ha = case closeConnection ss of
                True  -> \x -> x `E.finally` (hClose ha)
                False -> id
  forM sas $ \sa -> do
     fam  <- getFamily sa
     sock <- throwErrnoIfMinus1 "socket" $ c_socket fam sockStream 0
     setNonBlockingFD' sock
     let socket = Socket sock
     let on :: CInt
         on = 1
         os = fromIntegral $ sizeOf on
     _ <- with on $ \onptr -> c_setsockopt sock (1) (2) onptr os
     bind socket sa
     listen socket 128
     let loop = do (s,psa) <- accept socket sa
                   a <- unsafeInterleaveIO $ case reverseAddress ss of
                          ReverseNumeric -> rnumeric psa
                          ReverseName    -> rname psa
                   ha <- socketToHandle s
                   sf ha a
                   loop
     forkIO loop
foreign import CALLCONV unsafe "setsockopt" c_setsockopt ::
  CInt -> CInt -> CInt -> Ptr a -> CInt -> IO CInt
bind :: Socket -> SocketAddress -> IO ()
bind (Socket sock) (SA sa len) = do
  withForeignPtr sa $ \sa_ptr -> 
    throwErrnoIfMinus1_ "bind" $ c_bind sock sa_ptr (fromIntegral len)
listen :: Socket -> Int -> IO ()
listen (Socket s) iv = throwErrnoIfMinus1_ "listen" (c_listen s (toEnum iv))
accept :: Socket -> SocketAddress -> IO (Socket, SocketAddress)
accept (Socket lfd) (SA _ len) = do
  sa <- mallocForeignPtrBytes len
  s  <- withForeignPtr sa $ \sa_ptr -> do
          with (fromIntegral len) $ \len_ptr -> do
            throwErrnoIfMinus1RetryMayBlock "accept" (c_accept lfd sa_ptr len_ptr) (threadWaitRead (fromIntegral lfd))
  setNonBlockingFD' s
  return (Socket s,SA sa len)
foreign import CALLCONV SAFE_ON_WIN "accept"  c_accept  :: CInt -> Ptr () -> Ptr (SLen) -> IO CInt
dgramServer  :: StringLike packet => ServerSpec 
             -> (packet -> Address -> IO [packet]) 
                -> IO [ThreadId] 
dgramServer ss sfun = do
  sas <- a2sas sockDgram (aiNumericserv .|. aiPassive) (address ss)
  when (null sas) $ fail "No address for server!"
  forM sas $ \sa -> do
     fam  <- getFamily sa
     sock <- throwErrnoIfMinus1 "socket" $ c_socket fam sockDgram 0
     setNonBlockingFD' sock
     let socket = Socket sock
     let on :: CInt
         on = 1
         os = fromIntegral $ sizeOf on
     _ <- with on $ \onptr -> c_setsockopt sock (1) (2) onptr os
     bind socket sa
     let loop = do (str,psa) <- recvFrom socket (recvSize ss) sa
                   a <- unsafeInterleaveIO $ case reverseAddress ss of
                          ReverseNumeric -> rnumeric psa
                          ReverseName    -> rname psa
                   case threading ss of
                    Threaded -> forkIO (mapM_ (sendTo psa socket) =<< sfun str a) >> return ()
                    Inline   -> mapM_ (sendTo psa socket) =<< sfun str a
                   loop
     forkIO loop
rnumeric, rname :: SocketAddress -> IO Address
rnumeric (SA sa len) = do
  f <- getFamily (SA sa len)
  withForeignPtr sa $ \sa_ptr -> do
  let v4fmt, v6fmt :: [Word8] -> String
      v4fmt         = intercalate "." . map show
      v6fmt xs      = if head xs == 0 then ':':':':v6map (dropWhile (==0) xs) else v6map xs
      v6map         = intercalate ":" . units
      units []      = []
      units [x]     = [showHex x ""]
      units (x:y:r) = dropWhile (=='0') (showHex x $ ldigit y) : units r
      ldigit x      = case showHex x "" of
                        [z] -> ['0',z]
                        z   -> z
  case () of
    _ | f == afInet -> do n <- fmap v4fmt $ peekArray 4 $ ((\hsc_ptr -> hsc_ptr `plusPtr` 4)) sa_ptr
                          p <- ntohs =<< ((\hsc_ptr -> peekByteOff hsc_ptr 2)) sa_ptr
                          return $ IPv4 n (fromIntegral p)
      | f == afInet6-> do n <- fmap v6fmt $ peekArray 16 $ ((\hsc_ptr -> hsc_ptr `plusPtr` 4)) sa_ptr
                          p <- ntohs =<< ((\hsc_ptr -> peekByteOff hsc_ptr 2)) sa_ptr
                          return $ IPv6 n (fromIntegral p)
      | f == afLocal-> do n <- peekCString $ ((\hsc_ptr -> hsc_ptr `plusPtr` 2)) sa_ptr
                          return $ Unix n
      | otherwise   -> do fail "Unsupported address family!"
foreign import CALLCONV unsafe ntohs :: Word16 -> IO Word16
rname (SA sa len) = do
  f <- getFamily (SA sa len)
  withForeignPtr sa $ \sa_ptr -> do
  let rev = do allocaArray 256 $ \hptr -> do
               throwGAIErrorIf $ getnameinfo sa_ptr (fromIntegral len) hptr 256 nullPtr 0 0
               peekCString hptr
  case () of
    _ | f == afInet  -> do n <- rev
                           p <- ntohs =<< ((\hsc_ptr -> peekByteOff hsc_ptr 2)) sa_ptr
                           return $ IPv4 n (fromIntegral p)
      | f == afInet6 -> do n <- rev
                           p <- ntohs =<< ((\hsc_ptr -> peekByteOff hsc_ptr 2)) sa_ptr
                           return $ IPv6 n (fromIntegral p)
      | f == afLocal -> do n <- peekCString $ ((\hsc_ptr -> hsc_ptr `plusPtr` 2)) sa_ptr
                           return $ Unix n
      | otherwise    -> do fail "Unsupported address family!"
type SLen = Word32
foreign import CALLCONV safe getnameinfo :: Ptr () -> SLen -> Ptr CChar -> SLen -> Ptr CChar -> SLen -> CInt -> IO CInt
sleepForever :: IO ()
sleepForever = threadDelay maxBound >> sleepForever