module Boltzmann.Data.Oracle where
import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Data
import Data.Hashable ( Hashable )
import Data.HashMap.Lazy ( HashMap )
import qualified Data.HashMap.Lazy as HashMap
import Data.Maybe ( fromJust, isJust )
import Data.Monoid
import qualified Data.Vector as V
import GHC.Generics ( Generic )
import Numeric.AD
import Boltzmann.Data.Common
import Boltzmann.Data.Types
import Boltzmann.Solver
data DataDef m = DataDef
{ count :: Int
, points :: Int
, index :: HashMap TypeRep (Either Aliased Ix)
, xedni :: HashMap Ix SomeData'
, xedni' :: HashMap Aliased (Ix, Alias m)
, types :: HashMap C [(Integer, Constr, [C'])]
, lTerm :: HashMap Ix (Nat, Integer)
, degree :: HashMap Ix Int
} deriving Show
data C = C Ix Int
deriving (Eq, Ord, Show, Generic)
instance Hashable C
data AC = AC Aliased Int
deriving (Eq, Ord, Show, Generic)
instance Hashable AC
type C' = (Maybe Aliased, C)
newtype Aliased = Aliased Int
deriving (Eq, Ord, Show, Generic)
instance Hashable Aliased
type Ix = Int
data Nat = Zero | Succ Nat
deriving (Eq, Ord, Show)
instance Monoid Nat where
mempty = Zero
mappend (Succ n) = Succ . mappend n
mappend Zero = id
natToInt :: Nat -> Int
natToInt Zero = 0
natToInt (Succ n) = 1 + natToInt n
infinity :: Nat
infinity = Succ infinity
dataDef :: [Alias m] -> DataDef m
dataDef as = DataDef
{ count = 0
, points = 0
, index = index
, xedni = HashMap.empty
, xedni' = xedni'
, types = HashMap.empty
, lTerm = HashMap.empty
, degree = HashMap.empty
} where
xedni' = HashMap.fromList (fmap (\(i, a) -> (i, (1, a))) as')
index = HashMap.fromList (fmap (\(i, a) -> (ofType a, Left i)) as')
as' = zip (fmap Aliased [0 ..]) as
ofType (Alias f) = typeRep (f undefined)
collectTypes :: Data a => [Alias m] -> proxy a -> DataDef m
collectTypes as a = collectTypesM a `execState` dataDef as
primOrder :: Int
primOrder = 1
primOrder' :: Nat
primOrder' = Succ Zero
primlCoef :: Integer
primlCoef = 1
type GUnfold m = forall b r. Data b => m (b -> r) -> m r
type AMap m = HashMap Aliased (Ix, Alias m)
collectTypesM :: Data a => proxy a
-> State (DataDef m) (Either Aliased Ix, ((Nat, Integer), Maybe Int))
collectTypesM a = chaseType a (const id)
chaseType :: Data a => proxy a
-> ((Maybe (Alias m), Ix) -> AMap m -> AMap m)
-> State (DataDef m) (Either Aliased Ix, ((Nat, Integer), Maybe Int))
chaseType a k = do
let t = typeRep a
dd@DataDef{..} <- get
let
lookup i r =
let
lTerm_i = lTerm #! i
degree_i = HashMap.lookup i degree
in return (r, (lTerm_i, degree_i))
case HashMap.lookup t index of
Nothing -> do
let i = count
put dd
{ count = i + 1
, index = HashMap.insert t (Right i) index
, xedni = HashMap.insert i (someData' a) xedni
, xedni' = k (Nothing, i) xedni'
}
traverseType a i
Just (Right i) -> do
put dd { xedni' = k (Nothing, i) xedni' }
lookup i (Right i)
Just (Left j) ->
case xedni' #! j of
(1, Alias f) -> do
(_, ld) <- chaseType (ofType f) $ \(alias, i) ->
let
alias' = case alias of
Nothing -> Alias f
Just (Alias g) -> Alias (composeCastM f g)
in
k (Just alias', i) . HashMap.insert j (i, alias')
return (Left j, ld)
(i, _) -> lookup i (Left j)
where
ofType :: (m a -> m b) -> m a
ofType _ = undefined
traverseType
:: Data a => proxy a -> Ix
-> State (DataDef m) (Either Aliased Ix, ((Nat, Integer), Maybe Int))
traverseType a i = do
let d = withProxy dataTypeOf a
mfix $ \ ~(_, (lTerm_i0, _)) -> do
modify $ \dd@DataDef{..} -> dd
{ lTerm = HashMap.insert i lTerm_i0 lTerm
}
(types_i, ld@(_, degree_i)) <- traverseType' a d
modify $ \dd@DataDef{..} -> dd
{ types = HashMap.insert (C i 0) types_i types
, degree = maybe id (HashMap.insert i) degree_i degree
}
return (Right i, ld)
traverseType'
:: Data a => proxy a -> DataType
-> State (DataDef m)
([(Integer, Constr, [(Maybe Aliased, C)])], ((Nat, Integer), Maybe Int))
traverseType' a d | isAlgType d = do
let
constrs = dataTypeConstrs d
collect
:: GUnfold (StateT
([Either Aliased Ix], (Nat, Integer), Maybe Int)
(State (DataDef m)))
collect mkCon = do
f <- mkCon
let ofType :: (b -> a) -> Proxy b
ofType _ = Proxy
b = ofType f
(j, (lTerm_, degree_)) <- lift (collectTypesM b)
modify $ \(js, lTerm', degree') ->
(j : js, lMul lTerm_ lTerm', liftA2 (+) degree_ degree')
return (withProxy f b)
tlds <- forM constrs $ \constr -> do
(js, lTerm', degree') <-
gunfold collect return constr `proxyType` a
`execStateT` ([], (Zero, 1), Just 1)
dd <- get
let
c (Left j) = (Just j, C (fst (xedni' dd #! j)) 0)
c (Right i) = (Nothing, C i 0)
return ((1, constr, [ c j | j <- js]), lTerm', degree')
let
(types_i, ls, ds) = unzip3 tlds
lTerm_i = first Succ (lSum ls)
degree_i = maxDegree ds
return (types_i, (lTerm_i, degree_i))
traverseType' _ _ =
return ([], ((primOrder', primlCoef), Just primOrder))
lPlus :: (Nat, Integer) -> (Nat, Integer) -> (Nat, Integer)
lPlus (Zero, lCoef) (Zero, lCoef') = (Zero, lCoef + lCoef')
lPlus (Zero, lCoef) _ = (Zero, lCoef)
lPlus _ (Zero, lCoef') = (Zero, lCoef')
lPlus (Succ order, lCoef) (Succ order', lCoef') =
first Succ $ lPlus (order, lCoef) (order', lCoef')
lSum :: [(Nat, Integer)] -> (Nat, Integer)
lSum [] = (infinity, 0)
lSum ls = foldl1 lPlus ls
lMul :: (Nat, Integer) -> (Nat, Integer) -> (Nat, Integer)
lMul (order, lCoef) (order', lCoef') = (order <> order', lCoef * lCoef')
lProd :: [(Nat, Integer)] -> (Nat, Integer)
lProd = foldl lMul (Zero, 1)
maxDegree :: [Maybe Int] -> Maybe Int
maxDegree = foldl (liftA2 max) (Just minBound)
point :: DataDef m -> DataDef m
point dd@DataDef{..} = dd
{ points = points'
, types = foldl g types [0 .. count1]
} where
points' = points + 1
g types i = HashMap.insert (C i points') (types' i) types
types' i = types #! C i 0 >>= h
h (_, constr, js) = do
ps <- partitions points' (length js)
let
mult = multinomial points' ps
js' = zipWith (\(j', C i _) p -> (j', C i p)) js ps
return (mult, constr, js')
type Oracle = HashMap C Double
makeOracle :: DataDef m -> TypeRep -> Maybe Double -> Oracle
makeOracle dd0 t size' =
seq v
HashMap.fromList (zip cs (V.toList v))
where
dd@DataDef{..} = if isJust size' then point dd0 else dd0
cs = flip C <$> [0 .. points] <*> [0 .. count 1]
m = count * (points + 1)
k = points 1
i = case index #! t of
Left j -> fst (xedni' #! j)
Right i -> i
checkSize _ (Just ys) | V.any (< 0) ys = False
checkSize (Just size) (Just ys) =
size >= size_
where
size_ = ys V.! j' / ys V.! j
j = dd ? C i k
j' = dd ? C i (k + 1)
checkSize Nothing (Just _) = True
checkSize _ Nothing = False
phis :: Num a => V.Vector (a -> V.Vector a -> a)
phis = V.fromList [ phi dd c (types #! c) | c <- listCs dd ]
eval' :: Double -> Maybe (V.Vector Double)
eval' x = fixedPoint defSolveArgs phi' (V.replicate m 0)
where
phi' :: (Mode a, Scalar a ~ Double) => V.Vector a -> V.Vector a
phi' y = fmap (\f -> f (auto x) y) phis
v = (fromJust . snd) (search eval' (checkSize size'))
phi :: Num a => DataDef m -> C -> [(Integer, constr, [C'])]
-> a -> V.Vector a -> a
phi DataDef{..} (C i _) [] =
case xedni #! i of
SomeData a ->
case (dataTypeRep . withProxy dataTypeOf) a of
AlgRep _ -> \_ _ -> 0
_ -> \x _ -> fromInteger primlCoef * x ^ primOrder
phi dd@DataDef{..} _ tyInfo = f
where
f x y = x * (sum . fmap (toProd y)) tyInfo
toProd y (w, _, js) =
fromInteger w * product [ y V.! (dd ? j) | (_, j) <- js ]
type Generators m = (HashMap AC (SomeData m), HashMap C (SomeData m))
makeGenerators
:: forall m. MonadRandomLike m
=> DataDef m -> Oracle -> Generators m
makeGenerators DataDef{..} oracle =
seq oracle
(generatorsL, generatorsR)
where
f (C i _) tyInfo = case xedni #! i of
SomeData a -> SomeData $ incr >>
case tyInfo of
[] -> defGen
_ -> frequencyWith doubleR (fmap g tyInfo) `proxyType` a
g :: Data a => (Integer, Constr, [C']) -> (Double, m a)
g (v, constr, js) =
( fromInteger v * w
, gunfold generate return constr `runReaderT` gs)
where
gs = fmap (\(j', i) -> m j' i) js
m = maybe (generatorsR #!) m'
m' j (C _ k) = (generatorsL #! AC j k)
w = product $ fmap ((oracle #!) . snd) js
h (j, (i, Alias f)) k =
(AC j k, applyCast f (generatorsR #! C i k))
generatorsL = HashMap.fromList (liftA2 h (HashMap.toList xedni') [0 .. points])
generatorsR = HashMap.mapWithKey f types
type SmallGenerators m =
(HashMap Aliased (SomeData m), HashMap Ix (SomeData m))
smallGenerators
:: forall m. MonadRandomLike m => DataDef m -> SmallGenerators m
smallGenerators DataDef{..} = (generatorsL, generatorsR)
where
f i (SomeData a) = SomeData $ incr >>
case types #! C i 0 of
[] -> defGen
tyInfo ->
let gs = (tyInfo >>= g (fst (lTerm #! i))) in
frequencyWith integerR gs `proxyType` a
g :: Data a => Nat -> (Integer, Constr, [C']) -> [(Integer, m a)]
g minSize (_, constr, js) =
guard (minSize == Succ size) *>
[(weight, gunfold generate return constr `runReaderT` gs)]
where
(size, weight) = lProd [ lTerm #! i | (_, C i _) <- js ]
gs = fmap lookup js
lookup (j', C i _) = maybe (generatorsR #! i) (generatorsL #!) j'
h (j, (i, Alias f)) = (j, applyCast f (generatorsR #! i))
generatorsL = (HashMap.fromList . fmap h . HashMap.toList) xedni'
generatorsR = HashMap.mapWithKey f xedni
generate :: Applicative m => GUnfold (ReaderT [SomeData m] m)
generate rest = ReaderT $ \(g : gs) ->
rest `runReaderT` gs <*> unSomeData g
defGen :: (Data a, MonadRandomLike m) => m a
defGen = gen
where
gen =
let dt = withProxy dataTypeOf gen in
case dataTypeRep dt of
IntRep -> fromConstr . mkIntegralConstr dt <$> int
FloatRep -> fromConstr . mkRealConstr dt <$> double
CharRep -> fromConstr . mkCharConstr dt <$> char
AlgRep _ -> error "Cannot generate for empty type."
NoRep -> error "No representation."
(?) :: DataDef m -> C -> Int
dd ? C i k = i + k * count dd
listCs :: DataDef m -> [C]
listCs dd = liftA2 (flip C) [0 .. points dd] [0 .. count dd 1]
ix :: C -> Int
ix (C i _) = i
(?!) :: DataDef m -> Int -> C
dd ?! j = C i k
where (k, i) = j `divMod` count dd
getGenerator :: Data a => DataDef m -> Generators m -> proxy a -> Int -> m a
getGenerator dd (l, r) a k = unSomeData $
case index dd #! typeRep a of
Right i -> (r #! C i k)
Left j -> (l #! AC j k)
getSmallGenerator :: Data a => DataDef m -> SmallGenerators m -> proxy a -> m a
getSmallGenerator dd (l, r) a = unSomeData $
case index dd #! typeRep a of
Right i -> (r #! i)
Left j -> (l #! j)
(#!) :: (Eq k, Hashable k)
=> HashMap k v -> k -> v
(#!) = (HashMap.!)