{-# LANGUAGE RecordWildCards #-}

-- | Passive-aggressive optimization. Mainly based on:
--
-- Zakov, Shay and Goldberg, Yoav and Elhaded, Michael and Ziv-Ukelson, Michal
-- "Rich Parameterization Improves RNA Structure Prediction"
-- RECOMB 2011
--
-- and
--
-- Crammer, Koby and (et al)
-- "Online Passive-Aggressive Algorithms"
-- Journal of Machine Learning Research (2006)
--
-- TODO as always: move out of here and put in its own library

module BioInf.PassiveAggressive where

import qualified Data.Vector.Unboxed as VU
import Data.List as L
import Data.Set as S
import Control.Arrow
import Data.Map as M

import Biobase.TrainingData
import BioInf.Keys

import qualified BioInf.Params as P
import qualified BioInf.Params.Import as P
import qualified BioInf.Params.Export as P

import Statistics.ConfusionMatrix
import Statistics.PerformanceMetrics

import Data.PrimitiveArray as PA
import Data.PrimitiveArray.Ix



-- | Default implementation of P/A.

defaultPA :: Double -> P.Params -> TrainingData -> (P.Params,Double,Double,[(Int,Double)])
defaultPA aggressiveness params td@TrainingData{..}
--  | kScore+0.02 < pScore = error $ show (pScore,kScore,pOnly,kOnly,tau,changes)
--  | pScore > kScore = error "foo"
  | L.null $ pOnly++kOnly = (params,0,1,[])
  | sty >= 0.999 = (params,0,1,[])
--  | otherwise = error $ show (pOnly,kOnly,kScore,pScore,tau,changes)
  | otherwise = ( heck
                , tau
                , sty
                , changes
                )
  where
    new1 = P.fromList . VU.toList $ VU.accum (\v pm -> v+pm) cur changes
    new2 = P.fromList . VU.toList $ VU.accum (\v pm -> v+pm) (VU.fromList $ P.toList new1) []
    heck
      | P.toList new1 == P.toList new2 = new1
      | otherwise = error "fuck" -- ignore this line ;-) (impressive, that you actually read this code!)
    pFeatures = featureVector primary predicted
    kFeatures = featureVector primary secondary
    pOnly = pFeatures L.\\ kFeatures
    kOnly = kFeatures L.\\ pFeatures
    numChanges = genericLength $ pOnly ++ kOnly
    cur = VU.fromList . P.toList $ params
    pScore = sum . L.map (cur VU.!) $ pFeatures
    kScore = sum . L.map (cur VU.!) $ kFeatures
    pScore2 = sum . L.map (cur VU.!) $ pFeatures
    kScore2 = sum . L.map (cur VU.!) $ kFeatures
    tau
      | abs ((kScore2 - pScore2) - (kScore-pScore)) > 0.1
      = error $ "abs: \n" ++ z
      | val < 0      = error $ "val<0 \n" ++ z
      | sty >= 0.999 = 0
      | otherwise    = val -- 100 * val
      where
        val = min aggressiveness $ (kScore - pScore + sqrt (1-sty)) / (numChanges ^ 2)
        z = show ( kScore,pScore,kScore - pScore
                 , kScore2,pScore2, kScore2 - pScore2
                 ) ++ "\n" ++ primary ++ "\n" ++ (concat $ intersperse "\n" comments) ++ "\n" ++
                 ( L.concatMap (\x -> show x ++ "\n")
                 $ L.map (fun &&& (cur VU.!)) kOnly ) ++ " <<<\n" ++
                 ( L.concatMap (\x -> show x ++ "\n") 
                 $ L.map (fun &&& (cur VU.!)) pOnly ) ++ " ALL\n" ++
                 ( L.concatMap (\x -> show x ++ "\n")
                 $ L.map (fun &&& (cur VU.!)) pFeatures)
        fun i = let lol = vks M.! i in (lol, fun2 lol)
        fun2 hc@(HairpinClose k) = P.hairpinClose params PA.! k
        fun2 hl@(HairpinLength l) = P.hairpinLength params PA.! l
        fun2 _ = (-1)
    sty = case fmeasure (mkConfusionMatrix td) of -- currently optimizing using F_1
            Left  _ -> 1
            Right v -> v
    changes = zip kOnly (repeat $ negate tau) ++ zip pOnly (repeat tau)

-- | Pull in the statistical interface. From the confusion matrix, we
-- automagically get everything we need.
--
-- NOTE Unfortunately, StatisticalMethods has heavy dependencies.

instance MkConfusionMatrix TrainingData where
  mkConfusionMatrix TrainingData{..} = ConfusionMatrix
    { fn = Right . fromIntegral . S.size $ k `S.difference` p
    , fp = Right . fromIntegral . S.size $ p `S.difference` k
    , tn = Right . fromIntegral $ allPs - S.size (k `S.union` p)
    , tp = Right . fromIntegral . S.size $ k `S.intersection` p
    } where
        k = S.fromList secondary
        p = S.fromList predicted
        allPs = ((length primary) * (length primary -1)) `div` 2