module Bio.PDB.StructureBuilder.Parallel(parseParallel, parseWithNParallel, joinStructure, joinResult)
where
import Prelude hiding(String)
import Bio.PDB.StructureBuilder.Internals
import Bio.PDB.Structure
import Bio.PDB.EventParser.PDBEvents(PDBEvent(PDBParseError))
import GHC.Conc(numCapabilities)
import Control.Parallel.Strategies
import Bio.PDB.Util.ParFold(parFold1)
import Control.Arrow((&&&))
import qualified Bio.PDB.Structure.List as L
import qualified Data.ByteString.Char8 as BS
import qualified Control.Monad.ST as ST
import Control.Monad.State.Strict as State
import Data.STRef as STRef
partialParse :: FilePath -> String -> (Structure, L.List PDBEvent, Int)
partialParse fname contents = ST.runST $ do initial <- initializeState
(s, e, l) <- State.evalStateT parsing initial
return (s :: Structure, e :: L.List PDBEvent, l :: Int)
where parsing = do parsePDBRec (BS.pack fname) contents (\() !ev -> parseStep ev) ()
closeStructure
s <- State.gets currentStructure
e <- State.gets errors
lnref <- State.gets lineNo
ln <- lift $ STRef.readSTRef lnref
e' <- L.finalize e
return (s, e', ln)
parseParallel = parseWithNParallel numCapabilities
type ParseResult = (Structure, L.List PDBEvent, Int)
joinResult :: ParseResult -> ParseResult -> ParseResult
joinResult (struct1, errs1, ln1) (struct2, errs2, ln2) = (resultStruct, resultErrs, ln2)
where
resultStruct = struct1 `joinStructure` struct2
resultErrs = errs1 L.++ L.map (updateErrorLine ln1) errs2
joinStructure :: Structure -> Structure -> Structure
joinStructure = joiner models (\s m -> s { models = m }) modelId matchModelId joinModel
where
modelId1 `matchModelId` modelId2 | modelId2 == defaultModelId = True
modelId1 `matchModelId` modelId2 | modelId1 == modelId2 = True
modelId1 `matchModelId` modelId2 = False
joinModel :: Model -> Model -> Model
joinModel = joiner chains (\m c -> m { chains = c }) chainId (==) joinChain
joinChain :: Chain -> Chain -> Chain
joinChain = joiner residues (\c r -> c { residues = r }) resId (==) joinResidue
where
resId = resName &&& resSeq &&& insCode
joinResidue = joiner atoms (\r a -> r { atoms = a }) (const ()) (/=) (error "Never happens")
joiner :: (a -> L.List a1)-> (a -> L.List a1 -> t)-> (a1 -> t1)-> (t1 -> t1 -> Bool)-> (a1 -> a1 -> a1)-> a-> a-> t
joiner getter setter idGetter matcher subjoiner = join
where
s1 `join` s2 | len s1 == 0 || len s2 == 0 = s1 `setter` (getter s1 L.++ getter s2)
where
len = L.length . getter
s1 `join` s2 | id1 `matcher` id2 = s1 `setter` L.concat [L.init (getter s1) ,
L.singleton (m1 `subjoiner` m2),
L.tail (getter s2) ]
where
id1 = idGetter m1
id2 = idGetter m2
m1 = L.last $ getter s1
m2 = L.head $ getter s2
s1 `join` s2 = s1 `setter` (getter s1 L.++ getter s2)
updateErrorLine :: Int -> PDBEvent -> PDBEvent
updateErrorLine startingLineNo (PDBParseError lineNo colNo line) = PDBParseError (lineNo + startingLineNo) colNo line
updateErrorLine startingLineNo evt = evt
parseWithNParallel sparks fname input = (struct, errs)
where
chunkLen = ceiling (fromIntegral (BS.length input) / fromIntegral sparks)
chunks = chunkString chunkLen input
pList = map (partialParse fname) chunks
partialResults = pList `using` parList (evalTuple3 rdeepseq r0 r0)
(struct, errs, ln) = parFold1 joinResult partialResults
chunkString :: Int -> String -> [String]
chunkString l s | BS.length s <= l = [s]
chunkString l s | Just n <- BS.elemIndex '\n' (BS.drop l s) = BS.take (l+n+1) s:chunkString l (BS.drop (l+n+1) s)
chunkString l s = [s]