{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE UnboxedTuples #-} -- | -- Module : Network.DNS.Pattern.Internal -- Description : Internal DNS types and definitions -- -- This module is not part of public API and may change even between patch versions. module Network.DNS.Internal ( DomainLabel(..) , Domain(..) , DList(..) , toDList , fromDList , singleton , sbsMap , sbsSingleton , isLitChar , pprLabelsUtf16 ) where import Data.ByteString.Internal (c2w) import qualified Data.ByteString.Short as SBS import Data.Foldable (foldl') import Data.Function (on) import qualified Data.Text as T import GHC.Word import Data.ByteString.Short.Internal (ShortByteString(SBS)) import qualified Data.Text.Array as T import qualified Data.Text.Internal as T import GHC.ST (ST(..), runST) #if !MIN_VERSION_bytestring(0,11,3) import GHC.Exts (Int(..), Int#, MutableByteArray#, indexWord8Array#, isTrue#, newByteArray#, unsafeFreezeByteArray#, writeWord8Array#, (+#), (<#)) #else import GHC.Exts (Int(..), Int#, MutableByteArray#, indexWord8Array#, isTrue#, newByteArray#, unsafeFreezeByteArray#, writeWord8Array#, (+#), (<#)) #endif import Network.DNS.Internal.Prim isLitChar :: Word8 -> Bool isLitChar c = (c >= c2w 'a' && c <= c2w 'z') || (c >= c2w '0' && c <= c2w '9') || (c >= c2w 'A' && c <= c2w 'Z') || (c == c2w '_') || (c == c2w '-') -- | Domain label with case-insensitive 'Eq' and 'Ord' as per [RFC4343](https://datatracker.ietf.org/doc/html/rfc4343#section-3). data DomainLabel = DomainLabel { getDomainLabel_ :: !SBS.ShortByteString , getDomainLabelCF_ :: !SBS.ShortByteString } -- | A domain parsed into labels. Each label is a 'SBS.ShortByteString' rather than 'T.Text' or 'String' because a label can contain arbitrary bytes. -- However, the 'Ord' and 'Eq' instances do limited case-folding according to [RFC4343](https://datatracker.ietf.org/doc/html/rfc4343#section-3). newtype Domain = Domain [DomainLabel] deriving (Eq, Ord) instance Ord DomainLabel where (<=) = (<=) `on` getDomainLabelCF_ compare = compare `on` getDomainLabelCF_ instance Eq DomainLabel where (==) = (==) `on` getDomainLabelCF_ -- | Difference list à la Huhges newtype DList a = DList ([a] -> [a]) -- | Turn a list into 'DList' {-# INLINE toDList #-} toDList :: [a] -> DList a toDList = DList . (++) -- | Turn 'DList' back into a list. {-# INLINE fromDList #-} fromDList :: DList a -> [a] fromDList (DList dl) = dl [] -- | Create a 'DList' containing just the specified element {-# INLINE singleton #-} singleton :: a -> DList a singleton = DList . (:) instance Semigroup (DList a) where {-# INLINE (<>) #-} DList l <> DList r = DList (l . r) instance Monoid (DList a) where {-# INLINE mempty #-} mempty = DList id {-# INLINE sbsSingleton #-} sbsSingleton :: Word8 -> SBS.ShortByteString #if MIN_VERSION_bytestring(0,11,3) sbsSingleton = SBS.singleton #else sbsSingleton (W8# w) = runST $ ST $ \s1 -> case newByteArray# 1# s1 of (# s2, mba #) -> case writeWord8Array# mba 0# w s2 of s3 -> case unsafeFreezeByteArray# mba s3 of (# s4, ma #) -> (# s4, SBS ma #) #endif sbsMap :: (Word8 -> Word8) -> SBS.ShortByteString -> SBS.ShortByteString #if MIN_VERSION_bytestring(0,11,3) sbsMap = SBS.map #else sbsMap m sbs@(SBS ba) = runST $ ST $ \s1 -> case newByteArray# l# s1 of (# s2, mba #) -> case go mba 0# l# of ST f -> case f s2 of (# s3, _ #) -> case unsafeFreezeByteArray# mba s3 of (# s4, ma #) -> (# s4, SBS ma #) where !(I# l#) = SBS.length sbs go :: MutableByteArray# s -> Int# -> Int# -> ST s () go !mba !i !l | I# i >= I# l = return () | otherwise = (ST $ \s -> let !(W8# w') = m (W8# (indexWord8Array# ba i)) in (# writeWord8Array# mba i w' s, () #) ) >> go mba (i +# 1#) l #endif pprLabelsUtf16 :: [SBS.ShortByteString] -> T.Text pprLabelsUtf16 xs@(_:_) = let SBS ba = createSBS (codePoints * 2) (go xs 0#) in T.text (T.Array ba) 0 codePoints where -- We adjust for an extra codepoint to account per label for the dot separators. -- The case of root zone domain names (empty list, pretty-prints to ".") is handled -- in a separate `domainEncoderUtf16` definition below. codePoints = foldl' (\a x -> a + pprLabelCodepoints x + 1) 0 xs go (a:as) off mba = do I# off' <- labelWriterUtf16 a off mba writeWord8Array0 mba off' (c2w '.') go as (off' +# 2#) mba go [] off _mba = pure () pprLabelsUtf16 [] = T.pack "." domainEncoderUtf16 :: [a] -> ShortByteString domainEncoderUtf16 [] = SBS.pack [c2w '.'] createSBS :: Int -> (forall s. MBA s -> ST s a) -> SBS.ShortByteString createSBS len fill = runST $ do mba <- newByteArray len fill mba BA# ba# <- unsafeFreezeByteArray mba pure (SBS ba#) pprLabelUtf16 :: SBS.ShortByteString -> SBS.ShortByteString pprLabelUtf16 bs = createSBS (pprLabelCodepoints bs * 2) (labelWriterUtf16 bs 0#) pprLabelCodepoints :: SBS.ShortByteString -> Int pprLabelCodepoints sbs@(SBS ba) = go 0# where !(I# len) = SBS.length sbs go i# | isTrue# (i# <# len) = let a = W8# (indexWord8Array# ba i#) in case () of _ | isLitChar a -> go (i# +# 1#) | a == c2w '\\' -> go (i# +# 2#) | a == c2w '.' -> go (i# +# 2#) | otherwise -> go (i# +# 3#) go i = I# i -- Create an ST action to write to a mutable byte array, starting at some offset. Returns -- the new offset where we can resume writing to. labelWriterUtf16 :: SBS.ShortByteString -> Int# -> MBA s -> ST s Int labelWriterUtf16 sbs@(SBS ba) off mba = go 0# off where !(I# len) = SBS.length sbs go i# off | isTrue# (i# <# len) = let a = W8# (indexWord8Array# ba i#) in case () of _ | isLitChar a -> do writeWord8Array0 mba (off) a go (i# +# 1#) (off +# 2#) | a == c2w '\\' -> do writeWord8Array0 mba (off) (c2w '\\') writeWord8Array0 mba (off +# 2#) (c2w '\\') go (i# +# 1#) (off +# 4#) | a == c2w '.' -> do writeWord8Array0 mba (off) (c2w '\\') writeWord8Array0 mba (off +# 2#) (c2w '.') go (i# +# 1#) (off +# 4#) | otherwise -> do writeWord8Array0 mba (off) (c2w '\\') writeWord8Array0 mba (off +# 2#) (c2w '0' + o1) writeWord8Array0 mba (off +# 4#) (c2w '0' + o2) writeWord8Array0 mba (off +# 6#) (c2w '0' + o3) go (i# +# 1#) (off +# 8#) where (# o1, o2, o3 #) = case quotRem a 8 of (v1, r3) -> case quotRem v1 8 of (v2, r2) -> case quotRem v2 8 of (_, r1) -> (# r1, r2, r3 #) | otherwise = pure (I# off)