{-| Copyright : (C) 2012-2016, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij Assortment of utility function used in the Clash library -} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Clash.Util ( module Clash.Util , module X , makeLenses , SrcSpan , noSrcSpan , HasCallStack ) where import Control.Applicative as X (Applicative,(<$>),(<*>),pure) import Control.Arrow as X ((***),(&&&),first,second) import qualified Control.Exception as Exception import Control.Monad as X ((<=<),(>=>)) import Control.Monad.State (MonadState,State,StateT,runState) import qualified Control.Monad.State as State import Data.Typeable (Typeable) import Data.Function as X (on) import Data.Hashable (Hashable) import Data.HashMap.Lazy (HashMap) import qualified Data.HashMap.Lazy as HashMapL import Data.Maybe (fromMaybe) import Data.Text.Prettyprint.Doc import Data.Text.Prettyprint.Doc.Render.String import Data.Version (Version) import qualified Data.Time.Format as Clock import qualified Data.Time.Clock as Clock import Data.Time.Clock (UTCTime) import Control.Lens import Debug.Trace (trace) import GHC.Base (Int(..),isTrue#,(==#),(+#)) import GHC.Integer.Logarithms (integerLogBase#) import GHC.Stack (HasCallStack, callStack, prettyCallStack) import Type.Reflection (tyConPackage, typeRepTyCon, typeOf) import qualified Language.Haskell.TH as TH import SrcLoc (SrcSpan, noSrcSpan) import Clash.Unique #ifdef CABAL import qualified Paths_clash_lib (version) #endif data ClashException = ClashException SrcSpan String (Maybe String) instance Show ClashException where show (ClashException _ s eM) = s ++ "\n" ++ maybe "" id eM instance Exception.Exception ClashException assertPanic :: String -> Int -> a assertPanic file ln = Exception.throw (ClashException noSrcSpan ("ASSERT failed! file " ++ file ++ ", line " ++ show ln) Nothing) assertPprPanic :: HasCallStack => String -> Int -> Doc ann -> a assertPprPanic _file _line msg = pprPanic "ASSERT failed!" doc where doc = sep [ msg, callStackDoc ] pprPanic :: String -> Doc ann -> a pprPanic heading prettyMsg = Exception.throw (ClashException noSrcSpan (renderString (layoutPretty defaultLayoutOptions doc)) Nothing) where doc = sep [pretty heading, nest 2 prettyMsg] callStackDoc :: HasCallStack => Doc ann callStackDoc = "Call stack:" <+> hang 4 (vcat (map pretty (lines (prettyCallStack callStack)))) warnPprTrace :: HasCallStack => Bool -- ^ Trigger warning? -> String -- ^ File name -> Int -- ^ Line number -> Doc ann -- ^ Message -> a -- ^ Pass value (like trace) -> a warnPprTrace _ _ _ _ x | not debugIsOn = x warnPprTrace False _ _ _ x = x warnPprTrace True file ln msg x = pprDebugAndThen trace (vcat [heading0, heading1]) msg x where heading0 = hsep ["WARNING: file", pretty file <> comma, "line", pretty ln] heading1 = "WARNING CALLSTACK:" <> line <> pretty (prettyCallStack callStack) pprTrace :: String -> Doc ann -> a -> a pprTrace str = pprDebugAndThen trace (pretty str) pprTraceDebug :: String -> Doc ann -> a -> a pprTraceDebug str doc x | debugIsOn = pprDebugAndThen trace (pretty str) doc x | otherwise = x pprDebugAndThen :: (String -> a) -> Doc ann -> Doc ann -> a pprDebugAndThen cont heading prettyMsg = cont (renderString (layoutPretty defaultLayoutOptions doc)) where doc = sep [heading, nest 2 prettyMsg] -- | A class that can generate unique numbers class MonadUnique m where -- | Get a new unique getUniqueM :: m Int instance Monad m => MonadUnique (StateT Int m) where getUniqueM = do supply <- State.get State.modify (+1) return supply -- | Create a TH expression that returns the a formatted string containing the -- name of the module 'curLoc' is spliced into, and the line where it was spliced. curLoc :: TH.Q TH.Exp curLoc = do (TH.Loc _ _ modName (startPosL,_) _) <- TH.location TH.litE (TH.StringL $ modName ++ "(" ++ show startPosL ++ "): ") -- | Cache the result of a monadic action makeCached :: (MonadState s m, Hashable k, Eq k) => k -- ^ The key the action is associated with -> Lens' s (HashMap k v) -- ^ The Lens to the HashMap that is the cache -> m v -- ^ The action to cache -> m v makeCached key l create = do cache <- use l case HashMapL.lookup key cache of Just value -> return value Nothing -> do value <- create l %= HashMapL.insert key value return value -- | Cache the result of a monadic action using a 'UniqMap' makeCachedU :: (MonadState s m, Uniquable k) => k -- ^ Key the action is associated with -> Lens' s (UniqMap v) -- ^ Lens to the cache -> m v -- ^ Action to cache -> m v makeCachedU key l create = do cache <- use l case lookupUniqMap key cache of Just value -> return value Nothing -> do value <- create l %= extendUniqMap key value return value -- | Run a State-action using the State that is stored in a higher-layer Monad liftState :: (MonadState s m) => Lens' s s' -- ^ Lens to the State in the higher-layer monad -> State s' a -- ^ The State-action to perform -> m a liftState l m = do s <- use l let (a,s') = runState m s l .= s' return a -- | Functorial version of 'Control.Arrow.first' firstM :: Functor f => (a -> f c) -> (a, b) -> f (c, b) firstM f (x,y) = (,y) <$> f x -- | Functorial version of 'Control.Arrow.second' secondM :: Functor f => (b -> f c) -> (a, b) -> f (a, c) secondM f (x,y) = (x,) <$> f y combineM :: (Applicative f) => (a -> f b) -> (c -> f d) -> (a,c) -> f (b,d) combineM f g (x,y) = (,) <$> f x <*> g y -- | Performs trace when first argument evaluates to 'True' traceIf :: Bool -> String -> a -> a traceIf True msg = trace msg traceIf False _ = id {-# INLINE traceIf #-} -- | Monadic version of 'Data.List.partition' partitionM :: Monad m => (a -> m Bool) -> [a] -> m ([a], [a]) partitionM _ [] = return ([], []) partitionM p (x:xs) = do test <- p x (ys, ys') <- partitionM p xs return $ if test then (x:ys, ys') else (ys, x:ys') -- | Monadic version of 'Data.List.mapAccumL' mapAccumLM :: (Monad m) => (acc -> x -> m (acc,y)) -> acc -> [x] -> m (acc,[y]) mapAccumLM _ acc [] = return (acc,[]) mapAccumLM f acc (x:xs) = do (acc',y) <- f acc x (acc'',ys) <- mapAccumLM f acc' xs return (acc'',y:ys) -- | if-then-else as a function on an argument ifThenElse :: (a -> Bool) -> (a -> b) -> (a -> b) -> a -> b ifThenElse t f g a = if t a then f a else g a infixr 5 <:> -- | Applicative version of 'GHC.Types.(:)' (<:>) :: Applicative f => f a -> f [a] -> f [a] x <:> xs = (:) <$> x <*> xs -- | Safe indexing, returns a 'Nothing' if the index does not exist indexMaybe :: [a] -> Int -> Maybe a indexMaybe [] _ = Nothing indexMaybe (x:_) 0 = Just x indexMaybe (_:xs) n = indexMaybe xs (n-1) -- | Unsafe indexing, return a custom error message when indexing fails indexNote :: String -> [a] -> Int -> a indexNote note = \xs i -> fromMaybe (error note) (indexMaybe xs i) -- | Safe version of 'head' headMaybe :: [a] -> Maybe a headMaybe (a:_) = Just a headMaybe _ = Nothing -- | Safe version of 'tail' tailMaybe :: [a] -> Maybe [a] tailMaybe (_:as) = Just as tailMaybe _ = Nothing -- | Split the second list at the length of the first list splitAtList :: [b] -> [a] -> ([a], [a]) splitAtList [] xs = ([], xs) splitAtList _ xs@[] = (xs, xs) splitAtList (_:xs) (y:ys) = (y:ys', ys'') where (ys', ys'') = splitAtList xs ys clashLibVersion :: Version #ifdef CABAL clashLibVersion = Paths_clash_lib.version #else clashLibVersion = error "development version" #endif -- | Return number of occurrences of an item in a list countEq :: Eq a => a -- ^ Needle -> [a] -- ^ Haystack -> Int -- ^ Times needle was found in haystack countEq a as = length (filter (== a) as) -- | \x y -> floor (logBase x y), x > 1 && y > 0 flogBase :: Integer -> Integer -> Maybe Int flogBase x y | x > 1 && y > 0 = Just (I# (integerLogBase# x y)) flogBase _ _ = Nothing -- | \x y -> ceiling (logBase x y), x > 1 && y > 0 clogBase :: Integer -> Integer -> Maybe Int clogBase x y | x > 1 && y > 0 = case y of 1 -> Just 0 _ -> let z1 = integerLogBase# x y z2 = integerLogBase# x (y-1) in if isTrue# (z1 ==# z2) then Just (I# (z1 +# 1#)) else Just (I# z1) clogBase _ _ = Nothing -- | Determine whether two lists are of equal length equalLength :: [a] -> [b] -> Bool equalLength [] [] = True equalLength (_:as) (_:bs) = equalLength as bs equalLength _ _ = False -- | Determine whether two lists are not of equal length neLength :: [a] -> [b] -> Bool neLength [] [] = False neLength (_:as) (_:bs) = neLength as bs neLength _ _ = True -- | Zip two lists of equal length -- -- NB Errors out for a DEBUG compiler when the two lists are not of equal length zipEqual :: [a] -> [b] -> [(a,b)] #if !defined(DEBUG) zipEqual = zip #else zipEqual [] [] = [] zipEqual (a:as) (b:bs) = (a,b) : zipEqual as bs zipEqual _ _ = error "zipEqual" #endif -- | Is this a DEBUG compiler? debugIsOn :: Bool #if defined(DEBUG) debugIsOn = True #else debugIsOn = False #endif -- | Short-circuiting monadic version of 'any' anyM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool anyM _ [] = return False anyM p (x:xs) = do q <- p x if q then return True else anyM p xs -- | Get the package id of the type of a value -- >>> pkgIdFromTypeable (undefined :: TopEntity) -- "clash-prelude-0.99.3-64904d90747cb49e17166bbc86fec8678918e4ead3847193a395b258e680373c" pkgIdFromTypeable :: Typeable a => a -> String pkgIdFromTypeable = tyConPackage . typeRepTyCon . typeOf reportTimeDiff :: UTCTime -> UTCTime -> String reportTimeDiff start end = Clock.formatTime Clock.defaultTimeLocale fmt (Clock.UTCTime (toEnum 0) (fromRational (toRational diff))) where diff = Clock.diffUTCTime start end fmt | diff >= 3600 = "%-Hh%-Mm%-S%03Qs" | diff >= 60 = "%-Mm%-S%03Qs" | otherwise = "%-S%03Qs" -- | Converts a curried function to a function on a triple uncurry3 :: (a -> b -> c -> d) -> (a,b,c) -> d uncurry3 = \f (a,b,c) -> f a b c {-# INLINE uncurry3 #-}