{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module HaskellWorks.Data.BalancedParens.ParensSeq
( ParensSeq(..)
, mempty
, size
, fromWord64s
, fromPartialWord64s
, toPartialWord64s
, fromBools
, toBools
, splitAt
, take
, drop
, firstChild
, nextSibling
, (<|), (><), (|>)
) where
import Data.Coerce
import Data.Foldable
import Data.Word
import HaskellWorks.Data.BalancedParens.Internal.ParensSeq (Elem (Elem), ParensSeq (ParensSeq), ParensSeqFt)
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.FingerTree (ViewL (..), ViewR (..), (<|), (><), (|>))
import HaskellWorks.Data.Positioning
import Prelude hiding (drop, max, min, splitAt, take)
import qualified Data.List as L
import qualified HaskellWorks.Data.BalancedParens.Internal.ParensSeq as PS
import qualified HaskellWorks.Data.BalancedParens.Internal.Word as W
import qualified HaskellWorks.Data.FingerTree as FT
empty :: ParensSeq
empty = ParensSeq FT.empty
size :: ParensSeq -> Count
size (ParensSeq parens) = PS.size (FT.measure parens :: PS.Measure)
fromWord64s :: Traversable f => f Word64 -> ParensSeq
fromWord64s = foldl go empty
where go :: ParensSeq -> Word64 -> ParensSeq
go ps w = ParensSeq (PS.parens ps |> Elem w 64)
fromPartialWord64s :: Traversable f => f (Word64, Count) -> ParensSeq
fromPartialWord64s = foldl go empty
where go :: ParensSeq -> (Word64, Count) -> ParensSeq
go ps (w, n) = ParensSeq (PS.parens ps |> Elem w n)
toPartialWord64s :: ParensSeq -> [(Word64, Count)]
toPartialWord64s = L.unfoldr go . coerce
where go :: ParensSeqFt -> Maybe ((Word64, Count), ParensSeqFt)
go ft = case FT.viewl ft of
PS.Elem w n :< rt -> Just ((w, coerce n), rt)
FT.EmptyL -> Nothing
fromBools :: [Bool] -> ParensSeq
fromBools = go empty
where go :: ParensSeq -> [Bool] -> ParensSeq
go (ParensSeq ps) (b:bs) = case FT.viewr ps of
FT.EmptyR -> go (ParensSeq (FT.singleton (Elem b' 1))) bs
lt :> Elem w n ->
let newPs = if n >= 64
then ps |> Elem b' 1
else lt |> Elem (w .|. (b' .<. fromIntegral n)) (n + 1)
in go (ParensSeq newPs) bs
where b' = if b then 1 else 0 :: Word64
go ps [] = ps
toBools :: ParensSeq -> [Bool]
toBools ps = toBoolsDiff ps []
toBoolsDiff :: ParensSeq -> [Bool] -> [Bool]
toBoolsDiff ps = mconcat (fmap go (toPartialWord64s ps))
where go :: (Word64, Count) -> [Bool] -> [Bool]
go (w, n) = W.partialToBoolsDiff (fromIntegral n) w
drop :: Count -> ParensSeq -> ParensSeq
drop n ps = snd (splitAt n ps)
take :: Count -> ParensSeq -> ParensSeq
take n ps = fst (splitAt n ps)
splitAt :: Count -> ParensSeq -> (ParensSeq, ParensSeq)
splitAt n (ParensSeq parens) = case FT.split (PS.atSizeBelowZero n) parens of
(lt, rt) -> let
n' = n - PS.size (FT.measure lt :: PS.Measure)
u = 64 - n'
in case FT.viewl rt of
PS.Elem w nw :< rrt -> if n' >= nw
then (ParensSeq lt , ParensSeq rrt )
else (ParensSeq (lt |> PS.Elem ((w .<. u) .>. u) n'), ParensSeq (PS.Elem (w .>. n') (nw - n') <| rrt))
FT.EmptyL -> (ParensSeq lt, ParensSeq FT.empty)
firstChild :: ParensSeq -> Count -> Maybe Count
firstChild ps n = case FT.viewl ft of
PS.Elem w nw :< rt -> if nw >= 2
then case w .&. 3 of
3 -> Just (n + 1)
_ -> Nothing
else if nw >= 1
then case w .&. 1 of
1 -> case FT.viewl rt of
PS.Elem w' nw' :< _ -> if nw' >= 1
then case w' .&. 1 of
1 -> Just (n + 1)
_ -> Nothing
else Nothing
FT.EmptyL -> Nothing
_ -> Nothing
else Nothing
FT.EmptyL -> Nothing
where ParensSeq ft = drop (n - 1) ps
nextSibling :: ParensSeq -> Count -> Maybe Count
nextSibling (ParensSeq ps) n = do
let (lt0, rt0) = PS.ftSplit (PS.atSizeBelowZero (n - 1)) ps
_ <- case FT.viewl rt0 of
PS.Elem w nw :< _ -> if nw >= 1 && w .&. 1 == 1 then Just () else Nothing
FT.EmptyL -> Nothing
let (lt1, rt1) = PS.ftSplit (PS.atSizeBelowZero 1) rt0
let (lt2, rt2) = PS.ftSplit PS.atMinZero rt1
case FT.viewl rt2 of
PS.Elem w nw :< _ -> if nw >= 1 && w .&. 1 == 0 then Just () else Nothing
FT.EmptyL -> Nothing
let (lt3, rt3) = PS.ftSplit (PS.atSizeBelowZero 1) rt2
case FT.viewl rt3 of
PS.Elem w nw :< _ -> if nw >= 1 && w .&. 1 == 1 then Just () else Nothing
FT.EmptyL -> Nothing
return $ 1
+ PS.size (FT.measure lt0 :: PS.Measure)
+ PS.size (FT.measure lt1 :: PS.Measure)
+ PS.size (FT.measure lt2 :: PS.Measure)
+ PS.size (FT.measure lt3 :: PS.Measure)