{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE BangPatterns, DeriveFunctor #-}
module Geom2D.CubicBezier.MetaPath
       (unmetaOpen, unmetaClosed, ClosedMetaPath(..), OpenMetaPath (..),
        MetaJoin (..), MetaNodeType (..), Tension (..))
where
import Geom2D
import Geom2D.CubicBezier.Basic
import Data.List
import Text.Printf
import qualified Data.Vector.Unboxed as V
import Geom2D.CubicBezier.Numeric
data OpenMetaPath a = OpenMetaPath [(Point a, MetaJoin a)] (Point a)
                        
                    deriving (Functor, Traversable, Foldable)
data ClosedMetaPath a = ClosedMetaPath [(Point a, MetaJoin a)]
                        
                        
                      deriving (Eq, Functor, Traversable, Foldable)
data MetaJoin a = MetaJoin { metaTypeL :: MetaNodeType a
                           
                           
                           
                           , tensionL :: Tension a
                             
                             
                           , tensionR :: Tension a
                             
                             
                           , metaTypeR :: MetaNodeType a
                             
                             
                           }
                | Controls (Point a) (Point a)
                  
                deriving (Show, Eq, Functor, Traversable, Foldable)
data MetaNodeType a = Open
                    
                    
                    
                    
                    
                  | Curl {curlgamma :: a}
                    
                    
                    
                    
                    
                  | Direction {nodedir :: Point a}
                    
                  deriving (Eq, Show, Functor, Foldable, Traversable)
data Tension a = Tension {tensionValue :: a}
               
               
               
               
             | TensionAtLeast {tensionValue :: a}
               
               
               
             deriving (Eq, Show, Functor, Foldable, Traversable)
instance (Show a, Real a) => Show (ClosedMetaPath a) where
  show (ClosedMetaPath nodes) =
    showPath nodes ++ "cycle"
instance (Show a, Real a) => Show (OpenMetaPath a) where
  show (OpenMetaPath nodes lastpoint) =
    showPath nodes ++ showPoint lastpoint
showPath :: (Show a, Real a) => [(Point a, MetaJoin a)] -> String
showPath = concatMap showNodes
  where
    showNodes (p, Controls u v) =
      showPoint p ++ "..controls " ++ showPoint u ++ "and " ++ showPoint v ++ ".."
    showNodes (p, MetaJoin m1 t1 t2 m2) =
      showPoint p ++ typename m1 ++ ".." ++ tensions ++ typename m2
      where
        tensions
          | t1 == t2 && t1 == Tension 1 = ""
          | t1 == t2 = printf "tension %s.." (showTension t1)
          | otherwise = printf "tension %s and %s.."
                        (showTension t1) (showTension t2)
    showTension (TensionAtLeast t) = printf "atleast %.3f" (realToFrac t :: Double) :: String
    showTension (Tension t) = printf "%.3f" (realToFrac t :: Double) :: String
    typename Open = ""
    typename (Curl g) = printf "{curl %.3f}" (realToFrac g :: Double) :: String
    typename (Direction dir) = printf "{%s}" (showPoint dir) :: String
showPoint :: Show a => Point a -> String
showPoint (Point x y) = "(" ++ show x ++ ", " ++ show y ++ ")"
unmetaOpen :: OpenMetaPath Double -> OpenPath Double
unmetaOpen (OpenMetaPath nodes endpoint) =
  unmetaOpen' (flip sanitize endpoint $
              removeEmptyDirs nodes)
  endpoint
unmetaOpen' :: [(DPoint, MetaJoin Double)] -> DPoint -> OpenPath Double
unmetaOpen' nodes endpoint =
  let subsegs = openSubSegments nodes endpoint
      path = joinSegments $ map unmetaSubSegment subsegs
  in OpenPath path endpoint
unmetaClosed :: ClosedMetaPath Double -> ClosedPath Double
unmetaClosed (ClosedMetaPath nodes) =
  case spanList bothOpen (removeEmptyDirs nodes) of
    ([], []) -> error "empty metapath"
    (l, []) -> if fst (last l) == fst (head l)
               then unmetaAsOpen l []
               else unmetaCyclic l
    (l, m:n) ->
      if leftOpen (m:n)
      then unmetaAsOpen (l++[m]) n
      else unmetaAsOpen l (m:n)
unmetaAsOpen :: [(DPoint, MetaJoin Double)] -> [(DPoint, MetaJoin Double)] -> ClosedPath Double
unmetaAsOpen l m = ClosedPath (l'++m')
  where n = length m
        OpenPath o _ =
          unmetaOpen' (sanitizeCycle (m++l)) (fst $ head (m ++l))
        (m',l') = splitAt n o
openSubSegments :: [(DPoint, MetaJoin Double)] -> DPoint -> [OpenMetaPath Double]
openSubSegments [] _   = []
openSubSegments l lastPoint =
  case spanList (not . breakPoint) l of
    (m, n:o) ->
      let point = case o of
            ((p,_):_) -> p
            _ -> lastPoint
      in OpenMetaPath (m ++ [n]) point :
         openSubSegments o lastPoint
    _ -> error "openSubSegments': unexpected end of segments"
joinSegments :: [OpenPath Double] -> [(DPoint, PathJoin Double)]
joinSegments = concatMap nodes
  where nodes (OpenPath n _) = n
        
unmetaCyclic :: [(DPoint, MetaJoin Double)] -> ClosedPath Double
unmetaCyclic nodes =
  let points = map fst nodes
      chords = zipWith (^-^) (tail $ cycle points) points
      tensionsA = map (tensionL . snd) nodes
      tensionsB = map (tensionR . snd) nodes
      turnAngles = zipWith turnAngle chords (tail $ cycle chords)
      thetas = solveCyclicTriD2 $
               eqsCycle tensionsA
               points
               tensionsB
               turnAngles
      phis = zipWith (\x y -> -(x+y)) turnAngles (tail $ cycle thetas)
  in ClosedPath $ zip points $
     zipWith6 unmetaJoin points (tail $ cycle points)
     thetas phis tensionsA tensionsB
unmetaSubSegment :: OpenMetaPath Double -> OpenPath Double
unmetaSubSegment (OpenMetaPath [(p, Controls u v)] q) =
  OpenPath [(p, JoinCurve u v)] q
unmetaSubSegment (OpenMetaPath nodes lastpoint) =
  let points = map fst nodes ++ [lastpoint]
      joins = map snd nodes
      chords = zipWith (^-^) (tail points) points
      tensionsA = map tensionL joins
      tensionsB = map tensionR joins
      turnAngles = zipWith turnAngle chords (tail chords) ++ [0]
      thetas = solveTriDiagonal2 $
               eqsOpen points joins chords turnAngles
               (map tensionValue tensionsA)
               (map tensionValue tensionsB)
      phis = zipWith (\x y -> -(x+y)) turnAngles (tail thetas)
      pathjoins =
        zipWith6 unmetaJoin points (tail points) thetas phis tensionsA tensionsB
  in OpenPath (zip points pathjoins) lastpoint
removeEmptyDirs :: [(DPoint, MetaJoin Double)] -> [(DPoint, MetaJoin Double)]
removeEmptyDirs = map remove
  where remove (p, MetaJoin (Direction (Point 0 0)) tl tr jr) = remove (p, MetaJoin Open tl tr jr)
        remove (p, MetaJoin jl tl tr (Direction (Point 0 0))) = (p, MetaJoin jl tl tr Open)
        remove j = j
bothOpen :: [(DPoint, MetaJoin Double)] -> Bool
bothOpen ((p, MetaJoin Open _ _ Open):(q, _):_) = p /= q
bothOpen [(_, MetaJoin Open _ _ Open)] = True
bothOpen _ = False
leftOpen :: [(DPoint, MetaJoin Double)] -> Bool
leftOpen ((p, MetaJoin Open _ _ _):(q, _):_) = p /= q
leftOpen [(_, MetaJoin Open _ _ _)] = True
leftOpen _ = False
sanitizeCycle :: [(DPoint, MetaJoin Double)] -> [(DPoint, MetaJoin Double)]
sanitizeCycle [] = []
sanitizeCycle l = take n $ tail $
                  sanitize (drop (n-1) $ cycle l) (fst $ head l)
  where n = length l
sanitize :: [(DPoint, MetaJoin Double)] -> DPoint -> [(DPoint, MetaJoin Double)]
sanitize [] _ = []
sanitize [(p, MetaJoin m t1 t2 Open)] r =
  if p == r
  then [(p, Controls p p)]
  else [(p, MetaJoin m t1 t2 (Curl 1))]
sanitize ((p, MetaJoin m1 tl tr Open): rest@(node2:node3:_)) r
  | (fst node2 == fst node3) && (metaTypeL (snd node2) == Open) =
    (p, MetaJoin m1 tl tr (Curl 1)) : sanitize rest r
sanitize (node1@(p, MetaJoin m1 tl tr m2): node2@(q, MetaJoin n1 sl sr n2): rest) r
  | p == q =
    
    
    let newnode = (p, Controls p p)
    in case (m2, n1) of
      (Curl g, Open) -> 
        newnode : sanitize ((q, MetaJoin (Curl g) sl sr n2):rest) r
      (Direction dir, Open) ->   
        newnode : sanitize ((q, MetaJoin (Direction dir) sl sr n2) : rest) r
      (Open, Open) ->   
        newnode : sanitize ((q, MetaJoin (Curl 1) sl sr n2) : rest) r
      _ -> newnode : sanitize (node2:rest) r
  | otherwise =
    case (m2, n1) of
      (Curl g, Open) -> 
        node1 : sanitize ((q, MetaJoin (Curl g) sl sr n2):rest) r
      (Open, Curl g) -> 
        (p, MetaJoin m1 tl tr (Curl g)) : sanitize (node2:rest) r
      (Direction dir, Open) ->   
        node1 : sanitize ((q, MetaJoin (Direction dir) sl sr n2) : rest) r
      (Open, Direction dir) ->   
        (p, MetaJoin m1 tl tr (Direction dir)) : sanitize (node2:rest) r
      _ -> node1 : sanitize (node2:rest) r
sanitize ((p, m): (q, n): rest) r =
  case (m, n) of
    (Controls _u v, MetaJoin Open t1 t2 mt2) 
      | q == v    -> (p, m) : sanitize ((q, MetaJoin (Curl 1) t1 t2 mt2): rest) r
      | otherwise -> (p, m) : sanitize ((q, MetaJoin (Direction (q^-^v)) t1 t2 mt2): rest) r
    (MetaJoin mt1 tl tr Open, Controls u _v) 
      | u == p    -> (p, MetaJoin mt1 tl tr (Curl 1)) : sanitize ((q, n): rest) r
      | otherwise -> (p, MetaJoin mt1 tl tr (Direction (u^-^p))) : sanitize ((q, n): rest) r
    _ -> (p, m) : sanitize ((q, n) : rest) r
sanitize (n:l) r = n:sanitize l r
spanList :: ([a] -> Bool) -> [a] -> ([a], [a])
spanList _ xs@[] =  (xs, xs)
spanList p xs@(x:xs')
  | p xs =  let (ys,zs) = spanList p xs' in (x:ys,zs)
  | otherwise    =  ([],xs)
breakPoint :: [(DPoint, MetaJoin Double)] -> Bool
breakPoint ((_,  MetaJoin _ _ _ Open):(_, MetaJoin Open _ _ _):_) = False
breakPoint _ = True
solveTriDiagonal2 :: [(Double, Double, Double, Double)] -> [Double]
solveTriDiagonal2 [] = error "solveTriDiagonal: not enough equations"
solveTriDiagonal2 ((_, b0, c0, d0): rows) =
  V.toList $ solveTriDiagonal (b0, c0, d0) (V.fromList rows)
solveCyclicTriD2 :: [(Double, Double, Double, Double)] -> [Double]
solveCyclicTriD2 = V.toList . solveCyclicTriD . V.fromList
turnAngle :: DPoint -> DPoint -> Double
turnAngle (Point 0 0) _ = 0
turnAngle (Point x y) q = vectorAngle $ rotateVec p $* q
  where p = Point x (-y)
zipNext :: [b] -> [(b, b)]
zipNext [] = []
zipNext l = zip l (tail $ cycle l)
eqsCycle :: [Tension Double] -> [DPoint] -> [Tension Double]
         -> [Double] -> [(Double, Double, Double, Double)]
eqsCycle tensionsA points tensionsB turnAngles =
  zipWith4 eqTension
  (zipNext (map tensionValue tensionsA))
  (zipNext dists)
  (zipNext turnAngles)
  (zipNext (map tensionValue tensionsB))
  where
    dists = zipWith vectorDistance points (tail $ cycle points)
eqsOpen :: [DPoint] -> [MetaJoin Double] -> [DPoint] -> [Double]
        -> [Double] -> [Double] -> [(Double, Double, Double, Double)]
eqsOpen _ [MetaJoin mt1 t1 t2 mt2] [delta] _ _ _ =
  let replaceType Open = Curl 1
      replaceType t = t
  in case (replaceType mt1, replaceType mt2) of
    (Curl g, Direction dir) ->
      [eqCurl0 g (tensionValue t1) (tensionValue t2) 0,
       (0, 1, 0, turnAngle delta dir)]
    (Direction dir, Curl g) ->
      [(0, 1, 0, turnAngle delta dir),
       eqCurlN g (tensionValue t1) (tensionValue t2)]
    (Direction dir, Direction dir2) ->
      [(0, 1, 0, turnAngle delta dir),
       (0, 1, 0, turnAngle delta dir2)]
    (Curl _, Curl _) ->
      [(0, 1, 0, 0), (0, 1, 0, 0)]
    _ -> error "illegal end of open path"
eqsOpen points joins chords turnAngles tensionsA tensionsB =
  eq0 : restEquations joins tensionsA dists turnAngles tensionsB
  where
    dists = zipWith vectorDistance points (tail points)
    eq0 = case head joins of
      (MetaJoin (Curl g) _ _ _) -> eqCurl0 g (head tensionsA) (head tensionsB) (head turnAngles)
      (MetaJoin (Direction dir) _ _ _) -> (0, 1, 0, turnAngle (head chords) dir)
      (MetaJoin Open _ _ _) -> eqCurl0 1 (head tensionsA) (head tensionsB) (head turnAngles)
      (Controls _ _) -> error "eqsOpen: illegal join"
    restEquations [lastnode] (tensionA:_) _ _ (tensionB:_) =
      case lastnode of
        MetaJoin _ _ _ (Curl g) -> [eqCurlN g tensionA tensionB]
        MetaJoin _ _ _ Open -> [eqCurlN 1 tensionA tensionB]
        MetaJoin _ _ _ (Direction dir) -> [(0, 1, 0, turnAngle (last chords) dir)]
        (Controls _ _) -> error "eqsOpen: illegal join"
    restEquations (_:othernodes) (tensionA:restTA) (d:restD) (turn:restTurn) (tensionB:restTB) =
      eqTension (tensionA, head restTA) (d, head restD) (turn, head restTurn) (tensionB, head restTB) :
      restEquations othernodes restTA restD restTurn restTB
    restEquations _ _ _ _ _ = error "eqsOpen: illegal rest equations"
eqTension :: (Double, Double) -> (Double, Double)
          -> (Double, Double) -> (Double, Double)
          -> (Double, Double, Double, Double)
eqTension (tensionA', tensionA) (dist', dist) (psi', psi) (tensionB', tensionB) =
  (a, b+c, d, -b*psi' - d*psi)
  where
    a = tensionB' * tensionB' / (tensionA' * dist')
    b = (3 - 1/tensionA') * tensionB' * tensionB' / dist'
    c = (3 - 1/tensionB) * tensionA * tensionA / dist
    d = tensionA * tensionA / (tensionB * dist)
eqCurl0 :: Double -> Double -> Double -> Double
        -> (Double, Double, Double, Double)
eqCurl0 gamma tensionA tensionB psi = (0, c, d, r)
  where
    c = chi/tensionA + 3 - 1/tensionB
    d = (3 - 1/tensionA)*chi + 1/tensionB
    chi = gamma*tensionB*tensionB / (tensionA*tensionA)
    r = -d*psi
eqCurlN :: Double -> Double -> Double
        -> (Double, Double, Double, Double)
eqCurlN gamma tensionA tensionB = (a, b, 0, 0)
  where
    a = (3 - 1/tensionB)*chi + 1/tensionA
    b = chi/tensionB + 3 - 1/tensionA
    chi = gamma*tensionA*tensionA / (tensionB*tensionB)
unmetaJoin :: DPoint -> DPoint -> Double -> Double -> Tension Double
           -> Tension Double -> PathJoin Double
unmetaJoin !z0 !z1 !theta !phi !alpha !beta
  | abs phi < 1e-4 && abs theta < 1e-4 = JoinLine
  | otherwise = JoinCurve u v
  where Point dx dy = z1^-^z0
        bounded = (sf <= 0 && st <= 0 && sf <= 0) ||
                  (sf >= 0 && st >= 0 && sf >= 0)
        rr' = velocity st sf ct cf alpha
        ss' = velocity sf st cf ct beta
        stf = st*cf + sf*ct 
        st = sin theta
        sf = sin phi
        ct = cos theta
        cf = cos phi
        rr = case alpha of
          TensionAtLeast _ | bounded ->
            min rr' (sf/stf)
          _ -> rr'
        ss = case beta of
          TensionAtLeast _ | bounded ->
            min ss' (st/stf)
          _ -> ss'
        
        u = z0 ^+^ rr *^ Point (dx*ct - dy*st) (dy*ct + dx*st)
        
        v = z1 ^-^ ss *^ Point (dx*cf + dy*sf) (dy*cf - dx*sf)
constant1, constant2, sqrt2 :: Double
constant1 = 3*(sqrt 5 - 1)/2
constant2 = 3*(3 - sqrt 5)/2
sqrt2 = sqrt 2
velocity :: Double -> Double -> Double
         -> Double -> Tension Double -> Double
velocity st sf ct cf t =
  min 4 $
  (2 + sqrt2 * (st - sf/16)*(sf - st/16)*(ct - cf)) /
  ((3 + constant1*ct + constant2*cf) * tensionValue t)