module UniqueLogic.ST.Rule (
   -- * Custom rules
   generic2,
   generic3,
   -- * Common rules
   equ, pair, max, add, mul, square, pow,
   ) where

import qualified UniqueLogic.ST.System as Sys

import Data.Monoid (Monoid, )

import qualified Prelude as P
import Prelude hiding (max)


generic2 ::
   (Sys.Var var, Monoid w) =>
   (b -> a) -> (a -> b) ->
   var w s a -> var w s b -> Sys.T w s ()
generic2 f g x y =
   sequence_ $
   Sys.assignment2 f y x :
   Sys.assignment2 g x y :
   []

generic3 ::
   (Sys.Var var, Monoid w) =>
   (b -> c -> a) -> (c -> a -> b) -> (a -> b -> c) ->
   var w s a -> var w s b -> var w s c -> Sys.T w s ()
generic3 f g h x y z =
   sequence_ $
   Sys.assignment3 f y z x :
   Sys.assignment3 g z x y :
   Sys.assignment3 h x y z :
   []


equ ::
   (Sys.Var var, Monoid w) =>
   var w s a -> var w s a -> Sys.T w s ()
equ = generic2 id id

max ::
   (Ord a, Sys.Var var, Monoid w) =>
   var w s a -> var w s a -> var w s a -> Sys.T w s ()
max =
   Sys.assignment3 P.max

{- |
You might be tempted to use the 'pair' rule to collect parameters
for rules with more than three arguments.
This is generally not a good idea since this way you lose granularity.
For building rules with more than three arguments,
please build according assignments with 'Sys.arg' and 'Sys.runApply'
and bundle these assignments to rules.
This is the way, 'generic2' and 'generic3' work.
-}
pair ::
   (Sys.Var var, Monoid w) =>
   var w s a -> var w s b -> var w s (a,b) -> Sys.T w s ()
pair x y xy =
   Sys.assignment3 (,) x y xy >>
   Sys.assignment2 fst xy x >>
   Sys.assignment2 snd xy y

add :: (Num a, Sys.Var var, Monoid w) =>
   var w s a -> var w s a -> var w s a -> Sys.T w s ()
add = generic3 subtract (-) (+)

mul :: (Fractional a, Sys.Var var, Monoid w) =>
   var w s a -> var w s a -> var w s a -> Sys.T w s ()
mul = generic3 (flip (/)) (/) (*)

square :: (Floating a, Sys.Var var, Monoid w) =>
   var w s a -> var w s a -> Sys.T w s ()
square = generic2 sqrt (^(2::Int))

pow :: (Floating a, Sys.Var var, Monoid w) =>
   var w s a -> var w s a -> var w s a -> Sys.T w s ()
pow = generic3 (\x y -> y ** recip x) (flip logBase) (**)