{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE ViewPatterns #-}
module Data.TypeRepMap.Internal where
import Prelude hiding (lookup)
import Control.Monad.ST (ST, runST)
import Control.Monad.Zip (mzip)
import Data.Function (on)
import Data.Kind (Type)
import Data.List (intercalate, nubBy)
import Data.Primitive.Array (Array, MutableArray, indexArray, mapArray', readArray, sizeofArray,
thawArray, unsafeFreezeArray, writeArray)
import Data.Primitive.PrimArray (PrimArray, indexPrimArray, sizeofPrimArray)
import Data.Semigroup (Semigroup (..))
import GHC.Base (Any, Int (..), Int#, (*#), (+#), (<#))
import GHC.Exts (IsList (..), inline, sortWith)
import GHC.Fingerprint (Fingerprint (..))
import GHC.Prim (eqWord#, ltWord#)
import GHC.Word (Word64 (..))
import Type.Reflection (SomeTypeRep (..), TypeRep, Typeable, typeRep, withTypeable)
import Type.Reflection.Unsafe (typeRepFingerprint)
import Unsafe.Coerce (unsafeCoerce)
import qualified Data.Map.Strict as Map
import qualified GHC.Exts as GHC (fromList, toList)
data TypeRepMap (f :: k -> Type) =
TypeRepMap
{ fingerprintAs :: {-# UNPACK #-} !(PrimArray Word64)
, fingerprintBs :: {-# UNPACK #-} !(PrimArray Word64)
, trAnys :: {-# UNPACK #-} !(Array Any)
, trKeys :: {-# UNPACK #-} !(Array Any)
}
instance Show (TypeRepMap f) where
show TypeRepMap{..} = "TypeRepMap [" ++ showKeys ++ "]"
where
showKeys :: String
showKeys = intercalate ", " $ toList $ mapArray' (show . anyToTypeRep) trKeys
instance Semigroup (TypeRepMap f) where
(<>) :: TypeRepMap f -> TypeRepMap f -> TypeRepMap f
(<>) = union
{-# INLINE (<>) #-}
instance Monoid (TypeRepMap f) where
mempty = TypeRepMap mempty mempty mempty mempty
mappend = (<>)
{-# INLINE mempty #-}
{-# INLINE mappend #-}
toFingerprints :: TypeRepMap f -> [Fingerprint]
toFingerprints TypeRepMap{..} =
zipWith Fingerprint (GHC.toList fingerprintAs) (GHC.toList fingerprintBs)
empty :: TypeRepMap f
empty = mempty
{-# INLINE empty #-}
one :: forall a f . Typeable a => f a -> TypeRepMap f
one x = insert x empty
{-# INLINE one #-}
insert :: forall a f . Typeable a => f a -> TypeRepMap f -> TypeRepMap f
insert x = fromTriples . addX . toTriples
where
tripleX :: (Fingerprint, Any, Any)
tripleX@(fpX, _, _) = (calcFp @a, toAny x, unsafeCoerce $ typeRep @a)
addX :: [(Fingerprint, Any, Any)] -> [(Fingerprint, Any, Any)]
addX l = tripleX : deleteByFst fpX l
{-# INLINE insert #-}
type KindOf (a :: k) = k
delete :: forall a (f :: KindOf a -> Type) . Typeable a => TypeRepMap f -> TypeRepMap f
delete = fromTriples . deleteByFst (typeFp @a) . toTriples
{-# INLINE delete #-}
adjust :: forall a f . Typeable a => (f a -> f a) -> TypeRepMap f -> TypeRepMap f
adjust fun tr = case cachedBinarySearch (typeFp @a) (fingerprintAs tr) (fingerprintBs tr) of
Nothing -> tr
Just i -> tr {trAnys = changeAnyArr i (trAnys tr)}
where
changeAnyArr :: Int -> Array Any -> Array Any
changeAnyArr i trAs = runST $ do
let n = sizeofArray trAs
mutArr <- thawArray trAs 0 n
a <- toAny . fun . fromAny <$> readArray mutArr i
writeArray mutArr i a
unsafeFreezeArray mutArr
{-# INLINE adjust #-}
hoist :: (forall x. f x -> g x) -> TypeRepMap f -> TypeRepMap g
hoist f (TypeRepMap as bs ans ks) = TypeRepMap as bs (mapArray' (toAny . f . fromAny) ans) ks
{-# INLINE hoist #-}
hoistA :: (Applicative t) => (forall x. f x -> t (g x)) -> TypeRepMap f -> t (TypeRepMap g)
hoistA f (TypeRepMap as bs (toList -> ans) ks) = (\l -> TypeRepMap as bs (fromList $ map toAny l) ks)
<$> traverse (f . fromAny) ans
{-# INLINE hoistA #-}
hoistWithKey :: forall f g. (forall x. Typeable x => f x -> g x) -> TypeRepMap f -> TypeRepMap g
hoistWithKey f (TypeRepMap as bs ans ks) = TypeRepMap as bs newAns ks
where
newAns = mapArray' mapAns (mzip ans ks)
mapAns (a, k) = toAny $ withTr (unsafeCoerce k) $ fromAny a
withTr :: forall x. TypeRep x -> f x -> g x
withTr t = withTypeable t f
{-# INLINE hoistWithKey #-}
unionWith :: (forall x. f x -> f x -> f x) -> TypeRepMap f -> TypeRepMap f -> TypeRepMap f
unionWith f m1 m2 = fromTriples
$ toTripleList
$ Map.unionWith combine
(fromTripleList $ toTriples m1)
(fromTripleList $ toTriples m2)
where
combine :: (Any, Any) -> (Any, Any) -> (Any, Any)
combine (av, ak) (bv, _) = (toAny $ f (fromAny av) (fromAny bv), ak)
fromTripleList :: Ord a => [(a, b, c)] -> Map.Map a (b, c)
fromTripleList = Map.fromList . map (\(a, b, c) -> (a, (b, c)))
toTripleList :: Map.Map a (b, c) -> [(a, b, c)]
toTripleList = map (\(a, (b, c)) -> (a, b, c)) . Map.toList
{-# INLINE unionWith #-}
union :: TypeRepMap f -> TypeRepMap f -> TypeRepMap f
union = unionWith const
{-# INLINE union #-}
member :: forall a (f :: KindOf a -> Type) . Typeable a => TypeRepMap f -> Bool
member tm = case lookup @a tm of
Nothing -> False
Just _ -> True
{-# INLINE member #-}
lookup :: forall a f . Typeable a => TypeRepMap f -> Maybe (f a)
lookup tVect = fromAny . (trAnys tVect `indexArray`)
<$> cachedBinarySearch (typeFp @a)
(fingerprintAs tVect)
(fingerprintBs tVect)
{-# INLINE lookup #-}
size :: TypeRepMap f -> Int
size = sizeofPrimArray . fingerprintAs
{-# INLINE size #-}
keys :: TypeRepMap f -> [SomeTypeRep]
keys TypeRepMap{..} = SomeTypeRep . anyToTypeRep <$> toList trKeys
{-# INLINE keys #-}
cachedBinarySearch :: Fingerprint -> PrimArray Word64 -> PrimArray Word64 -> Maybe Int
cachedBinarySearch (Fingerprint (W64# a) (W64# b)) fpAs fpBs = inline (go 0#)
where
go :: Int# -> Maybe Int
go i = case i <# len of
0# -> Nothing
_ -> let !(W64# valA) = indexPrimArray fpAs (I# i) in case a `ltWord#` valA of
0# -> case a `eqWord#` valA of
0# -> go (2# *# i +# 2#)
_ -> let !(W64# valB) = indexPrimArray fpBs (I# i) in case b `eqWord#` valB of
0# -> case b `ltWord#` valB of
0# -> go (2# *# i +# 2#)
_ -> go (2# *# i +# 1#)
_ -> Just (I# i)
_ -> go (2# *# i +# 1#)
len :: Int#
len = let !(I# l) = sizeofPrimArray fpAs in l
{-# INLINE cachedBinarySearch #-}
toAny :: f a -> Any
toAny = unsafeCoerce
fromAny :: Any -> f a
fromAny = unsafeCoerce
anyToTypeRep :: Any -> TypeRep f
anyToTypeRep = unsafeCoerce
typeFp :: forall a . Typeable a => Fingerprint
typeFp = typeRepFingerprint $ typeRep @a
{-# INLINE typeFp #-}
toTriples :: TypeRepMap f -> [(Fingerprint, Any, Any)]
toTriples tm = zip3 (toFingerprints tm) (GHC.toList $ trAnys tm) (GHC.toList $ trKeys tm)
deleteByFst :: Eq a => a -> [(a, b, c)] -> [(a, b, c)]
deleteByFst x = filter ((/= x) . fst3)
nubByFst :: (Eq a) => [(a, b, c)] -> [(a, b, c)]
nubByFst = nubBy ((==) `on` fst3)
fst3 :: (a, b, c) -> a
fst3 (a, _, _) = a
data WrapTypeable f where
WrapTypeable :: Typeable a => f a -> WrapTypeable f
instance Show (WrapTypeable f) where
show (WrapTypeable (_ :: f a)) = show $ calcFp @a
instance IsList (TypeRepMap f) where
type Item (TypeRepMap f) = WrapTypeable f
fromList :: [WrapTypeable f] -> TypeRepMap f
fromList = fromTriples . map (\x -> (fp x, an x, k x))
where
fp :: WrapTypeable f -> Fingerprint
fp (WrapTypeable (_ :: f a)) = calcFp @a
an :: WrapTypeable f -> Any
an (WrapTypeable x) = toAny x
k :: WrapTypeable f -> Any
k (WrapTypeable (_ :: f a)) = unsafeCoerce $ typeRep @a
toList :: TypeRepMap f -> [WrapTypeable f]
toList = map toWrapTypeable . toTriples
where
toWrapTypeable :: (Fingerprint, Any, Any) -> WrapTypeable f
toWrapTypeable (_, an, k) = withTypeable (unsafeCoerce k) $ fromAny an
calcFp :: forall a . Typeable a => Fingerprint
calcFp = typeRepFingerprint $ typeRep @a
fromTriples :: [(Fingerprint, Any, Any)] -> TypeRepMap f
fromTriples kvs = TypeRepMap (GHC.fromList fpAs) (GHC.fromList fpBs) (GHC.fromList ans) (GHC.fromList ks)
where
(fpAs, fpBs) = unzip $ map (\(Fingerprint a b) -> (a, b)) fps
(fps, ans, ks) = unzip3 $ fromSortedList $ sortWith fst3 $ nubByFst kvs
fromSortedList :: forall a . [a] -> [a]
fromSortedList l = runST $ do
let n = length l
let arrOrigin = fromListN n l
arrResult <- thawArray arrOrigin 0 n
go n arrResult arrOrigin
toList <$> unsafeFreezeArray arrResult
where
go :: forall s . Int -> MutableArray s a -> Array a -> ST s ()
go len result origin = () <$ loop 0 0
where
loop :: Int -> Int -> ST s Int
loop i first =
if i >= len
then pure first
else do
newFirst <- loop (2 * i + 1) first
writeArray result i (indexArray origin newFirst)
loop (2 * i + 2) (newFirst + 1)