module Huzzy.TypeTwo.ZSlices.Sets where

import Data.Function
import Data.List
import Huzzy.Base.Sets
import Huzzy.TypeOne.Sets
import Huzzy.TypeTwo.Interval.Sets

data T2ZSet a = T2ZS { zLevels :: Int
                     , zSlices :: [IT2Set a]
                     , zdom    :: [a]
                     }

instance Fuzzy (T2ZSet a) where
    a ?&& b = a { zLevels = zLevels a, zSlices = zipWith (?&&) (zSlices a) (zSlices b) }
    a ?|| b = a { zLevels = zLevels a, zSlices = zipWith (?||) (zSlices a) (zSlices b) }
    fnot a  = a { zLevels = zLevels a, zSlices = map (fnot) (zSlices a) }

instance FSet (T2ZSet a) where
    type Value (T2ZSet a)    = a
    type Support (T2ZSet a)  = [(a,a)]
    type Returned (T2ZSet a) = MF Double
    support s = support (head $ zSlices s)
    hedge d s = s { zSlices = map (hedge d) (zSlices s)}
    x `is` s  = discrete disPairs
                where
                    its      = zSlices s
                    (ls, us) = unzip $ map (x`is`) its
                    zs       = zLevelAxis (length its)
                    -- todo dirty hack to ensure max is returned
                    disPairs = sortBy (flip compare `on` snd ) $ zip ls zs ++ zip us zs


zLevelAxis :: Int -> [Double]
zLevelAxis n = 0 : (count step (n'-1))
                where
                    n' = fromIntegral $ n-1
                    step = 1/n'
                    count s 0 = [s*n']
                    count s z = (s*(n'-z)) : count s (z-1)

contZT2 :: (Enum a, Num a) => a -> a -> a -> [IT2Set a] -> T2ZSet a
contZT2 minB maxB res its = case check of
                                True -> error "Truth values must be in the range [0..1]"
                                False -> case check' of
                                    True -> error "Truth values must be in the range [0..1]"
                                    False ->  T2ZS { zLevels = length its
                                                   , zSlices = its
                                                   , zdom    = domain
                                                   }
                            where
                                (MF lf, MF uf) = (lmf $ head its, umf $ head its)
                                domain = [minB, minB+res .. maxB]
                                check  = any (\x -> x > 1 || x < 0) (map lf domain)
                                check' = any (\x -> x > 1 || x < 0) (map uf domain)

discZT2 :: [a] -> [IT2Set a] -> T2ZSet a
discZT2 dom its = case check of
                    True -> error "Truth values must be in the range [0..1]"
                    False -> case check' of
                            True -> error "Truth values must be in the range [0..1]"
                            False ->  T2ZS { zLevels = length its
                                            , zSlices = its
                                            , zdom    = dom
                                           }
                    where
                        (MF lf, MF uf) = (lmf $ head its, umf $ head its)
                        check  = any (\x -> x > 1 || x < 0) (map lf dom)
                        check' = any (\x -> x > 1 || x < 0) (map uf dom)

unsafeZT2 :: [a] -> [IT2Set a] -> T2ZSet a
unsafeZT2 dom its = T2ZS { zLevels = length its
                         , zSlices = its
                         , zdom    = dom
                         }

cylExtT2 :: T1Set Double -> Int -> T2ZSet Double
cylExtT2 s z = T2ZS { zLevels = z
                    , zSlices = map (\(l, r) -> cylExt l r) lsrs
                    , zdom = []
                    }
                where
                    zs = zLevelAxis z
                    lsrs = map (findCuts s) zs

t2Tri :: (Double, Double) ->
         (Double, Double) ->
         (Double, Double) ->
         Int -> T2ZSet Double
t2Tri (a,a') (b,b') (c,c') z = T2ZS { zLevels = z
                                    , zSlices = base : rc (z-1) stepA stepC
                                    , zdom = dom }
                                where
                                    dom    = [min a a' .. max c c']
                                    base   = unsafeMkIT2 dom (tri a b c) (tri a' b' c')
                                    stepA  = ((a-a')/fromIntegral (z-1))/2
                                    stepC  = ((c-c')/fromIntegral (z-1))/2
                                    rc 0 _ _   = []
                                    rc z sa sc = (unsafeMkIT2
                                        [min (a-sa) (a'-sa) .. max (c-sc) (c'-sc)]
                                        (tri (a-sa) b (c-sc))
                                        ((tri (a'-sa) b' (c'-sc))))
                                        : (rc (z-1) (sa+stepA) (sc+stepC))