{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE UndecidableInstances #-}

module Random (Gen (..), Split (..), Uniform (..),
               uniform,  range,  weighted,
               uniformM, rangeM, weightedM) where

import Control.Applicative
import Control.Monad.Primitive
import Control.Monad.Trans.Reader
import qualified Control.Monad.Trans.State as M
import Data.Bool
import Data.Foldable
import qualified Data.List as L
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Data.Primitive.MutVar
import Data.Ratio
import Data.Semigroup
import Data.Tuple (swap)
import Numeric.Natural

import Util

class Gen g where
    type Mut s g = m | m -> s g
    type instance Mut s g = MutVar s g
    type Native g
    uniformNative :: M.State g (Native g)
    uniformNativeM :: PrimMonad m => ReaderT (Mut (PrimState m) g) m (Native g)
    skip :: Natural -> g -> g
    skipM :: PrimMonad m => Natural -> ReaderT (Mut (PrimState m) g) m ()

    default uniformNativeM :: (Mut (PrimState m) g ~ MutVar (PrimState m) g,
                               PrimMonad m) => ReaderT (Mut (PrimState m) g) m (Native g)
    uniformNativeM = ReaderT $ flip atomicModifyMutVar $ swap . M.runState uniformNative

    skip n = appEndo . stimes n . Endo $ M.execState uniformNative

    default skipM :: (Mut (PrimState m) g ~ MutVar (PrimState m) g,
                      PrimMonad m) => Natural -> ReaderT (Mut (PrimState m) g) m ()
    skipM = flip mtimesA (() <$ uniformNativeM)

class Split g where
    split :: g -> (g, g)

class Uniform a where
    liftUniform :: (Bounded b, Enum b, Monad m) => m b -> m a

uniform :: (Gen g, Bounded (Native g), Enum (Native g), Uniform a) => M.State g a
uniform = liftUniform uniformNative

uniformM :: (Gen g, Bounded (Native g), Enum (Native g), Uniform a, PrimMonad m)
         => ReaderT (Mut (PrimState m) g) m a
uniformM = liftUniform uniformNativeM

instance {-# OVERLAPPABLE #-} (Bounded a, Enum a) => Uniform a where
    liftUniform = range' (minBound, maxBound)

range :: (Gen g, Bounded (Native g), Enum (Native g), Enum a) => (a, a) -> M.State g a
range = flip range' uniformNative

rangeM :: (Gen g, Bounded (Native g), Enum (Native g), Enum a, PrimMonad m)
       => (a, a) -> ReaderT (Mut (PrimState m) g) m a
rangeM = flip range' uniformNativeM

range' ::  a b m . (Enum a, Bounded b, Enum b, Monad m) => (a, a) -> m b -> m a
range' (a, b) = untilJust
              . fmap (toEnumMayWrap' . foldr (\ m n -> card @b * n + fromEnum' m) 0)
              . replicateA @_ @[] r
  where toEnumMayWrap' :: Natural -> Maybe a
        toEnumMayWrap' n | n > r * card @b `div` card_a * card_a = Nothing
                         | otherwise = [a..b] !!? (n `div` card_a)

        r = (card_a + card @b - 1) `div` card @b

        card_a = L.genericLength [a..b]

{-# INLINE[1] range' #-}
{-# RULES "range'/()" range' = (pure . pure . pure) () #-}
{-# RULES "range'"    range' = pure id #-}

weighted :: (Gen g, Bounded (Native g), Enum (Native g))
         => NonEmpty (a, Ratio Natural) -> M.State g a
weighted = weighted' range

weightedM :: (Gen g, Bounded (Native g), Enum (Native g), PrimMonad m)
          => NonEmpty (a, Ratio Natural) -> ReaderT (Mut (PrimState m) g) m a
weightedM = weighted' rangeM

weighted' :: Functor f
          => ((Natural, Natural) -> f Natural) -> NonEmpty (a, Ratio Natural) -> f a
weighted' range aps = flip go aps . (% b) <$> range (0, b - 1)
  where b = lcms $ denominator . snd <$> aps
        go x = NE.uncons & \ case ((a, _), Nothing) -> a
                                  ((a, p), Just aps) -> bool (go (x - p) aps) a (x < p)

lcms :: NonEmpty Natural -> Natural
lcms = liftA2 div product (foldr' gcd 0)