module Language.Python.Pickle where
import Control.Applicative ((<$>), (<*), (*>))
import Control.Monad.State
import Control.Monad.Writer
import qualified Data.ByteString as S
import Data.Attoparsec hiding (parse, take)
import qualified Data.Attoparsec as A
import Data.Attoparsec.ByteString.Char8 (decimal, double, signed)
import Data.Int (Int32)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.List (foldl')
import Data.Map (Map)
import qualified Data.Map as M
import Data.Serialize.Get (getWord16le, getWord32le, getWord64be, runGet)
import Data.Serialize.Put (runPut, putByteString, putWord8, putWord16le, putWord32le, putWord64be, Put)
import Data.Word (Word32, Word64)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr (castPtr)
import Foreign.Storable (peek)
import System.IO.Unsafe (unsafePerformIO)
parse :: S.ByteString -> Either String [OpCode]
parse = parseOnly (many1 $ choice opcodes)
unpickle :: S.ByteString -> Either String Value
unpickle s = do
xs <- parse s
unpickle' xs
pickle :: Value -> S.ByteString
pickle value = runPut $ do
putByteString "\128\STX"
mapM_ serialize . runPickler $ pickle' value
putByteString "."
opcodes :: [Parser OpCode]
opcodes =
[ true, false, int, binint, binint1, binint2, long, long1, long4
, string', binstring, short_binstring
, none
, newtrue, newfalse
, unicode, binunicode
, float, binfloat
, empty_list, append, appends, list
, empty_tuple, tuple, tuple1, tuple2, tuple3
, empty_dict, dict, setitem, setitems
, pop, dup, mark, popmark
, get', binget, long_binget, put', binput, long_binput
, ext1, ext2, ext4
, global, reduce, build, inst, obj, newobj
, proto, stop
, persid, binpersid
]
int, binint, binint1, binint2, long, long1, long4 :: Parser OpCode
int = string "I" *> (INT <$> decimalInt)
binint = string "J" *> (BININT <$> int4)
binint1 = string "K" *> (BININT1 . fromIntegral <$> anyWord8)
binint2 = string "M" *> (BININT2 <$> uint2)
long = string "L" *> (LONG <$> decimalLong)
long1 = string "\138" *> (LONG1 <$> decodeLong1)
long4 = string "\139" *> (LONG4 <$> decodeLong4)
string', binstring, short_binstring :: Parser OpCode
string' = string "S" *> (STRING <$> stringnl)
binstring = string "T" *> (BINSTRING <$> undefined)
short_binstring = do
_ <- string "U"
i <- fromIntegral <$> anyWord8
s <- A.take i
return $ SHORT_BINSTRING s
none :: Parser OpCode
none = string "N" *> return NONE
true, false, newtrue, newfalse :: Parser OpCode
true = string "I01" *> return NEWTRUE
false = string "I00" *> return NEWFALSE
newtrue = string "\136" *> return NEWTRUE
newfalse = string "\137" *> return NEWFALSE
unicode, binunicode :: Parser OpCode
unicode = string "V" *> (UNICODE <$> undefined)
binunicode = string "X" *> (BINUNICODE <$> undefined)
float, binfloat :: Parser OpCode
float = string "F" *> (FLOAT <$> doubleFloat)
binfloat = string "G" *> (BINFLOAT <$> float8)
empty_list, append, appends, list :: Parser OpCode
empty_list = string "]" *> return EMPTY_LIST
append = string "a" *> return APPEND
appends = string "e" *> return APPENDS
list = string "l" *> return LIST
empty_tuple, tuple, tuple1, tuple2, tuple3 :: Parser OpCode
empty_tuple = string ")" *> return EMPTY_TUPLE
tuple = string "t" *> return TUPLE
tuple1 = string "\133" *> return TUPLE1
tuple2 = string "\134" *> return TUPLE2
tuple3 = string "\135" *> return TUPLE3
empty_dict, dict, setitem, setitems :: Parser OpCode
empty_dict = string "}" *> return EMPTY_DICT
dict = string "d" *> return DICT
setitem = string "s" *> return SETITEM
setitems = string "u" *> return SETITEMS
pop, dup, mark, popmark :: Parser OpCode
pop = string "0" *> return POP
dup = string "2" *> return DUP
mark = string "(" *> return MARK
popmark = string "1" *> return POP_MARK
get', binget, long_binget, put', binput, long_binput :: Parser OpCode
get' = string "g" *> (GET <$> decimalInt)
binget = string "h" *> (BINGET . fromIntegral <$> anyWord8)
long_binget = string "j" *> (LONG_BINGET <$> undefined)
put' = string "p" *> (PUT <$> decimalInt)
binput = string "q" *> (BINPUT . fromIntegral <$> anyWord8)
long_binput = string "r" *> (LONG_BINPUT <$> undefined)
ext1, ext2, ext4 :: Parser OpCode
ext1 = string "\130" *> (EXT1 <$> undefined)
ext2 = string "\131" *> (EXT2 <$> undefined)
ext4 = string "\132" *> (EXT4 <$> undefined)
global, reduce, build, inst, obj, newobj :: Parser OpCode
global = string "c" *> (uncurry GLOBAL <$> undefined)
reduce = string "R" *> return REDUCE
build = string "b" *> return BUILD
inst = string "i" *> (uncurry INST <$> undefined)
obj = string "o" *> return OBJ
newobj = string "\129" *> return NEWOBJ
proto, stop :: Parser OpCode
proto = string "\128" *> (PROTO . fromIntegral <$> anyWord8)
stop = string "." *> return STOP
persid, binpersid :: Parser OpCode
persid = string "P" *> (PERSID <$> undefined)
binpersid = string "Q" *> return BINPERSID
decimalInt :: Parser Int
decimalInt = signed decimal <* string "\n"
doubleFloat = double <* string "\n"
decimalLong :: Parser Int
decimalLong = signed decimal <* string "L\n"
decodeLong1 :: Parser Int
decodeLong1 = do
n <- fromIntegral <$> anyWord8
if n > 0
then do
ns <- A.take n
let a = fst . fst $ toLong ns
return $ if S.last ns > 127 then negate $ 256 ^ S.length ns a else a
else return 0
where toLong = S.mapAccumL (\(a, b) w -> ((a + 256 ^ b * fromIntegral w, b + 1), w)) (0, 0)
decodeLong4 = undefined
stringnl = choice
[ string "'" *> takeTill (== 39) <* string "'\n"
]
float8 :: Parser Double
float8 = do
w <- runGet getWord64be <$> A.take 8
case w of
Left err -> fail err
Right x -> return $ coerce x
where
coerce :: Word64 -> Double
coerce x = unsafePerformIO $ with x $ \p ->
peek (castPtr p) :: IO Double
int4 :: Parser Int
int4 = do
w <- runGet getWord32le <$> A.take 4
case w of
Left err -> fail err
Right x -> return . fromIntegral $ coerce x
where
coerce :: Word32 -> Int32
coerce x = unsafePerformIO $ with x $ \p ->
peek (castPtr p) :: IO Int32
uint2 :: Parser Int
uint2 = do
w <- runGet getWord16le <$> A.take 2
case w of
Left err -> fail err
Right x -> return $ fromIntegral x
serialize :: OpCode -> Put
serialize opcode = case opcode of
BINGET i -> putByteString "h" >> putWord8 (fromIntegral i)
BINPUT i -> putByteString "q" >> putWord8 (fromIntegral i)
BININT i -> putByteString "J" >> putWord32le (fromIntegral i)
BININT1 i -> putByteString "K" >> putWord8 (fromIntegral i)
BININT2 i -> putByteString "M" >> putUint2 i
NONE -> putByteString "N"
NEWTRUE -> putByteString "\136"
NEWFALSE -> putByteString "\137"
LONG1 0 -> putByteString "\138\NUL"
LONG1 i -> putByteString "\138" >> encodeLong1 i
BINFLOAT d -> putByteString "G" >> putFloat8 d
SHORT_BINSTRING s -> do
putByteString "U"
putWord8 . fromIntegral $ S.length s
putByteString s
EMPTY_DICT -> putByteString "}"
EMPTY_LIST -> putByteString "]"
EMPTY_TUPLE -> putByteString ")"
TUPLE -> putByteString "t"
TUPLE1 -> putByteString "\133"
TUPLE2 -> putByteString "\134"
TUPLE3 -> putByteString "\135"
MARK -> putByteString "("
SETITEM -> putByteString "s"
SETITEMS -> putByteString "u"
APPEND -> putByteString "a"
APPENDS -> putByteString "e"
x -> error $ "serialize: " ++ show x
putFloat8 :: Double -> Put
putFloat8 d = putWord64be (coerce d)
where
coerce :: Double -> Word64
coerce x = unsafePerformIO $ with x $ \p ->
peek (castPtr p) :: IO Word64
putUint2 :: Int -> Put
putUint2 d = putWord16le (fromIntegral d)
encodeLong1 :: Int -> Put
encodeLong1 i = do
let xs = f $ abs i
f j | j < 256 = [fromIntegral j]
| otherwise = let (n, r) = j `divMod` 256 in fromIntegral r : f n
e = 256 ^ length xs
if i < 0
then do
if (abs i) > e `div` 2
then do
putWord8 (fromIntegral $ length xs + 1)
mapM_ putWord8 $ f (e + i)
putWord8 255
else do
putWord8 (fromIntegral $ length xs)
mapM_ putWord8 $ f (e + i)
else
if i > e `div` 2 1
then do
putWord8 (fromIntegral $ length xs + 1)
mapM_ putWord8 xs
putWord8 0
else do
putWord8 (fromIntegral $ length xs)
mapM_ putWord8 xs
data OpCode =
INT Int
| BININT Int
| BININT1 Int
| BININT2 Int
| LONG Int
| LONG1 Int
| LONG4 Int
| STRING S.ByteString
| BINSTRING S.ByteString
| SHORT_BINSTRING S.ByteString
| NONE
| NEWTRUE
| NEWFALSE
| UNICODE S.ByteString
| BINUNICODE S.ByteString
| FLOAT Double
| BINFLOAT Double
| EMPTY_LIST
| APPEND
| APPENDS
| LIST
| EMPTY_TUPLE
| TUPLE
| TUPLE1
| TUPLE2
| TUPLE3
| EMPTY_DICT
| DICT
| SETITEM
| SETITEMS
| POP
| DUP
| MARK
| POP_MARK
| GET Int
| BINGET Int
| LONG_BINGET Int
| PUT Int
| BINPUT Int
| LONG_BINPUT Int
| EXT1 Int
| EXT2 Int
| EXT4 Int
| GLOBAL S.ByteString S.ByteString
| REDUCE
| BUILD
| INST S.ByteString S.ByteString
| OBJ
| NEWOBJ
| PROTO Int
| STOP
| PERSID S.ByteString
| BINPERSID
deriving Show
data Value =
Dict (Map Value Value)
| List [Value]
| Tuple [Value]
| None
| Bool Bool
| BinInt Int
| BinLong Int
| BinFloat Double
| BinString S.ByteString
| MarkObject
deriving (Eq, Ord, Show)
unpickle' :: [OpCode] -> Either String Value
unpickle' xs = execute xs [] (IM.empty)
type Stack = [Value]
type Memo = IntMap Value
execute :: [OpCode] -> Stack -> Memo -> Either String Value
execute [] [value] _ = Right value
execute (op:ops) stack memo = case executeOne op stack memo of
Left err -> Left err
Right (stack', memo') -> execute ops stack' memo'
execute _ _ _ = Left "`execute` unimplemented"
executePartial :: [OpCode] -> Stack -> Memo -> (Stack, Memo, [OpCode])
executePartial [] stack memo = (stack, memo, [])
executePartial (op:ops) stack memo = case executeOne op stack memo of
Left _ -> (stack, memo, op:ops)
Right (stack', memo') -> executePartial ops stack' memo'
executeOne :: OpCode -> Stack -> Memo -> Either String (Stack, Memo)
executeOne EMPTY_DICT stack memo = return (Dict M.empty: stack, memo)
executeOne EMPTY_LIST stack memo = return (List []: stack, memo)
executeOne EMPTY_TUPLE stack memo = return (Tuple []: stack, memo)
executeOne (PUT i) (s:stack) memo = return (s:stack, IM.insert i s memo)
executeOne (GET i) stack memo = executeLookup i stack memo
executeOne (BINPUT i) (s:stack) memo = return (s:stack, IM.insert i s memo)
executeOne (BINGET i) stack memo = executeLookup i stack memo
executeOne NONE stack memo = return (None:stack, memo)
executeOne NEWTRUE stack memo = return (Bool True:stack, memo)
executeOne NEWFALSE stack memo = return (Bool False:stack, memo)
executeOne (INT i) stack memo = return (BinInt i:stack, memo)
executeOne (BININT i) stack memo = return (BinInt i:stack, memo)
executeOne (BININT1 i) stack memo = return (BinInt i:stack, memo)
executeOne (BININT2 i) stack memo = return (BinInt i:stack, memo)
executeOne (LONG i) stack memo = return (BinLong i:stack, memo)
executeOne (LONG1 i) stack memo = return (BinLong i:stack, memo)
executeOne (FLOAT d) stack memo = return (BinFloat d:stack, memo)
executeOne (BINFLOAT d) stack memo = return (BinFloat d:stack, memo)
executeOne (STRING s) stack memo = return (BinString s:stack, memo)
executeOne (SHORT_BINSTRING s) stack memo = return (BinString s:stack, memo)
executeOne MARK stack memo = return (MarkObject:stack, memo)
executeOne TUPLE stack memo = executeTuple [] stack memo
executeOne TUPLE1 (a:stack) memo = return (Tuple [a]:stack, memo)
executeOne TUPLE2 (b:a:stack) memo = return (Tuple [a, b]:stack, memo)
executeOne TUPLE3 (c:b:a:stack) memo = return (Tuple [a, b, c]:stack, memo)
executeOne DICT stack memo = executeDict [] stack memo
executeOne SETITEM stack memo = executeSetitem stack memo
executeOne SETITEMS stack memo = executeSetitems [] stack memo
executeOne LIST stack memo = executeList [] stack memo
executeOne APPEND stack memo = executeAppend stack memo
executeOne APPENDS stack memo = executeAppends [] stack memo
executeOne (PROTO _) stack memo = return (stack, memo)
executeOne STOP stack memo = Right (stack, memo)
executeOne op _ _ = Left $ "Can't execute opcode " ++ show op ++ "."
executeLookup :: Int -> Stack -> Memo -> Either String (Stack, Memo)
executeLookup k stack memo = case IM.lookup k memo of
Nothing -> Left "Unknown memo key"
Just s -> Right (s:stack, memo)
executeTuple :: Monad m => [Value] -> Stack -> Memo -> m ([Value], Memo)
executeTuple l (MarkObject:stack) memo = return (Tuple l:stack, memo)
executeTuple l (a:stack) memo = executeTuple (a : l) stack memo
executeDict :: Monad m => [(Value, Value)] -> Stack -> Memo -> m ([Value], Memo)
executeDict l (MarkObject:stack) memo = return (l `addToDict` Dict M.empty:stack, memo)
executeDict l (a:b:stack) memo = executeDict ((b, a) : l) stack memo
executeList :: Monad m => [Value] -> Stack -> Memo -> m ([Value], Memo)
executeList l (MarkObject:stack) memo = return (List l:stack, memo)
executeList l (x:stack) memo = executeList (x : l) stack memo
executeSetitem :: Monad m => Stack -> Memo -> m ([Value], Memo)
executeSetitem (v:k:Dict d:stack) memo = return (Dict (M.insert k v d):stack, memo)
executeSetitems :: Monad m => [(Value, Value)] -> Stack -> Memo -> m ([Value], Memo)
executeSetitems l (MarkObject:Dict d:stack) memo = return (l `addToDict` Dict d:stack, memo)
executeSetitems l (a:b:stack) memo = executeSetitems ((b, a) : l) stack memo
executeAppend :: Monad m => Stack -> Memo -> m ([Value], Memo)
executeAppend (x:List xs:stack) memo = return (List (xs ++ [x]):stack, memo)
executeAppends :: Monad m => [Value] -> Stack -> Memo -> m ([Value], Memo)
executeAppends l (MarkObject:List xs:stack) memo = return (List (xs ++ l):stack, memo)
executeAppends l (x:stack) memo = executeAppends (x : l) stack memo
addToDict :: [(Value, Value)] -> Value -> Value
addToDict l (Dict d) = Dict $ foldl' add d l
where add d' (k, v) = M.insert k v d'
newtype Pickler a = Pickler { runP :: WriterT [OpCode] (State (Map Value Int)) a }
deriving (Functor, Applicative, Monad, MonadWriter [OpCode], MonadState (Map Value Int))
runPickler :: Pickler () -> [OpCode]
runPickler p = evalState (execWriterT (runP p)) M.empty
pickle' :: Value -> Pickler ()
pickle' value = do
m <- get
case M.lookup value m of
Just k -> tell [BINGET k]
Nothing -> case value of
Dict d -> pickleDict d
List xs -> pickleList xs
Tuple xs -> pickleTuple xs
None -> tell [NONE]
Bool True -> tell [NEWTRUE]
Bool False -> tell [NEWFALSE]
BinInt i -> pickleBinInt i
BinLong i -> pickleBinLong i
BinFloat d -> pickleBinFloat d
BinString s -> pickleBinString s
x -> error $ "TODO: pickle " ++ show x
binput' :: Value -> Pickler ()
binput' value = do
i <- gets M.size
m <- get
put (M.insert value i m)
tell [BINPUT i]
pickleDict :: Map Value Value -> Pickler ()
pickleDict d = do
tell [EMPTY_DICT]
binput' (Dict d)
let kvs = M.toList d
case kvs of
[] -> return ()
[(k,v)] -> pickle' k >> pickle' v >> tell [SETITEM]
_ -> do
tell [MARK]
mapM_ (\(k, v) -> pickle' k >> pickle' v) kvs
tell [SETITEMS]
pickleList :: [Value] -> Pickler ()
pickleList xs = do
tell [EMPTY_LIST]
binput' (List xs)
case xs of
[] -> return ()
[x] -> pickle' x >> tell [APPEND]
_ -> do
tell [MARK]
mapM_ pickle' xs
tell [APPENDS]
pickleTuple :: [Value] -> Pickler ()
pickleTuple [] = tell [EMPTY_TUPLE]
pickleTuple [a] = do
pickle' a
tell [TUPLE1]
binput' (Tuple [a])
pickleTuple [a, b] = do
pickle' a
pickle' b
tell [TUPLE2]
binput' (Tuple [a, b])
pickleTuple [a, b, c] = do
pickle' a
pickle' b
pickle' c
tell [TUPLE3]
binput' (Tuple [a, b, c])
pickleTuple xs = do
tell [MARK]
mapM_ pickle' xs
tell [TUPLE]
binput' (Tuple xs)
pickleBinInt :: Int -> Pickler ()
pickleBinInt i | i >= 0 && i < 256 = tell [BININT1 i]
| i >= 256 && i < 65536 = tell [BININT2 i]
| otherwise = tell [BININT i]
pickleBinLong :: Int -> Pickler ()
pickleBinLong i = tell [LONG1 i]
pickleBinFloat :: Double -> Pickler ()
pickleBinFloat d = do
tell [BINFLOAT d]
pickleBinString :: S.ByteString -> Pickler ()
pickleBinString s = do
tell [SHORT_BINSTRING s]
binput' (BinString s)
dictGet :: Value -> Value -> Either String (Maybe Value)
dictGet (Dict d) v = return $ M.lookup v d
dictGet _ _ = Left "dictGet: not a dict."
dictGet' :: Value -> Value -> Either String Value
dictGet' (Dict d) v = case M.lookup v d of
Just value -> return value
Nothing -> Left "dictGet': no such key."
dictGet' _ _ = Left "dictGet': not a dict."
dictGetString :: Value -> S.ByteString -> Either String S.ByteString
dictGetString (Dict d) s = case M.lookup (BinString s) d of
Just (BinString s') -> return s'
_ -> Left "dictGetString: not a dict, or no such key."