{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
-- | This module is intended to be imported qualified
--
-- @
-- import qualified Ldap.Client as Ldap
-- @
module Ldap.Client
  ( with
  , Host(..)
  , defaultTlsSettings
  , insecureTlsSettings
  , PortNumber
  , Ldap
  , LdapError(..)
  , ResponseError(..)
  , Type.ResultCode(..)
    -- * Bind
  , Password(..)
  , bind
  , externalBind
    -- * Search
  , search
  , SearchEntry(..)
    -- ** Search modifiers
  , Search
  , Mod
  , Type.Scope(..)
  , scope
  , size
  , time
  , typesOnly
  , Type.DerefAliases(..)
  , derefAliases
  , Filter(..)
    -- * Modify
  , modify
  , Operation(..)
    -- * Add
  , add
    -- * Delete
  , delete
    -- * ModifyDn
  , RelativeDn(..)
  , modifyDn
    -- * Compare
  , compare
    -- * Extended
  , Oid(..)
  , extended
    -- * Miscellanous
  , Dn(..)
  , Attr(..)
  , AttrValue
  , AttrList
    -- * Re-exports
  , NonEmpty
  ) where

#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative ((<$>))
#endif
import qualified Control.Concurrent.Async as Async
import           Control.Concurrent.STM (atomically, throwSTM)
import           Control.Concurrent.STM.TMVar (putTMVar)
import           Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue)
import           Control.Exception (Exception, Handler(..), bracket, throwIO, catch, catches)
import           Control.Monad (forever)
import qualified Data.ASN1.BinaryEncoding as Asn1
import qualified Data.ASN1.Encoding as Asn1
import qualified Data.ASN1.Error as Asn1
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as ByteString.Lazy
import           Data.Foldable (asum)
import           Data.Function (fix)
import           Data.List.NonEmpty (NonEmpty((:|)))
import qualified Data.Map.Strict as Map
import           Data.Monoid (Endo(appEndo))
import           Data.Text (Text)
#if __GLASGOW_HASKELL__ < 710
import           Data.Traversable (traverse)
#endif
import           Data.Typeable (Typeable)
import           Network.Connection (Connection)
import qualified Network.Connection as Conn
import           Prelude hiding (compare)
import qualified System.IO.Error as IO

import           Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1))
import           Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1)
import qualified Ldap.Asn1.Type as Type
import           Ldap.Client.Internal
import           Ldap.Client.Bind (Password(..), bind, externalBind)
import           Ldap.Client.Search
  ( search
  , Search
  , Mod
  , scope
  , size
  , time
  , typesOnly
  , derefAliases
  , Filter(..)
  , SearchEntry(..)
  )
import           Ldap.Client.Modify (Operation(..), modify, RelativeDn(..), modifyDn)
import           Ldap.Client.Add (add)
import           Ldap.Client.Delete (delete)
import           Ldap.Client.Compare (compare)
import           Ldap.Client.Extended (Oid(..), extended, noticeOfDisconnectionOid)

{-# ANN module ("HLint: ignore Use first" :: String) #-}


newLdap :: IO Ldap
newLdap = Ldap
  <$> newTQueueIO

-- | Various failures that can happen when working with LDAP.
data LdapError =
    IOError !IOError             -- ^ Network failure.
  | ParseError !Asn1.ASN1Error   -- ^ Invalid ASN.1 data received from the server.
  | ResponseError !ResponseError -- ^ An LDAP operation failed.
  | DisconnectError !Disconnect  -- ^ Notice of Disconnection has been received.
    deriving (Show, Eq)

newtype WrappedIOError = WrappedIOError IOError
    deriving (Show, Eq, Typeable)

instance Exception WrappedIOError

data Disconnect = Disconnect !Type.ResultCode !Dn !Text
    deriving (Show, Eq, Typeable)

instance Exception Disconnect

-- | The entrypoint into LDAP.
--
-- It catches all LDAP-related exceptions.
with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a)
with host port f = do
  context <- Conn.initConnectionContext
  bracket (Conn.connectTo context params) Conn.connectionClose (\conn ->
    bracket newLdap unbindAsync (\l -> do
      inq  <- newTQueueIO
      outq <- newTQueueIO
      as   <- traverse Async.async
        [ input inq conn
        , output outq conn
        , dispatch l inq outq
        , f l
        ]
      fmap (Right . snd) (Async.waitAnyCancel as)))
 `catches`
  [ Handler (\(WrappedIOError e) -> return (Left (IOError e)))
  , Handler (return . Left . ParseError)
  , Handler (return . Left . ResponseError)
  ]
 where
  params = Conn.ConnectionParams
    { Conn.connectionHostname =
        case host of
          Plain h -> h
          Tls   h _ -> h
    , Conn.connectionPort = port
    , Conn.connectionUseSecure =
        case host of
          Plain  _ -> Nothing
          Tls _ settings -> pure settings
    , Conn.connectionUseSocks = Nothing
    }

defaultTlsSettings :: Conn.TLSSettings
defaultTlsSettings = Conn.TLSSettingsSimple
  { Conn.settingDisableCertificateValidation = False
  , Conn.settingDisableSession = False
  , Conn.settingUseServerName = False
  }

insecureTlsSettings :: Conn.TLSSettings
insecureTlsSettings = Conn.TLSSettingsSimple
  { Conn.settingDisableCertificateValidation = True
  , Conn.settingDisableSession = False
  , Conn.settingUseServerName = False
  }

input :: FromAsn1 a => TQueue a -> Connection -> IO b
input inq conn = wrap . flip fix [] $ \loop chunks -> do
  chunk <- Conn.connectionGet conn 8192
  case ByteString.length chunk of
    0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
    _ -> do
      let chunks' = chunk : chunks
      case Asn1.decodeASN1 Asn1.BER (ByteString.Lazy.fromChunks (reverse chunks')) of
        Left  Asn1.ParsingPartial
                   -> loop chunks'
        Left  e    -> throwIO e
        Right asn1 -> do
          flip fix asn1 $ \loop' asn1' ->
            case parseAsn1 asn1' of
              Nothing -> return ()
              Just (asn1'', a) -> do
                atomically (writeTQueue inq a)
                loop' asn1''
          loop []

output :: ToAsn1 a => TQueue a -> Connection -> IO b
output out conn = wrap . forever $ do
  msg <- atomically (readTQueue out)
  Conn.connectionPut conn (encode (toAsn1 msg))
 where
  encode x = Asn1.encodeASN1' Asn1.DER (appEndo x [])

dispatch
  :: Ldap
  -> TQueue (Type.LdapMessage Type.ProtocolServerOp)
  -> TQueue (Type.LdapMessage Request)
  -> IO a
dispatch Ldap { client } inq outq =
  flip fix (Map.empty, 1) $ \loop (!req, !counter) ->
    loop =<< atomically (asum
      [ do New new var <- readTQueue client
           writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing)
           return (Map.insert (Type.Id counter) ([], var) req, counter + 1)
      , do Type.LdapMessage mid op _
               <- readTQueue inq
           res <- case op of
             Type.BindResponse {}          -> done mid op req
             Type.SearchResultEntry {}     -> saveUp mid op req
             Type.SearchResultReference {} -> return req
             Type.SearchResultDone {}      -> done mid op req
             Type.ModifyResponse {}        -> done mid op req
             Type.AddResponse {}           -> done mid op req
             Type.DeleteResponse {}        -> done mid op req
             Type.ModifyDnResponse {}      -> done mid op req
             Type.CompareResponse {}       -> done mid op req
             Type.ExtendedResponse {}      -> probablyDisconnect mid op req
             Type.IntermediateResponse {}  -> saveUp mid op req
           return (res, counter)
      ])
 where
  saveUp mid op res =
    return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res)

  done mid op req =
    case Map.lookup mid req of
      Nothing -> return req
      Just (stack, var) -> do
        putTMVar var (op :| stack)
        return (Map.delete mid req)

  probablyDisconnect (Type.Id 0)
                     (Type.ExtendedResponse
                       (Type.LdapResult code
                                        (Type.LdapDn (Type.LdapString dn))
                                        (Type.LdapString reason)
                                        _)
                       moid _)
                     req =
    case moid of
      Just (Type.LdapOid oid)
        | Oid oid == noticeOfDisconnectionOid -> throwSTM (Disconnect code (Dn dn) reason)
      _ -> return req
  probablyDisconnect mid op req = done mid op req

wrap :: IO a -> IO a
wrap m = m `catch` (throwIO . WrappedIOError)