{-# LANGUAGE CApiFFI #-}

module FnMatch (fnmatch, FnMatchFlags (..)) where

import Data.Bits ((.|.))
import Data.ByteString qualified as BS
import Data.ByteString.Unsafe qualified as BU
import Foreign.C
import GHC.IO (unsafePerformIO)

foreign import capi "fnmatch.h fnmatch"
  c_fnmatch ::
    CString -> CString -> CInt -> IO CInt

data FnMatchFlags
  = FlagNoEscape
  | FlagPathName
  | FlagPeriod
  | FlagLeadingDir
  | FlagCaseFold
  | FlagExtMatch
  deriving (FnMatchFlags -> FnMatchFlags -> Bool
(FnMatchFlags -> FnMatchFlags -> Bool)
-> (FnMatchFlags -> FnMatchFlags -> Bool) -> Eq FnMatchFlags
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FnMatchFlags -> FnMatchFlags -> Bool
== :: FnMatchFlags -> FnMatchFlags -> Bool
$c/= :: FnMatchFlags -> FnMatchFlags -> Bool
/= :: FnMatchFlags -> FnMatchFlags -> Bool
Eq, Int -> FnMatchFlags -> ShowS
[FnMatchFlags] -> ShowS
FnMatchFlags -> String
(Int -> FnMatchFlags -> ShowS)
-> (FnMatchFlags -> String)
-> ([FnMatchFlags] -> ShowS)
-> Show FnMatchFlags
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FnMatchFlags -> ShowS
showsPrec :: Int -> FnMatchFlags -> ShowS
$cshow :: FnMatchFlags -> String
show :: FnMatchFlags -> String
$cshowList :: [FnMatchFlags] -> ShowS
showList :: [FnMatchFlags] -> ShowS
Show)

fnmatch :: BS.ByteString -> BS.ByteString -> [FnMatchFlags] -> Bool
fnmatch :: ByteString -> ByteString -> [FnMatchFlags] -> Bool
fnmatch ByteString
pattern ByteString
str [FnMatchFlags]
flags = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
  ByteString -> (CString -> IO Bool) -> IO Bool
forall a. ByteString -> (CString -> IO a) -> IO a
BU.unsafeUseAsCString ByteString
pattern ((CString -> IO Bool) -> IO Bool)
-> (CString -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CString
c_pattern ->
    ByteString -> (CString -> IO Bool) -> IO Bool
forall a. ByteString -> (CString -> IO a) -> IO a
BU.unsafeUseAsCString ByteString
str ((CString -> IO Bool) -> IO Bool)
-> (CString -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CString
c_str -> do
      CInt
result <- CString -> CString -> CInt -> IO CInt
c_fnmatch CString
c_pattern CString
c_str CInt
flags'
      Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CInt
result CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0)
  where
    flags' :: CInt
flags' = (CInt -> CInt -> CInt) -> CInt -> [CInt] -> CInt
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\CInt
flag CInt
acc -> CInt
acc CInt -> CInt -> CInt
forall a. Bits a => a -> a -> a
.|. CInt
flag) CInt
0 ((FnMatchFlags -> CInt) -> [FnMatchFlags] -> [CInt]
forall a b. (a -> b) -> [a] -> [b]
map FnMatchFlags -> CInt
forall {a}. Num a => FnMatchFlags -> a
flagToCInt [FnMatchFlags]
flags)
    flagToCInt :: FnMatchFlags -> a
flagToCInt FnMatchFlags
FlagNoEscape = a
1
    flagToCInt FnMatchFlags
FlagPathName = a
2
    flagToCInt FnMatchFlags
FlagPeriod = a
4
    flagToCInt FnMatchFlags
FlagLeadingDir = a
8
    flagToCInt FnMatchFlags
FlagCaseFold = a
16
    flagToCInt FnMatchFlags
FlagExtMatch = a
32