{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE NamedFieldPuns       #-}
{-# LANGUAGE OverlappingInstances #-}
{-# LANGUAGE RecordWildCards      #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Language.ArrayForth.Synthesis where

import           Control.Arrow                   (first)
import           Control.Monad.Random            (Random, random, randomR)

import           Data.Function                   (on)
import           Data.Functor                    ((<$>))
import           Data.List                       (elemIndices, genericLength, (\\))
import           Data.Monoid                     (Monoid (..))

import           Language.ArrayForth.Distance
import           Language.ArrayForth.Interpreter
import           Language.ArrayForth.Opcode
import           Language.ArrayForth.Program
import           Language.ArrayForth.State

import           Language.Synthesis.Distribution (Distr (..), mix,
                                                  negativeInfinity, randInt,
                                                  uniform)
import           Language.Synthesis.Mutations    hiding (mix)
import qualified Language.Synthesis.Mutations    as M
import           Language.Synthesis.Synthesis    (Score (..))

import           Text.Printf

-- | A score type that contains a correctness value and a performance
-- value.
data DefaultScore = DefaultScore Double Double deriving (Ord, Eq)

instance Score DefaultScore where
  toScore (DefaultScore correctness performance) = correctness + 0.1 * performance

instance Show DefaultScore where show (DefaultScore a b) = printf "<%.2f, %.2f>" a b

instance Monoid DefaultScore where
  mempty = DefaultScore 0 0
  DefaultScore c₁ p₁ `mappend` DefaultScore c₂ p₂ = DefaultScore (c₁ + c₂) (p₁ + p₂)

-- | Creates an evaluation function from a spec, a set of inputs and a
-- function for comparing program traces.
trace :: Monoid score => Program -> [State] -> (Trace -> Trace -> score) -> Program -> score
trace spec inputs score program = mconcat $ zipWith score specs throttled
  where specs   = stepProgram . load spec <$> inputs
        results = stepProgram . load program <$> inputs
        throttled = zipWith go specs results
          where go spec' trace' = either id id $ throttle (length spec') trace'

-- | Using a given correctness measure, produce a score also
-- containing performance.
withPerformance :: Score s => (Trace -> Trace -> s) -> (Trace -> Trace -> DefaultScore)
withPerformance score spec result = DefaultScore (toScore $ score spec res) performance
  where res = either id id $ throttle (length spec) result
        performance = case throttle (length spec) result of
          Right res' -> (countTime spec - countTime res') / 10
          Left  res' -> countTime spec - countTime res' - 1e10

-- | Given a specification program and some inputs, evaluate a program
-- against the specification for both performance and
-- correctness. Normalize the score based on the number of test cases.
evaluate :: Program -> [State] -> (State -> State -> Distance) -> Program -> DefaultScore
evaluate spec inputs distance =
  normalize . trace spec inputs (withPerformance (distance `on` last))
  where normalize (DefaultScore c p) = DefaultScore (c / len) (p / len)
        len = genericLength inputs

-- I need this so that I can get a distribution over Forth words.
instance Random F18Word where
  randomR (start, end) gen =
    first fromInteger $ randomR (fromIntegral start, fromIntegral end) gen
  random = randomR (0, maxBound)

-- | The default distribution of instructions. For now, we do not
-- support any sort of jumps. All the other possible instructions
-- along with constant numbers and unused slots are equally
-- likely. The numeric value of constants is currently a uniform
-- distribution over 18-bit words.
defaultOps :: Distr Instruction
defaultOps = mix [(constants, 1.0), (uniform [Unused], 1.0),
                  (uniform instrs, genericLength instrs)]
  where instrs = map Opcode $ filter (not . isJump) opcodes \\ [Unext, Nop]
        constants = let Distr {..} = randInt (0, maxBound)
                        logProb (Number n) = logProbability n
                        logProb _          = negativeInfinity in
                    Distr { sample = Number <$> sample
                          , logProbability = logProb }

pairs :: [(Instruction, Instruction)]
pairs = map (\ (a, b) -> (Opcode a, Opcode b))
        [ (SetA, ReadA)
        , (Push, Pop)
        , (Over, Drop) ]

removePairs :: Distr Instruction -> Mutation Program
removePairs instrDistr program =
  mix [(mutateInstructionsAt instrDistr is program, 1.0) | is <- findPairs program]
  where findPairs program' = do (a, b) <- pairs
                                indexA <- elemIndices a program'
                                indexB <- elemIndices b program'
                                return [indexA, indexB]

-- | The default mutations to try. For now, this will either change an
-- instruction or swap two instructions in the program, with equal
-- probability.
defaultMutations :: Mutation Program
defaultMutations = M.mix [(mutateInstruction defaultOps, 1), (swapInstructions, 1)]