{-# LANGUAGE BangPatterns, PatternGuards #-}
-- | This module leverages sequential builder interface in @Bio.PDB.StructureBuilder.Internal@
-- to run parser and @StructureBuilder@ in parallel with linear speedup.
module Bio.PDB.StructureBuilder.Parallel(parseParallel, parseWithNParallel, joinStructure, joinResult)
where

import           Prelude hiding(String)
import           Bio.PDB.StructureBuilder.Internals
import           Bio.PDB.Structure
import           Bio.PDB.EventParser.PDBEvents(PDBEvent(PDBParseError))
import           GHC.Conc(numCapabilities)
import           Control.Parallel.Strategies
import           Bio.PDB.Util.ParFold(parFold1)
import           Control.Arrow((&&&))
import qualified Bio.PDB.Structure.List     as L
import qualified Data.ByteString.Char8      as BS
import qualified Control.Monad.ST           as ST
import           Control.Monad.State.Strict as State 
import           Data.STRef                 as STRef

-- | Parse a fragment of PDB file, returning final line number (within the chunk.)
partialParse :: FilePath -> String -> (Structure, L.List PDBEvent, Int)
partialParse fname contents = ST.runST $ do initial <- initializeState
                                            (s, e, l)  <- State.evalStateT parsing initial
                                            return (s :: Structure, e :: L.List PDBEvent, l :: Int)
  where parsing = do parsePDBRec (BS.pack fname) contents (\() !ev -> parseStep ev) ()
                     closeStructure
                     s     <- State.gets currentStructure
                     e     <- State.gets errors
                     lnref <- State.gets lineNo
                     ln    <- lift $ STRef.readSTRef lnref
                     e'    <- L.finalize e
                     return (s, e', ln)

-- | Parse file in parallel with as many threads as we have capabilities.
parseParallel = parseWithNParallel numCapabilities
-- TODO: merging
-- | Intermediate result from parsing of PDB chunk.
type ParseResult = (Structure, L.List PDBEvent, Int)

-- | Joins 'ParseResult' from two different chunks and returns a single 'ParseResult'.
joinResult :: ParseResult -> ParseResult -> ParseResult
joinResult (struct1, errs1, ln1) (struct2, errs2, ln2) = (resultStruct, resultErrs, ln2)
  where
    resultStruct = struct1 `joinStructure` struct2
    resultErrs   = errs1 L.++ L.map (updateErrorLine ln1) errs2

-- | Joins 'Structure's resulting from partial parses.
joinStructure ::  Structure -> Structure -> Structure
joinStructure = joiner models (\s m -> s { models = m }) modelId matchModelId joinModel
  where
    modelId1 `matchModelId` modelId2 | modelId2 == defaultModelId = True
    modelId1 `matchModelId` modelId2 | modelId1 == modelId2       = True
    modelId1 `matchModelId` modelId2                              = False

-- | Joins 'Model's resulting from partial parses.
joinModel :: Model -> Model -> Model
joinModel = joiner chains (\m c -> m { chains = c }) chainId (==) joinChain

-- | Joins 'Chain's resulting from partial parses.
joinChain :: Chain -> Chain -> Chain
joinChain = joiner residues (\c r -> c { residues = r }) resId (==) joinResidue
  where
    resId = resName &&& resSeq &&& insCode

-- | Joins 'Residue's resulting from partial parses.
joinResidue = joiner atoms (\r a -> r { atoms = a }) (const ()) (/=) (error "Never happens")

{-# INLINE joiner #-}
-- | Produce joinX function, given:
--   * getter for subordinate component vector,
--   * setter for subordinate component vector,
--   * getter for the id of a subordinate component,
--   * matcher for ids of subordinate components that decides whether they have to be joined,
--   * and joining function for subordinate objects (if they share the same id.)
--
-- This joining function merges two data structures.
joiner :: (a -> L.List a1)-> (a -> L.List a1 -> t)-> (a1 -> t1)-> (t1 -> t1 -> Bool)-> (a1 -> a1 -> a1)-> a-> a-> t
joiner getter setter idGetter matcher subjoiner = join
  where
    s1 `join` s2 | len s1 == 0 || len s2 == 0 = s1 `setter` (getter s1 L.++ getter s2)
      where
        len = L.length . getter
    s1 `join` s2 | id1 `matcher` id2          = s1 `setter` L.concat [L.init      (getter s1)        ,
                                                                      L.singleton (m1 `subjoiner` m2),
                                                                      L.tail      (getter s2)        ]
      where
        id1 = idGetter m1
        id2 = idGetter m2
        m1       = L.last $ getter s1
        m2       = L.head $ getter s2
    s1 `join` s2                              = s1 `setter` (getter s1 L.++ getter s2)

-- | Increments line numbers in 'PDBParseError' records by a given value.
updateErrorLine ::  Int -> PDBEvent -> PDBEvent
updateErrorLine startingLineNo (PDBParseError lineNo colNo line) = PDBParseError (lineNo + startingLineNo) colNo line
updateErrorLine startingLineNo evt                               = evt

--parseParallelWithNSparks :: Int -> FilePath -> String -> [(Structure, List PDBEvent)]
-- | Parse input file with N parallel threads.
parseWithNParallel sparks fname input = (struct, errs)
  where
    chunkLen = ceiling (fromIntegral (BS.length input) / fromIntegral sparks)
    chunks = chunkString chunkLen input
    pList = map (partialParse fname) chunks
    partialResults  = pList `using` parList (evalTuple3 rdeepseq r0 r0)
    (struct, errs, ln)  = parFold1 joinResult partialResults
    --(struct, errs, ln)  = foldl joinResult (head partialResults) (tail partialResults)
-- TODO: correct line numbers! partial parse should return Structure + line number

-- | Splits a ByteString into chunks of given size, and ending at end of line.
chunkString ::  Int -> String -> [String]
chunkString l s | BS.length s <= l                          = [s]
chunkString l s | Just n <- BS.elemIndex '\n' (BS.drop l s) = BS.take (l+n+1) s:chunkString l (BS.drop (l+n+1) s)
chunkString l s                                             = [s]