module Optimization.TrustRegion.Fista ( -- * Fast Iterative Shrinkage-Thresholding Algorithm fista ) where import Linear -- | Fast Iterative Shrinkage-Thresholding Algorithm (FISTA) with -- constant stepsize {-# INLINEABLE fista #-} fista :: (Additive f, Fractional a, Floating a) => a -> (f a -> f a) -> f a -> [f a] fista l df x0' = go x0' x0' 1 where go x0 y1 t1 = let x1 = y1 ^-^ df y1 ^/ l t2 = (1 + sqrt (1 + 4 * t1^2)) / 2 y2 = x1 ^+^ (t1-1) / t2 *^ (x1 ^-^ x0) in x1 : go x1 y2 t2