module Statistics.Regression
(
olsRegress
, ols
, rSquare
, bootstrapRegress
) where
import Control.Applicative ((<$>))
import Control.Concurrent (forkIO)
import Control.Concurrent.Chan (newChan, readChan, writeChan)
import Control.DeepSeq (rnf)
import Control.Monad (forM_, replicateM)
import GHC.Conc (getNumCapabilities)
import Prelude hiding (pred, sum)
import Statistics.Function as F
import Statistics.Matrix hiding (map)
import Statistics.Matrix.Algorithms (qr)
import Statistics.Resampling (splitGen)
import Statistics.Types (Estimate(..),ConfInt,CL,estimateFromInterval,significanceLevel)
import Statistics.Sample (mean)
import Statistics.Sample.Internal (sum)
import System.Random.MWC (GenIO, uniformR)
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
olsRegress :: [Vector]
-> Vector
-> (Vector, Double)
olsRegress preds@(_:_) resps
| any (/=n) ls = error $ "predictor vector length mismatch " ++
show lss
| G.length resps /= n = error $ "responder/predictor length mismatch " ++
show (G.length resps, n)
| otherwise = (coeffs, rSquare mxpreds resps coeffs)
where
coeffs = ols mxpreds resps
mxpreds = transpose .
fromVector (length lss + 1) n .
G.concat $ preds ++ [G.replicate n 1]
lss@(n:ls) = map G.length preds
olsRegress _ _ = error "no predictors given"
ols :: Matrix
-> Vector
-> Vector
ols a b
| rs < cs = error $ "fewer rows than columns " ++ show d
| otherwise = solve r (transpose q `multiplyV` b)
where
d@(rs,cs) = dimension a
(q,r) = qr a
solve :: Matrix
-> Vector
-> Vector
solve r b
| n /= l = error $ "row/vector mismatch " ++ show (n,l)
| otherwise = U.create $ do
s <- U.thaw b
rfor n 0 $ \i -> do
si <- (/ unsafeIndex r i i) <$> M.unsafeRead s i
M.unsafeWrite s i si
for 0 i $ \j -> F.unsafeModify s j $ subtract (unsafeIndex r j i * si)
return s
where n = rows r
l = U.length b
rSquare :: Matrix
-> Vector
-> Vector
-> Double
rSquare pred resp coeff = 1 r / t
where
r = sum $ flip U.imap resp $ \i x -> square (x p i)
t = sum $ flip U.map resp $ \x -> square (x mean resp)
p i = sum . flip U.imap coeff $ \j -> (* unsafeIndex pred i j)
bootstrapRegress
:: GenIO
-> Int
-> CL Double
-> ([Vector] -> Vector -> (Vector, Double))
-> [Vector]
-> Vector
-> IO (V.Vector (Estimate ConfInt Double), Estimate ConfInt Double)
bootstrapRegress gen0 numResamples cl rgrss preds0 resp0
| numResamples < 1 = error $ "bootstrapRegress: number of resamples " ++
"must be positive"
| otherwise = do
caps <- getNumCapabilities
gens <- splitGen caps gen0
done <- newChan
forM_ (zip gens (balance caps numResamples)) $ \(gen,count) -> forkIO $ do
v <- V.replicateM count $ do
let n = U.length resp0
ixs <- U.replicateM n $ uniformR (0,n1) gen
let resp = U.backpermute resp0 ixs
preds = map (flip U.backpermute ixs) preds0
return $ rgrss preds resp
rnf v `seq` writeChan done v
(coeffsv, r2v) <- (G.unzip . V.concat) <$> replicateM caps (readChan done)
let coeffs = flip G.imap (G.convert coeffss) $ \i x ->
est x . U.generate numResamples $ \k -> (coeffsv G.! k) G.! i
r2 = est r2s (G.convert r2v)
(coeffss, r2s) = rgrss preds0 resp0
est s v = estimateFromInterval s (w G.! lo, w G.! hi) cl
where w = F.sort v
lo = round c
hi = truncate (n c)
n = fromIntegral numResamples
c = n * (significanceLevel cl / 2)
return (coeffs, r2)
balance :: Int -> Int -> [Int]
balance numSlices numItems = zipWith (+) (replicate numSlices q)
(replicate r 1 ++ repeat 0)
where (q,r) = numItems `quotRem` numSlices