module Data.Nimber (
Nimber(fromNimber),
toNimber, nimRecip
)
where
import Data.Bits
import Data.List
import Data.Maybe
import Data.Ratio
import Control.Monad
import qualified Data.MemoCombinators as Memo
import qualified Data.Set as S
newtype Nimber = Nimber {
fromNimber :: Integer
} deriving (Eq, Ord)
memoNimber :: (Nimber -> r) -> Nimber -> r
memoNimber = Memo.wrap toNimber fromNimber Memo.integral
toNimber :: Integer -> Nimber
toNimber x
| x < 0 = error "negative nimbers not defined"
| otherwise = Nimber x
instance Show Nimber where
show (Nimber x) = '*' : show x
instance Enum Nimber where
pred (Nimber x) = Nimber (x1)
succ (Nimber x) = Nimber (x+1)
toEnum = Nimber . toInteger
fromEnum = fromEnum .fromNimber
instance Num Nimber where
abs = id
negate = id
(+) (Nimber x) (Nimber y) = toNimber (x `xor` y)
signum 0 = 0
signum _ = 1
fromInteger = toNimber
a * b = sum $ fastMult (fromNimber a) (fromNimber b) where
fastMult a b =
let aBits = reverse $ toBits a
bBits = reverse $ toBits b
in map (\(xs, ys) -> pow2mult $ bitProduct (toBits $ toInteger xs) (toBits $ toInteger ys)) $ filter (\(m, n) -> aBits !! m * bBits !! n == 1) $ liftM2 (,) [0 .. length aBits 1] [0 .. length bBits 1]
where toBits n = reverse $ unfoldr (\x -> if x==0 then Nothing else Just (x `rem` 2, x `div` 2)) n
pow2mult [] = 1
pow2mult [0] = 1
pow2mult [1] = 2
pow2mult (0:xs) = pow2mult xs
pow2mult (1:xs) = toNimber $ 2^(2^(length xs)) * (fromNimber $ pow2mult xs)
pow2mult (x:xs) = pow2mult (x1:xs) + pow2mult (x2:(map (+1) xs))
bitProduct xs ys
| lx == ly = zipWith (+) xs ys
| lx < ly = bitProduct ys xs
| lx > ly = zipWith (+) xs (replicate (lx ly) 0 ++ ys)
| otherwise = error "trichotomy violation"
where lx = length xs
ly = length ys
instance Fractional Nimber where
recip = memoNimber recip' where
recip' a = fromJust $ find (\n -> n * a == 1) [1..]
fromRational r = (toNimber $ numerator r) / (toNimber $ denominator r)
nimRecip :: Nimber -> Nimber
nimRecip = memoNimber nimRecip' where
nimRecip' a = mex . S.toList $ fixedPoint enlarge (S.fromList [0]) where
fixedPoint f x = fromJust $ find (\x -> f x == x) $ iterate f x
mex xs = fromJust $ find (`notElem` xs) [0..]
enlarge xs = xs `S.union` (S.fromList (liftM2 f [1 .. pred a] (S.toList xs)))
f a' b = (1 + (a' + a) * b) * (nimRecip a')