{-# LANGUAGE CPP, ScopedTypeVariables #-}

-- | A channel module with transparent network communication.
module Control.CUtils.NetChan (NetSend, NetRecv, localHost, newNetChan, newNetSend, newNetRecv, send, receive, recv, recvSend, sendRecv, recvRecv, activateSend, activateRecv, Auth, authServer, authClient, example) where

-- This module has a strategy for routing around dead nodes. See 'routeAround'.

import System.IO
import System.Process
import Data.List (find, isPrefixOf, (\\))
import Network
import Network.Socket (socketToHandle, SockAddr(..))
import Network.BSD
import Control.Concurrent
import Control.Monad
import Data.ByteString.Lazy hiding (map, isPrefixOf, dropWhile, drop, head, split)
import qualified Data.ByteString as B
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import qualified Data.Map as M
import Data.Maybe
import Data.Char
import Data.IORef
import Data.Bits
import Control.Exception
import System.IO.Unsafe
import Crypto.Hash.SHA512
import Codec.Crypto.RSA.Pure
import Crypto.Random
import Reexport.Crypto.Random
import Data.SecureMem
import Foreign.Storable
import Data.Tagged
import Prelude hiding (lookup, length, splitAt, catch)

import Control.CUtils.Split

type Ident = ByteString

{-# NOINLINE serverup #-}
serverup = unsafePerformIO (newMVar False)

{-# NOINLINE table #-}
table :: MVar (M.Map Ident (ByteString -> IO ()))
table = unsafePerformIO (newMVar (M.singleton empty (\_ -> return ())))

data ChannelFibre = ChannelFibre (MVar Bool) Handle

data NetSend t = NetSend HostName Ident (MVar [HostName]) (MVar [ChannelFibre])

data NetRecv t = NetRecv Ident (NetSend t) (NetSend HostName) (Chan ByteString)

instance Eq ChannelFibre where
	ChannelFibre _ hdl == ChannelFibre _ hdl2 = hdl == hdl2

instance Eq (NetSend t) where
	NetSend _ ident _ _ == NetSend _ ident2 _ _ = ident == ident2

instance Eq (NetRecv t) where
	NetRecv ident _ _ _ == NetRecv ident2 _ _ _ = ident == ident2

port = 2999

getIPAddress :: String -> Word32
getIPAddress ip = shiftL n4 24 .|. shiftL n3 16 .|. shiftL n2 8 .|. n1 where
	[n1,n2,n3,n4] = map read $ split '.' ip

-- Hack - just gets the local IP address
localHost = liftM (drop 39 . head . dropWhile (not . isPrefixOf "   IPv4") . lines) $ readProcess "ipconfig" [] []

-- The identifier of a channel is determined by the originating host and a host-unique serial number.
identifier :: String -> Word32 -> Ident
identifier ip entry = encode (entry, getIPAddress ip)

--- Channel creation.

-- | Creates a new channel, with receive and send ends.
newNetChan :: (Binary t) => IO (NetRecv t, NetSend t)
newNetChan = do
	mp <- readMVar table
	host <- localHost
	let ident = identifier host (fromIntegral (M.size mp))
	liftM2 (,) (__newNetRecv True Nothing ident) (__newNetSend True host ident)

modifyIdent b ident = append (pack $ map (fromIntegral . ord) $ if b then "main" else "back") ident

__emptyNetSend :: Bool -> NetSend HostName -> HostName -> Ident -> IO (NetSend t)
__emptyNetSend b backDown hostName ident = do
	let ident' = modifyIdent b ident

	-- Create a back channel.
	buffer <- newMVar []
	-- Fill the buffer immediately, so this host gets the data before downstreams die.
	if b then do
			backR <- __newNetRecv False (Just backDown) ident
			let loop = do
				host <- recv backR
				modifyMVar_ buffer (return . (host:))

				loop
			forkIO loop
		else
			return undefined

	mvar <- newMVar []
	return (NetSend hostName ident' buffer mvar)

__addConnection s@(NetSend _ ident buffer mvar) hostName = do
	mvar2 <- newMVar False

	-- Open a TCPIP socket to send
	hdl <- withSocketsDo $ connectTo hostName (PortNumber port)
	hSetBuffering hdl (BlockBuffering (Just 1024))

	-- Send identifier
	hPut hdl ident

	-- Send list of upstreams
	upstreams <- readMVar buffer
	let bs = encode (hostName : upstreams)
	hPut hdl $ encode $ length bs
	hPut hdl bs

	hFlush hdl

	modifyMVar_ mvar (return . (ChannelFibre mvar2 hdl:))

__newNetSend b hostName ident = do
	s <- if b then
			__emptyNetSend False undefined "" ident
		else
			return undefined
	s <- __emptyNetSend b s hostName ident
	__addConnection s hostName
	return s

-- | Open a channel to another host
newNetSend hostName = __newNetSend True hostName (identifier hostName 0)

readLoop f hdl = do
	n <- liftM decode (hGet hdl 8)
	bs <- hGet hdl n
	f bs
	readLoop f hdl

server socket = withSocketsDo $ do
	-- Accept loop
	let loop = do
		(hdl, host, _) <- accept socket
		ident <- hGet hdl 12
		may <- liftM (M.lookup ident) $ readMVar table
		maybe
			(hPutStrLn stderr ("The host " ++ host ++ " used an invalid Ident: " ++ show ident))
			(\f -> forkIO (withSocketsDo (readLoop f hdl)) >> return ())
			may
		loop

	loop

__newNetRecv :: (Binary t) => Bool -> Maybe (NetSend t) -> Ident -> IO (NetRecv t)
__newNetRecv b may ident = do
	chan <- newChan

	-- Create a back channel
	--
	-- The downstream of the back channel is the upstream of the main channel.
	backS <- if b then
			__emptyNetSend False undefined "" ident
		else
			return undefined

	downstream <- maybe
		(__emptyNetSend b backS "" ident)
		return
		may

	let ident' = modifyIdent b ident

	gotUpstreams <- newIORef False
	let listener bs = do
		got <- readIORef gotUpstreams
		if got then do
				writeChan chan bs

				-- Send the value to downstream receive ends.
				__send downstream bs
			else do
				writeIORef gotUpstreams True
				let x:xs = decode bs
				when b $ do
					let NetSend _ _ buffer _ = backS
					modifyMVar_ buffer (\_ -> return xs)
					__addConnection backS x

	-- Put a listener in the table.
	modifyMVar_ table (return . M.insert ident' listener)

	-- Start the server singleton
	modifyMVar_ serverup (\b -> unless b (withSocketsDo $ listenOn (PortNumber port) >>= forkIO . server >> return ()) >> return True)

	return (NetRecv ident' downstream backS chan)

-- | Creates a receive end of this host's channel. Type unsafe!
newNetRecv :: (Binary t) => IO (NetRecv t)
newNetRecv = localHost >>= \host -> __newNetRecv True Nothing (identifier host 0)

--- Send and receive.

-- If send fails, route around the node.
routeAround fib s@(NetSend _ ident buffer mvar) = do
	hosts <- modifyMVar buffer (\ls -> return ([], ls))
	mapM_ (__addConnection s) hosts
	modifyMVar_ mvar (return . (\\[fib]))

__send snd@(NetSend _ ident _ mvar) s = readMVar mvar >>= mapM_ (\fib@(ChannelFibre mvar hdl) -> do
	b <- modifyMVar mvar (\b ->
		s `seq` catch (hPut hdl (encode (length s)) >> hPut hdl s) (\(_ :: SomeException) -> routeAround fib snd >> __send snd s)
		>> return (True, b))
	-- Buffering
	unless b $ void $ forkIO $ do
		threadDelay 100000
		modifyMVar_ mvar (\_ -> return False)
		catch (hFlush hdl) (\(_ :: SomeException) -> routeAround fib snd >> __send snd s))

-- | Sends something on a channel.
send :: (Binary t) => NetSend t -> t -> IO ()
send snd x = __send snd (encode x)

receive (NetRecv _ _ _ chan) = readChan chan

-- | Receives something from a channel.
recv :: (Binary t) => NetRecv t -> IO t
recv r = liftM decode $ receive r

--- Sending and receiving channels.

-- | Receives the send end of a channel, on a channel.
recvSend r = recv r >>= activateSend

-- | Sends the receive end of a channel, on a channel.
sendRecv s@(NetSend hostName _ _ mvar) r@(NetRecv ident s2 backS _) = do
	send s r

	-- This node is now responsible for passing on messages to the destination(s).
	__addConnection s2 hostName

	-- Inform upstream of this
	send backS hostName

-- | Receives the receive end of a channel, on a channel.
recvRecv r = recv r >>= activateRecv

--- Channel data utilities.

instance Binary (NetSend t) where
	put (NetSend hostName ident _ _) = put hostName >> put ident
	get = liftM2 (\x y -> NetSend x y undefined undefined) get get

instance Binary (NetRecv t) where
	put (NetRecv ident _ _ _) = put ident
	get = liftM (\x -> NetRecv x undefined undefined undefined) get

activateSend :: NetSend t -> IO (NetSend t)
activateSend (NetSend hostName ident _ _) = __newNetSend True hostName ident

activateRecv :: (Binary t) => NetRecv t -> IO (NetRecv t)
activateRecv (NetRecv x _ _ _) = __newNetRecv True Nothing x

repeatM m = m >> repeatM m

data Auth t = Auth B.ByteString B.ByteString ByteString

putLazy = mapM_ putByteString . toChunks

instance Binary (Auth t) where
	put (Auth b b2 b3) = putByteString b >> putByteString b2 >> putLazy b3
	get = liftM3 Auth (getByteString 64) (getByteString 100) getRemainingLazyByteString

instance CryptoRandomGen EntropyPool where
	newGen _ = error "newGen: unsupported on SystemRNG"
	genSeedLength = Tagged maxBound
	genBytes l g = entropy `seq` Right (readSecureMem entropy, g) where
		entropy = grabEntropy l g
	reseed _ = Right
	reseedInfo _ = Never
	reseedPeriod _ = Never

readSecureMem mem = unsafePerformIO $ withSecureMemPtrSz mem $ \n p -> liftM B.pack $ mapM (peekByteOff p) [0..n-1]

-- | Remote exercise of authority. Commands are transmitted in the clear,
--   but authenticated.
--
--   auth - The authority to be served (runs on a separate thread).
--
--   r - The receive end from the host.
--
--   s - The send end to the host.
--
--   publicKey - The public key of the intended recipient.
authServer :: (Binary t) => (t -> IO ()) -> NetRecv (Auth t) -> NetSend ByteString -> PublicKey -> IO ()
authServer auth r s publicKey = do
	pool <- createEntropyPool

	-- Engage in crypto to agree on a random certificate.
	(cert, g) <- either throwIO return $ genBytes 100 pool
	cert <- return $ fromChunks [cert]
	let ei = encrypt g publicKey cert
	(enc, _) <- either throwIO return ei
	send s enc

	-- Accept requests.
	forkIO $ repeatM $ do
		command <- receive r
		let (tk, dr) = splitAt 64 command
		let x = runGet (getByteString 100 >> get) dr
		-- Check the hash before approving the command.
		when (fromChunks [hashlazy $ append cert dr] == tk) $ auth x

	return ()

-- | privateKey - The private key for this host.
--
--   Returns a function that can be used to send messages.
authClient :: (Binary t) => NetRecv ByteString -> NetSend (Auth t) -> PrivateKey -> IO (t -> IO ())
authClient r s privateKey = do
	-- Decrypt the certificate.
	enc <- recv r
	let ei = decrypt privateKey enc
	cert <- either throwIO return ei

	pool <- createEntropyPool
	return $ \x -> do
		salt <- grabEntropyIO 100 pool
		salt <- return $ readSecureMem salt
		let enc = encode x
		send s $ Auth (hashlazy $ cert `append` fromChunks [salt] `append` enc) salt enc
-- The format of an authenticated record is:
-- 
-- * An eight-byte record length, in bytes
--
-- * A 64-byte hash
--
-- * A 100-byte salt
--
-- * The remainder of the record contains the data

example = do
	pool <- createEntropyPool
	let Right (pub, priv, _) = generateKeyPair pool 1024
	(r :: NetRecv ByteString, s) <- newNetChan
	(r2 :: NetRecv (Auth Int), s2) <- newNetChan
	authServer print r2 s pub
	threadDelay 100000
	f <- authClient r s2 priv
	f 1
	f 2
	f 5