-- | Functions for visualising Stockholm alignments
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies     #-}
{-# LANGUAGE RankNTypes #-}

module Bio.StockholmDraw
    (
     drawStockholmLines,
     drawStockholm,
     convertWUSStoDotBracket,
     extractGapfreeStructure,
     extractGapfreeIndexedStructure,
     isGap
    ) where

import Diagrams.Prelude
import Diagrams.Backend.Cairo
import qualified Bio.StockholmData as S
import qualified Data.Text as T
import Data.Maybe
import qualified Data.Vector as V
import Data.List
import Graphics.SVGFonts
import Bio.StockholmFont

drawStockholmLines :: Int -> Double -> V.Vector (Int, V.Vector (Colour Double)) -> S.StockholmAlignment -> QDiagram Cairo V2 Double Any
drawStockholmLines entriesNumberCutoff maxWidth columnComparisonLabels aln = alignmentBlocks
  where seqEntries = V.fromList (take entriesNumberCutoff (S.sequenceEntries aln))
        --consensusStructureEntry = if null (S.columnAnnotations aln) then mempty else drawConsensusStructureEntry maxIdLength (S.columnAnnotations aln)
        maybeConsensusStructureEntry = find ((T.pack "SS_cons"==) . S.tag) (S.columnAnnotations aln)
        consensusStructureEntry = maybe V.empty makeConsensusStructureVectorEntry maybeConsensusStructureEntry
        entryNumber = V.length vectorEntries 
        seqVectorEntries = V.map makeVectorEntries seqEntries
        vectorEntries = seqVectorEntries V.++ consensusStructureEntry
        maxEntryLength = V.maximum (V.map (V.length . snd) vectorEntries)
        maxIdLength = V.maximum (V.map (length . fst) vectorEntries)
        letterWidth = 2.0 :: Double
        availableLettersPerRow = maxWidth / letterWidth
        blocks = makeLetterIntervals entryNumber availableLettersPerRow maxEntryLength
        --comparison labels are sparse, because some columns are not directly assigned to a node, but modeled via indel
        fullComparisonColLabels = fillComparisonColLabels (maxEntryLength  +1) columnComparisonLabels
        alignmentBlocks = vcat' with { _sep = 6.0 } (map (drawStockholmRowBlock maxIdLength vectorEntries maxEntryLength fullComparisonColLabels) blocks)  
        
extractGapfreeStructure :: String -> String -> String
extractGapfreeStructure alignedSequence regularStructure1 = entryStructure
  where regularsequence1 = map convertToRegularGap alignedSequence
        bpindicestest1 = basePairIndices regularStructure1 [] 0
        sequencegaps = elemIndices '-' regularsequence1
        -- convert incomplete basepairs to .
        incompleteBasepairs = filter (\(i,j) -> elem i sequencegaps || elem j sequencegaps) bpindicestest1
        incompleteIndicesCharacterPairs = concatMap (\(a,b) -> [(a,'.'),(b,'.')]) incompleteBasepairs
        completeBPStructure = V.update (V.fromList regularStructure1) (V.fromList incompleteIndicesCharacterPairs)
        -- remove gap character postitions from structure string
        gapfreeCompleteStructure = V.filter (\(i,_) -> notElem i sequencegaps) (V.indexed completeBPStructure)
        entryStructure = map snd (V.toList gapfreeCompleteStructure)


extractGapfreeIndexedStructure :: String -> String -> [(Int,Char)]
extractGapfreeIndexedStructure alignedSequence regularStructure1 = indexedEntryStructure
  where regularsequence1 = map convertToRegularGap alignedSequence
        bpindicestest1 = basePairIndices regularStructure1 [] 0
        sequencegaps = elemIndices '-' regularsequence1
        -- convert incomplete basepairs to .
        incompleteBasepairs = filter (\(i,j) -> elem i sequencegaps || elem j sequencegaps) bpindicestest1
        incompleteIndicesCharacterPairs = concatMap (\(a,b) -> [(a,'.'),(b,'.')]) incompleteBasepairs
        completeBPStructure = V.update (V.fromList regularStructure1) (V.fromList incompleteIndicesCharacterPairs)
        -- remove gap character postitions from structure string
        gapfreeCompleteStructure = V.filter (\(i,_) -> notElem i sequencegaps) (V.indexed completeBPStructure)
        indexedEntryStructure = (V.toList gapfreeCompleteStructure)        

basePairIndices :: String -> [Int] -> Int -> [(Int,Int)]
basePairIndices (x:xs) ys counter
  | x == '(' = basePairIndices xs (counter:ys) (counter+1)
  | x == ')' = (head ys,counter) : basePairIndices xs (tail ys) (counter+1)
  | x == '.' = [] ++ basePairIndices xs ys (counter+1)
  | otherwise = [] ++ basePairIndices xs ys (counter+1)
basePairIndices [] _ _ = []

isGap :: Char -> Bool
isGap char
  | char == '.' = True
  | char == '-' = True
  | char == ' ' = True
  | char == '\n' = True
  | otherwise = False

convertToRegularGap :: Char -> Char
convertToRegularGap char
  | char == '.' = '-'
  | char == ' ' = '-'
  | char == '\n' = '-'
  | otherwise = char

convertWUSStoDotBracket :: T.Text -> T.Text
convertWUSStoDotBracket wuss = T.pack $ map convertWUSSCharToDotBracket (T.unpack wuss)

convertWUSSCharToDotBracket :: Char -> Char
convertWUSSCharToDotBracket c
  | c == '<' = '('
  | c == '>' = ')'
  | c == '_' = '.'
  | c == '-' = '.'
  | c == '(' = '('
  | c == ')' = ')'
  | c == '.' = '.'
  | c == '[' = '('
  | c == ']' = ')'
  | c == '{' = '('
  | c == '}' = ')'
  | c == '~' = '.'
  | c == ':' = '.'
  | c == ',' = '.'
  | otherwise = c

fillComparisonColLabels :: Int ->  V.Vector (Int, V.Vector (Colour Double)) ->  V.Vector (Int, V.Vector (Colour Double))
fillComparisonColLabels maxEntryLength sparseComparisonColLabels = fullComparisonColLabels
   where fullComparisonColLabels = V.generate maxEntryLength  (makeFullComparisonColLabel sparseComparisonColLabels)

makeFullComparisonColLabel :: V.Vector (Int, V.Vector (Colour Double)) -> Int -> (Int, V.Vector (Colour Double))
makeFullComparisonColLabel sparseComparisonColLabels colIndex = fullComparisonColLabel
  where availableLabel = V.find (\(a,_)-> colIndex == a) sparseComparisonColLabels
        fullComparisonColLabel = fromMaybe (colIndex,V.singleton white) availableLabel

drawStockholmRowBlock :: Int ->  V.Vector (String, V.Vector Char) -> Int -> V.Vector (Int, V.Vector (Colour Double)) -> ((Int, Int), V.Vector (Int, Int, Int)) -> QDiagram Cairo V2 Double Any
drawStockholmRowBlock maxIdLength vectorEntries maxEntryLength comparisonColLabels ((startIndex,endIndex),letterIntervals) = blockSequences
  where indices = [startIndex..safeEndIndex]
        safeEndIndex = if (endIndex-1) > (maxEntryLength-1) then maxEntryLength-1 else endIndex-1
        indexLine = drawStockholmIndexLine maxIdLength indices comparisonColLabels
        blockSequences = indexLine === strutY 2.0 === vcat' with { _sep = 2.0 } (V.toList (V.map (drawStockholmEntryLine maxIdLength vectorEntries) letterIntervals))

drawStockholmIndexLine :: Int -> [Int] -> V.Vector (Int, V.Vector (Colour Double)) -> QDiagram Cairo V2 Double Any
drawStockholmIndexLine maxIdLength indices comparisonColLabels = indexLine
  where --entryText = (spacer ++ indexLetters)      
        spacerLength = maxIdLength + 3
        spacer = replicate spacerLength ' '
        --indexLetters = map show indices
        --indexPositions = maximum (map length indices)
        maxEntryIndex = 1 + maximum indices
        maxEntryText = show maxEntryIndex
        totalBoxYlength = fromIntegral (length maxEntryText) * 2.5
        indexLine = hcat (map setAlignmentLetter spacer) ||| hcat (map (drawStockholmIndexLineCol comparisonColLabels totalBoxYlength) indices)

drawStockholmIndexLineCol :: V.Vector (Int, V.Vector (Colour Double)) -> Double -> Int -> QDiagram Cairo V2 Double Any
drawStockholmIndexLineCol comparisonColLabels totalBoxYlength entryIndex = indexTextBox # translate (r2 (0, (1.25 * ((fromIntegral letterNumber))))) <> colourBoxes # translate (r2 (0, negate ((singleBoxYLength/2) - (totalBoxYlength/ 2)))) 
  where columnNumber = fst comparisonColLabel
        -- comparisonColLabel with index zero holds label for first col
        comparisonColLabel = comparisonColLabels V.! (entryIndex + 1)
        --comparisonColLabel = if (entryIndex + 1) > ((V.length comparisonColLabels) -1) then (entryIndex + 1,V.singleton white) else comparisonColLabels V.! (entryIndex + 1)
        colColours = snd comparisonColLabel
        boxNumber = fromIntegral $ V.length colColours
        singleBoxYLength = totalBoxYlength / boxNumber
        entryText = show columnNumber
        colourBoxes = vcat (V.toList (V.map (colorBox singleBoxYLength) colColours))
        letterNumber = Data.List.length entryText
        textYSpacer = rect 2 (totalBoxYlength - 2.5 * ((fromIntegral letterNumber))) # lw 0.0     
        indexTextBox =  textYSpacer === vcat (map setAlignmentLetter entryText) 

colorBox :: Double -> Colour Double -> QDiagram Cairo V2 Double Any
colorBox singleBoxYLength colColour = rect 2 singleBoxYLength # fc colColour # lw 0.1

drawStockholmEntryLine :: Int -> V.Vector (String, V.Vector Char) -> (Int, Int, Int) -> QDiagram Cairo V2 Double Any
drawStockholmEntryLine maxIdLength aln (seqIndex,currentStart,safeLength) = entryDia
  where entry = aln V.! seqIndex
        entryText = seqId ++ spacer ++ entrySeq
        seqId = fst entry
        entrySeq = V.toList (V.slice currentStart safeLength (snd entry))
        spacerLength = (maxIdLength + 3) - length seqId
        spacer = replicate spacerLength ' '
        entryDia = hcat (map setAlignmentLetter entryText)

drawStockholm :: Int -> S.StockholmAlignment -> QDiagram Cairo V2 Double Any
drawStockholm entriesNumberCutoff aln = alignTL (vcat' with { _sep = 1 } (map (drawStockholmEntry maxIdLength) currentEntries)) === consensusStructureEntry
   where currentEntries = take entriesNumberCutoff (S.sequenceEntries aln)
         --entryNumber = length currentEntries
         maxIdLength = maximum (map (T.length . S.sequenceId) currentEntries)
         consensusStructureEntry = if null (S.columnAnnotations aln) then mempty else drawConsensusStructureEntry maxIdLength (S.columnAnnotations aln)

drawConsensusStructureEntry :: Int -> [S.AnnotationEntry] -> QDiagram Cairo V2 Double Any
drawConsensusStructureEntry maxIdLength entries 
  | isJust maybeSecStructureEntry = hcat (map setAlignmentLetter entryText)
  | otherwise = mempty
  where maybeSecStructureEntry = find ((T.pack "SS_cons"==) . S.tag) entries
        entryText = T.unpack (seqId `T.append` spacer `T.append` S.annotation (fromJust maybeSecStructureEntry))
        seqId = T.pack "SS_cons"
        spacerLength = (maxIdLength + 3) - T.length seqId
        spacer = T.replicate spacerLength (T.pack " ")
        --entryDia = maybe mempty (\entryText -> hcat (map setAlignmentLetter (T.unpack (S.annotation entryText)))) maybeSecStructureEntry


drawStockholmEntry :: Int -> S.SequenceEntry -> QDiagram Cairo V2 Double Any
drawStockholmEntry maxIdLength entry = entryDia
  where entryText = T.unpack (seqId `T.append` spacer `T.append` S.entrySequence entry)
        seqId = S.sequenceId entry
        spacerLength = (maxIdLength + 3) - T.length seqId
        spacer = T.replicate spacerLength (T.pack " ")
        entryDia = hcat (map setAlignmentLetter entryText)

setAlignmentLetter :: Char -> QDiagram Cairo V2 Double Any
setAlignmentLetter echar = textSVG_ (TextOpts linLibertineFont INSIDE_H KERN False 2.5 2.5) [echar] # fc black # fillRule EvenOdd  # lw 0.0 # translate (r2 (negate 0.75, negate 0.75)) <> rect 2 2 # lw 0.0
--setAlignmentLetter echar = alignedText 0.5 0.5 [echar] # fontSize 2.0 <> rect 2 2.5 # lw 0

-- LetterInterval (SeqNr,Start,Length)
makeLetterIntervals :: Int -> Double -> Int -> [((Int,Int),V.Vector (Int,Int,Int))]
makeLetterIntervals seqNumber letterNumberPerRow letterNumber = rowIntervals
  where --rowVector = V.iterateN rowNumber (1+) 0
        rowList = [0..(rowNumber-1)]
        rowNumber = ceiling $ fromIntegral letterNumber / letterNumberPerRow
        rowIntervals = map (setAlignmentInterval (floor letterNumberPerRow) letterNumber seqNumber)  rowList

setAlignmentInterval :: Int -> Int -> Int -> Int -> ((Int,Int),V.Vector (Int,Int,Int))
setAlignmentInterval letterNumberPerRow letterNumber seqNumber rowIndex = ((indexStart,indexEnd),seqLines)
  where seqVector = V.iterateN seqNumber (1+) 0
        seqLines = V.map (setAlignmentLineInterval letterNumberPerRow letterNumber rowIndex) seqVector
        indexStart = rowIndex * letterNumberPerRow
        indexEnd = indexStart + letterNumberPerRow
setAlignmentLineInterval :: Int -> Int -> Int -> Int -> (Int,Int,Int)
setAlignmentLineInterval letterNumberPerRow letterNumber rowIndex seqIndex = (seqIndex,currentStart,safeLength)
  where currentStart = rowIndex * letterNumberPerRow
        rowLength = letterNumberPerRow
        safeLength = if currentStart + rowLength >= letterNumber then letterNumber - currentStart else rowLength

makeVectorEntries :: S.SequenceEntry -> (String, V.Vector Char)
makeVectorEntries entry = (entrySeqId,entrySeq)
  where entrySeq = V.fromList (T.unpack (S.entrySequence entry))
        entrySeqId = T.unpack (S.sequenceId entry)

makeConsensusStructureVectorEntry :: S.AnnotationEntry -> V.Vector (String, V.Vector Char)
makeConsensusStructureVectorEntry entry = V.singleton (entryId,entryTxt)
  where entryTxt = V.fromList (T.unpack (S.annotation entry))
        entryId = T.unpack (S.tag entry)