module Optimization.Constrained.ProjectedSubgradient
(
projSubgrad
, linearProjSubgrad
, StepSched
, optimalStepSched
, constStepSched
, sqrtKStepSched
, invKStepSched
, Constraint(..)
, linearProjection
) where
import Linear
import Data.Traversable
import Data.Function (on)
import Data.List (maximumBy)
type StepSched f a = [f a -> a -> a]
projSubgrad :: (Additive f, Traversable f, Metric f, Ord a, Fractional a)
=> StepSched f a
-> (f a -> f a)
-> (f a -> f a)
-> (f a -> a)
-> f a
-> [f a]
projSubgrad stepSizes proj df f = go stepSizes
where go (alpha:rest) x0 =
let p = negated $ df x0
step = alpha (df x0) (f x0)
x1 = proj $ x0 ^+^ step *^ p
in x1 : go rest x1
go [] _ = []
linearProjSubgrad :: (Additive f, Traversable f, Metric f, Ord a, Fractional a)
=> StepSched f a
-> (f a -> f a)
-> f a
-> a
-> f a
-> [f a]
linearProjSubgrad stepSizes proj a b = go stepSizes
where go (alpha:rest) x0 =
let p = negated $ df x0
step = alpha a (f x0)
x1 = proj $ x0 ^+^ step *^ p
in x1 : go rest x1
go [] _ = []
df _ = a
f x = a `dot` x b
optimalStepSched :: (Fractional a, Metric f)
=> a
-> StepSched f a
optimalStepSched fStar =
repeat $ \gk fk->(fk fStar) / quadrance gk
constStepSched :: a
-> StepSched f a
constStepSched gamma =
repeat $ \_ _ -> gamma
sqrtKStepSched :: Floating a
=> a
-> StepSched f a
sqrtKStepSched gamma =
map (\k _ _ -> gamma / sqrt (fromIntegral k)) [0..]
invKStepSched :: Fractional a
=> a
-> StepSched f a
invKStepSched gamma =
map (\k _ _ -> gamma / fromIntegral k) [0..]
data Constraint f a = Constr Ordering a (f a)
deriving (Show)
linearProjection :: (Fractional a, Ord a, RealFloat a, Metric f)
=> [Constraint f a]
-> f a -> f a
linearProjection constraints x =
case unmet of
[] -> x
_ -> linearProjection constraints $ fixConstraint x
$ maximumBy (flip compare `on` (`ap` x)) unmet
where unmet = filter (not . met x) constraints
ap (Constr _ b a) c = a `dot` c b
met c (Constr t a constr) = let y = constr `dot` c a
in case t of
EQ -> abs y < 1e-4
GT -> y >= 0 || abs y < 1e-4
LT -> y <= 0 || abs y < 1e-4
fixConstraint c (Constr _ b a) = c ^-^ (a `dot` c b) *^ a ^/ quadrance a