{-# language NoMonomorphismRestriction #-}
{-# language ScopedTypeVariables #-}
{-# language DataKinds #-}
{-# language ForeignFunctionInterface #-}
{-# language RecursiveDo #-}
module CodeGen.X86.Utils where

import           Data.Char
import           Data.Monoid
import           Control.Monad
import           Foreign
import           System.Environment
import           Debug.Trace

import           CodeGen.X86.Asm
import           CodeGen.X86.CodeGen
import           CodeGen.X86.CallConv

-------------------------------------------------------------- derived constructs

-- | execute code unless condition is true
unless :: Condition -> CodeM a -> CodeM ()
unless Condition
cc CodeM a
x = mdo
  Condition -> Label -> CodeM ()
j Condition
cc Label
l
  CodeM a
x
  Label
l <- CodeM Label
label
  forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | do while loop construction
doWhile :: Condition -> CodeM a -> CodeM ()
doWhile Condition
cc CodeM a
x = do
  Label
l <- CodeM Label
label
  CodeM a
x
  Condition -> Label -> CodeM ()
j Condition
cc Label
l

-- | if-then-else
if_ :: Condition -> CodeM a -> CodeM a -> CodeM ()
if_ Condition
cc CodeM a
a CodeM a
b = mdo
  Condition -> Label -> CodeM ()
j (Condition -> Condition
N Condition
cc) Label
l1
  CodeM a
a
  Label -> CodeM ()
jmp Label
l2
  Label
l1 <- CodeM Label
label
  CodeM a
b
  Label
l2 <- CodeM Label
label
  forall (m :: * -> *) a. Monad m => a -> m a
return ()

leaData :: Operand 'RW s -> a -> CodeM ()
leaData Operand 'RW s
r a
d = mdo
  forall {s :: Size} {s' :: Size}.
(IsSize s, IsSize s') =>
Operand 'RW s -> Operand 'RW s' -> CodeM ()
lea Operand 'RW s
r forall a b. (a -> b) -> a -> b
$ forall (rw :: Access). Label -> Operand rw 'S8
ipRel8 Label
l1
  Label -> CodeM ()
jmp Label
l2
  Label
l1 <- CodeM Label
label
  Bytes -> CodeM ()
db forall a b. (a -> b) -> a -> b
$ forall a. HasBytes a => a -> Bytes
toBytes a
d
  Label
l2 <- CodeM Label
label
  forall (m :: * -> *) a. Monad m => a -> m a
return ()

------------------------------------------------------------------------------ 

foreign import ccall "static stdio.h &printf" printf :: FunPtr a

------------------------------------------------------------------------------ 
-- * utils

mov' :: forall s s' r . IsSize s' => Operand RW s -> Operand r s' -> Code
mov' :: forall (s :: Size) (s' :: Size) (r :: Access).
IsSize s' =>
Operand 'RW s -> Operand r s' -> CodeM ()
mov' Operand 'RW s
a = forall {s :: Size} {r :: Access}.
IsSize s =>
Operand 'RW s -> Operand r s -> CodeM ()
mov (forall (s' :: Size) (s :: Size).
IsSize s' =>
Operand 'RW s -> Operand 'RW s'
resizeOperand Operand 'RW s
a :: Operand RW s')

newtype CString = CString String

instance HasBytes CString where
  toBytes :: CString -> Bytes
toBytes (CString String
cs) = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. HasBytes a => a -> Bytes
toBytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Word8) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String
cs forall a. [a] -> [a] -> [a]
++ String
"\0")

-- | we should implement PUSHA and POPA later
{- HLINT ignore all_regs_except_rsp -}
all_regs_except_rsp :: [Operand rw S64]
all_regs_except_rsp :: forall (rw :: Access). [Operand rw 'S64]
all_regs_except_rsp =
  [ forall (c :: Size -> *). FromReg c => c 'S64
rax
  , forall (c :: Size -> *). FromReg c => c 'S64
rcx
  , forall (c :: Size -> *). FromReg c => c 'S64
rdx
  , forall (c :: Size -> *). FromReg c => c 'S64
rbx
  , {- rsp, -}
    forall (c :: Size -> *). FromReg c => c 'S64
rbp
  , forall (c :: Size -> *). FromReg c => c 'S64
rsi
  , forall (c :: Size -> *). FromReg c => c 'S64
rdi
  , forall (c :: Size -> *). FromReg c => c 'S64
r8
  , forall (c :: Size -> *). FromReg c => c 'S64
r9
  , forall (c :: Size -> *). FromReg c => c 'S64
r10
  , forall (c :: Size -> *). FromReg c => c 'S64
r11
  , forall (c :: Size -> *). FromReg c => c 'S64
r12
  , forall (c :: Size -> *). FromReg c => c 'S64
r13
  , forall (c :: Size -> *). FromReg c => c 'S64
r14
  , forall (c :: Size -> *). FromReg c => c 'S64
r15
  ]

{- HLINT ignore push_all -}
push_all :: CodeM ()
push_all = forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ forall {r :: Access}. Operand r 'S64 -> CodeM ()
push Operand Any 'S64
r | Operand Any 'S64
r <- forall (rw :: Access). [Operand rw 'S64]
all_regs_except_rsp ]

{- HLINT ignore pop_all -}
pop_all :: CodeM ()
pop_all = forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ Operand 'RW 'S64 -> CodeM ()
pop Operand 'RW 'S64
r | Operand 'RW 'S64
r <- forall a. [a] -> [a]
reverse forall (rw :: Access). [Operand rw 'S64]
all_regs_except_rsp ]

traceReg :: IsSize s => String -> Operand RW s -> Code
traceReg :: forall (s :: Size). IsSize s => String -> Operand 'RW s -> CodeM ()
traceReg String
d Operand 'RW s
r = do
  CodeM ()
pushf
  CodeM ()
push_all
  forall (s :: Size) (s' :: Size) (r :: Access).
IsSize s' =>
Operand 'RW s -> Operand r s' -> CodeM ()
mov' forall (c :: Size -> *). FromReg c => c 'S64
arg2 Operand 'RW s
r
  forall {s :: Size} {a}.
(IsSize s, HasBytes a) =>
Operand 'RW s -> a -> CodeM ()
leaData forall (c :: Size -> *). FromReg c => c 'S64
arg1 (String -> CString
CString forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Operand 'RW s
r forall a. [a] -> [a] -> [a]
++ String
" = %" forall a. [a] -> [a] -> [a]
++ String
s forall a. [a] -> [a] -> [a]
++ String
d forall a. [a] -> [a] -> [a]
++ String
"\n")
  forall {s :: Size} {r :: Access}.
IsSize s =>
Operand 'RW s -> Operand r s -> CodeM ()
xor_ forall (c :: Size -> *). FromReg c => c 'S64
rax forall (c :: Size -> *). FromReg c => c 'S64
rax
  forall a. Operand 'RW 'S64 -> FunPtr a -> CodeM ()
callFun forall (c :: Size -> *). FromReg c => c 'S64
r11 forall a. FunPtr a
printf
  CodeM ()
pop_all
  CodeM ()
popf
 where
  s :: String
s = case forall a. HasSize a => a -> Size
size Operand 'RW s
r of
    Size
S8  -> String
"hh"
    Size
S16 -> String
"h"
    Size
S32 -> String
""
    Size
S64 -> String
"l"