{-# language BangPatterns #-}
{-# language DeriveFunctor #-}
{-# language DerivingStrategies #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language UnboxedTuples #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
module Automata.Nfst
(
Nfst
, evaluate
, evaluateAscii
, union
, toDfst
, toNfsa
, rejection
, Builder
, State
, build
, state
, transition
, epsilon
, accept
) where
import Automata.Internal (State(..),Epsilon(..),Nfsa(..),Dfsa(..),TransitionNfsa(..),toDfsaMapping)
import Automata.Internal.Transducer (Nfst(..),Dfst(..),TransitionNfst(..),MotionDfst(..),Edge(..),EdgeDest(..),epsilonClosure,rejection,union)
import Control.Monad.ST (runST)
import Data.ByteString (ByteString)
import Data.Foldable (for_,fold)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import Data.Monoid (Any(..))
import Data.Primitive (Array,indexArray)
import Data.Set (Set)
import Debug.Trace
import qualified Data.ByteString.Char8 as BC
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.Set.Unboxed as SU
import qualified Data.Map.Interval.DBTSLL as DM
import qualified Data.Map.Lifted.Unlifted as MLN
import qualified Data.Primitive.Contiguous as C
import qualified Data.Foldable as F
debugTrace :: Show a => a -> a
debugTrace = id
evaluate :: forall f t m. (Foldable f, Ord t, Ord m) => Nfst t m -> f t -> Set [m]
evaluate (Nfst transitions finals) tokens = S.unions $ M.elems $ M.filterWithKey
(\k _ -> SU.member k finals)
(F.foldl' step (M.unionsWith (<>) (map (\s -> M.singleton s (S.singleton [])) (SU.toList (transitionNfstEpsilon (C.index transitions 0))))) tokens)
where
step :: Map Int (Set [m]) -> t -> Map Int (Set [m])
step active token = M.unionsWith (<>) $ M.foldlWithKey'
( \xs state outputSets -> MLN.foldlWithKey'
(\zs outputTokenNext nextStates -> M.unionsWith (<>) (map (\s -> M.singleton s (S.mapMonotonic (outputTokenNext:) outputSets)) (SU.toList nextStates)) : zs)
xs
(DM.lookup token (transitionNfstConsume (C.index transitions state)))
) [] active
evaluateAscii :: forall m. Ord m => Nfst Char m -> ByteString -> Set [m]
evaluateAscii (Nfst transitions finals) tokens = S.unions $ M.elems $ M.filterWithKey
(\k _ -> SU.member k finals)
(BC.foldl' step (M.unionsWith (<>) (map (\s -> M.singleton s (S.singleton [])) (SU.toList (transitionNfstEpsilon (C.index transitions 0))))) tokens)
where
step :: Map Int (Set [m]) -> Char -> Map Int (Set [m])
step active token = M.unionsWith (<>) $ M.foldlWithKey'
( \xs state outputSets -> MLN.foldlWithKey'
(\zs outputTokenNext nextStates -> M.unionsWith (<>) (map (\s -> M.singleton s (S.mapMonotonic (outputTokenNext:) outputSets)) (SU.toList nextStates)) : zs)
xs
(DM.lookup token (transitionNfstConsume (C.index transitions state)))
) [] active
toDfst :: forall t m. (Ord t, Bounded t, Enum t, Monoid m) => Nfst t m -> Dfst t m
toDfst x@(Nfst tx _) =
let (mapping,Dfsa t0 f) = toDfsaMapping (toNfsa x)
mapping' = debugTrace mapping
revMapping :: Map Int (SU.Set Int)
revMapping = debugTrace $ M.foldlWithKey' (\acc k v -> M.insertWith (<>) v k acc) M.empty mapping'
t1 = C.imap
(\source m -> DM.mapBijection
(\dest ->
let oldSources = fromMaybe (error "Automata.Nfst.toDfst: missing old source") (M.lookup source revMapping)
oldDests = fromMaybe (error "Automata.Nfst.toDfst: missing old dest") (M.lookup dest revMapping)
newOutput = SU.foldMap (\oldSource -> DM.foldMap (MLN.foldMapWithKey' (\output oldDestStates -> if getAny (SU.foldMap (\oldDest -> Any (SU.member oldDest oldDests)) oldDestStates) then output else mempty)) (transitionNfstConsume (indexArray tx oldSource))) oldSources
in MotionDfst dest newOutput
) m
) t0
in Dfst t1 f
toNfsa :: Nfst t m -> Nfsa t
toNfsa (Nfst t f) = Nfsa
(fmap (\(TransitionNfst eps m) -> TransitionNfsa eps (DM.map (MLN.foldlWithKey' (\acc _ x -> acc <> x) mempty) m)) t)
f
newtype Builder t m s a = Builder (Int -> [Edge t m] -> [Epsilon] -> [Int] -> Result t m a)
deriving stock (Functor)
data Result t m a = Result !Int ![Edge t m] ![Epsilon] ![Int] a
deriving stock (Functor)
instance Applicative (Builder t m s) where
pure a = Builder (\i es eps fs -> Result i es eps fs a)
Builder f <*> Builder g = Builder $ \i es eps fs -> case f i es eps fs of
Result i' es' eps' fs' x -> case g i' es' eps' fs' of
Result i'' es'' eps'' fs'' y -> Result i'' es'' eps'' fs'' (x y)
instance Monad (Builder t m s) where
Builder f >>= g = Builder $ \i es eps fs -> case f i es eps fs of
Result i' es' eps' fs' a -> case g a of
Builder g' -> g' i' es' eps' fs'
state :: Builder t m s (State s)
state = Builder $ \i edges eps final -> Result (i + 1) edges eps final (State i)
accept :: State s -> Builder t m s ()
accept (State n) = Builder $ \i edges eps final -> Result i edges eps (n : final) ()
transition ::
t
-> t
-> m
-> State s
-> State s
-> Builder t m s ()
transition lo hi output (State source) (State dest) =
Builder $ \i edges eps final -> Result i (Edge source dest lo hi output : edges) eps final ()
epsilon ::
State s
-> State s
-> Builder t m s ()
epsilon (State source) (State dest) =
Builder $ \i edges eps final -> Result i edges (if source /= dest then Epsilon source dest : eps else eps) final ()
build :: forall t m a. (Bounded t, Ord t, Enum t, Monoid m, Ord m) => (forall s. State s -> Builder t m s a) -> Nfst t m
build fromStartState =
case state >>= fromStartState of
Builder f -> case f 0 [] [] [] of
Result totalStates edges epsilons final _ ->
let ts0 = runST $ do
transitions <- C.replicateM totalStates (TransitionNfst SU.empty (DM.pure mempty))
outbounds <- C.replicateM totalStates []
epsilonArr <- C.replicateM totalStates []
for_ epsilons $ \(Epsilon source destination) -> do
edgeDests0 <- C.read epsilonArr source
let !edgeDests1 = destination : edgeDests0
C.write epsilonArr source edgeDests1
(epsilonArr' :: Array [Int]) <- C.unsafeFreeze epsilonArr
for_ edges $ \(Edge source destination lo hi output) -> do
edgeDests0 <- C.read outbounds source
let !edgeDests1 = EdgeDest destination lo hi output : edgeDests0
C.write outbounds source edgeDests1
(outbounds' :: Array [EdgeDest t m]) <- C.unsafeFreeze outbounds
flip C.imapMutable' transitions $ \i (TransitionNfst _ _) ->
let dests = C.index outbounds' i
eps = C.index epsilonArr' i
in TransitionNfst
( SU.fromList eps )
( mconcat
( map
(\(EdgeDest dest lo hi output) ->
DM.singleton mempty lo hi (MLN.singleton output (SU.singleton dest)) :: DM.Map t (MLN.Map m (SU.Set Int))
)
dests
)
)
C.unsafeFreeze transitions
ts1 = C.imap (\s (TransitionNfst eps consume) -> TransitionNfst (epsilonClosure ts0 (SU.singleton s <> eps)) (DM.map (MLN.map (epsilonClosure ts0)) consume)) ts0
in Nfst ts1 (SU.fromList final)