module Data.Count.Counter where
import Control.Applicative ((<$>), (<*>))
import Data.Tuple (swap)
data Counter a = UnsafeMkCounter {
cCount :: Maybe Integer,
cToPos :: a -> Integer,
cFromPos :: Integer -> a
}
unitCounter :: Counter ()
unitCounter =
UnsafeMkCounter {
cCount = Just 1,
cToPos = \() -> 0,
cFromPos = \0 -> ()
}
voidCounter :: Counter a
voidCounter =
UnsafeMkCounter {
cCount = Just 0,
cToPos = const undefined,
cFromPos = const undefined
}
natCounter :: Counter Integer
natCounter =
UnsafeMkCounter {
cCount = Nothing,
cToPos = id,
cFromPos = id
}
dropCounter :: Integer -> Counter a -> Counter a
dropCounter skip aC =
UnsafeMkCounter {
cCount = max 0 . subtract skip <$> cCount aC,
cToPos = subtract skip . cToPos aC,
cFromPos = cFromPos aC . (+skip)
}
sumCounter :: Counter a -> Counter b -> Counter (Either a b)
sumCounter aC bC =
UnsafeMkCounter {
cCount = (+) <$> cCount aC <*> cCount bC,
cToPos = case (cCount aC, cCount bC) of
(Nothing, Nothing) -> \ab -> case ab of
Left a -> 2*cToPos aC a
Right b -> 2*cToPos bC b + 1
(Just acount, _) -> \ab -> case ab of
Left a -> cToPos aC a
Right b -> acount + cToPos bC b
(Nothing, Just bcount) -> cToPos (sumCounter bC aC) . invert,
cFromPos = case (cCount aC, cCount bC) of
(Nothing, Nothing) -> \n -> case n `divMod` 2 of
(n', 0) -> Left $ cFromPos aC $ n'
(n', 1) -> Right $ cFromPos bC $ n'
(Just acount, _) -> \n -> if n < acount
then Left $ cFromPos aC $ n
else Right $ cFromPos bC $ n acount
(Nothing, Just _) -> invert . cFromPos (sumCounter bC aC)
}
where
invert m = case m of
Left a -> Right a
Right a -> Left a
prodCounter :: Counter a -> Counter b -> Counter (a, b)
prodCounter aC bC =
UnsafeMkCounter {
cCount = if Just 0 `elem` [cCount aC, cCount bC]
then Just 0
else (*) <$> cCount aC <*> cCount bC,
cToPos = case (cCount aC, cCount bC) of
(Nothing, Nothing) -> posf $ \(an, bn) -> tri (an + bn) + an
(_, Just bcount) -> posf $ \(an, bn) -> an*bcount + bn
(Just _, Nothing) -> cToPos (prodCounter bC aC) . swap,
cFromPos = case (cCount aC, cCount bC) of
(Nothing, Nothing) -> \n -> let (tpos, rpos) = rtri n in
(cFromPos aC rpos, cFromPos bC (tpos rpos))
(_, Just bcount) -> \n -> let (an, bn) = n `divMod` bcount in
(cFromPos aC an, cFromPos bC bn)
(Just _, Nothing) -> swap . cFromPos (prodCounter bC aC)
}
where
posf f (a, b) = f (cToPos aC a, cToPos bC b)
tri :: Integer -> Integer
tri n = n*(n + 1) `div` 2
rtri :: Integer -> (Integer, Integer)
rtri n =
(r, n tri r)
where
r = (squareRoot (1 + 8*n) 1) `div` 2
sq n = n*n
squareRoot 0 = 0
squareRoot 1 = 1
squareRoot n =
let twopows = iterate sq 2
(lowerRoot, lowerN) =
last $ takeWhile ((n>=) . snd) $ zip (1:twopows) twopows
newtonStep x = div (x + div n x) 2
iters = iterate newtonStep (squareRoot (div n lowerN) * lowerRoot)
isRoot r = sq r <= n && n < sq (r+1)
in head $ dropWhile (not . isRoot) iters
boundedEnumCounter :: (Bounded a, Enum a) => Counter a
boundedEnumCounter = counter
where
[min, max] = map (toInteger . fromEnum) [minBound, maxBound `asTypeOf` cFromPos counter 0]
counter = UnsafeMkCounter {
cCount = Just $ max min + 1,
cToPos = \v -> (toInteger . fromEnum) v min,
cFromPos = \n -> toEnum . fromInteger $ min + n
}
isoCounter :: Counter a -> (b -> a) -> (a -> b) -> Counter b
isoCounter aC b2a a2b =
UnsafeMkCounter {
cCount = cCount aC,
cToPos = cToPos aC . b2a,
cFromPos = a2b . cFromPos aC
}
maybeCounter :: Counter a -> Counter (Maybe a)
maybeCounter aC = isoCounter (sumCounter aC unitCounter) f g
where
f m = case m of
Just a -> Left a
Nothing -> Right ()
g e = case e of
Left a -> Just a
Right () -> Nothing
listCounter :: Counter a -> Counter [a]
listCounter aC =
counter
where
inner = sumCounter (prodCounter aC counter) unitCounter
count = succ <$> cCount (prodCounter aC integerCounter)
counter = (isoCounter inner fromLs toLs){ cCount = count }
fromLs l = case l of
(a:as) -> Left (a, as)
[] -> Right ()
toLs e = case e of
Left (a, as) -> (a:as)
Right () -> []
integerCounter :: Counter Integer
integerCounter =
UnsafeMkCounter {
cCount = Nothing,
cToPos = \i -> if i > 0
then i*2 1
else abs i*2,
cFromPos = \n -> case (n + 1) `divMod` 2 of
(n', 0) -> n'
(n', 1) -> negate n'
}
allValuesFor :: Counter a -> [a]
allValuesFor aC =
map (cFromPos aC) range
where
range = case cCount aC of
Just n -> [0..n 1]
Nothing -> [0..]