module Servant.Common.BaseUrl (
  
    BaseUrl (..)
  , InvalidBaseUrlException
  , Scheme (..)
  
  , parseBaseUrl
  , showBaseUrl
) where
import           Control.Monad.Catch (Exception, MonadThrow, throwM)
import           Data.List
import           Data.Typeable
import           GHC.Generics
import           Network.URI hiding (path)
import           Safe
import           Text.Read
data Scheme =
    Http  
  | Https 
  deriving (Show, Eq, Ord, Generic)
data BaseUrl = BaseUrl
  { baseUrlScheme :: Scheme 
  , baseUrlHost   :: String   
  , baseUrlPort   :: Int      
  , baseUrlPath   :: String   
  } deriving (Show, Ord, Generic)
instance Eq BaseUrl where
    BaseUrl a b c path == BaseUrl a' b' c' path'
        = a == a' && b == b' && c == c' && s path == s path'
        where s ('/':x) = x
              s x       = x
showBaseUrl :: BaseUrl -> String
showBaseUrl (BaseUrl urlscheme host port path) =
  schemeString ++ "//" ++ host ++ (portString </> path)
    where
      a </> b = if "/" `isPrefixOf` b || null b then a ++ b else a ++ '/':b
      schemeString = case urlscheme of
        Http  -> "http:"
        Https -> "https:"
      portString = case (urlscheme, port) of
        (Http, 80) -> ""
        (Https, 443) -> ""
        _ -> ":" ++ show port
data InvalidBaseUrlException = InvalidBaseUrlException String deriving (Show, Typeable)
instance Exception InvalidBaseUrlException
parseBaseUrl :: MonadThrow m => String -> m BaseUrl
parseBaseUrl s = case parseURI (removeTrailingSlash s) of
  
  
  Just (URI "http:" (Just (URIAuth "" host (':' : (readMaybe -> Just port)))) path "" "") ->
    return (BaseUrl Http host port path)
  Just (URI "http:" (Just (URIAuth "" host "")) path "" "") ->
    return (BaseUrl Http host 80 path)
  Just (URI "https:" (Just (URIAuth "" host (':' : (readMaybe -> Just port)))) path "" "") ->
    return (BaseUrl Https host port path)
  Just (URI "https:" (Just (URIAuth "" host "")) path "" "") ->
    return (BaseUrl Https host 443 path)
  _ -> if "://" `isInfixOf` s
    then throwM (InvalidBaseUrlException $ "Invalid base URL: " ++ s)
    else parseBaseUrl ("http://" ++ s)
 where
  removeTrailingSlash str = case lastMay str of
    Just '/' -> init str
    _ -> str