module System.FilePath.Cryptographic
( CryptoID(..)
, CryptoFileName
, module Data.Binary.SerializationLength
, encrypt
, decrypt
, CryptoIDError(..)
) where
import Data.CryptoID
import Data.CryptoID.Poly hiding (encrypt, decrypt)
import Data.CryptoID.ByteString (cipherBlockSize)
import qualified Data.CryptoID.Poly as Poly (encrypt, decrypt)
import System.FilePath (FilePath)
import qualified Codec.Binary.Base32 as Base32
import Data.CaseInsensitive (CI)
import qualified Data.CaseInsensitive as CI
import Data.Binary
import Data.Binary.SerializationLength
import Data.Encoding.UTF8
import Data.Encoding (decodeStrictByteString, encodeStrictByteString)
import Data.Char (toUpper)
import Data.Ratio ((%))
import Data.List
import qualified Data.ByteString as ByteString
import Control.Monad
import Control.Monad.Catch
import Data.Proxy
import GHC.TypeLits
type CryptoFileName (namespace :: Symbol) = CryptoID namespace (CI FilePath)
paddedLength :: Integral a => a -> a
paddedLength l = bs * ceiling (l % bs)
where bs = fromIntegral cipherBlockSize
encrypt :: forall a m namespace.
( KnownSymbol namespace
, Binary a
, MonadThrow m
, HasFixedSerializationLength a
) => CryptoIDKey -> a -> m (CryptoFileName namespace)
encrypt = Poly.encrypt determineLength $ return . encode
where
determineLength str = do
let l = ByteString.length str
unless (fromIntegral l == natVal (Proxy :: Proxy (SerializationLength a))) $
throwM $ CiphertextConversionFailed str
return . Just $ paddedLength l
encode str = CI.mk . dropWhileEnd (== '=') . decodeStrictByteString UTF8 $ Base32.encode str
decrypt :: forall a m namespace.
( KnownSymbol namespace
, Binary a
, MonadThrow m
, HasFixedSerializationLength a
) => CryptoIDKey -> CryptoFileName namespace -> m a
decrypt = Poly.decrypt $ (\str -> either (const . throwM $ CiphertextConversionFailed str) return $ Base32.decode str) . encodeStrictByteString UTF8 . padding (natVal (Proxy :: Proxy (SerializationLength a))) . map toUpper . CI.original
where
padding l str = str ++ replicate (genericIndex paddingTable $ l' `mod` 5) '='
where
l' = paddedLength l
paddingTable = [0, 6, 4, 3, 1]