-- SPDX-FileCopyrightText: 2021 Serokell <https://serokell.io/>
--
-- SPDX-License-Identifier: MPL-2.0

{-# LANGUAGE CPP #-}
{-# LANGUAGE InterruptibleFFI #-}

-- | Internal utilities for reading passwords.
module Data.SensitiveBytes.IO.Internal.Password
  ( readPassword
  ) where

import Control.Exception.Safe (MonadMask, bracket)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Text (Text)
import qualified Data.Text.IO as T
import Foreign.C.Error (eILSEQ, getErrno)
import Foreign.C.Types (CInt (..))
import Foreign.Ptr (Ptr)
import System.IO (Handle, hFlush)

#if defined(mingw32_HOST_OS)
#else
import Data.Coerce (coerce)
import System.Posix.IO (handleToFd)
import System.Posix.Types (Fd (Fd))
import qualified System.Posix.Terminal as Term
#endif


foreign import ccall interruptible "readline_max"
  c_readLineMax :: CInt -> Ptr () -> CInt -> IO CInt

-- | A quick wrapper around the C function that turns the Haskell IO
-- 'Handle' into a system-dependent handle/fd.
readLineMax :: Handle -> Ptr () -> CInt -> IO CInt
#if defined(mingw32_HOST_OS)
readLineMax _ bufPtr maxLength = do
  c_readLineMax 0 bufPtr maxLength
#else
readLineMax :: Handle -> Ptr () -> CInt -> IO CInt
readLineMax Handle
hIn Ptr ()
bufPtr CInt
maxLength = do
  Fd
fdIn <- Handle -> IO Fd
handleToFd Handle
hIn
  CInt -> Ptr () -> CInt -> IO CInt
c_readLineMax (Fd -> CInt
coerce Fd
fdIn) Ptr ()
bufPtr CInt
maxLength
#endif


-- | Flush stdout, disable echo, and read user input from stdin.
readPassword
  :: Handle  -- ^ Input file handle.
  -> Handle  -- ^ Output file handle.
  -> Text  -- ^ Prompt.
  -> Ptr ()  -- ^ Target buffer.
  -> Int  -- ^ Target buffer size.
  -> IO Int
readPassword :: Handle -> Handle -> Text -> Ptr () -> Int -> IO Int
readPassword Handle
hIn Handle
hOut Text
prompt Ptr ()
bufPtr Int
allocSize = do
  Handle -> Text -> IO ()
T.hPutStr Handle
hOut Text
prompt
  Handle -> IO Int -> IO Int
forall (m :: * -> *) r.
(MonadIO m, MonadMask m) =>
Handle -> m r -> m r
withEchoDisabled Handle
hIn (IO Int -> IO Int) -> IO Int -> IO Int
forall a b. (a -> b) -> a -> b
$ do
    Handle -> IO ()
hFlush Handle
hOut  -- need to flush _after_ echo is disabled
    -- TODO: Do we also want to install signal handlers?
    CInt
res <- Handle -> Ptr () -> CInt -> IO CInt
readLineMax Handle
hIn Ptr ()
bufPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
allocSize)
    if CInt
res CInt -> CInt -> Bool
forall a. Ord a => a -> a -> Bool
>= CInt
0
    then do
      Handle -> Text -> IO ()
T.hPutStrLn Handle
hOut Text
""
      Int -> IO Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> IO Int) -> Int -> IO Int
forall a b. (a -> b) -> a -> b
$ CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
res
    else do
      Errno
errno <- IO Errno
getErrno
      -- TODO: Maybe return a Maybe or throw a proper exception?
      case CInt
res of
        -1 -> do
          if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eILSEQ
          then [Char] -> IO Int
forall a. HasCallStack => [Char] -> a
error [Char]
"readPassword: locale/terminal misconfiguration"
          else [Char] -> IO Int
forall a. HasCallStack => [Char] -> a
error [Char]
"readPassword: read error"
        CInt
_ -> [Char] -> IO Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO Int) -> [Char] -> IO Int
forall a b. (a -> b) -> a -> b
$ [Char]
"readPassword: impossible error happened: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> CInt -> [Char]
forall a. Show a => a -> [Char]
show CInt
res

-- | Run an action with terminal echo off (and then restore it).
withEchoDisabled :: (MonadIO m, MonadMask m) => Handle -> m r -> m r
#if defined(mingw32_HOST_OS)
withEchoDisabled _ = id  -- on Windows our @c_readLineMax@ does not echo anyway
#else
withEchoDisabled :: Handle -> m r -> m r
withEchoDisabled Handle
hIn m r
act = do
  Fd
fin <- IO Fd -> m Fd
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Fd -> m Fd) -> IO Fd -> m Fd
forall a b. (a -> b) -> a -> b
$ Handle -> IO Fd
handleToFd Handle
hIn
  IO Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Fd -> IO Bool
Term.queryTerminal Fd
fin) m Bool -> (Bool -> m r) -> m r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
False -> m r
act
    Bool
True -> do
      TerminalAttributes
attrs <- IO TerminalAttributes -> m TerminalAttributes
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO TerminalAttributes -> m TerminalAttributes)
-> IO TerminalAttributes -> m TerminalAttributes
forall a b. (a -> b) -> a -> b
$ Fd -> IO TerminalAttributes
Term.getTerminalAttributes Fd
fin
      let attrsNoEcho :: TerminalAttributes
attrsNoEcho = TerminalAttributes -> TerminalMode -> TerminalAttributes
Term.withoutMode TerminalAttributes
attrs TerminalMode
Term.EnableEcho
      m () -> (() -> m ()) -> (() -> m r) -> m r
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Fd -> TerminalAttributes -> TerminalState -> IO ()
Term.setTerminalAttributes Fd
fin TerminalAttributes
attrsNoEcho TerminalState
Term.WhenFlushed)
        (\()
_ -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Fd -> TerminalAttributes -> TerminalState -> IO ()
Term.setTerminalAttributes Fd
fin TerminalAttributes
attrs TerminalState
Term.Immediately)
        (\()
_ -> m r
act)
#endif