module Text.PhonotacticLearner.Util.ConjugateGradient (
traceInline, regulaFalsiSearch, conjugateGradientSearch,
) where
import qualified Data.Map as M
import Data.List
import Data.Ix
import Debug.Trace
import qualified Data.Vector.Unboxed as V
import System.IO
import System.IO.Unsafe
import Numeric
import Data.Array.IArray
import Text.PhonotacticLearner.Util.Ring
rfInitSigma :: Double
rfInitSigma = 0.05
traceInline :: String -> a -> a
traceInline s x = unsafePerformIO $ do
hPutStr stderr s
hFlush stderr
return x
regulaFalsiSearch :: Double
-> (Vec -> Vec -> Double)
-> Vec
-> Vec
-> Vec
regulaFalsiSearch epsilon f' xinit sdir = if (dxinit > 0) then xinit else pos (rfs a1 a2 0)
where
dir = normalizeVec sdir
dxinit = f' xinit dir
pos :: Double -> Vec
pos alpha = xinit ⊕ (alpha ⊙ dir)
doublingSearch = [(a, f' (pos a) dir) | a <- iterate (*2) rfInitSigma]
(a1,a2) = head (filter (\((_,dx),(_,dy)) -> (dx <= 0) && (dy >= 0)) (zip ((0, dxinit):doublingSearch) doublingSearch))
secant (!x,!dx) (!y,!dy) = (x*dy y*dx) / (dy dx)
rfs :: (Double, Double) -> (Double, Double) -> Int -> Double
rfs (!x,!dx) (!y,!dy) !bal
| (dx == 0) = x
| (dy == 0) = y
| ((yx) < epsilon) = secant (x,dx) (y,dy)
| (dz <= 0) = rfs (z,dz) (y,dy) (min bal 0 1)
| otherwise = rfs (x,dx) (z,dz) (max bal 0 + 1)
where
sy = if bal <= (2) then (0.707 ^ negate bal) else 1
sx = if bal >= 2 then (0.707 ^ bal) else 1
z = (secant (x, sx*dx) (y, sy*dy))
dz = f' (pos z) dir
conjugateGradientSearch :: Bool
-> (Double, Double)
-> (Vec -> (Vec, Bool))
-> (Vec -> (Double, Vec))
-> (Vec -> Vec -> Double)
-> Vec
-> Vec
conjugateGradientSearch shouldtrace (e1, e2) conproj fstar f' start = cjs dims (start ⊕ vec [2*e1]) zero zero start
where
opttrace = if shouldtrace then traceInline else const id
dims = length (coords start)
cjs :: Int -> Vec -> Vec -> Vec -> Vec -> Vec
cjs !bal !oldx !olddir !oldgrad !x = if normVec (oldx ⊖ x) < e1 || normVec (x ⊖ newx) < e1
then newx
else cjs nbal' x sdir grad newx'
where
(v,grad) = fstar x
beta' = innerProd grad (grad ⊖ oldgrad) / innerProd oldgrad oldgrad --Polak-Ribière
(beta, nbal) = if (bal >= dims || beta' <= 0) then (0,0) else (beta', bal + 1)
sdir = (beta ⊙ olddir) ⊖ grad
newx = opttrace (if beta <= 0 then "+" else "-") $ regulaFalsiSearch e2 f' x sdir
(newx', iscorr) = conproj newx
nbal' = if iscorr then dims else nbal