{-# LANGUAGE ScopedTypeVariables, CPP, BangPatterns #-}

module Bio.Phylogeny.PhyBin.RFDistance
       (
         -- * Types
         DenseLabelSet, DistanceMatrix,

         -- * Bipartition (Bip) utilities
         allBips, foldBips, dispBip,
         consensusTree, bipsToTree, filterCompatible, compatibleWith,

         -- * ADT for dense sets
         mkSingleDense, mkEmptyDense, bipSize,
         denseUnions, denseDiff, invertDense, markLabel,
         
        -- * Methods for computing distance matrices
        naiveDistMatrix, hashRF, 

        -- * Output
        printDistMat)
       where

import           Control.Monad
import           Control.Monad.ST
import           Data.Function       (on)
import           Data.Word
import qualified Data.Vector                 as V
import qualified Data.Vector.Mutable         as MV
import qualified Data.Vector.Unboxed.Mutable as MU
import qualified Data.Vector.Unboxed         as U
import           Text.PrettyPrint.HughesPJClass hiding (char, Style)
import           System.IO      (hPutStrLn, hPutStr, Handle)
import           System.IO.Unsafe

-- import           Control.LVish
-- import qualified Data.LVar.Set   as IS
-- import qualified Data.LVar.SLSet as SL

-- import           Data.LVar.Map   as IM
-- import           Data.LVar.NatArray as NA

import           Bio.Phylogeny.PhyBin.CoreTypes
import           Bio.Phylogeny.PhyBin.PreProcessor (pruneTreeLeaves)
-- import           Data.BitList
import qualified Data.Set as S
import qualified Data.List as L
import qualified Data.IntSet as SI
import qualified Data.Map.Strict as M
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import           Data.Monoid
import           Prelude as P
import           Debug.Trace

#ifdef BITVEC_BIPS
import qualified Data.Vector.Unboxed.Bit     as UB
import qualified Data.Bit                    as B
#endif

-- I don't understand WHY, but I seem to get the same answers WITHOUT this.
-- Normalization and symmetric difference do make things somewhat slower (e.g. 1.8
-- seconds vs. 2.2 seconds for 150 taxa / 100 trees)
#define NORMALIZATION
-- define BITVEC_BIPS

--------------------------------------------------------------------------------
-- A data structure choice
--------------------------------------------------------------------------------

-- type DenseLabelSet s = BitList


-- | Dense sets of taxa, aka Bipartitions or BiPs
--   We assume that taxa labels have been mapped onto a dense, contiguous range of integers [0,N).
-- 
--   NORMALIZATION Rule: Bipartitions are really two disjoint sets.  But as long as
--   the parent set (the union of the partitions, aka "all taxa") then a bipartition
--   can be represented just by *one* subset.  Yet we must choose WHICH subset for
--   consistency.  We use the rule that we always choose the SMALLER.  Thus the
--   DenseLabelSet should always be half the size or less, compared to the total
--   number of taxa.
-- 
--   A set that is more than a majority of the taxa can be normalized by "flipping",
--   i.e. taking the taxa that are NOT in that set.
#ifdef BITVEC_BIPS

#  if 1
type DenseLabelSet = UB.Vector B.Bit
markLabel lab = UB.modify (\vec -> MU.write vec lab (B.fromBool True)) 
mkEmptyDense  size = UB.replicate size (B.fromBool False)
mkSingleDense size ind = markLabel ind (mkEmptyDense size)
denseUnions        = UB.unions
bipSize            = UB.countBits
denseDiff          = UB.difference
invertDense size bip = UB.invert bip
dispBip labs bip = show$ map (\(ix,_) -> (labs M.! ix)) $
                        filter (\(_,bit) -> B.toBool bit) $
                        zip [0..] (UB.toList bip)
denseIsSubset a b = UB.or (UB.difference b a)
traverseDense_ fn bip =
  U.ifoldr' (\ ix bit acc ->
              (if B.toBool bit
               then fn ix
               else return ()) >> acc)
        (return ()) bip

#  else
-- TODO: try tracking the size:
data DenseLabelSet = DLS {-# UNPACK #-} !Int (UB.Vector B.Bit)
markLabel lab (DLS _ vec)= DLS (UB.modify (\vec -> return (MU.write vec lab (B.fromBool True))) ) vec
-- ....
#  endif

#else
type DenseLabelSet = SI.IntSet
markLabel lab set   = SI.insert lab set 
mkEmptyDense _size  = SI.empty
mkSingleDense _size = SI.singleton
denseUnions _size   = SI.unions 
bipSize             = SI.size
denseDiff           = SI.difference
denseIsSubset       = SI.isSubsetOf

dispBip labs bip = "[" ++ unwords strs ++ "]"
  where strs = map (labs M.!) $ SI.toList bip
invertDense size bip = loop SI.empty (size-1)
  where -- There's nothing for it but to iterate and test for membership:
    loop !acc ix | ix < 0           = acc
                 | SI.member ix bip = loop acc (ix-1)
                 | otherwise        = loop (SI.insert ix acc) (ix-1)
traverseDense_ fn bip =
  -- FIXME: need guaranteed non-allocating way to do this.
  SI.foldr' (\ix acc ->  fn ix >> acc) (return ()) bip
#endif

markLabel    :: Label -> DenseLabelSet -> DenseLabelSet
mkEmptyDense :: Int -> DenseLabelSet
mkSingleDense :: Int -> Label -> DenseLabelSet
denseUnions  :: Int -> [DenseLabelSet] -> DenseLabelSet
bipSize      :: DenseLabelSet -> Int

-- | Print a BiPartition in a pretty form
dispBip      :: LabelTable -> DenseLabelSet -> String

-- | Assume that total taxa are 0..N-1, invert membership:
invertDense  :: Int -> DenseLabelSet -> DenseLabelSet

traverseDense_ :: Monad m => (Int -> m ()) -> DenseLabelSet -> m ()


--------------------------------------------------------------------------------
-- Dirt-simple reference implementation
--------------------------------------------------------------------------------

type DistanceMatrix = V.Vector (U.Vector Int)

-- | Returns a triangular distance matrix encoded as a vector.
--   Also return the set-of-BIPs representation for each tree.
--
--   This uses a naive method, directly computing the pairwise
--   distance between each pair of trees.
--
--   This method is TOLERANT of differences in the laba/taxa sets between two trees.
--   It simply prunes to the intersection before doing the distance comparison.
--   Other scoring methods may be added in the future.  (For example, penalizing for
--   missing taxa.)
naiveDistMatrix :: [NewickTree DefDecor] -> (DistanceMatrix, V.Vector (S.Set DenseLabelSet))
naiveDistMatrix lst = 
   let sz = P.length lst
       treeVect  = V.fromList lst
       labelSets = V.map treeLabels treeVect
       eachbips  = V.map allBips    treeVect
       mat = V.generate sz $ \ i ->        
             U.generate i  $ \ j ->
             let 
                 inI = (labelSets V.! i)
                 inJ = (labelSets V.! j)
                 inBoth = S.intersection inI inJ

                 -- Match will always succeed due to size==0 test below:
                 Just prI = pruneTreeLeaves inBoth (treeVect V.! i)
                 Just prJ = pruneTreeLeaves inBoth (treeVect V.! j)
                   
                 -- Memoization: If we are using it at its full size we can use the cached one:
                 bipsI = if S.size inBoth == S.size inI
                         then (eachbips V.! i)
                         else allBips prI
                 bipsJ = if S.size inBoth == S.size inJ
                         then (eachbips V.! j)
                         else allBips prJ

                 diff1 = S.size (S.difference bipsI bipsJ)
                 diff2 = S.size (S.difference bipsJ bipsI) -- symettric difference
             in if S.size inBoth == 0
                then 0 -- This is weird, but what other answer could we give?
                else diff1 + diff2
   in (mat, eachbips)

 where
   treeLabels :: NewickTree a -> S.Set Label
   treeLabels (NTLeaf _ lab)  = S.singleton lab
   treeLabels (NTInterior _ ls) = S.unions (map treeLabels ls)

-- | The number of bipartitions implied by a tree is one per EDGE in the tree.  Thus
-- each interior node carries a list of BiPs the same length as its list of children.
labelBips :: NewickTree a -> NewickTree (a, [DenseLabelSet])
labelBips tr =
--    trace ("labelbips "++show allLeaves++" "++show size) $
#ifdef NORMALIZATION  
    fmap (\(a,ls) -> (a,map (normBip size) ls)) $
#endif
    loop tr
  where    
    size = numLeaves tr
    zero = mkEmptyDense size
    loop (NTLeaf dec lab) = NTLeaf (dec, [markLabel lab zero]) lab      
    loop (NTInterior dec chlds) =
      let chlds' = map loop chlds
          sets   = map (denseUnions size . snd . get_dec) chlds' in
      NTInterior (dec, sets) chlds'

    allLeaves = leafSet tr
    leafSet (NTLeaf _ lab)    = mkSingleDense size lab
    leafSet (NTInterior _ ls) = denseUnions size $ map leafSet ls

-- normBip :: DenseLabelSet -> DenseLabelSet -> DenseLabelSet
--    normBip allLeaves bip =
normBip :: Int -> DenseLabelSet -> DenseLabelSet    
normBip totsize bip =
  let -- size     = bipSize allLeaves
      halfSize = totsize `quot` 2
--      flipped  = denseDiff allLeaves bip
      flipped  = invertDense totsize bip 
  in 
  case compare (bipSize bip) halfSize of
    LT -> bip 
    GT -> flipped -- Flip it
    EQ -> min bip flipped -- This is a painful case, we need a tie-breaker
    

foldBips :: Monoid m => (DenseLabelSet -> m) -> NewickTree a -> m
foldBips f tr = F.foldMap f' (labelBips tr)
 where f' (_,bips) = F.foldMap f bips
  
-- | Get all non-singleton BiPs implied by a tree.
allBips :: NewickTree a -> S.Set DenseLabelSet
allBips tr = S.filter ((> 1) . bipSize) $ foldBips S.insert tr S.empty

--------------------------------------------------------------------------------
-- Optimized, LVish version
--------------------------------------------------------------------------------
-- First, necessary types:

-- UNFINISHED:
#if 0
-- | A collection of all observed bipartitons (bips) with a mapping of which trees
-- contain which Bips.
type BipTable s = IMap DenseLabelSet s (SparseTreeSet s)
-- type BipTable = IMap BitList (U.Vector Bool)
-- type BipTable s = IMap BitList s (NA.NatArray s Word8)

-- | Sets of taxa (BiPs) that are expected to be sparse.
type SparseTreeSet s = IS.ISet s TreeID
-- TODO: make this a set of numeric tree IDs...
-- NA.NatArray s Word8

type TreeID = AnnotatedTree
-- | Tree's are identified simply by their order within the list of input trees.
-- type TreeID = Int
#endif

--------------------------------------------------------------------------------
-- Alternate way of slicing the problem: HashRF
--------------------------------------------------------------------------------

-- The distance matrix is an atomically-bumped matrix of numbers.
-- type DistanceMat s = NA.NatArray s Word32
-- Except... bump isn't supported by our idempotent impl.

-- | This version slices the problem a different way.  A single pass over the trees
-- populates the table of bipartitions.  Then the table can be processed (locally) to
-- produce (non-localized) increments to a distance matrix.
hashRF :: Int -> [NewickTree a] -> DistanceMatrix
hashRF num_taxa trees = build M.empty (zip [0..] trees)
  where
    num_trees = length trees
    -- First build the table:
    build acc [] = ingest acc
    build acc ((ix,hd):tl) =
      let bips = allBips hd
          acc' = S.foldl' fn acc bips
          fn acc bip = M.alter fn2 bip acc
          fn2 (Just membs) = Just (markLabel ix membs)
          fn2 Nothing      = Just (mkSingleDense num_taxa ix)
      in      
      build acc' tl

    -- Second, ingest the table to construct the distance matrix:
    ingest :: M.Map DenseLabelSet DenseLabelSet -> DistanceMatrix
    ingest bipTable = runST theST
      where
       theST :: forall s0 . ST s0 DistanceMatrix
       theST = do 
        -- Triangular matrix, starting narrow and widening:
        matr <- MV.new num_trees
        -- Too bad MV.replicateM is insufficient.  It should pass index.  
        -- Instead we write this C-style:
        for_ (0,num_trees) $ \ ix -> do 
          row <- MU.replicate ix (0::Int)
          MV.write matr ix row
          return ()

        unsafeIOToST$ putStrLn$" Built matrix for dim "++show num_trees

        let bumpMatr i j | j < i     = incr i j
                         | otherwise = incr j i
            incr :: Int -> Int -> ST s0 ()
            incr i j = do -- Not concurrency safe yet:
--                          unsafeIOToST$ putStrLn$" Reading at position "++show(i,j)
                          row <- MV.read matr i
                          elm <- MU.read row j
                          MU.write row j (elm+1)
                          return ()
            fn bipMembs =
              -- Here we quadratically consider all pairs of trees and ask whether
              -- their edit distance is increased based on this particular BiP.
              -- Actually, as an optimization, it is sufficient to consider only the
              -- cartesian product of those that have and those that don't.
              let haveIt   = bipMembs
                  -- Depending on how invertDense is written, it could be useful to
                  -- fuse this in and deforest "dontHave".
                  dontHave = invertDense num_trees bipMembs
                  fn1 trId = traverseDense_ (fn2 trId) dontHave
                  fn2 trId1 trId2 = bumpMatr trId1 trId2
              in
--                 trace ("Computed donthave "++ show dontHave) $ 
                 traverseDense_ fn1 haveIt
        F.traverse_ fn bipTable
        v1 <- V.unsafeFreeze matr
        T.traverse (U.unsafeFreeze) v1


--------------------------------------------------------------------------------
-- Miscellaneous Helpers
--------------------------------------------------------------------------------

instance Pretty a => Pretty (S.Set a) where
 pPrint s = pPrint (S.toList s)
 

printDistMat :: Handle -> V.Vector (U.Vector Int) -> IO () 
printDistMat h mat = do
  hPutStrLn h "Robinson-Foulds distance (matrix format):"
  hPutStrLn h "-----------------------------------------"
  V.forM_ mat $ \row -> do 
    U.forM_ row $ \elem -> do
      hPutStr h (show elem)
      hPutStr h " "
    hPutStr h "0\n"          
  hPutStrLn h "-----------------------------------------"

-- My own forM for numeric ranges (not requiring deforestation optimizations).
-- Inclusive start, exclusive end.
{-# INLINE for_ #-}
for_ :: Monad m => (Int, Int) -> (Int -> m ()) -> m ()
for_ (start, end) _fn | start > end = error "for_: start is greater than end"
for_ (start, end) fn = loop start
  where
   loop !i | i == end  = return ()
           | otherwise = do fn i; loop (i+1)

-- | Which of a set of trees are compatible with a consensus?
filterCompatible :: NewickTree a -> [NewickTree b] -> [NewickTree b]
filterCompatible consensus trees =
    let cbips = allBips consensus in
    [ tr | tr <- trees
         , cbips `S.isSubsetOf` allBips tr ]

-- | `compatibleWith consensus tree` -- Is a tree compatible with a consensus?
--   This is more efficient if partially applied then used repeatedly.
-- 
-- Note, tree compatibility is not the same as an exact match.  It's
-- like (<=) rather than (==).  The "star topology" is consistent with the
-- all trees, because it induces the empty set of bipartitions.  
compatibleWith :: NewickTree a -> NewickTree b -> Bool
compatibleWith consensus =
  let consBips = allBips consensus in 
  \ newTr -> S.isSubsetOf consBips (allBips newTr)

-- | Consensus between two trees, which may even have different label maps.
consensusTreeFull (FullTree n1 l1 t1) (FullTree n2 l2 t2) =
  error "FINISHME - consensusTreeFull"

-- | Take only the bipartitions that are agreed on by all trees.
consensusTree :: Int -> [NewickTree a] -> NewickTree ()
consensusTree _ [] = error "Cannot take the consensusTree of the empty list"
consensusTree num_taxa (hd:tl) = bipsToTree num_taxa intersection
  where
    intersection = L.foldl' S.intersection (allBips hd) (map allBips tl)
--     intersection = loop (allBips hd) tl
--     loop :: S.Set DenseLabelSet -> [NewickTree a] -> S.Set DenseLabelSet
--     loop !remain []      = remain
--     -- Was attempting to use foldBips here as an optimization:
-- --     loop !remain (hd:tl) = loop (foldBips S.delete hd remain) tl
--     loop !remain (hd:tl) = loop (S.difference remain (allBips hd)) tl    
      
-- | Convert from bipartitions BACK to a single tree.
bipsToTree :: Int -> S.Set DenseLabelSet -> NewickTree ()
bipsToTree num_taxa origbip =
--  trace ("Doing bips in order: "++show sorted++"\n") $ 
  loop lvl0 sorted
  where
    -- We consider each subset in increasing size order.
    -- FIXME: If we tweak the order on BIPs, then we can just use S.toAscList here:
    sorted = L.sortBy (compare `on` bipSize) (S.toList origbip)

    lvl0 = [ (mkSingleDense num_taxa ix, NTLeaf () ix)
           | ix <- [0..num_taxa-1] ]

    -- VERY expensive!  However, due to normalization issues this is necessary for now:
    -- TODO: in the future make it possible to definitively denormalize.
    -- isMatch bip x = denseIsSubset x bip || denseIsSubset x (invertDense num_taxa bip)
    isMatch bip x = denseIsSubset x bip 

    -- We recursively glom together subtrees until we have a complete tree.
    -- We only process larger subtrees after we have processed all the smaller ones.
    loop !subtrees [] =
      case subtrees of
        []    -> error "bipsToTree: internal error"
        [(_,one)] -> one
        lst   -> NTInterior () (map snd lst)
    loop !subtrees (bip:tl) =
--      trace (" -> looping, subtrees "++show subtrees) $ 
      let (in_,out) = L.partition (isMatch bip. fst) subtrees in
      case in_ of
        [] -> error $"bipsToTree: Internal error!  No match for bip: "++show bip
              ++" out is\n "++show out++"\n and remaining bips "++show (length tl)
              ++"\n when processing orig bip set:\n  "++show origbip
          -- loop out tl
        _ -> 
         -- Here all subtrees that match the current bip get merged:
         loop ((denseUnions num_taxa (map fst in_),
                NTInterior ()        (map snd in_)) : out) tl