module Codec.Ktx where

import Data.Binary (Binary(..), decodeFileOrFail)
import Data.Binary.Get (Get, ByteOffset, getWord32le, getWord32be, getByteString, skip)
import Data.Binary.Put (Put, putByteString, putWord32le)
import Data.ByteString (ByteString)
import Data.Coerce (coerce)
import Data.Foldable (for_)
import Data.Map.Strict (Map)
import Data.Text (Text)
import Data.Vector (Vector)
import Data.Word (Word32)
import GHC.Generics (Generic)

import qualified Data.Text.Encoding as Text
import qualified Data.Map.Strict as Map
import qualified Data.Vector as Vector
import qualified Data.ByteString as BS

fromFile :: FilePath -> IO (Either (ByteOffset, String) Ktx)
fromFile = decodeFileOrFail

data Ktx = Ktx
  { header :: Header
  , kvs    :: KeyValueData
  , images :: MipLevels
  } deriving (Show, Generic)

instance Binary Ktx where
  get = do
    header <- get
    kvs <- getKeyValueData header
    images <- getImages header
    pure Ktx{..}

  put Ktx{..} = do
    put header
    putKeyValueData kvs
    putImages images

-- * Header

data Header = Header
  { identifier            :: ByteString
  , endianness            :: Word32
  , glType                :: Word32
  , glTypeSize            :: Word32
  , glFormat              :: Word32
  , glInternalFormat      :: Word32
  , glBaseInternalFormat  :: Word32
  , pixelWidth            :: Word32
  , pixelHeight           :: Word32
  , pixelDepth            :: Word32
  , numberOfArrayElements :: Word32
  , numberOfFaces         :: Word32
  , numberOfMipmapLevels  :: Word32
  , bytesOfKeyValueData   :: Word32
  } deriving (Show, Generic)

instance Binary Header where
  get = do
    identifier <- getByteString 12
    if identifier == canonicalIdentifier then
      pure ()
    else
      fail $ "KTX identifier mismatch: " <> show identifier

    endianness <- getWord32le
    let
      getNext =
        if endianness == endiannessLE then
          getWord32le
        else
          getWord32be

    glType                <- getNext
    glTypeSize            <- getNext
    glFormat              <- getNext
    glInternalFormat      <- getNext
    glBaseInternalFormat  <- getNext
    pixelWidth            <- getNext
    pixelHeight           <- getNext
    pixelDepth            <- getNext
    numberOfArrayElements <- getNext
    numberOfFaces         <- getNext
    numberOfMipmapLevels  <- getNext
    bytesOfKeyValueData   <- getNext

    pure Header{..}

  put Header{..} = do
    putByteString identifier
    putWord32le endianness
    putWord32le glType
    putWord32le glTypeSize
    putWord32le glFormat
    putWord32le glInternalFormat
    putWord32le glBaseInternalFormat
    putWord32le pixelWidth
    putWord32le pixelHeight
    putWord32le pixelDepth
    putWord32le numberOfArrayElements
    putWord32le numberOfFaces
    putWord32le numberOfMipmapLevels
    putWord32le bytesOfKeyValueData

endiannessLE :: Word32
endiannessLE = 0x04030201

canonicalIdentifier :: ByteString
canonicalIdentifier = BS.pack
  [ 0xAB, 0x4B, 0x54, 0x58, 0x20, 0x31, 0x31, 0xBB -- «KTX 11»
  , 0x0D, 0x0A, 0x1A, 0x0A                         -- \r\n\x1A\n
  ]

-- * Key-value data

type KeyValueData = Map Key Value

newtype Key = Key Text
  deriving (Eq, Ord, Show, Generic)

newtype Value = Value ByteString
  deriving (Show, Generic)

getKeyValueData :: Header -> Get KeyValueData
getKeyValueData Header{..} = Map.fromList <$> go bytesOfKeyValueData []
  where
    go remains acc
      | remains == 0 =
          pure acc

      | remains < 0 =
          fail ""

      | otherwise = do
          keyAndValueByteSize <- getSize
          keyAndValue <- getByteString (fromIntegral keyAndValueByteSize)
          _valuePadding <- skip . fromIntegral $ 3 - ((keyAndValueByteSize + 3) `rem` 4)

          let (key, value) = BS.span (/= 0x00) keyAndValue
          go (remains - keyAndValueByteSize) $
            ( Key $ Text.decodeUtf8 key
            , Value value
            ) : acc

    getSize =
      if endianness == endiannessLE then
        getWord32le
      else
        getWord32be

putKeyValueData :: Map Key Value -> Put
putKeyValueData kvs =
  for_ (Map.toList kvs) \(Key key, Value value) -> do
    let keyAndValue = Text.encodeUtf8 key <> BS.singleton 0x00 <> value
    putWord32le (fromIntegral $ BS.length keyAndValue)
    putByteString keyAndValue

-- * Images

type MipLevels = Vector MipLevel

data MipLevel = MipLevel
  { imageSize     :: Word32
  , arrayElements :: Vector ArrayElement
  }
  deriving (Show, Generic)

newtype ArrayElement = ArrayElement
  { faces :: Vector Face
  }
  deriving (Show, Generic)

newtype Face = Face
  { zSlices :: Vector ZSlice
  }
  deriving (Show, Generic)

newtype ZSlice = ZSlice
  { block :: ByteString
  }
  deriving (Generic)

instance Show ZSlice where
  show ZSlice{..} =
    let
      size = BS.length block
    in
      mconcat
        [ "ZSlice ("
        , show size
        , ") "
        , show (BS.take 32 block)
        ]

getImages :: Header -> Get MipLevels
getImages Header{..} =
  some_ numberOfMipmapLevels' do
    imageSize <- getImageSize

    let
      sliceSize = fromIntegral $
        if numberOfFaces == 6 then
          imageSize
        else
          imageSize
            `div` numberOfArrayElements'
            `div` numberOfFaces
            `div` pixelDepth'

    elements <- some_ numberOfArrayElements' $
      some_ numberOfFaces $
        some_ pixelDepth' $
          ZSlice <$> getByteString sliceSize

    pure MipLevel
      { imageSize     = imageSize
      , arrayElements = coerce elements
      }

  where
    some_ n action = Vector.forM (Vector.fromList [1..n]) \_ix -> action

    numberOfMipmapLevels'
      | numberOfMipmapLevels == 0 = 1
      | otherwise                 = numberOfMipmapLevels

    numberOfArrayElements'
      | numberOfArrayElements == 0 = 1
      | otherwise                  = numberOfArrayElements

    pixelDepth'
      | pixelDepth == 0 = 1
      | otherwise       = pixelDepth

    getImageSize =
      if endianness == endiannessLE then
        getWord32le
      else
        getWord32be

putImages :: MipLevels -> Put
putImages mipLevels = Vector.forM_ mipLevels \MipLevel{..} -> do
  put imageSize
  Vector.forM_ arrayElements \ArrayElement{..} ->
    Vector.forM_ faces \Face{..} ->
      Vector.forM_ zSlices \ZSlice{..} ->
        putByteString block