{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}

module Posix.Directory
  ( getCurrentWorkingDirectory
  ) where

import Data.Primitive (ByteArray)
import Foreign.C.Error (Errno, eRANGE, getErrno)
import Foreign.C.Types (CChar, CSize (..))
import Foreign.Ptr (nullPtr)
import GHC.Exts (Ptr (..))

import qualified Data.Primitive as PM
import qualified Foreign.Storable as FS

foreign import ccall safe "getcwd"
  c_getcwd :: Ptr CChar -> CSize -> IO (Ptr CChar)

{- | Get the current working directory without using the system locale
  to convert it to text. This is implemented with a safe FFI call
  since it may block.
-}
getCurrentWorkingDirectory :: IO (Either Errno ByteArray)
getCurrentWorkingDirectory :: IO (Either Errno ByteArray)
getCurrentWorkingDirectory = Int -> IO (Either Errno ByteArray)
go (Int
4096 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
chunkOverhead)
 where
  go :: Int -> IO (Either Errno ByteArray)
go !Int
sz = do
    -- It may be nice to add a variant of getCurrentWorkingDirectory that
    -- allow the user to supply an initial pinned buffer. I'm not sure
    -- how many other POSIX functions there are that could benefit
    -- from this. Calls to getCurrentWorkingDirectory are extremely rare,
    -- so there would be little benefit here, but there may be other
    -- functions where these repeated 4KB allocations might trigger
    -- GC very quickly.
    MutableByteArray RealWorld
marr <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
sz
    let !(Ptr Addr#
addr) = MutableByteArray RealWorld -> Ptr Word8
forall s. MutableByteArray s -> Ptr Word8
PM.mutableByteArrayContents MutableByteArray RealWorld
marr
    Ptr CChar
ptr <- Ptr CChar -> CSize -> IO (Ptr CChar)
c_getcwd (Addr# -> Ptr CChar
forall a. Addr# -> Ptr a
Ptr Addr#
addr) (Int -> CSize
intToCSize Int
sz)
    -- We probably want to use touch# or with# here.
    if Ptr CChar
ptr Ptr CChar -> Ptr CChar -> Bool
forall a. Eq a => a -> a -> Bool
/= Ptr CChar
forall a. Ptr a
nullPtr
      then do
        Int
strSize <- Ptr CChar -> IO Int
findNullByte Ptr CChar
ptr
        MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
strSize
        MutableByteArray (PrimState IO)
-> Int -> MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> MutableByteArray (PrimState m) -> Int -> Int -> m ()
PM.copyMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
0 MutableByteArray RealWorld
MutableByteArray (PrimState IO)
marr Int
0 Int
strSize
        ByteArray
dst' <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
        Either Errno ByteArray -> IO (Either Errno ByteArray)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Either Errno ByteArray
forall a b. b -> Either a b
Right ByteArray
dst')
      else do
        Errno
errno <- IO Errno
getErrno
        if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eRANGE
          then Int -> IO (Either Errno ByteArray)
go (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz)
          else (Errno -> Either Errno ByteArray)
-> IO Errno -> IO (Either Errno ByteArray)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Errno -> Either Errno ByteArray
forall a b. a -> Either a b
Left IO Errno
getErrno

chunkOverhead :: Int
chunkOverhead :: Int
chunkOverhead = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int -> Int
forall a. Prim a => a -> Int
PM.sizeOf (Int
forall a. HasCallStack => a
undefined :: Int)

intToCSize :: Int -> CSize
intToCSize :: Int -> CSize
intToCSize = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral

-- There must be a null byte present or bad things will happen.
-- This will return a nonnegative number.
findNullByte :: Ptr CChar -> IO Int
findNullByte :: Ptr CChar -> IO Int
findNullByte = Int -> Ptr CChar -> IO Int
go Int
0
 where
  go :: Int -> Ptr CChar -> IO Int
  go :: Int -> Ptr CChar -> IO Int
go !Int
ix !Ptr CChar
ptr = do
    Ptr CChar -> Int -> IO CChar
forall a. Storable a => Ptr a -> Int -> IO a
FS.peekElemOff Ptr CChar
ptr Int
ix IO CChar -> (CChar -> IO Int) -> IO Int
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      CChar
0 -> Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ix
      CChar
_ -> Int -> Ptr CChar -> IO Int
go (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Ptr CChar
ptr