{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Language.LSP.Client.Decoding where

import Control.Exception (catch, throw)
import Control.Monad (liftM2)
import Data.Aeson (Result (Error, Success), Value, decode)
import Data.Aeson.Types (parse)
import Data.ByteString.Lazy (LazyByteString)
import Data.ByteString.Lazy.Char8 qualified as LazyByteString
import Data.Dependent.Map (DMap)
import Data.Dependent.Map qualified as DMap
import Data.Functor
import Data.Functor.Const
import Data.Functor.Product (Product (Pair))
import Data.IxMap (IxMap)
import Data.IxMap qualified as IxMap
import Data.Maybe (fromJust, fromMaybe)
import Language.LSP.Client.Exceptions
import Language.LSP.Protocol.Message (FromServerMessage, FromServerMessage' (FromServerMess, FromServerRsp), LspId, MessageDirection (..), MessageKind (..), Method, SClientMethod, SMethod, TNotificationMessage, TResponseMessage (..), parseServerMessage)
import System.IO (Handle, hGetLine)
import System.IO.Error (isEOFError)
import Prelude hiding (id)

{- | Fetches the next message bytes based on
 the Content-Length header
-}
getNextMessage :: Handle -> IO LazyByteString
getNextMessage :: Handle -> IO LazyByteString
getNextMessage Handle
h = do
    [([Char], [Char])]
headers <- Handle -> IO [([Char], [Char])]
getHeaders Handle
h
    case [Char] -> Int
forall a. Read a => [Char] -> a
read ([Char] -> Int) -> ([Char] -> [Char]) -> [Char] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [Char]
forall a. HasCallStack => [a] -> [a]
init ([Char] -> Int) -> Maybe [Char] -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> [([Char], [Char])] -> Maybe [Char]
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [Char]
"Content-Length" [([Char], [Char])]
headers of
        Maybe Int
Nothing -> SessionException -> IO LazyByteString
forall a e. Exception e => e -> a
throw SessionException
NoContentLengthHeader
        Just Int
size -> Handle -> Int -> IO LazyByteString
LazyByteString.hGet Handle
h Int
size

getHeaders :: Handle -> IO [(String, String)]
getHeaders :: Handle -> IO [([Char], [Char])]
getHeaders Handle
h = do
    [Char]
l <- IO [Char] -> (IOError -> IO [Char]) -> IO [Char]
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (Handle -> IO [Char]
hGetLine Handle
h) IOError -> IO [Char]
forall a. IOError -> a
eofHandler
    let ([Char]
name, [Char]
val) = (Char -> Bool) -> [Char] -> ([Char], [Char])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
':') [Char]
l
    if [Char] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
val then [([Char], [Char])] -> IO [([Char], [Char])]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [] else (([Char]
name, Int -> [Char] -> [Char]
forall a. Int -> [a] -> [a]
drop Int
2 [Char]
val) :) ([([Char], [Char])] -> [([Char], [Char])])
-> IO [([Char], [Char])] -> IO [([Char], [Char])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Handle -> IO [([Char], [Char])]
getHeaders Handle
h
  where
    eofHandler :: IOError -> a
    eofHandler :: forall a. IOError -> a
eofHandler IOError
e
        | IOError -> Bool
isEOFError IOError
e = SessionException -> a
forall a e. Exception e => e -> a
throw SessionException
UnexpectedServerTermination
        | Bool
otherwise = IOError -> a
forall a e. Exception e => e -> a
throw IOError
e

type RequestMap = IxMap LspId RequestCallback

emptyRequestMap :: RequestMap
emptyRequestMap :: RequestMap
emptyRequestMap = RequestMap
forall {a} (k :: a -> *) (f :: a -> *). IxMap k f
IxMap.emptyIxMap

data RequestCallback (m :: Method 'ClientToServer 'Request) = RequestCallback
    { forall (m :: Method 'ClientToServer 'Request).
RequestCallback m -> TResponseMessage m -> IO ()
requestCallback :: TResponseMessage m -> IO ()
    , forall (m :: Method 'ClientToServer 'Request).
RequestCallback m -> SClientMethod m
requestMethod :: SClientMethod m
    }

type NotificationMap = DMap SMethod NotificationCallback

emptyNotificationMap :: NotificationMap
emptyNotificationMap :: NotificationMap
emptyNotificationMap = NotificationMap
forall a. Monoid a => a
mempty

newtype NotificationCallback (m :: Method 'ServerToClient 'Notification) = NotificationCallback
    { forall (m :: Method 'ServerToClient 'Notification).
NotificationCallback m -> TNotificationMessage m -> IO ()
notificationCallback :: TNotificationMessage m -> IO ()
    }

instance Semigroup (NotificationCallback m) where
    (NotificationCallback TNotificationMessage m -> IO ()
c1) <> :: NotificationCallback m
-> NotificationCallback m -> NotificationCallback m
<> (NotificationCallback TNotificationMessage m -> IO ()
c2) = (TNotificationMessage m -> IO ()) -> NotificationCallback m
forall (m :: Method 'ServerToClient 'Notification).
(TNotificationMessage m -> IO ()) -> NotificationCallback m
NotificationCallback ((TNotificationMessage m -> IO ()) -> NotificationCallback m)
-> (TNotificationMessage m -> IO ()) -> NotificationCallback m
forall a b. (a -> b) -> a -> b
$ (IO () -> IO () -> IO ())
-> (TNotificationMessage m -> IO ())
-> (TNotificationMessage m -> IO ())
-> TNotificationMessage m
-> IO ()
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
(*>) TNotificationMessage m -> IO ()
c1 TNotificationMessage m -> IO ()
c2

instance Monoid (NotificationCallback m) where
    mempty :: NotificationCallback m
mempty = (TNotificationMessage m -> IO ()) -> NotificationCallback m
forall (m :: Method 'ServerToClient 'Notification).
(TNotificationMessage m -> IO ()) -> NotificationCallback m
NotificationCallback (IO () -> TNotificationMessage m -> IO ()
forall a b. a -> b -> a
const (IO () -> TNotificationMessage m -> IO ())
-> IO () -> TNotificationMessage m -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

updateRequestMap :: LspId m -> RequestCallback m -> RequestMap -> RequestMap
updateRequestMap :: forall (m :: Method 'ClientToServer 'Request).
LspId m -> RequestCallback m -> RequestMap -> RequestMap
updateRequestMap = ((RequestMap -> Maybe RequestMap -> RequestMap
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> RequestMap
forall a. HasCallStack => [Char] -> a
error [Char]
"updateRequestMap: duplicate key registration") .) .) ((RequestCallback m -> RequestMap -> Maybe RequestMap)
 -> RequestCallback m -> RequestMap -> RequestMap)
-> (LspId m -> RequestCallback m -> RequestMap -> Maybe RequestMap)
-> LspId m
-> RequestCallback m
-> RequestMap
-> RequestMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LspId m -> RequestCallback m -> RequestMap -> Maybe RequestMap
forall {a} (k :: a -> *) (m :: a) (f :: a -> *).
IxOrd k =>
k m -> f m -> IxMap k f -> Maybe (IxMap k f)
IxMap.insertIxMap

appendNotificationCallback :: SMethod m -> NotificationCallback m -> NotificationMap -> NotificationMap
appendNotificationCallback :: forall (m :: Method 'ServerToClient 'Notification).
SMethod m
-> NotificationCallback m -> NotificationMap -> NotificationMap
appendNotificationCallback = (NotificationCallback m
 -> NotificationCallback m -> NotificationCallback m)
-> SMethod m
-> NotificationCallback m
-> NotificationMap
-> NotificationMap
forall {k1} (k2 :: k1 -> *) (f :: k1 -> *) (v :: k1).
GCompare k2 =>
(f v -> f v -> f v) -> k2 v -> f v -> DMap k2 f -> DMap k2 f
DMap.insertWith' NotificationCallback m
-> NotificationCallback m -> NotificationCallback m
forall a. Semigroup a => a -> a -> a
(<>)

removeNotificationCallback :: SMethod (m :: Method 'ServerToClient 'Notification) -> NotificationMap -> NotificationMap
removeNotificationCallback :: forall (m :: Method 'ServerToClient 'Notification).
SMethod m -> NotificationMap -> NotificationMap
removeNotificationCallback = SMethod m -> NotificationMap -> NotificationMap
forall {k1} (k2 :: k1 -> *) (f :: k1 -> *) (v :: k1).
GCompare k2 =>
k2 v -> DMap k2 f -> DMap k2 f
DMap.delete

decodeFromServerMsg :: LazyByteString -> RequestMap -> ((FromServerMessage, IO ()), RequestMap)
decodeFromServerMsg :: LazyByteString
-> RequestMap -> ((FromServerMessage, IO ()), RequestMap)
decodeFromServerMsg LazyByteString
bytes RequestMap
reqMap = Result
  (FromServerMessage' (Product RequestCallback (Const RequestMap)))
-> ((FromServerMessage, IO ()), RequestMap)
unP (Result
   (FromServerMessage' (Product RequestCallback (Const RequestMap)))
 -> ((FromServerMessage, IO ()), RequestMap))
-> Result
     (FromServerMessage' (Product RequestCallback (Const RequestMap)))
-> ((FromServerMessage, IO ()), RequestMap)
forall a b. (a -> b) -> a -> b
$ (Value
 -> Parser
      (FromServerMessage' (Product RequestCallback (Const RequestMap))))
-> Value
-> Result
     (FromServerMessage' (Product RequestCallback (Const RequestMap)))
forall a b. (a -> Parser b) -> a -> Result b
parse Value
-> Parser
     (FromServerMessage' (Product RequestCallback (Const RequestMap)))
p Value
obj
  where
    obj :: Value
obj = Maybe Value -> Value
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Value -> Value) -> Maybe Value -> Value
forall a b. (a -> b) -> a -> b
$ LazyByteString -> Maybe Value
forall a. FromJSON a => LazyByteString -> Maybe a
decode LazyByteString
bytes :: Value
    p :: Value
-> Parser
     (FromServerMessage' (Product RequestCallback (Const RequestMap)))
p = LookupFunc
  'ClientToServer (Product RequestCallback (Const RequestMap))
-> Value
-> Parser
     (FromServerMessage' (Product RequestCallback (Const RequestMap)))
forall (a :: Method 'ClientToServer 'Request -> *).
LookupFunc 'ClientToServer a
-> Value -> Parser (FromServerMessage' a)
parseServerMessage (LookupFunc
   'ClientToServer (Product RequestCallback (Const RequestMap))
 -> Value
 -> Parser
      (FromServerMessage' (Product RequestCallback (Const RequestMap))))
-> LookupFunc
     'ClientToServer (Product RequestCallback (Const RequestMap))
-> Value
-> Parser
     (FromServerMessage' (Product RequestCallback (Const RequestMap)))
forall a b. (a -> b) -> a -> b
$ \(LspId m
lid :: LspId m) ->
        let (Maybe (RequestCallback m)
maybeCallback, RequestMap
newMap) = LspId m -> RequestMap -> (Maybe (RequestCallback m), RequestMap)
forall {a} (k :: a -> *) (m :: a) (f :: a -> *).
IxOrd k =>
k m -> IxMap k f -> (Maybe (f m), IxMap k f)
IxMap.pickFromIxMap LspId m
lid RequestMap
reqMap
         in Maybe (RequestCallback m)
maybeCallback Maybe (RequestCallback m)
-> (RequestCallback m
    -> (SMethod m, Product RequestCallback (Const RequestMap) m))
-> Maybe (SMethod m, Product RequestCallback (Const RequestMap) m)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \RequestCallback m
c -> (RequestCallback m
c.requestMethod, RequestCallback m
-> Const RequestMap m
-> Product RequestCallback (Const RequestMap) m
forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> Product f g a
Pair RequestCallback m
c (RequestMap -> Const RequestMap m
forall {k} a (b :: k). a -> Const a b
Const RequestMap
newMap))
    unP :: Result
  (FromServerMessage' (Product RequestCallback (Const RequestMap)))
-> ((FromServerMessage, IO ()), RequestMap)
unP (Success (FromServerMess SMethod m
m TMessage m
msg)) = ((SMethod m -> TMessage m -> FromServerMessage
forall (t :: MessageKind) (m :: Method 'ServerToClient t)
       (a :: Method 'ClientToServer 'Request -> *).
SMethod m -> TMessage m -> FromServerMessage' a
FromServerMess SMethod m
m TMessage m
msg, () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()), RequestMap
reqMap)
    unP (Success (FromServerRsp (Pair RequestCallback m
c (Const RequestMap
newMap)) TResponseMessage m
msg)) =
        ((SMethod m -> TResponseMessage m -> FromServerMessage
forall (m :: Method 'ClientToServer 'Request)
       (a :: Method 'ClientToServer 'Request -> *).
a m -> TResponseMessage m -> FromServerMessage' a
FromServerRsp RequestCallback m
c.requestMethod TResponseMessage m
msg, RequestCallback m
c.requestCallback TResponseMessage m
msg), RequestMap
newMap)
    unP (Error [Char]
e) = [Char] -> ((FromServerMessage, IO ()), RequestMap)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ((FromServerMessage, IO ()), RequestMap))
-> [Char] -> ((FromServerMessage, IO ()), RequestMap)
forall a b. (a -> b) -> a -> b
$ [Char]
"Error decoding " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Value -> [Char]
forall a. Show a => a -> [Char]
show Value
obj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" :" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
e