module Network.AWS.Authentication (
runAction, isAmzHeader, preSignedURI,
S3Action(..),
mimeEncodeQP, mimeDecode
) where
import Network.AWS.AWSResult
import Network.AWS.AWSConnection
import Network.AWS.ArrowUtils
import Network.HTTP as HTTP hiding (simpleHTTP_)
import Network.HTTP.HandleStream (simpleHTTP_)
import Network.Stream (Result)
import Network.URI as URI
import qualified Data.ByteString.Lazy.Char8 as L
import Data.ByteString.Char8 (pack, unpack)
import Data.HMAC
import Codec.Binary.Base64 (encode, decode)
import Codec.Utils (Octet)
import Data.Char (intToDigit, digitToInt, ord, chr, toLower)
import Data.Bits ((.&.))
import qualified Codec.Binary.UTF8.String as US
import Data.List (sortBy, groupBy, intersperse, isInfixOf)
import Data.Maybe
import System.Time
import System.Locale
import Text.Regex
import Control.Arrow
import Control.Arrow.ArrowTree
import Text.XML.HXT.Arrow.XmlArrow
import Text.XML.HXT.Arrow.XmlOptions
import Text.XML.HXT.DOM.XmlKeywords
import Text.XML.HXT.Arrow.XmlState
import Text.XML.HXT.Arrow.ReadDocument
data S3Action =
S3Action {
s3conn :: AWSConnection,
s3bucket :: String,
s3object :: String,
s3query :: String,
s3metadata :: [(String, String)],
s3body :: L.ByteString,
s3operation :: RequestMethod
} deriving (Show)
requestFromAction :: S3Action
-> HTTP.HTTPRequest L.ByteString
requestFromAction a =
Request { rqURI = URI { uriScheme = "",
uriAuthority = Nothing,
uriPath = qpath,
uriQuery = s3query a,
uriFragment = "" },
rqMethod = s3operation a,
rqHeaders = Header HdrHost (s3Hostname a) :
headersFromAction a,
rqBody = (s3body a)
}
where qpath = '/' : s3object a
headersFromAction :: S3Action
-> [Header]
headersFromAction = map (\(k,v) -> case k of
"Content-Type" -> Header HdrContentType v
"Content-Length" -> Header HdrContentLength v
"Content-MD5" -> Header HdrContentMD5 v
otherwise -> Header (HdrCustom k) (mimeEncodeQP v))
. s3metadata
addContentLengthHeader :: HTTP.HTTPRequest L.ByteString -> HTTP.HTTPRequest L.ByteString
addContentLengthHeader req = insertHeaderIfMissing HdrContentLength conlength req
where conlength = show (L.length (rqBody req))
addAuthenticationHeader :: S3Action
-> HTTP.HTTPRequest L.ByteString
-> HTTP.HTTPRequest L.ByteString
addAuthenticationHeader act req = insertHeader HdrAuthorization auth_string req
where auth_string = "AWS " ++ awsAccessKey conn ++ ":" ++ signature
signature = (makeSignature conn (stringToSign act req))
conn = s3conn act
makeSignature :: AWSConnection
-> String
-> String
makeSignature c s =
encode (hmac_sha1 keyOctets msgOctets)
where keyOctets = string2words (awsSecretKey c)
msgOctets = string2words s
stringToSign :: S3Action -> HTTP.HTTPRequest L.ByteString -> String
stringToSign a r =
canonicalizeHeaders r ++
canonicalizeAmzHeaders r ++
canonicalizeResource a
canonicalizeHeaders :: HTTP.HTTPRequest L.ByteString -> String
canonicalizeHeaders r =
http_verb ++ "\n" ++
hdr_content_md5 ++ "\n" ++
hdr_content_type ++ "\n" ++
dateOrExpiration ++ "\n"
where http_verb = show (rqMethod r)
hdr_content_md5 = get_header HdrContentMD5
hdr_date = get_header HdrDate
hdr_content_type = get_header HdrContentType
get_header h = fromMaybe "" (findHeader h r)
dateOrExpiration = fromMaybe hdr_date (findHeader HdrExpires r)
canonicalizeAmzHeaders :: HTTP.HTTPRequest L.ByteString -> String
canonicalizeAmzHeaders r =
let amzHeaders = filter isAmzHeader (rqHeaders r)
amzHeaderKV = map headerToLCKeyValue amzHeaders
sortedGroupedHeaders = groupHeaders (sortHeaders amzHeaderKV)
uniqueHeaders = combineHeaders sortedGroupedHeaders
in concatMap (\a -> a ++ "\n") (map showHeader uniqueHeaders)
showHeader :: (String, String) -> String
showHeader (k,v) = k ++ ":" ++ removeLeadingTrailingWhitespace(fold_whitespace v)
fold_whitespace :: String -> String
fold_whitespace s = subRegex (mkRegex "\n\r( |\t)+") s " "
removeLeadingTrailingWhitespace :: String -> String
removeLeadingTrailingWhitespace s = subRegex (mkRegex "^\\s+") (subRegex (mkRegex "\\s+$") s "") ""
combineHeaders :: [[(String, String)]] -> [(String, String)]
combineHeaders = map mergeSameHeaders
mergeSameHeaders :: [(String, String)] -> (String, String)
mergeSameHeaders h@(x:_) = let values = map snd h
in ((fst x), (concat $ intersperse "," values))
groupHeaders :: [(String, String)] -> [[(String, String)]]
groupHeaders = groupBy (\a b -> fst a == fst b)
sortHeaders :: [(String, String)] -> [(String, String)]
sortHeaders = sortBy (\a b -> fst a `compare` fst b)
headerToLCKeyValue :: Header -> (String, String)
headerToLCKeyValue (Header k v) = (map toLower (show k), v)
isAmzHeader :: Header -> Bool
isAmzHeader h =
case h of
Header (HdrCustom k) _ -> isPrefix amzHeader k
otherwise -> False
isPrefix :: Eq a => [a] -> [a] -> Bool
isPrefix a b = a == take (length a) b
amzHeader :: String
amzHeader = "x-amz-"
canonicalizeResource :: S3Action -> String
canonicalizeResource a = bucket ++ uri ++ subresource
where uri = '/' : s3object a
bucket = case (s3bucket a) of
b@(_:_) -> '/' : map toLower b
otherwise -> ""
subresource = case (subresource_match) of
[] -> ""
x:_ -> x
subresource_match = filter (\sr -> isInfixOf sr (s3query a))
["?versioning", "?torrent", "?logging", "?acl", "?location"]
addDateToReq :: HTTP.HTTPRequest L.ByteString
-> String
-> HTTP.HTTPRequest L.ByteString
addDateToReq r d = r {HTTP.rqHeaders =
HTTP.Header HTTP.HdrDate d : HTTP.rqHeaders r}
addExpirationToReq :: HTTP.HTTPRequest L.ByteString -> String -> HTTP.HTTPRequest L.ByteString
addExpirationToReq r = addHeaderToReq r . HTTP.Header HTTP.HdrExpires
addHeaderToReq :: HTTP.HTTPRequest L.ByteString -> Header -> HTTP.HTTPRequest L.ByteString
addHeaderToReq r h = r {HTTP.rqHeaders = h : HTTP.rqHeaders r}
s3Hostname :: S3Action -> String
s3Hostname a =
let s3host = awsHost (s3conn a) in
case (s3bucket a) of
b@(_:_) -> b ++ "." ++ s3host
otherwise -> s3host
httpCurrentDate :: IO String
httpCurrentDate =
do c <- getClockTime
let utc_time = (toUTCTime c) {ctTZName = "GMT"}
return $ formatCalendarTime defaultTimeLocale rfc822DateFormat utc_time
string2words :: String -> [Octet]
string2words = US.encode
runAction :: S3Action -> IO (AWSResult (HTTPResponse L.ByteString))
runAction a = runAction' a (s3Hostname a)
runAction' :: S3Action -> String -> IO (AWSResult (HTTPResponse L.ByteString))
runAction' a hostname = do
c <- (openTCPConnection hostname (awsPort (s3conn a)))
cd <- httpCurrentDate
let aReq = addAuthenticationHeader a $
addContentLengthHeader $
addDateToReq (requestFromAction a) cd
result <- simpleHTTP_ c aReq
close c
createAWSResult a result
preSignedURI :: S3Action
-> Integer
-> URI
preSignedURI a e =
let c = (s3conn a)
srv = (awsHost c)
pt = (show (awsPort c))
accessKeyQuery = "AWSAccessKeyId=" ++ awsAccessKey c
beginQuery = case (s3query a) of
"" -> "?"
x -> x ++ "&"
expireQuery = "Expires=" ++ show e
toSign = "GET\n\n\n" ++ show e ++ "\n/" ++ s3bucket a ++ "/" ++ s3object a
sigQuery = "Signature=" ++ urlEncode (makeSignature c toSign)
q = beginQuery ++ accessKeyQuery ++ "&" ++
expireQuery ++ "&" ++ sigQuery
in URI { uriScheme = "http:",
uriAuthority = Just (URIAuth "" srv (':' : pt)),
uriPath = "/" ++ s3bucket a ++ "/" ++ s3object a,
uriQuery = q,
uriFragment = ""
}
createAWSResult :: S3Action -> Result (HTTPResponse L.ByteString) -> IO (AWSResult (HTTPResponse L.ByteString))
createAWSResult a b = either handleError handleSuccess b
where handleError = return . Left . NetworkError
handleSuccess s = case (rspCode s) of
(2,_,_) -> return (Right s)
(3,0,7) -> case (findHeader HdrLocation s) of
Just l -> runAction' a (getHostname l)
Nothing -> return (Left $ AWSError "Temporary Redirect" "Redirect without location header")
(4,0,4) -> return (Left $ AWSError "NotFound" "404 Not Found")
otherwise -> do e <- parseRestErrorXML (L.unpack (rspBody s))
return (Left e)
getHostname :: String -> String
getHostname h = case parseURI h of
Just u -> case (uriAuthority u) of
Just auth -> (uriRegName auth)
Nothing -> ""
Nothing -> ""
parseRestErrorXML :: String -> IO ReqError
parseRestErrorXML x =
do e <- runX (readString [withValidate no] x
>>> processRestError)
case e of
[] -> return (AWSError "NoErrorInMsg"
("HTTP Error condition, but message body"
++ "did not contain error code."))
x:xs -> return x
processRestError = deep (isElem >>> hasName "Error") >>>
split >>> first (text <<< atTag "Code") >>>
second (text <<< atTag "Message") >>>
unsplit (\x y -> AWSError x y)
mimeEncodeQP, mimeDecode :: String -> String
mimeDecode a
| isPrefix utf8qp a =
mimeDecodeQP $ encoded_payload utf8qp a
| isPrefix utf8b64 a =
mimeDecodeB64 $ encoded_payload utf8b64 a
| otherwise =
a
where
utf8qp = "=?UTF-8?Q?"
utf8b64 = "=?UTF-8?B?"
encoded_payload prefix = reverse . drop 2 . reverse . drop (length prefix)
mimeDecodeQP :: String -> String
mimeDecodeQP =
US.decodeString . mimeDecodeQP'
mimeDecodeQP' :: String -> String
mimeDecodeQP' ('=':a:b:rest) =
chr (16 * digitToInt a + digitToInt b)
: mimeDecodeQP' rest
mimeDecodeQP' (h:t) =h : mimeDecodeQP' t
mimeDecodeQP' [] = []
mimeDecodeB64 :: String -> String
mimeDecodeB64 s =
case decode s of
Nothing -> ""
Just a -> US.decode a
mimeEncodeQP s =
if any reservedChar s
then "=?UTF-8?Q?" ++ (mimeEncodeQP' $ US.encodeString s) ++ "?="
else s
mimeEncodeQP' :: String -> String
mimeEncodeQP' [] = []
mimeEncodeQP' (h:t) =
let str = if reservedChar h then escape h else [h]
in str ++ mimeEncodeQP' t
where
escape x =
let y = ord x in
[ '=', intToDigit ((y `div` 16) .&. 0xf), intToDigit (y .&. 0xf) ]
reservedChar :: Char -> Bool
reservedChar x
| xi >= 0x20 && xi <= 0x7e = False
| otherwise = True
where xi = ord x