{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Concurrent.Worker
    (WorkMsg(..), StateCmd(..)
    , WorkEntry(..), WorkGroup(..)
    , WorkItems(..), WorkId, WorkState(..)
    , WorkControls(..), worker
    , DispBlk(..)
    )
where

import qualified Data.ByteString as S
import           Data.Conduit
import           Data.List hiding (lines)
import           Data.Ratio ((%))
import           Data.String.Conversions (cs)
import qualified Data.Text as T
import           Data.Time.Clock
import           Data.Time.Format
import           Data.Conduit.List (sourceList, consume, sourceNull)
import qualified Data.Conduit.List as CL
import qualified Data.Conduit.Binary as CB
import           Data.Conduit.Text (decode, utf8)
import           Control.Concurrent (ThreadId, Chan, newChan
                                    , readChan, writeChan, forkIO)
import qualified Control.Concurrent as CC
import           Control.Monad (forever, when, unless, void)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Monad.Trans.Class (lift)
import           Control.Monad.Trans.State.Lazy (StateT, evalStateT
                                                , gets, get, put, modify)
import           Data.Conduit.RemoteOp
import           System.Exit (ExitCode(..))
import           Data.Monoid ((<>))


data WorkControls = WorkControls {
      numParallel :: Int
    , useDirectSSH :: Bool
    , queryResponses :: [T.Text]
    }


data WorkMsg = ShowSelInfo Int
             | ClearAll
             | StartRun [T.Text] [WorkId] -- command
             | FinishedErr Int
             | FinishedGood Int
             | AddOutput Int DispBlk
             | StartResponse
             | AddToResponse T.Text
             | EndResponse
             | AbandonResponse
             | IncrParallel
             | DecrParallel
               deriving Show

-- Different types of information that can be displayed in the output region
data DispBlk = DispOut [T.Text] | DispErr [T.Text] | DispInp [T.Text] | DispInfo [T.Text]
               deriving Show

data StateCmd = NewItems [WorkItems] (Maybe (WorkId -> WorkState -> T.Text))
              | ResetUI
              | ChgState WorkState WorkId
              | DisplaySet DispBlk
              | DisplayAdd DispBlk
              | SetProgress T.Text Int
              | LogInfo T.Text
              | EndOfInformation
              -- removed: | SelItem Int


type WorkGroupName = T.Text
type WorkId = Int
type NumWorkItems = Int
data WorkItems = WorkGroup WorkGroupName WorkItems | WorkItems NumWorkItems deriving Show
data WorkState = NoWork | WorkDone | WorkFailed | WorkInProgress deriving (Show, Eq)

class WorkEntry a where
    name     :: a -> T.Text
    identify :: a -> WorkState -> T.Text
    rmtaddr  :: a -> T.Text

class WorkEntry (GroupEntry g) => WorkGroup g where
    type GroupEntry g
    getItems   :: g -> [WorkItems]
    numEntries :: g -> Int
    getEntry   :: g -> Int -> GroupEntry g


data RunCompl = RC { total     :: Int
                   , completed :: Int
                   , failures  :: Int
                   , command   :: Int -> IO ThreadId
                   , startTime :: UTCTime
                   , pending   :: [Int]
                   }
              | Idle


worker :: WorkGroup g => GenOutput -> Chan WorkMsg -> WorkControls -> g -> IO ()
worker genOutput inpWorkchan config entries =
    do let numE = numEntries entries
           noOuts = replicate numE []
           noCmds = replicate numE T.empty
           workerSt8 = WS entries config noOuts noCmds 0 Idle genOutput inpWorkchan
       genOutput $ NewItems (getItems entries) $ Just (identify . getEntry entries)
       evalStateT doWork workerSt8

type GenOutput = StateCmd -> IO ()

data WorkerState e = WS { entrieS::e   -- constrained as: WorkGroup e
                        , cfg :: WorkControls
                        , outs::[[DispBlk]]
                        , commands::[T.Text]
                        , currSel::Int
                        , pct::RunCompl
                        , genOutput :: GenOutput
                        , inpWorkChan::Chan WorkMsg
                        }

doWork :: WorkGroup g => StateT (WorkerState g) IO ()
doWork = forever $
              do wmsg <- lift . readChan =<< gets inpWorkChan
                 case wmsg of
                   ShowSelInfo n -> do s <- get
                                       lift $ genOutput s $ DisplaySet $ DispInp $ T.lines $ commands s !! n
                                       lift $ mapM_ (genOutput s . DisplayAdd) $ outs s !! n
                                       lift $ genOutput s EndOfInformation
                                       put $ s { currSel = n }
                   ClearAll -> modify $ \w ->
                               let cfg' = (cfg w) { queryResponses = [] }
                                   outs' = replicate (numEntries $ entrieS w) []
                               in w { outs = outs', cfg = cfg' }
                   StartRun cmd ss -> startRun cmd ss
                   FinishedErr n  -> updFinish n WorkFailed
                   FinishedGood n -> updFinish n WorkDone
                   AddOutput n ls -> addOuts n ls
                   StartResponse -> modWSQR (\qr -> qr ++ [T.empty])
                   AddToResponse r  -> modWSQR (\qr -> init qr ++
                                                       [last qr `T.append` r])
                   EndResponse -> do modWSQR (\qr -> qr ++ [T.empty])
                                     s <- get
                                     let addedMsg = cs $ "Added response; " <>
                                                    show respCount <> " total."
                                         respCount = (length . queryResponses . cfg $ s) - 1
                                     lift $ genOutput s $ LogInfo addedMsg
                   AbandonResponse -> modWSQR init
                   IncrParallel -> modifyParallelCnt succ
                   DecrParallel -> modifyParallelCnt pred


modWSQR op = modify $ \w -> let qr = queryResponses $ cfg w
                                qr' = op qr
                                cfg' = (cfg w) { queryResponses = qr' }
                            in w { cfg = cfg' }

modifyParallelCnt op = do s <- get
                          let np = numParallel $ cfg s
                              np' = max 1 $ op np
                              cfg' = (cfg s) { numParallel = np' }
                              newParMsg = cs $ "Parallel remotes: " <> show np
                          liftIO $ genOutput s $ LogInfo newParMsg
                          put $ s { cfg = cfg' }


startRun :: WorkGroup g => [T.Text] -> [Int] -> StateT (WorkerState g) IO ()
startRun cml marked = do s <- get
                         let cmd = preProcCmd cml
                         liftIO $ genOutput s $ DisplaySet $ DispInp cmd
                         liftIO $ genOutput s $ SetProgress "Running" 0
                         now <- lift getCurrentTime
                         let ents = entrieS s
                             workchan = inpWorkChan s
                             out = outs s
                             cmds = commands s
                             pct' = RC numE 0 0
                                    (runCmd (cfg s) (writeChan workchan) cmd ents) now remss
                             numE = length marked
                             msg  = "Running on " ++ show numE ++ " targets beginning at "
                                    ++ formatTime defaultTimeLocale rfc822DateFormat now
                             (startss,remss) = splitAt (numParallel $ cfg s) marked
                             outs' = withMarked marked (const []) out
                             cmd' = T.unlines cmd
                             cmds' = withMarked marked (const cmd') cmds
                         lift $ genOutput s $ LogInfo $ T.pack msg
                         lift $ mapM_ (command pct') startss
                         modify $ \w -> w { outs=outs', commands=cmds', pct=pct' }


preProcCmd :: [T.Text] -> [T.Text]
preProcCmd cmds = let lclfiles = filter (\l -> "{LOCALFILE:" `T.isInfixOf` l) cmds -- KWQ
                  in cmds


withMarked :: [Int] -> (a -> a) -> [a] -> [a]
withMarked marked f lst = unfoldr wMarked (0,marked,lst)
    where wMarked (_,_,[]) = Nothing
          wMarked (i,[],os) = Just (head os, (0,[],tail os))
          wMarked (i,cs,os) = let nxt_y = (i+1, tail cs, tail os)
                                  nxt_n = (i+1, cs,      tail os)
                              in if i == head cs
                                 then Just (f $ head os, nxt_y)
                                 else Just (head os, nxt_n)


addOuts n ls =
    do genO <- gets genOutput
       oO <- gets outs
       let nw = be4 ++ [thisnw] ++ aft
           thisnw = old ++ [ls]
           (be4, andAft) = splitAt n oO
           aft = if null andAft then [] else tail andAft
           old = if null andAft then [] else head andAft
       csl <- gets currSel
       when (n == csl) $ lift $ genO $ DisplayAdd ls
       modify $ \w -> w { outs = nw }

statement = T.intercalate " " . map T.pack

updFinish n nS =
    do genO <- gets genOutput
       lift $ genO $ ChgState nS n
       pct <- gets pct
       case pct of
         RC d c f op started ps -> do
             let c' = c + 1
                 f' = if nS == WorkFailed then f + 1 else f
                 ps' = if null ps then ps else tail ps
                 pct' = pct { completed = c'
                            , failures = f'
                            , pending = ps'
                            }
                 pbarVal = ceiling $ (c' % d) * 100
             if c' == d
               then do lift $ allDone genO pct'
                       modify $ \w -> w { pct = Idle }
               else do let ptext = cs $ "Running - " ++ show c' ++ "/" ++ show d
                       lift $ genO $ SetProgress ptext pbarVal
                       unless (null ps) $ lift $ void $ op $ head ps
                       modify $ \w -> w { pct = pct' }
         Idle -> do lift $ genO $ LogInfo $ statement
                             ["Warning:", show nS
                             , "completion for" ,show n
                             , " rcvd in Idle mode"]
                    modify $ \w -> w { pct = Idle }

allDone genO RC{..} =
    let sgood = show (total - failures) ++ " successes"
        sbad = show failures ++ " failures"
        sttl = show total ++ " targets"
        endmsg elapsed = statement [ "Run completed on", sttl
                                   , "with", sgood
                                   , "and", sbad
                                   , "in", show elapsed
                                   ]
    in do genO $ SetProgress "Idle" 0
          elapsed <- flip diffUTCTime startTime <$> getCurrentTime
          genO $ LogInfo $ endmsg elapsed
          genO EndOfInformation


runCmd :: WorkGroup g => WorkControls -> (WorkMsg -> IO ()) -> [T.Text] -> g -> Int -> IO ThreadId
runCmd ctl reqwork cmdop entryState entryNum = do
  -- (queryResponses ctl) has the responses for password queries
  inpChan <- newChan
  let sshOp = remoteOp (useDirectSSH ctl) hostaddr cmdop (sourceChan inpChan)
  forkIO $ sshOp
             $= onStdOut CB.lines
             $= onStdErr CB.lines
             $= onStdOut (decode utf8)
             $= onStdErr (decode utf8)
             $= onStdErr (detectPasswordReq (queryResponses ctl) inpChan)
             $= onStdOut (detectPasswordReq (queryResponses ctl) inpChan)
             $= onStdOut (suppressPasswordEcho (queryResponses ctl))
             $= onStdErr (suppressPasswordEcho (queryResponses ctl))
             $$ reportOnWork reqwork entryNum
    where hostaddr = rmtaddr $ getEntry entryState entryNum

             -- can add above: $= dbgOut

dbgOut :: (Monad m, Show e, Show o) => Conduit (OpOutputType o e) m (OpOutputType o e)
dbgOut = awaitForever $ \o -> do
           yield $ DebugOut (cs $ show o)
           yield o

sourceChan inputChan = do i <- liftIO $ readChan inputChan
                          case i of
                               Nothing -> return ()
                               Just v -> yield v >> sourceChan inputChan


onStdOut :: Monad m => Conduit i m o -> Conduit (OpOutputType i e) m (OpOutputType o e)
onStdOut doFunc = awaitForever $ \o ->
                  case o of
                    StdOut l -> yield l $= doFunc $= awaitForever (yield . StdOut)
                    StdErr e -> yield $ StdErr e
                    StdOutEnd -> yield StdOutEnd
                    StdErrEnd -> yield StdErrEnd
                    Ended r -> yield $ Ended r
                    DebugOut t -> yield $ DebugOut t

onStdErr :: Monad m => Conduit i m o -> Conduit (OpOutputType s i) m (OpOutputType s o)
onStdErr doFunc = awaitForever $ \o ->
                  case o of
                    StdOut e -> yield $ StdOut e
                    StdErr l -> yield l $= doFunc $= awaitForever (yield . StdErr)
                    StdOutEnd -> yield StdOutEnd
                    StdErrEnd -> yield StdErrEnd
                    Ended r -> yield $ Ended r
                    DebugOut t -> yield $ DebugOut t


-- reportOnWork :: MonadIO m => (WorkMsg -> IO ()) -> Int -> Sink (OpOutputType T.Text) m ()
reportOnWork reqwork entryNum = do
  start <- liftIO getCurrentTime
  awaitForever $ inner reqwork entryNum start
    where inner r e st inp = liftIO $
              case inp of
                StdOut t -> reqwork $ AddOutput entryNum $ DispOut $ T.lines $ cs t
                StdErr t -> reqwork $ AddOutput entryNum $ DispErr $ T.lines $ cs t
                DebugOut t -> reqwork $ AddOutput entryNum $ DispInfo $ T.lines $ cs t
                Ended ExitSuccess -> do
                                  showElapsed st
                                  reqwork $ FinishedGood entryNum
                Ended (ExitFailure c) -> do
                                  showElapsed st
                                  reqwork $ FinishedErr entryNum
          showElapsed st = liftIO $ do
                                  end <- liftIO getCurrentTime
                                  let elapsed = diffUTCTime end st
                                      emsg = cs $ "[Elapsed time: " ++ show elapsed ++ "]"
                                  reqwork $ AddOutput entryNum $ DispInfo $ T.lines emsg


-- | Sometimes passwords that are offered by detectPasswordReq end up getting echoed to stdout.  Catch those and suppress them!
suppressPasswordEcho :: MonadIO m => [T.Text] -> Conduit T.Text m T.Text
suppressPasswordEcho pwlst = awaitForever $ \l ->
                             yield (foldr (\pw o -> T.replace pw "" o) l pwlst)

detectPasswordReq :: MonadIO m => [T.Text] -> Chan (Maybe S.ByteString) -> Conduit T.Text m T.Text
detectPasswordReq pwlst inpChan = onPwReq 0
    where onPwReq n = await >>= needsPw n
          -- onPwReq n = awaitForever $ needsPw n
          needsPw n (Just s) =
            -- do yield $ cs $ "Checking for a match n" ++ (show n) ++ " against " ++ (show $ "Password:" == s) ++ " / " ++ (show $ "[sudo] password" `T.isPrefixOf` s) ++ " ..."
            case s of
              "Password:"  -> sendPw n
              "Password: " -> sendPw n
              _ | "[sudo] password" `T.isPrefixOf` s -> sendPw n
                | " password: " `T.isSuffixOf` s -> sendPw n
                | " password:" `T.isSuffixOf` s -> sendPw n
                | otherwise -> yield s >> onPwReq n
          needsPw n Nothing = return () -- yield "No more input for password checking"
          sendPw _ =
            mapM_ sendEachPW pwlst
            -- yield "Sent the passwords"
            -- liftIO $ writeChan inpChan Nothing  -- do not send this or the next query will fail due to closed stdin
          sendEachPW p =
            -- do yield $ cs $ "Sending password: " ++ (show p)
            liftIO $ writeChan inpChan $ Just $ cs $ p `T.append` "\n"
          sendPW n = let pw | null pwlst = "password"
                            | n > length pwlst = last pwlst
                            | otherwise = pwlst !! n
                         tryPw = Just $ cs $ pw `T.append` "\n"
                         n' = if n + 1 >= length pwlst then 0 else n + 1
                     in do
                       yield $ cs $ "Sending password: " ++ show tryPw
                       liftIO $ writeChan inpChan tryPw
                       onPwReq n'
                       yield "Pop the password"