{-# LANGUAGE DefaultSignatures, FlexibleContexts, FlexibleInstances,
PolyKinds, ScopedTypeVariables, TypeApplications, TypeOperators #-}
module TreeSitter.Unmarshal
( parseByteString
, FieldName(..)
, Unmarshal(..)
, SymbolMatching(..)
, step
, push
, goto
, peekNode
, peekFieldName
, getFields
) where
import Control.Applicative
import Control.Effect hiding ((:+:))
import Control.Effect.Reader
import Control.Effect.Fail
import Control.Monad (void)
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.Map as Map
import qualified Data.Text as Text
import Data.Text.Encoding
import Foreign.C.String
import Foreign.Marshal.Alloc
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import TreeSitter.Cursor as TS
import TreeSitter.Language as TS
import TreeSitter.Node as TS
import TreeSitter.Parser as TS
import TreeSitter.Tree as TS
import Data.Proxy
import Prelude hiding (fail)
import Data.Maybe (fromMaybe, maybeToList)
import Data.List.NonEmpty (NonEmpty (..))
parseByteString :: Unmarshal t => Ptr TS.Language -> ByteString -> IO (Either String t)
parseByteString language bytestring = withParser language $ \ parser -> withParseTree parser bytestring $ \ treePtr ->
if treePtr == nullPtr then
pure (Left "error: didn't get a root node")
else
withRootNode treePtr $ \ rootPtr ->
withCursor (castPtr rootPtr) $ \ cursor ->
runM (runFail (runReader cursor (runReader bytestring (peekNode >>= unmarshalNodes . maybeToList))))
class Unmarshal a where
unmarshalNodes :: (MonadFail m, Carrier sig m, Member (Reader ByteString) sig, Member (Reader (Ptr Cursor)) sig, MonadIO m) => [Node] -> m a
default unmarshalNodes :: (MonadFail m, Carrier sig m, GUnmarshal (Rep a), Generic a, Member (Reader ByteString) sig, Member (Reader (Ptr Cursor)) sig, MonadIO m) => [Node] -> m a
unmarshalNodes [x] = do
goto (nodeTSNode x)
to <$> gunmarshalNode x
unmarshalNodes [] = fail "expected a node but didn't get one"
unmarshalNodes _ = fail "expected a node but got multiple"
instance Unmarshal Text.Text where
unmarshalNodes [node] = do
bytestring <- ask
let start = fromIntegral (nodeStartByte node)
end = fromIntegral (nodeEndByte node)
pure (decodeUtf8 (slice start end bytestring))
unmarshalNodes [] = fail "expected a node but didn't get one"
unmarshalNodes _ = fail "expected a node but got multiple"
instance Unmarshal a => Unmarshal (Maybe a) where
unmarshalNodes [] = pure Nothing
unmarshalNodes listOfNodes = Just <$> unmarshalNodes listOfNodes
instance (Unmarshal a, Unmarshal b, SymbolMatching a, SymbolMatching b) => Unmarshal (Either a b) where
unmarshalNodes [node] = do
let lhsSymbolMatch = symbolMatch (Proxy @a) node
rhsSymbolMatch = symbolMatch (Proxy @b) node
if lhsSymbolMatch
then Left <$> unmarshalNodes @a [node]
else if rhsSymbolMatch
then Right <$> unmarshalNodes @b [node]
else fail $ showFailure (Proxy @(Either a b)) node
unmarshalNodes [] = fail "expected a node of type (Either a b) but didn't get one"
unmarshalNodes _ = fail "expected a node of type (Either a b) but got multiple"
instance Unmarshal a => Unmarshal [a] where
unmarshalNodes (x:xs) = do
head' <- unmarshalNodes [x]
tail' <- unmarshalNodes xs
pure $ head' : tail'
unmarshalNodes [] = pure []
instance Unmarshal a => Unmarshal (NonEmpty a) where
unmarshalNodes (x:xs) = do
head' <- unmarshalNodes [x]
tail' <- unmarshalNodes xs
pure $ head' :| tail'
unmarshalNodes [] = fail "expected a node but didn't get one"
class SymbolMatching a where
symbolMatch :: Proxy a -> Node -> Bool
showFailure :: Proxy a -> Node -> String
instance SymbolMatching a => SymbolMatching (Maybe a) where
symbolMatch _ = symbolMatch (Proxy @a)
showFailure _ = showFailure (Proxy @a)
instance (SymbolMatching a, SymbolMatching b) => SymbolMatching (Either a b) where
symbolMatch _ = (||) <$> symbolMatch (Proxy @a) <*> symbolMatch (Proxy @b)
showFailure _ = sep <$> showFailure (Proxy @a) <*> showFailure (Proxy @b)
instance SymbolMatching a => SymbolMatching [a] where
symbolMatch _ = symbolMatch (Proxy @a)
showFailure _ = showFailure (Proxy @a)
instance SymbolMatching k => SymbolMatching (M1 C c (M1 S s (K1 i k))) where
symbolMatch _ = symbolMatch (Proxy @k)
showFailure _ = showFailure (Proxy @k)
instance (SymbolMatching f, SymbolMatching g) => SymbolMatching (f :+: g) where
symbolMatch _ = (||) <$> symbolMatch (Proxy @f) <*> symbolMatch (Proxy @g)
showFailure _ = sep <$> showFailure (Proxy @f) <*> showFailure (Proxy @g)
sep :: String -> String -> String
sep a b = a ++ ". " ++ b
step :: (Carrier sig m, Member (Reader (Ptr Cursor)) sig, MonadIO m) => m Bool
step = ask >>= liftIO . ts_tree_cursor_goto_next_sibling
push :: (Carrier sig m, Member (Reader (Ptr Cursor)) sig, MonadIO m) => m a -> m a
push m = do
void $ ask >>= liftIO . ts_tree_cursor_goto_first_child
a <- m
a <$ (ask >>= liftIO . ts_tree_cursor_goto_parent)
goto :: (Carrier sig m, Member (Reader (Ptr Cursor)) sig, MonadIO m) => TSNode -> m ()
goto node = do
cursor <- ask
liftIO (with node (ts_tree_cursor_reset_p cursor))
peekNode :: (Carrier sig m, Member (Reader (Ptr Cursor)) sig, MonadIO m) => m (Maybe Node)
peekNode = do
cursor <- ask
liftIO $ alloca $ \ tsNodePtr -> do
isValid <- ts_tree_cursor_current_node_p cursor tsNodePtr
if isValid then do
node <- alloca $ \ nodePtr -> do
ts_node_poke_p tsNodePtr nodePtr
peek nodePtr
pure (Just node)
else
pure Nothing
peekFieldName :: (Carrier sig m, Member (Reader (Ptr Cursor)) sig, MonadIO m) => m (Maybe FieldName)
peekFieldName = do
cursor <- ask
fieldName <- liftIO $ ts_tree_cursor_current_field_name cursor
if fieldName == nullPtr then
pure Nothing
else
Just . FieldName <$> liftIO (peekCString fieldName)
getFields :: (Carrier sig m, Member (Reader (Ptr Cursor)) sig, MonadIO m) => m (Map.Map FieldName [Node])
getFields = go Map.empty
where go fs = do
node <- peekNode
case node of
Just node' -> do
fieldName <- peekFieldName
keepGoing <- step
let fs' = case fieldName of
Just fieldName' -> Map.insertWith (++) fieldName' [node'] fs
_ -> if nodeIsNamed node' /= 0
then Map.insertWith (++) (FieldName "extraChildren") [node'] fs
else fs
if keepGoing then go fs'
else pure fs'
_ -> pure fs
slice :: Int -> Int -> ByteString -> ByteString
slice start end = take . drop
where drop = B.drop start
take = B.take (end - start)
newtype FieldName = FieldName { getFieldName :: String }
deriving (Eq, Ord, Show)
class GUnmarshal f where
gunmarshalNode :: (MonadFail m, Carrier sig m, Member (Reader ByteString) sig, Member (Reader (Ptr Cursor)) sig, MonadIO m) => Node -> m (f a)
instance GUnmarshal f => GUnmarshal (M1 D c f) where
gunmarshalNode node = M1 <$> gunmarshalNode node
instance GUnmarshal f => GUnmarshal (M1 C c f) where
gunmarshalNode node = M1 <$> gunmarshalNode node
instance GUnmarshal U1 where
gunmarshalNode _ = pure U1
instance {-# OVERLAPPABLE #-} GUnmarshal (M1 S s (K1 c Text.Text)) where
gunmarshalNode node = M1 . K1 <$> unmarshalNodes [node]
instance {-# OVERLAPPABLE #-} (Selector s, Unmarshal k) => GUnmarshal (M1 S s (K1 c k)) where
gunmarshalNode _ = push $ do
fields <- getFields
gunmarshalProductNode fields
instance (GUnmarshalSum f, GUnmarshalSum g, SymbolMatching f, SymbolMatching g) => GUnmarshal (f :+: g) where
gunmarshalNode = gunmarshalSumNode @(f :+: g)
instance (GUnmarshalProduct f, GUnmarshalProduct g) => GUnmarshal (f :*: g) where
gunmarshalNode _ = push $ getFields >>= gunmarshalProductNode @(f :*: g)
class GUnmarshalSum f where
gunmarshalSumNode :: (MonadFail m
, Carrier sig m
, Member (Reader ByteString) sig
, Member (Reader (Ptr Cursor)) sig
, MonadIO m)
=> Node -> m (f a)
instance (Unmarshal k, SymbolMatching k) => GUnmarshalSum (M1 C c (M1 S s (K1 i k))) where
gunmarshalSumNode node = M1 . M1 . K1 <$> unmarshalNodes [node]
instance (GUnmarshalSum f, GUnmarshalSum g, SymbolMatching f, SymbolMatching g) => GUnmarshalSum (f :+: g) where
gunmarshalSumNode node = do
let lhsSymbolMatch = symbolMatch (Proxy @f) node
rhsSymbolMatch = symbolMatch (Proxy @g) node
if lhsSymbolMatch
then L1 <$> gunmarshalSumNode @f node
else if rhsSymbolMatch
then R1 <$> gunmarshalSumNode @g node
else fail $ showFailure (Proxy @f) node `sep` showFailure (Proxy @g) node
class GUnmarshalProduct f where
gunmarshalProductNode :: (MonadFail m, Carrier sig m, Member (Reader ByteString) sig, Member (Reader (Ptr Cursor)) sig, MonadIO m) => Map.Map FieldName [Node] -> m (f a)
instance (GUnmarshalProduct f, GUnmarshalProduct g) => GUnmarshalProduct (f :*: g) where
gunmarshalProductNode fields = (:*:) <$> gunmarshalProductNode @f fields <*> gunmarshalProductNode @g fields
instance (Unmarshal k, Selector c) => GUnmarshalProduct (M1 S c (K1 i k)) where
gunmarshalProductNode fields =
M1 . K1 <$> unmarshalNodes (fromMaybe [] (Map.lookup (FieldName (selName @c undefined)) fields))