module Bio.Bam.Index (
    BamIndex(..),
    readBamIndex,
    readBaiIndex,
    readTabix,
    Region(..),
    Subsequence(..),
    eneeBamRefseq,
    eneeBamSubseq,
    eneeBamRegions,
    eneeBamUnaligned
) where
import Bio.Bam.Header
import Bio.Bam.Reader
import Bio.Bam.Rec
import Bio.Bam.Regions              ( Region(..), Subsequence(..) )
import Bio.Iteratee
import Bio.Prelude
import System.Directory             ( doesFileExist )
import System.FilePath              ( dropExtension, takeExtension, (<.>) )
import qualified Bio.Bam.Regions                as R
import qualified Data.IntMap.Strict             as M
import qualified Data.ByteString                as B
import qualified Data.Vector                    as V
import qualified Data.Vector.Mutable            as W
import qualified Data.Vector.Unboxed            as U
import qualified Data.Vector.Unboxed.Mutable    as N
import qualified Data.Vector.Algorithms.Intro   as N
data BamIndex a = BamIndex {
    
    minshift ::  !Int,
    
    depth ::  !Int,
    
    unaln_off ::  !Int64,
    
    extensions :: a,
    
    
    refseq_bins ::  !(V.Vector Bins),
    
    
    refseq_ckpoints ::  !(V.Vector Ckpoints) }
  deriving Show
type Bins = IntMap Segments
type Segments = U.Vector (Int64,Int64)
type Ckpoints = IntMap Int64
data Segment = Segment  !Int64  !Int64  !Int deriving Show
segmentLists :: BamIndex a -> Refseq -> R.Subsequence -> [[Segment]]
segmentLists bi@BamIndex{..} (Refseq ref) (R.Subsequence imap)
        | Just bins <- refseq_bins V.!? fromIntegral ref,
          Just cpts <- refseq_ckpoints V.!? fromIntegral ref
        = [ rgnToSegments bi beg end bins cpts | (beg,end) <- M.toList imap ]
segmentLists _ _ _ = []
rgnToSegments :: BamIndex a -> Int -> Int -> Bins -> Ckpoints -> [Segment]
rgnToSegments bi@BamIndex{..} beg end bins cpts =
    [ Segment boff' eoff end
    | bin <- binList bi beg end
    , (boff,eoff) <- maybe [] U.toList $ M.lookup bin bins
    , let boff' = max boff cpt
    , boff' < eoff ]
  where
    !cpt = maybe 0 snd $ M.lookupLE beg cpts
binList :: BamIndex a -> Int -> Int -> [Int]
binList BamIndex{..} beg end = binlist' 0 (minshift + 3*depth) 0
  where
    binlist' l s t = if l > depth then [] else [b..e] ++ go
      where
        b = t + beg `shiftR` s
        e = t + (end1) `shiftR` s
        go = binlist' (l+1) (s3) (t + 1 `shiftL` (3*l))
infix 4 ~~
(~~) :: [Segment] -> [Segment] -> [Segment]
Segment a b e : xs ~~ Segment u v f : ys
    |          b < u = Segment a b e : (xs ~~ Segment u v f : ys)     
    | a < u && b < v = Segment a v (max e f) : (xs ~~ ys)             
    |          b < v = Segment u v (max e f) : (xs ~~ ys)             
    | v < a          = Segment u v f : (xs ~~ Segment a b e : ys)     
    | u < a          = Segment u b (max e f) : (xs ~~ ys)             
    | otherwise      = Segment a b (max e f) : (xs ~~ ys)             
[] ~~ ys = ys
xs ~~ [] = xs
readBamIndex :: FilePath -> IO (BamIndex ())
readBamIndex fp | takeExtension fp == ".bai" = enumFile defaultBufSize fp readBaiIndex >>= run
                | takeExtension fp == ".csi" = enumFile defaultBufSize fp readBaiIndex >>= run
                | otherwise = tryIx               (fp <.> "bai") $
                              tryIx (dropExtension fp <.> "bai") $
                              tryIx               (fp <.> "csi") $
                              tryIx (dropExtension fp <.> "csi") $
                              enumFile defaultBufSize fp readBaiIndex >>= run
  where
    tryIx f k = do e <- doesFileExist f
                   if e then do r <- enumFile defaultBufSize f readBaiIndex >>= tryRun
                                case r of Right                     ix -> return ix
                                          Left (IterStringException _) -> k
                        else k
readBaiIndex :: MonadIO m => Iteratee Bytes m (BamIndex ())
readBaiIndex = iGetString 4 >>= switch
  where
    switch "BAI\1" = do nref <- fromIntegral `liftM` endianRead4 LSB
                        getIndexArrays nref 14 5 (const return) getIntervals
    switch "CSI\1" = do minshift <- fromIntegral `liftM` endianRead4 LSB
                        depth <- fromIntegral `liftM` endianRead4 LSB
                        endianRead4 LSB >>= dropStreamBS . fromIntegral 
                        nref <- fromIntegral `liftM` endianRead4 LSB
                        getIndexArrays nref minshift depth (addOneCheckpoint minshift depth) return
    switch magic   = throwErr . iterStrExc $ "index signature " ++ show magic ++ " not recognized"
    
    
    addOneCheckpoint minshift depth bin cp = do
            loffset <- fromIntegral `liftM` endianRead8 LSB
            let key = llim (fromIntegral bin) (3*depth) minshift
            return $! M.insertWith min key loffset cp
    
    llim bin dp sf | dp  ==  0 = 0
                   | bin >= ix = (bin  ix) `shiftL` sf
                   | otherwise = llim bin (dp3) (sf+3)
            where ix = (1 `shiftL` dp  1) `div` 7
type TabIndex = BamIndex TabMeta
data TabMeta = TabMeta { format :: TabFormat
                       , col_seq :: Int                           
                       , col_beg :: Int                           
                       , col_end :: Int                           
                       , comment_char :: Char
                       , skip_lines :: Int
                       , names :: V.Vector Bytes }
  deriving Show
data TabFormat = Generic | SamFormat | VcfFormat | ZeroBased   deriving Show
readTabix :: MonadIO m => Iteratee Bytes m TabIndex
readTabix = joinI $ decompressBgzf $ iGetString 4 >>= switch
  where
    switch "TBI\1" = do nref <- fromIntegral `liftM` endianRead4 LSB
                        format       <- liftM toFormat     (endianRead4 LSB)
                        col_seq      <- liftM fromIntegral (endianRead4 LSB)
                        col_beg      <- liftM fromIntegral (endianRead4 LSB)
                        col_end      <- liftM fromIntegral (endianRead4 LSB)
                        comment_char <- liftM (chr . fromIntegral) (endianRead4 LSB)
                        skip_lines   <- liftM fromIntegral (endianRead4 LSB)
                        names        <- liftM (V.fromList . B.split 0) . iGetString . fromIntegral =<< endianRead4 LSB
                        ix <- getIndexArrays nref 14 5 (const return) getIntervals
                        fin <- isFinished
                        if fin then return $! ix { extensions = TabMeta{..} }
                               else do unaln <- fromIntegral `liftM` endianRead8 LSB
                                       return $! ix { unaln_off = unaln, extensions = TabMeta{..} }
    switch magic   = throwErr . iterStrExc $ "index signature " ++ show magic ++ " not recognized"
    toFormat 1 = SamFormat
    toFormat 2 = VcfFormat
    toFormat x = if testBit x 16 then ZeroBased else Generic
getIntervals :: Monad m => (IntMap Int64, Int64) -> Iteratee Bytes m (IntMap Int64, Int64)
getIntervals (cp,mx0) = do
    nintv <- fromIntegral `liftM` endianRead4 LSB
    reduceM 0 nintv (cp,mx0) $ \(!im,!mx) int -> do
        oo <- fromIntegral `liftM` endianRead8 LSB
        return (if oo == 0 then im else M.insert (int * 0x4000) oo im, max mx oo)
getIndexArrays :: MonadIO m => Int -> Int -> Int
               -> (Word32 -> Ckpoints -> Iteratee Bytes m Ckpoints)
               -> ((Ckpoints, Int64) -> Iteratee Bytes m (Ckpoints, Int64))
               -> Iteratee Bytes m (BamIndex ())
getIndexArrays nref minshift depth addOneCheckpoint addManyCheckpoints
    | nref  < 1 = return $ BamIndex minshift depth 0 () V.empty V.empty
    | otherwise = do
        rbins  <- liftIO $ W.new nref
        rckpts <- liftIO $ W.new nref
        mxR <- reduceM 0 nref 0 $ \mx0 r -> do
                nbins <- endianRead4 LSB
                (!bins,!cpts,!mx1) <- reduceM 0 nbins (M.empty,M.empty,mx0) $ \(!im,!cp,!mx) _ -> do
                        bin <- endianRead4 LSB 
                        cp' <- addOneCheckpoint bin cp
                        segsarr <- getSegmentArray
                        let !mx' = if U.null segsarr then mx else max mx (snd (U.last segsarr))
                        return (M.insert (fromIntegral bin) segsarr im, cp', mx')
                (!cpts',!mx2) <- addManyCheckpoints (cpts,mx1)
                liftIO $ W.write rbins r bins >> W.write rckpts r cpts'
                return mx2
        liftM2 (BamIndex minshift depth mxR ()) (liftIO $ V.unsafeFreeze rbins) (liftIO $ V.unsafeFreeze rckpts)
getSegmentArray :: MonadIO m => Iteratee Bytes m Segments
getSegmentArray = do
    nsegs <- fromIntegral `liftM` endianRead4 LSB
    segsarr <- liftIO $ N.new nsegs
    loopM 0 nsegs $ \i -> do beg <- fromIntegral `liftM` endianRead8 LSB
                             end <- fromIntegral `liftM` endianRead8 LSB
                             liftIO $ N.write segsarr i (beg,end)
    liftIO $ N.sort segsarr >> U.unsafeFreeze segsarr
reduceM :: (Monad m, Enum ix, Eq ix) => ix -> ix -> a -> (a -> ix -> m a) -> m a
reduceM beg end acc cons = if beg /= end then cons acc beg >>= \n -> reduceM (succ beg) end n cons else return acc
loopM :: (Monad m, Enum ix, Eq ix) => ix -> ix -> (ix -> m ()) -> m ()
loopM beg end k = if beg /= end then k beg >> loopM (succ beg) end k else return ()
eneeBamRefseq :: Monad m => BamIndex b -> Refseq -> Enumeratee [BamRaw] [BamRaw] m a
eneeBamRefseq BamIndex{..} (Refseq r) iter
    | Just ckpts <- refseq_ckpoints V.!? fromIntegral r
    , Just (voff, _) <- M.minView ckpts
    , voff /= 0 = do seek $ fromIntegral voff
                     breakE ((Refseq r /=) . b_rname . unpackBam) iter
    | otherwise = return iter
eneeBamUnaligned :: Monad m => BamIndex b -> Enumeratee [BamRaw] [BamRaw] m a
eneeBamUnaligned BamIndex{..} iter = do when (unaln_off /= 0) $ seek $ fromIntegral unaln_off
                                        filterStream (not . isValidRefseq . b_rname . unpackBam) iter
eneeBamSegment :: Monad m => Segment -> Enumeratee [BamRaw] [BamRaw] m r
eneeBamSegment (Segment beg end mpos) out = do
    
    peekStream >>= \x -> case x of
        Just br | beg <= o && beg + 0x8000 > o -> return ()
            where o = fromIntegral $ virt_offset br
        _                                      -> seek $ fromIntegral beg
    let in_segment br = virt_offset br <= fromIntegral end && b_pos (unpackBam br) <= mpos
    takeWhileE in_segment out
eneeBamSubseq :: Monad m => BamIndex b -> Refseq -> R.Subsequence -> Enumeratee [BamRaw] [BamRaw] m a
eneeBamSubseq bi ref subs = foldr ((>=>) . eneeBamSegment) return segs ><> filterStream olap
  where
    segs = foldr (~~) [] $ segmentLists bi ref subs
    olap br = b_rname == ref && R.overlaps b_pos (b_pos + alignedLength b_cigar) subs
                    where BamRec{..} = unpackBam br
eneeBamRegions :: Monad m => BamIndex b -> [R.Region] -> Enumeratee [BamRaw] [BamRaw] m a
eneeBamRegions bi = foldr ((>=>) . uncurry (eneeBamSubseq bi)) return . R.toList . R.fromList