module Render.Unlit.TileMap.Code
  ( vert
  , frag
  ) where

import RIO

import Render.Code (Code, glsl)
import Render.Samplers qualified as Samplers
import Render.DescSets.Set0.Code (set0binding0, set0binding1, set0binding2)

vert :: Code
vert :: Code
vert = String -> Code
forall a. IsString a => String -> a
fromString
  [glsl|
    #version 450

    ${set0binding0}

    // vertexPos
    layout(location = 0) in vec3 vPosition;
    // vertexAttrs
    layout(location = 1) in vec2 vTexCoord;
    // tilemapParams
    layout(location = 2) in ivec4 iTextureIds; // combined: tileset, tileset sampler, map, repeat
    layout(location = 3) in vec2  iViewOffset;
    layout(location = 4) in vec2  iViewportSize;
    layout(location = 5) in vec2  iMapTextureSize;
    layout(location = 6) in vec2  iTilesetTextureSize;
    layout(location = 7) in vec2  iTileSize;

    // transformMat
    layout(location = 8) in mat4 iModel;

    layout(location = 0)      out  vec2 fTexCoord;
    layout(location = 1)      out  vec2 fPixCoord;
    layout(location = 2) flat out ivec4 fTextureIds;
    layout(location = 3) flat out  vec2 fTilesetTextureSize;
    layout(location = 4) flat out  vec2 fTileSize;

    void main() {
      vec4 fPosition = iModel * vec4(vPosition, 1.0);

      gl_Position
        = scene.projection
        * scene.view
        * fPosition;

      fPixCoord = (vTexCoord * iViewportSize) + iViewOffset;
      fTexCoord = fPixCoord / iMapTextureSize / iTileSize;

      fTextureIds = iTextureIds;
      fTilesetTextureSize = iTilesetTextureSize;
      fTileSize = iTileSize;
    }
  |]

frag :: Code
frag :: Code
frag = String -> Code
forall a. IsString a => String -> a
fromString
  [glsl|
    #version 450
    #extension GL_EXT_nonuniform_qualifier : enable

    ${set0binding1}
    ${set0binding2}

    layout(location = 0) in vec2 fTexCoord;
    layout(location = 1) in vec2 fPixCoord;

    // combined: tileset, tileset sampler, map, repeat
    layout(location = 2) flat in ivec4 fTextureIds;
    layout(location = 3) flat in vec2 fTilesetTextureSize;
    layout(location = 4) flat in vec2 fTileSize;

    layout(location = 0) out vec4 oColor;

    int tilesetTextureIx = fTextureIds[0];
    int tilesetSamplerIx = fTextureIds[1];
    int mapTextureIx     = fTextureIds[2];
    int repeatTiles      = fTextureIds[3];

    // TODO
    // const vec4 fTextureGamma = vec4(1.0);

    void main() {
      if (repeatTiles == 0 && (fTexCoord.x < 0.0 || fTexCoord.y < 0.0 || fTexCoord.x > 1.0 || fTexCoord.y > 1.0)) {
        discard;
      }

      vec4 map = texture(
        sampler2D(
          textures[nonuniformEXT(mapTextureIx)],
          samplers[$samplerId]
        ),
        fTexCoord
      );

      vec2 spriteOffset = floor(map.xy * 256.0) * fTileSize;
      vec2 spriteCoord = mod(fPixCoord, fTileSize);
      vec2 spriteUV = round(spriteOffset + spriteCoord) / fTilesetTextureSize;

      oColor = texture(
        sampler2D(
          textures[nonuniformEXT(tilesetTextureIx)],
          samplers[nonuniformEXT(tilesetSamplerIx)]
        ),
        spriteUV
      );
    }
  |]
  where
    samplerId :: Int32
samplerId = Collection Int32 -> Int32
forall a. Collection a -> a
Samplers.nearest Collection Int32
Samplers.indices