-----------------------------------------------------------------------------
-- |
-- Module      :  Data.F2
-- Copyright   :  (c) Marcel Fourné 2011-2012
-- License     :  BSD3
-- Maintainer  :  Marcel Fourné (hecc@bitrot.dyndns.org)
--
-- A timing attack resistant F(2^e) backend, all operations on little-endian data in unboxed bit vectors
--
-----------------------------------------------------------------------------

module Data.F2 (
  Data.F2.F2,
  Data.F2.add,
  Data.F2.shift,
  Data.F2.mul,
  Data.F2.reduceBy,
  Data.F2.pow,
  Data.F2.fromInteger,
  Data.F2.toInteger,
  Data.F2.length,
  Data.F2.even,
  Data.F2.odd,
  Data.F2.div,
  Data.F2.bininv
  )
       where

import Numeric
import Data.Char
import Data.Maybe
import qualified Data.List(map)
import qualified Data.Vector.Unboxed as V
import qualified Data.Bits as B
import qualified Data.Bit as Bit
import qualified Data.Vector.Unboxed.Bit as VB

type F2 = V.Vector Bit.Bit

-- |the binary representation of an Integer
binary :: Integer -> [Char]
binary = flip (showIntAtBase (2::Integer) intToDigit) []

-- |binary addition of @a@ and @b@
add :: F2 -> F2 -> F2
add a b = let l1 = V.length a
              l2 = V.length b
          in if l1 <= l2 then VB.zipWords (B.xor) (VB.pad l2 a) b
             else VB.zipWords (B.xor) a (VB.pad l1 b)

-- |a simple bitshift where @n@ shifts right, a negative @n@ shifts left
shift :: F2 -> Int -> F2
shift a n = if n == 0 then a
            else if n > 0 then V.replicate n (Bit.fromBool False) V.++ a
                 else V.take ((V.length a) + n) a

-- |binary multiplication of @a@ and @b@
mul :: F2 -> F2 -> F2
mul a b = let l1 = V.length a
              l2 = V.length b
              len = l1 + l2
              nullen = V.replicate len (Bit.fromBool False)
              pseudo = V.replicate l2 (Bit.fromBool False)
              fun a1 b1 | not $ V.null a1 = let ltemp = (V.length a1) - 1
                                            in if V.last a1 == Bit.fromBool True 
                                            -- real branch
                                            then fun (V.take ltemp a1) (add b1 (shift b ltemp))
                                            -- for timing-attack-resistance xor with 0s
                                            else fun (V.take ltemp a1) (add b1 (shift pseudo ltemp))
                        | otherwise = b1
          in elimFalses $ fun a nullen

-- |polynomial reduction of @a@ via @r@             
reduceBy :: F2 -> F2 -> F2
reduceBy a r | V.length r == 1 && V.head r == VB.fromBool True = V.singleton $ VB.fromBool False
             | V.length r == 1 && V.head r == VB.fromBool False = a
             | otherwise = let lr = V.length r
                               pseudo = V.replicate lr (Bit.fromBool False)
                               fun z | V.length z >= lr = let ltemp = (V.length z) - 1
                                                          in if V.last z == Bit.fromBool True
                                                             -- real branch
                                                             then fun (V.take ltemp $ add z (shift r (ltemp - lr)))
                                                             -- for timing-attack-resistance xor with 0s
                                                             else fun (V.take ltemp $ add z (shift pseudo (ltemp - lr)))
                                     | otherwise = z
                           in elimFalses $ fun a

-- | the power function, @b@ ^ @k@, using Montgomery ladder and some low-@k@ hardcoding against overheads
pow :: F2 -> F2 -> F2
pow b k | k == Data.F2.fromInteger 0 = V.singleton $ Bit.fromBool True
        | k == Data.F2.fromInteger 1 = b
        | k == Data.F2.fromInteger 2 = mul b b
        | k == Data.F2.fromInteger 3 = mul b $ mul b b
        | otherwise = let power2 a = mul a a
                          ex p1 p2 i
                            | i < 0 = p1
                            | not $ Bit.toBool $ k V.! i = ex (power2 p1) (mul p1 p2) (i - 1)
                            | otherwise = ex (mul p1 p2) (power2 p2) (i - 1)
                      in ex b (power2 b) ((V.length k) - 2)

-- | a helper function to shorten @a@ to length @n@
shortenTo :: F2 -> Int -> F2
shortenTo a n = V.take n a
                
-- | a helper function to shorten all MSB-leading "0" from @a@, this shortens unreduced results from multiplications
elimFalses :: F2 -> F2
elimFalses a = let i = V.length a
                   r = V.reverse a
                   find = VB.first (Bit.fromBool True) r
               in if find == Nothing then V.singleton $ VB.fromBool False
                  else shortenTo a $ i - (fromJust find)

-- |conversion helper function
fromInteger :: Integer -> F2
fromInteger z = let helper a = if a == '1' then Bit.fromBool True
                               else Bit.fromBool False
                    bin = binary z
                in V.reverse $ V.fromList $ Data.List.map helper bin

-- |conversion helper function
toInteger :: F2 -> Integer
toInteger z = let helper a = if a == Bit.fromBool True then 1
                             else 0
                  it rest n = let len = V.length rest
                              in if len > 0 then let el = V.last rest
                                                 in it (V.take (len - 1) rest) (n + (helper el)*2^(len-1))
                                 else n
              in it z 0

-- | the length of an F(2^e)
length :: F2 -> Int
length z = V.length z

-- | is the number even? The last bit decides...
even :: F2 -> Bool
even a = not $ Data.F2.odd a

-- | is the number odd? The last bit decides...
odd :: F2 -> Bool
odd a = Bit.toBool $ V.last a

-- | computing @k@/@f@ mod @m@ by binary inversion of @f@ in @m@
div :: F2 -> F2 -> F2 -> F2
div k f m = mul k $ bininv f m

-- | computing the modular inverse of "@f@ `mod` @m@"
bininv :: F2 -> F2 -> F2
bininv f m = let helper u v g1 g2 | u == Data.F2.fromInteger 1 = g1
                                  | otherwise = let j = (V.length u) - (V.length v)
                                                in if j < 0 
                                                   then helper (elimFalses (v `add` (shift u (-j)))) u (elimFalses (g2 `add` (shift g1 (-j)))) g1
                                                   else helper (elimFalses (u `add` (shift v j))) v (elimFalses (g1 `add` (shift g2 j))) g2
             in helper f m (Data.F2.fromInteger 1) (Data.F2.fromInteger 0)                          
                
-- add (mul (Data.F2.fromInteger 371) (Data.F2.fromInteger 1794)) (mul (Data.F2.fromInteger 203) (Data.F2.fromInteger 2053))
-- Data.F2.toInteger $ bininv (Data.F2.fromInteger 371) (Data.F2.fromInteger 2053)