module Geometry.Icosphere
  ( generateIndexed

  , icofaces
  , icopoints
  ) where

import RIO

import Geomancy.Vec3 (Vec3, vec3)
import Geomancy.Vec3 qualified as Vec3
import Data.Vector qualified as Vector
import Data.Vector.Mutable qualified as Mutable
import RIO.Vector.Partial ((!))
import RIO.Map qualified as Map
import RIO.Vector.Storable qualified as Storable
import Control.Monad.State.Strict (get, put, runState)
import Vulkan.NamedType ((:::))

import Geometry.Face (Face(..))

generateIndexed
  :: ( Fractional scale
     , Storable pos
     , Storable vertexAttr
     )
  => "subdivisions"  ::: Natural
  -> "initial"       ::: (Vec3 -> pointAttr)
  -> "midpoint"      ::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr)
  -> "vertex"        ::: (Vector (Vec3, pointAttr) -> [Face Int] -> Vector (pos, vertexAttr))
  -> "model vectors" ::: (Storable.Vector pos, Storable.Vector vertexAttr, Storable.Vector Word32)
generateIndexed :: ("subdivisions" ::: Natural)
-> ("initial" ::: (Vec3 -> pointAttr))
-> ("midpoint"
    ::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr))
-> ("vertex"
    ::: (Vector (Vec3, pointAttr)
         -> [Face Int] -> Vector (pos, vertexAttr)))
-> "model vectors"
   ::: (Vector pos, Vector vertexAttr, Vector Word32)
generateIndexed "subdivisions" ::: Natural
details "initial" ::: (Vec3 -> pointAttr)
mkInitialAttrs "midpoint"
::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr)
mkMidpointAttrs "vertex"
::: (Vector (Vec3, pointAttr)
     -> [Face Int] -> Vector (pos, vertexAttr))
mkVertices =
  ( Vector pos -> Vector pos
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
Storable.convert Vector pos
pv
  , Vector vertexAttr -> Vector vertexAttr
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
Storable.convert Vector vertexAttr
av
  , [Word32] -> Vector Word32
forall a. Storable a => [a] -> Vector a
Storable.fromList [Word32]
iv
  )
  where
    (Vector pos
pv, Vector vertexAttr
av) = Vector (pos, vertexAttr) -> (Vector pos, Vector vertexAttr)
forall a b. Vector (a, b) -> (Vector a, Vector b)
Vector.unzip (Vector (pos, vertexAttr) -> (Vector pos, Vector vertexAttr))
-> Vector (pos, vertexAttr) -> (Vector pos, Vector vertexAttr)
forall a b. (a -> b) -> a -> b
$ "vertex"
::: (Vector (Vec3, pointAttr)
     -> [Face Int] -> Vector (pos, vertexAttr))
mkVertices Vector (Vec3, pointAttr)
finalPoints [Face Int]
faces

    iv :: [Word32]
iv = do
      Face Int
face <- [Face Int]
faces
      Int
vert <- Face Int -> [Int]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Face Int
face
      pure $ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
vert

    ([Face Int]
faces, (Map (Int, Int) Int
_midpoints, Vector (Vec3, pointAttr)
finalPoints, Int
_finalPointsCount)) =
      State
  (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
-> ([Face Int],
    (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int))
forall s a. State s a -> s -> (a, s)
runState
        ([Face Int]
-> ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural)
-> State
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
go [Face Int]
icofaces "subdivisions" ::: Natural
details "subdivisions" ::: Natural
details)
        ( Map (Int, Int) Int
forall a. Monoid a => a
mempty
        , Vector (Vec3, pointAttr)
initialPoints
        , Vector Vec3 -> Int
forall a. Vector a -> Int
Vector.length Vector Vec3
icopoints
        )

    maxPoints :: Int
maxPoints = [Face Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Face Int]
icofaces Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
4 Int -> ("subdivisions" ::: Natural) -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ "subdivisions" ::: Natural
details) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8

    initialPoints :: Vector (Vec3, pointAttr)
initialPoints = (forall s. ST s (MVector s (Vec3, pointAttr)))
-> Vector (Vec3, pointAttr)
forall a. (forall s. ST s (MVector s a)) -> Vector a
Vector.create do
      MVector s (Vec3, pointAttr)
v <- Int -> ST s (MVector (PrimState (ST s)) (Vec3, pointAttr))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
Mutable.new Int
maxPoints
      (Int -> (Vec3, pointAttr) -> ST s ())
-> Vector (Vec3, pointAttr) -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m ()
Vector.imapM_
        (MVector (PrimState (ST s)) (Vec3, pointAttr)
-> Int -> (Vec3, pointAttr) -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
Mutable.unsafeWrite MVector s (Vec3, pointAttr)
MVector (PrimState (ST s)) (Vec3, pointAttr)
v)
        ((Vec3 -> (Vec3, pointAttr))
-> Vector Vec3 -> Vector (Vec3, pointAttr)
forall a b. (a -> b) -> Vector a -> Vector b
Vector.map (Vec3 -> Vec3
forall a. a -> a
id (Vec3 -> Vec3)
-> ("initial" ::: (Vec3 -> pointAttr)) -> Vec3 -> (Vec3, pointAttr)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& "initial" ::: (Vec3 -> pointAttr)
mkInitialAttrs) Vector Vec3
icopoints)
      pure MVector s (Vec3, pointAttr)
v

    go :: [Face Int]
-> ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural)
-> State
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
go [Face Int]
curFaces "subdivisions" ::: Natural
maxLevel "subdivisions" ::: Natural
curLevel = do
      -- traceShowM $ "Inflating level " <> textShow (maxLevel - curLevel)
      case "subdivisions" ::: Natural
curLevel of
        "subdivisions" ::: Natural
0 ->
          [Face Int]
-> State
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Face Int]
curFaces
        "subdivisions" ::: Natural
_ -> do
          let scale :: scale
scale = ("subdivisions" ::: Natural) -> scale
forall a b. (Integral a, Num b) => a -> b
fromIntegral "subdivisions" ::: Natural
curLevel scale -> scale -> scale
forall a. Fractional a => a -> a -> a
/ ("subdivisions" ::: Natural) -> scale
forall a b. (Integral a, Num b) => a -> b
fromIntegral "subdivisions" ::: Natural
maxLevel
          [[Face Int]]
next <- (Face Int
 -> State
      (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int])
-> [Face Int]
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
     Identity
     [[Face Int]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (scale
-> Face Int
-> State
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
subdivideFace scale
scale) [Face Int]
curFaces
          [Face Int]
-> ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural)
-> State
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
go ([[Face Int]] -> [Face Int]
forall a. Monoid a => [a] -> a
mconcat [[Face Int]]
next) "subdivisions" ::: Natural
maxLevel ("subdivisions" ::: Natural
curLevel ("subdivisions" ::: Natural)
-> ("subdivisions" ::: Natural) -> "subdivisions" ::: Natural
forall a. Num a => a -> a -> a
- "subdivisions" ::: Natural
1)

    subdivideFace :: scale
-> Face Int
-> State
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) [Face Int]
subdivideFace scale
scale (Face Int
a Int
b Int
c) = do
      (Map (Int, Int) Int
mids, Vector (Vec3, pointAttr)
points, Int
numPoints) <- StateT
  (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
  Identity
  (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
forall s (m :: * -> *). MonadState s m => m s
get

      let
        extras :: Vector (Vec3, pointAttr)
extras = Vector (Vec3, pointAttr)
forall a. Monoid a => a
mempty
        (Map (Int, Int) Int
midsAB, Vector (Vec3, pointAttr)
extrasAB, Int
ab) = scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
mids   Vector (Vec3, pointAttr)
extras   Vector (Vec3, pointAttr)
points Int
numPoints (Int
a, Int
b)
        (Map (Int, Int) Int
midsBC, Vector (Vec3, pointAttr)
extrasBC, Int
bc) = scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
midsAB Vector (Vec3, pointAttr)
extrasAB Vector (Vec3, pointAttr)
points Int
numPoints (Int
b, Int
c)
        (Map (Int, Int) Int
midsCA, Vector (Vec3, pointAttr)
extrasCA, Int
ca) = scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
midsBC Vector (Vec3, pointAttr)
extrasBC Vector (Vec3, pointAttr)
points Int
numPoints (Int
c, Int
a)

      (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
-> StateT
     (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int) Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
        ( Map (Int, Int) Int
midsCA
        , (forall s. ST s (Vector (Vec3, pointAttr)))
-> Vector (Vec3, pointAttr)
forall a. (forall s. ST s a) -> a
runST do
            MVector s (Vec3, pointAttr)
old <- Vector (Vec3, pointAttr)
-> ST s (MVector (PrimState (ST s)) (Vec3, pointAttr))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
Vector.unsafeThaw Vector (Vec3, pointAttr)
points

            (Int -> (Vec3, pointAttr) -> ST s ())
-> Vector (Vec3, pointAttr) -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m ()
Vector.imapM_
              ( \Int
i (Vec3, pointAttr)
point ->
                  MVector (PrimState (ST s)) (Vec3, pointAttr)
-> Int -> (Vec3, pointAttr) -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
Mutable.unsafeWrite MVector s (Vec3, pointAttr)
MVector (PrimState (ST s)) (Vec3, pointAttr)
old (Int
numPoints Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (Vec3, pointAttr)
point
              )
              Vector (Vec3, pointAttr)
extrasCA

            MVector (PrimState (ST s)) (Vec3, pointAttr)
-> ST s (Vector (Vec3, pointAttr))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze MVector s (Vec3, pointAttr)
MVector (PrimState (ST s)) (Vec3, pointAttr)
old
        , Int
numPoints Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Vector (Vec3, pointAttr) -> Int
forall a. Vector a -> Int
Vector.length Vector (Vec3, pointAttr)
extrasCA
        )

      pure
        [ Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
ab Int
bc Int
ca
        , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
ca Int
a Int
ab
        , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
ab Int
b Int
bc
        , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
bc Int
c Int
ca
        ]

    midpoint :: scale
-> Map (Int, Int) Int
-> Vector (Vec3, pointAttr)
-> Vector (Vec3, pointAttr)
-> Int
-> (Int, Int)
-> (Map (Int, Int) Int, Vector (Vec3, pointAttr), Int)
midpoint scale
scale Map (Int, Int) Int
mids Vector (Vec3, pointAttr)
extras Vector (Vec3, pointAttr)
points Int
numPoints (Int, Int)
parents =
      case (Int, Int) -> Map (Int, Int) Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Int, Int)
parents Map (Int, Int) Int
mids of
        Just Int
knownIx ->
          ( Map (Int, Int) Int
mids
          , Vector (Vec3, pointAttr)
extras
          , Int
knownIx
          )
        Maybe Int
Nothing ->
          let
            (Vec3
pos1, pointAttr
attr1) = Vector (Vec3, pointAttr)
points Vector (Vec3, pointAttr) -> Int -> (Vec3, pointAttr)
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! (Int, Int) -> Int
forall a b. (a, b) -> a
fst (Int, Int)
parents
            (Vec3
pos2, pointAttr
attr2) = Vector (Vec3, pointAttr)
points Vector (Vec3, pointAttr) -> Int -> (Vec3, pointAttr)
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! (Int, Int) -> Int
forall a b. (a, b) -> b
snd (Int, Int)
parents
            midPos :: Vec3
midPos = Float -> Vec3 -> Vec3 -> Vec3
Vec3.lerp Float
0.5 Vec3
pos1 Vec3
pos2

            newIx :: Int
newIx =
              Int
numPoints Int -> Int -> Int
forall a. Num a => a -> a -> a
+
              Vector (Vec3, pointAttr) -> Int
forall a. Vector a -> Int
Vector.length Vector (Vec3, pointAttr)
extras

            point :: (Vec3, pointAttr)
point =
              ( Vec3
midPos
              , "midpoint"
::: (scale -> Vec3 -> pointAttr -> pointAttr -> pointAttr)
mkMidpointAttrs scale
scale Vec3
midPos pointAttr
attr1 pointAttr
attr2
              )
          in
            ( (Int, Int) -> Int -> Map (Int, Int) Int -> Map (Int, Int) Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Int, Int)
parents Int
newIx Map (Int, Int) Int
mids
            , Vector (Vec3, pointAttr)
-> (Vec3, pointAttr) -> Vector (Vec3, pointAttr)
forall a. Vector a -> a -> Vector a
Vector.snoc Vector (Vec3, pointAttr)
extras (Vec3, pointAttr)
point
            , Int
newIx
            )

icofaces :: [Face Int]
icofaces :: [Face Int]
icofaces =
  [ -- faces around point 0
    Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face  Int
5 Int
11 Int
0
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face  Int
1  Int
5 Int
0
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face  Int
7  Int
1 Int
0
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
10  Int
7 Int
0
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
11 Int
10 Int
0

    -- 5 adjacent faces
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
9  Int
5  Int
1
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
4 Int
11  Int
5
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
2 Int
10 Int
11
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
6  Int
7 Int
10
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
8  Int
1  Int
7

    -- 5 adjacent faces around point 3
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
4 Int
9 Int
3
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
2 Int
4 Int
3
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
6 Int
2 Int
3
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
8 Int
6 Int
3
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
9 Int
8 Int
3

    -- 5 adjacent faces
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face  Int
5 Int
9 Int
4
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
11 Int
4 Int
2
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face Int
10 Int
2 Int
6
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face  Int
7 Int
6 Int
8
  , Int -> Int -> Int -> Face Int
forall a. a -> a -> a -> Face a
Face  Int
1 Int
8 Int
9
  ]

icopoints :: Vector Vec3
icopoints :: Vector Vec3
icopoints = [Vec3] -> Vector Vec3
forall a. [a] -> Vector a
Vector.fromList
  [ Float -> Float -> Float -> Vec3
vec3 (-Float
1) Float
0   Float
t
  , Float -> Float -> Float -> Vec3
vec3   Float
1  Float
0   Float
t
  , Float -> Float -> Float -> Vec3
vec3 (-Float
1) Float
0 (-Float
t)
  , Float -> Float -> Float -> Vec3
vec3   Float
1  Float
0 (-Float
t)

  , Float -> Float -> Float -> Vec3
vec3   Float
0 (-Float
t) (-Float
1)
  , Float -> Float -> Float -> Vec3
vec3   Float
0 (-Float
t)   Float
1
  , Float -> Float -> Float -> Vec3
vec3   Float
0   Float
t  (-Float
1)
  , Float -> Float -> Float -> Vec3
vec3   Float
0   Float
t    Float
1

  , Float -> Float -> Float -> Vec3
vec3   Float
t     Float
1  Float
0
  , Float -> Float -> Float -> Vec3
vec3   Float
t   (-Float
1) Float
0
  , Float -> Float -> Float -> Vec3
vec3 (-Float
t)    Float
1  Float
0
  , Float -> Float -> Float -> Vec3
vec3 (-Float
t)  (-Float
1) Float
0
  ]
  where
    t :: Float
t = (Float
1.0 Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float -> Float
forall a. Floating a => a -> a
sqrt Float
5.0) Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
2.0