{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE NoForeignFunctionInterface #-}

module Resource.Texture.Ktx2
  ( load
  , loadBytes
  , loadKtx2
  ) where

import RIO

import Codec.Compression.Zstd.FFI qualified as Zstd
import Codec.Ktx.KeyValue qualified as KeyValue
import Codec.Ktx2.Header qualified as Header
import Codec.Ktx2.Level qualified as Level
import Codec.Ktx2.Read qualified as Read
import Data.Kind (Type)
import Data.Vector qualified as Vector
import Foreign qualified
import GHC.Stack (withFrozenCallStack)
import UnliftIO.Resource (MonadResource)
import Vulkan.Core10 qualified as Vk
import VulkanMemoryAllocator qualified as VMA

import Engine.Vulkan.Types (HasVulkan(..), MonadVulkan, Queues)
import Resource.Image qualified as Image
import Resource.Source (Source(..))
import Resource.Source qualified as Source
import Resource.Texture (Texture(..), TextureLayers(..))
import Resource.Texture qualified as Texture

load
  :: forall (a :: Type) env m
  .  ( TextureLayers a
     , MonadVulkan env m
     , MonadResource m
     , MonadThrow m
     , HasLogFunc env
     , Typeable a
     , HasCallStack
     )
  => Queues Vk.CommandPool
  -> Source
  -> m (Texture a)
load :: forall a env (m :: * -> *).
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env, Typeable a, HasCallStack) =>
Queues CommandPool -> Source -> m (Texture a)
load Queues CommandPool
pool Source
source =
  (HasCallStack => m (Texture a)) -> m (Texture a)
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m (Texture a)) -> m (Texture a))
-> (HasCallStack => m (Texture a)) -> m (Texture a)
forall a b. (a -> b) -> a -> b
$
    case Source
source of
      Source.File Maybe Text
label FilePath
path ->
        -- XXX: the codec has a more efficient loader for files
        Maybe Text -> Queues CommandPool -> FilePath -> m (Texture a)
forall {k} (a :: k) env (m :: * -> *).
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env) =>
Maybe Text -> Queues CommandPool -> FilePath -> m (Texture a)
loadFile Maybe Text
label Queues CommandPool
pool FilePath
path
      Source
_bytes ->
        (ByteString -> m (Texture a)) -> Source -> m (Texture a)
forall a (m :: * -> *) env.
(MonadIO m, MonadReader env m, HasLogFunc env, Typeable a,
 HasCallStack) =>
(ByteString -> m a) -> Source -> m a
Source.load (Maybe Text -> Queues CommandPool -> ByteString -> m (Texture a)
forall {k} (a :: k) env (m :: * -> *).
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env) =>
Maybe Text -> Queues CommandPool -> ByteString -> m (Texture a)
loadBytes Source
source.label Queues CommandPool
pool) Source
source

loadFile
  :: ( TextureLayers a
     , MonadVulkan env m
     , MonadResource m
     , MonadThrow m
     , HasLogFunc env
     )
  => Maybe Text
  -> Queues Vk.CommandPool
  -> FilePath
  -> m (Texture a)
loadFile :: forall {k} (a :: k) env (m :: * -> *).
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env) =>
Maybe Text -> Queues CommandPool -> FilePath -> m (Texture a)
loadFile Maybe Text
label Queues CommandPool
pool FilePath
path =
  m FileContext
-> (FileContext -> m ())
-> (FileContext -> m (Texture a))
-> m (Texture a)
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (FilePath -> m FileContext
forall (io :: * -> *). MonadIO io => FilePath -> io FileContext
Read.open FilePath
path) FileContext -> m ()
forall (io :: * -> *). MonadIO io => FileContext -> io ()
Read.close ((FileContext -> m (Texture a)) -> m (Texture a))
-> (FileContext -> m (Texture a)) -> m (Texture a)
forall a b. (a -> b) -> a -> b
$
    Maybe Text -> Queues CommandPool -> FileContext -> m (Texture a)
forall {k} (a :: k) (m :: * -> *) env src.
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env, ReadChunk src, ReadLevel src) =>
Maybe Text -> Queues CommandPool -> Context src -> m (Texture a)
loadKtx2 (Maybe Text
label Maybe Text -> Maybe Text -> Maybe Text
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Text -> Maybe Text
forall a. a -> Maybe a
Just (FilePath -> Text
forall a. IsString a => FilePath -> a
fromString FilePath
path)) Queues CommandPool
pool

loadBytes
  :: ( TextureLayers a
     , MonadVulkan env m
     , MonadResource m
     , MonadThrow m
     , HasLogFunc env
     )
  => Maybe Text
  -> Queues Vk.CommandPool
  -> ByteString
  -> m (Texture a)
loadBytes :: forall {k} (a :: k) env (m :: * -> *).
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env) =>
Maybe Text -> Queues CommandPool -> ByteString -> m (Texture a)
loadBytes Maybe Text
label Queues CommandPool
pool ByteString
bytes =
  ByteString -> m BytesContext
forall (io :: * -> *). MonadIO io => ByteString -> io BytesContext
Read.bytes ByteString
bytes m BytesContext -> (BytesContext -> m (Texture a)) -> m (Texture a)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe Text -> Queues CommandPool -> BytesContext -> m (Texture a)
forall {k} (a :: k) (m :: * -> *) env src.
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env, ReadChunk src, ReadLevel src) =>
Maybe Text -> Queues CommandPool -> Context src -> m (Texture a)
loadKtx2 Maybe Text
label Queues CommandPool
pool

loadKtx2
  :: forall a m env src
  .  ( TextureLayers a
     , MonadVulkan env m
     , MonadResource m
     , MonadThrow m
     , HasLogFunc env
     , Read.ReadChunk src
     , Read.ReadLevel src
     )
  => Maybe Text
  -> Queues Vk.CommandPool
  -> Read.Context src
  -> m (Texture a)
loadKtx2 :: forall {k} (a :: k) (m :: * -> *) env src.
(TextureLayers a, MonadVulkan env m, MonadResource m, MonadThrow m,
 HasLogFunc env, ReadChunk src, ReadLevel src) =>
Maybe Text -> Queues CommandPool -> Context src -> m (Texture a)
loadKtx2 Maybe Text
label Queues CommandPool
pool ktx2 :: Context src
ktx2@(Read.Context src
_src Header
header) = do
  Utf8Builder -> m ()
forall (m :: * -> *) env.
(MonadIO m, MonadReader env m, HasLogFunc env, HasCallStack) =>
Utf8Builder -> m ()
logDebug (Utf8Builder -> m ()) -> Utf8Builder -> m ()
forall a b. (a -> b) -> a -> b
$ (Maybe Text, Header) -> Utf8Builder
forall a. Show a => a -> Utf8Builder
displayShow (Maybe Text
label, Header
header)

  KeyValueData
kvd <- Context src -> m KeyValueData
forall src (io :: * -> *).
(ReadChunk src, MonadIO io) =>
Context src -> io KeyValueData
Read.keyValueData Context src
ktx2
  Utf8Builder -> m ()
forall (m :: * -> *) env.
(MonadIO m, MonadReader env m, HasLogFunc env, HasCallStack) =>
Utf8Builder -> m ()
logDebug (Utf8Builder -> m ()) -> Utf8Builder -> m ()
forall a b. (a -> b) -> a -> b
$ (Maybe Text, Format, Extent3D, Word32, Map Text Text)
-> Utf8Builder
forall a. Show a => a -> Utf8Builder
displayShow (Maybe Text
label, Format
format, Extent3D
extent, Word32
numLayers, KeyValueData -> Map Text Text
KeyValue.textual KeyValueData
kvd)

  Vector Level
levels <- Context src -> m (Vector Level)
forall src (io :: * -> *).
(ReadChunk src, MonadIO io) =>
Context src -> io (Vector Level)
Read.levels Context src
ktx2

  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Vector Level -> Int
forall a. Vector a -> Int
Vector.length Vector Level
levels Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    FilePath -> m ()
forall (m :: * -> *) a.
(MonadIO m, HasCallStack) =>
FilePath -> m a
throwString FilePath
"Ktx2 contains no image levels"

  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Vector Level -> Int
forall a. Vector a -> Int
Vector.length Vector Level
levels Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Header
header.levelCount) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    FilePath -> m ()
forall (m :: * -> *) a.
(MonadIO m, HasCallStack) =>
FilePath -> m a
throwString (FilePath -> m ()) -> FilePath -> m ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Ktx2 level count mismatch " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> (Int, Word32) -> FilePath
forall a. Show a => a -> FilePath
show (Vector Level -> Int
forall a. Vector a -> Int
Vector.length Vector Level
levels, Header
header.levelCount)

  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Word32
numLayers Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== forall (a :: k). TextureLayers a => Word32
forall {k} (a :: k). TextureLayers a => Word32
textureLayers @a) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    TextureError -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (TextureError -> m ()) -> TextureError -> m ()
forall a b. (a -> b) -> a -> b
$ Word32 -> Word32 -> TextureError
Texture.ArrayError (forall (a :: k). TextureLayers a => Word32
forall {k} (a :: k). TextureLayers a => Word32
textureLayers @a) Word32
numLayers

  let
    levelSizes :: Vector Int
levelSizes = (Level -> Int) -> Vector Level -> Vector Int
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> (Level -> Word64) -> Level -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.uncompressedByteLength)) Vector Level
levels
    totalSize :: Int
totalSize = Vector Int -> Int
forall a. Num a => Vector a -> a
Vector.sum Vector Int
levelSizes
    offsets :: Vector Int
offsets = Vector Int -> Vector Int
forall a. Vector a -> Vector a
Vector.init (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b. (a -> b -> a) -> a -> Vector b -> Vector a
Vector.scanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 Vector Int
levelSizes

  DstImage
dst <- Queues CommandPool
-> Maybe Text
-> Extent3D
-> Word32
-> Word32
-> Format
-> m DstImage
forall env (m :: * -> *).
(MonadVulkan env m, MonadResource m) =>
Queues CommandPool
-> Maybe Text
-> Extent3D
-> Word32
-> Word32
-> Format
-> m DstImage
Image.allocateDst
    Queues CommandPool
pool
    Maybe Text
label
    Extent3D
extent
    (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ Vector Level -> Int
forall a. Vector a -> Int
Vector.length Vector Level
levels)
    Word32
numLayers
    Format
format

  Allocator
vma <- (env -> Allocator) -> m Allocator
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks env -> Allocator
forall a. HasVulkan a => a -> Allocator
getAllocator
  Allocator
-> BufferCreateInfo '[]
-> AllocationCreateInfo
-> (m (Buffer, Allocation, AllocationInfo)
    -> ((Buffer, Allocation, AllocationInfo) -> m ())
    -> ((Buffer, Allocation, AllocationInfo) -> m (Texture a))
    -> m (Texture a))
-> ((Buffer, Allocation, AllocationInfo) -> m (Texture a))
-> m (Texture a)
forall (a :: [*]) (io :: * -> *) r.
(Extendss BufferCreateInfo a, PokeChain a, MonadIO io) =>
Allocator
-> BufferCreateInfo a
-> AllocationCreateInfo
-> (io (Buffer, Allocation, AllocationInfo)
    -> ((Buffer, Allocation, AllocationInfo) -> io ()) -> r)
-> r
VMA.withBuffer Allocator
vma (Int -> BufferCreateInfo '[]
forall a. Integral a => a -> BufferCreateInfo '[]
Texture.stageBufferCI Int
totalSize) AllocationCreateInfo
Texture.stageAllocationCI m (Buffer, Allocation, AllocationInfo)
-> ((Buffer, Allocation, AllocationInfo) -> m ())
-> ((Buffer, Allocation, AllocationInfo) -> m (Texture a))
-> m (Texture a)
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket \(Buffer
staging, Allocation
stage, AllocationInfo
stageInfo) -> do
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO case Header
header.supercompressionScheme of
      Word32
Header.SC_NONE ->
        Vector (Int, Level) -> ((Int, Level) -> IO Bool) -> IO ()
forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
Vector.forM_ (Vector Int -> Vector Level -> Vector (Int, Level)
forall a b. Vector a -> Vector b -> Vector (a, b)
Vector.zip Vector Int
offsets Vector Level
levels) \(Int
offset, Level
level) ->
          IO Bool -> IO Bool
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> IO Bool) -> (Ptr () -> IO Bool) -> Ptr () -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context src -> Level -> Ptr () -> IO Bool
forall src (io :: * -> *).
(ReadLevel src, MonadIO io) =>
Context src -> Level -> Ptr () -> io Bool
Read.levelToPtr Context src
ktx2 Level
level (Ptr () -> IO Bool) -> Ptr () -> IO Bool
forall a b. (a -> b) -> a -> b
$
            Ptr () -> Int -> Ptr ()
forall a b. Ptr a -> Int -> Ptr b
Foreign.plusPtr (AllocationInfo -> Ptr ()
VMA.mappedData AllocationInfo
stageInfo) Int
offset

      Word32
Header.SC_ZSTANDARD -> do
        let maxSize :: Int
maxSize = Vector Int -> Int
forall a. Ord a => Vector a -> a
Vector.maximum Vector Int
levelSizes
        Int -> Int -> (Ptr () -> IO ()) -> IO ()
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
Foreign.allocaBytesAligned Int
maxSize Int
16 \Ptr ()
src ->
          Vector (Int, Level) -> ((Int, Level) -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
Vector.forM_ (Vector Int -> Vector Level -> Vector (Int, Level)
forall a b. Vector a -> Vector b -> Vector (a, b)
Vector.zip Vector Int
offsets Vector Level
levels) \(Int
offset, Level
level) -> do
            let expected :: CSize
expected = Word64 -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Level
level.uncompressedByteLength
            Context src -> Level -> Ptr () -> IO Bool
forall src (io :: * -> *).
(ReadLevel src, MonadIO io) =>
Context src -> Level -> Ptr () -> io Bool
Read.levelToPtr Context src
ktx2 Level
level Ptr ()
src

            Either FilePath CSize
res <-
              IO CSize -> IO (Either FilePath CSize)
Zstd.checkError (IO CSize -> IO (Either FilePath CSize))
-> IO CSize -> IO (Either FilePath CSize)
forall a b. (a -> b) -> a -> b
$
                Ptr Any -> CSize -> Ptr () -> CSize -> IO CSize
forall dst src. Ptr dst -> CSize -> Ptr src -> CSize -> IO CSize
Zstd.decompress
                  (Ptr () -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
Foreign.plusPtr (AllocationInfo -> Ptr ()
VMA.mappedData AllocationInfo
stageInfo) Int
offset)
                  CSize
expected
                  Ptr ()
src
                  (Word64 -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Level
level.byteLength)
            case Either FilePath CSize
res of
              Right CSize
size | CSize
size CSize -> CSize -> Bool
forall a. Eq a => a -> a -> Bool
== CSize
expected ->
                () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              Right CSize
unexpected ->
                FilePath -> IO ()
forall (m :: * -> *) a.
(MonadIO m, HasCallStack) =>
FilePath -> m a
throwString (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$
                  FilePath
"Zstd decompressed unexpected amount of bytes: " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> (CSize, CSize) -> FilePath
forall a. Show a => a -> FilePath
show (CSize
unexpected, CSize
expected)
              Left FilePath
err ->
                FilePath -> IO ()
forall (m :: * -> *) a.
(MonadIO m, HasCallStack) =>
FilePath -> m a
throwString FilePath
err

      Word32
huh ->
        FilePath -> IO ()
forall a. HasCallStack => FilePath -> a
error (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Unexpected supercompression scheme: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Word32 -> FilePath
forall a. Show a => a -> FilePath
show Word32
huh

    Allocator -> Allocation -> Word64 -> Word64 -> m ()
forall (io :: * -> *).
MonadIO io =>
Allocator -> Allocation -> Word64 -> Word64 -> io ()
VMA.flushAllocation Allocator
vma Allocation
stage Word64
0 Word64
Vk.WHOLE_SIZE

    AllocatedImage
final <- Queues CommandPool
-> Buffer -> DstImage -> Vector Int -> m AllocatedImage
forall env (m :: * -> *) deviceSize (t :: * -> *).
(MonadVulkan env m, Integral deviceSize, Foldable t) =>
Queues CommandPool
-> Buffer
-> DstImage
-> ("mip offsets" ::: t deviceSize)
-> m AllocatedImage
Image.copyBufferToDst
      Queues CommandPool
pool
      Buffer
staging
      DstImage
dst
      Vector Int
offsets

    pure Texture
      { $sel:tFormat:Texture :: Format
tFormat         = Format
format
      , $sel:tMipLevels:Texture :: Word32
tMipLevels      = Header
header.levelCount
      , $sel:tLayers:Texture :: Word32
tLayers         = Word32
numLayers
      , $sel:tAllocatedImage:Texture :: AllocatedImage
tAllocatedImage = AllocatedImage
final
      }
  where
    format :: Format
format = Int32 -> Format
Vk.Format (Word32 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Header
header.vkFormat)

    extent :: Extent3D
extent = Vk.Extent3D
      { $sel:width:Extent3D :: Word32
width = Header
header.pixelWidth
      , $sel:height:Extent3D :: Word32
height = Header
header.pixelHeight
      , $sel:depth:Extent3D :: Word32
depth = Word32 -> Word32 -> Word32
forall a. Ord a => a -> a -> a
max Word32
1 Header
header.pixelDepth
      }

    -- XXX: can be flat array or a cubemap
    numLayers :: Word32
numLayers = Word32 -> Word32 -> Word32
forall a. Ord a => a -> a -> a
max Header
header.faceCount Header
header.layerCount