module System.Posix.Pty (
spawnWithPty
, Pty
, PtyControlCode (..)
, createPty
, tryReadPty
, readPty
, writePty
, resizePty
, ptyDimensions
, getTerminalAttributes
, setTerminalAttributes
, sendBreak
, drainOutput
, discardData
, controlFlow
, getTerminalProcessGroupID
, getTerminalName
, getSlaveTerminalName
, module System.Posix.Terminal
) where
import Control.Applicative
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Unsafe.Coerce (unsafeCoerce)
import Foreign
import Foreign.C.String (CString, newCString, peekCString)
import Foreign.C.Types
import Foreign.C.Error (Errno(..), getErrno)
#if defined(linux_HOST_OS)
import Foreign.C.Error (eIO)
import System.IO.Error (catchIOError)
#endif
import System.IO (Handle)
import System.IO.Error (mkIOError, eofErrorType)
import System.Posix.IO.ByteString (fdToHandle, fdReadBuf)
import System.Posix.Types
import System.Process (ProcessHandle)
import System.Process.Internals (mkProcessHandle)
import qualified System.Posix.Terminal as T
import System.Posix.Terminal hiding
( getTerminalAttributes
, setTerminalAttributes
, sendBreak
, drainOutput
, discardData
, controlFlow
, getTerminalProcessGroupID
, setTerminalProcessGroupID
, queryTerminal
, getTerminalName
, openPseudoTerminal
, getSlaveTerminalName)
data Pty = Pty !Fd !Handle
data PtyControlCode = FlushRead
| FlushWrite
| OutputStopped
| OutputStarted
| DoStop
| NoStop
deriving (Eq, Read, Show)
createPty :: Fd -> IO (Maybe Pty)
createPty fd = do
isTerm <- T.queryTerminal fd
if isTerm
then Just . Pty fd <$> fdToHandle fd
else return Nothing
fdReadBS :: Fd -> ByteCount -> IO ByteString
fdReadBS fd n
| n <= 0 = return BS.empty
| otherwise = BSI.createAndTrim (fromIntegral n) fill
where
fill buf = do
rc <- wrap (fdReadBuf fd buf n)
case rc of
_ | rc == 0 -> eof
| otherwise -> return (fromIntegral rc)
wrap :: IO a -> IO a
#if defined(linux_HOST_OS)
wrap action = catchIOError action $ \ioE -> do
errno <- getErrno
case errno of
e | e == eIO -> eof
_ -> ioError ioE
#else
wrap = id
#endif
eof = do
hnd <- fdToHandle fd
ioError $ mkIOError eofErrorType "eof" (Just hnd) Nothing
tryReadPty :: Pty -> IO (Either [PtyControlCode] ByteString)
tryReadPty (Pty fd _) = do
result <- fdReadBS fd 1024
case BS.uncons result of
Just (byte, rest)
| byte == 0 -> return (Right rest)
| BS.null rest -> return (Left $ byteToControlCode byte)
| otherwise -> ioError can'tHappen
Nothing -> ioError can'tHappen
where
can'tHappen = userError "Uh-oh! Something different went horribly wrong!"
readPty :: Pty -> IO ByteString
readPty pty = tryReadPty pty >>= \case
Left _ -> readPty pty
Right bs -> return bs
writePty :: Pty -> ByteString -> IO ()
writePty (Pty _ hnd) = BS.hPut hnd
resizePty :: Pty -> (Int, Int) -> IO ()
resizePty (Pty fd _) (x, y) =
set_pty_size fd x y >>= throwCErrorOnMinus1 "unable to set pty dimensions"
ptyDimensions :: Pty -> IO (Int, Int)
ptyDimensions (Pty fd _) = alloca $ \x -> alloca $ \y -> do
get_pty_size fd x y >>= throwCErrorOnMinus1 "unable to get pty size"
(,) <$> peek x <*> peek y
spawnWithPty :: Maybe [(String, String)]
-> Bool
-> FilePath
-> [String]
-> (Int, Int)
-> IO (Pty, ProcessHandle)
spawnWithPty env' search path' argv' (x, y) = do
path <- newCString path'
argv <- mapM newCString argv'
env <- maybe (return []) (mapM fuse) env'
(ptyFd, cpid) <- forkExecWithPty x y path (fromBool search) argv env
mapM_ free (env ++ argv)
free path
throwCErrorOnMinus1 "unable to fork or open new pty" ptyFd
hnd <- fdToHandle ptyFd
ph <- mkProcessHandle (unsafeCoerce cpid) False
return (Pty ptyFd hnd, ph)
where
fuse (key, val) = newCString (key ++ "=" ++ val)
getFd :: Pty -> Fd
getFd (Pty fd _) = fd
throwCErrorOnMinus1 :: (Eq a, Num a) => String -> a -> IO ()
throwCErrorOnMinus1 s i = when (i == 1) $ do
errnoMsg <- getErrno >>= \(Errno code) -> (peekCString . strerror) code
ioError . userError $ s ++ ": " ++ errnoMsg
forkExecWithPty :: Int
-> Int
-> CString
-> CInt
-> [CString]
-> [CString]
-> IO (Fd, CInt)
forkExecWithPty x y path search argv' env' = do
argv <- newArray0 nullPtr (path:argv')
env <- case env' of
[] -> return nullPtr
_ -> newArray0 nullPtr env'
alloca $ \pid -> do
result <- fork_exec_with_pty x y search path argv env pid
free argv >> free env
pid' <- peek pid
return (result, pid')
byteToControlCode :: Word8 -> [PtyControlCode]
byteToControlCode i = map snd $ filter ((/=0) . (.&.i) . fst) codeMapping
where codeMapping :: [(Word8, PtyControlCode)]
codeMapping =
[ (tiocPktFlushRead, FlushRead)
, (tiocPktFlushWrite, FlushWrite)
, (tiocPktStop, OutputStopped)
, (tiocPktStart, OutputStarted)
, (tiocPktDoStop, DoStop)
, (tiocPktNoStop, NoStop)
]
tiocPktFlushRead :: Word8
tiocPktFlushRead = 1
tiocPktFlushWrite :: Word8
tiocPktFlushWrite = 2
tiocPktStop :: Word8
tiocPktStop = 4
tiocPktStart :: Word8
tiocPktStart = 8
tiocPktDoStop :: Word8
tiocPktDoStop = 32
tiocPktNoStop :: Word8
tiocPktNoStop = 16
foreign import ccall unsafe "string.h"
strerror :: CInt -> CString
foreign import ccall "pty_size.h"
set_pty_size :: Fd -> Int -> Int -> IO CInt
foreign import ccall "pty_size.h"
get_pty_size :: Fd -> Ptr Int -> Ptr Int -> IO CInt
foreign import ccall "fork_exec_with_pty.h"
fork_exec_with_pty :: Int
-> Int
-> CInt
-> CString
-> Ptr CString
-> Ptr CString
-> Ptr CInt
-> IO Fd
getTerminalAttributes :: Pty -> IO TerminalAttributes
getTerminalAttributes = T.getTerminalAttributes . getFd
setTerminalAttributes :: Pty -> TerminalAttributes -> TerminalState -> IO ()
setTerminalAttributes = T.setTerminalAttributes . getFd
sendBreak :: Pty -> Int -> IO ()
sendBreak = T.sendBreak . getFd
drainOutput :: Pty -> IO ()
drainOutput = T.drainOutput . getFd
discardData :: Pty -> QueueSelector -> IO ()
discardData = T.discardData . getFd
controlFlow :: Pty -> FlowAction -> IO ()
controlFlow = T.controlFlow . getFd
getTerminalProcessGroupID :: Pty -> IO ProcessGroupID
getTerminalProcessGroupID = T.getTerminalProcessGroupID . getFd
getTerminalName :: Pty -> IO FilePath
getTerminalName = T.getTerminalName . getFd
getSlaveTerminalName :: Pty -> IO FilePath
getSlaveTerminalName = T.getSlaveTerminalName . getFd