module Network.Protocol.OAuth.Request (Request(..),HTTPMethod(..),Parameter,PercentEncoding(encode,encodes,decode,decodes),append_param,show_url,show_oauthurl,show_oauthheader,show_urlencoded,read_urlencoded,(>>+)) where
import Data.Bits as B
import qualified Data.ByteString.Lazy as B1
import qualified Data.ByteString.Lazy.UTF8 as B2
import qualified Data.ByteString.Lazy.Char8 as B3
import qualified Data.Word as W
import qualified Data.Char as C
import qualified Data.List as L
type Parameter = (String,Maybe String)
data HTTPMethod = GET
| POST
| DELETE
| PUT
deriving (Show,Read,Eq)
class PercentEncoding a where
encode :: a -> B1.ByteString
encodes :: [a] -> B1.ByteString
encodes = B1.concat . map encode
decode :: B1.ByteString -> (a,B1.ByteString)
decodes :: B1.ByteString -> [a]
decodes = L.unfoldr decode'
where
decode' bs | B1.null bs = Nothing
| otherwise = (Just . decode) bs
data Request = HTTP { ssl :: Bool,
method :: HTTPMethod,
host :: String,
port :: Int,
path :: String,
params :: [Parameter]
}
deriving (Show,Read,Eq)
append_param :: Request -> (String,Maybe String) -> Request
append_param r kv = let o_params = params r
n_params = kv : o_params
in r { params = n_params }
read_urlencoded :: B1.ByteString -> [Parameter]
read_urlencoded u | B1.null u = []
| otherwise = (map param' . map keyval' . B1.split 0x26) u
where
keyval' s = let (k,v) = B1.break (==0x3d) s
in (k, B1.drop 1 v)
param' (k,v) | B1.null v = (decodes k,Nothing)
| otherwise = (decodes k,(Just . decodes) v)
show_url :: Request -> B1.ByteString
show_url (HTTP s m h p0 p1 ps) = B1.concat [endpoint', path', query']
where
endpoint' | s && p0==443 = B3.pack $ "https://" ++ h
| s = B3.pack $ "https://" ++ h ++ (':':(show p0))
| not s && p0==80 = B3.pack $ "http://" ++ h
| otherwise = B3.pack $ "http://" ++ h ++ (':':(show p0))
path' = (B1.cons 0x2f . B1.concat . L.intersperse (B1.singleton 0x2f) . map encodes . _path_comp) p1
query' | m/=GET || null ps = B1.empty
| otherwise = (B1.cons 0x3f . show_urlencoded) ps
show_oauthurl :: Request -> B1.ByteString
show_oauthurl req = let params' = params req
req' = req { params = filter (not . L.isPrefixOf "oauth_" . fst) params' }
in show_url req'
show_oauthheader :: String
-> Request
-> B1.ByteString
show_oauthheader realm (HTTP _ _ _ _ _ p) | B1.null params' = realm'
| otherwise = B1.concat [realm', B1.singleton 0x2c, params']
where
encodes' s = B1.concat [B1.singleton 0x22, encodes s, B1.singleton 0x22]
params' = (_urlencode encodes' 0x2c . filter (L.isPrefixOf "oauth_" . fst)) p
realm' = B3.pack ("OAuth realm=\"" ++ realm ++ "\"")
show_urlencoded :: [Parameter] -> B1.ByteString
show_urlencoded = _urlencode encodes 0x26
(>>+) :: Request -> (String,Maybe String) -> Request
(>>+) = append_param
instance PercentEncoding Char where
encode = B1.pack . concat . map enc' . B1.unpack . B2.fromString . (:[])
where
enc' b | elem b whitelist' = [b]
| otherwise = let b0 = b .&. 0x0F
b1 = B.shiftR (b .&. 0xF0) 4
in ((37:) . map (fromIntegral . C.ord . C.toUpper . C.intToDigit . fromIntegral)) [b1,b0]
whitelist' = [0x61..0x7a] ++ [0x41..0x5a] ++ [0x30..0x39] ++ [0x2d,0x2e,0x5f,0x7e]
decode bytes = let c0 = (head . decodes) bytes
b0 = encode c0
in (c0, B1.drop (B1.length b0) bytes)
decodes = B2.toString . B1.pack . fold' . B1.unpack
where
fold' (37:b1:b0:bs) = let b1' = (fromIntegral . C.digitToInt . C.chr . fromIntegral) b1
b0' = (fromIntegral . C.digitToInt . C.chr . fromIntegral) b0
bl = (B.shiftL b1' 4) .&. 0xF0
br = b0' .&. 0x0F
in (bl .|. br) : fold' bs
fold' (b:bs) = b : fold' bs
fold' [] = []
_urlencode :: (String -> B1.ByteString) -> W.Word8 -> [Parameter] -> B1.ByteString
_urlencode ve s p | null p = B1.empty
| otherwise = (B1.init . foldr fold' B1.empty . L.sort) p
where
fold' (k,Nothing) = B1.append (B1.concat [encodes k, B1.singleton 0x3d, B1.singleton s])
fold' (k,Just v) = B1.append (B1.concat [encodes k, B1.singleton 0x3d, ve v, B1.singleton s])
_path_comp :: String -> [String]
_path_comp p = (filter (not . null) . L.unfoldr unfold') p ++ trailing'
where
unfold' p1 = case (break (=='/') p1)
of ([],[]) -> Nothing
(l,r) -> Just (l,drop 1 r)
trailing' | last p=='/' = [[]]
| otherwise = []