-- | Capnproto message canonicalization, per:
--
-- https://capnproto.org/encoding.html#canonicalization
{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies     #-}
module Capnp.Canonicalize
    ( canonicalize
    ) where

-- Note [Allocation strategy]
--
-- The implementation makes use of knowledge of how we allocate values inside
-- a message; in particular, we assume objects are allocated sequentially,
-- and that if the first segment is big enough we will never allocate a second
-- segment.
--
-- If we ever make the allocator plugable, we will have to revisit this and
-- ensure that our assumptions still hold.

-- Note [Other assumptions]
--
-- This code relies on the fact that Capnp.Pointer.serializePointer does the
-- canonicalization of zero-sized struct pointers for us; see the comments there
-- for more details.

import Data.Word

import Data.Foldable    (for_)
import Data.Maybe       (isNothing)
import Data.Traversable (for)

import           Capnp.Bits    (WordCount)
import qualified Capnp.Message as M
import qualified Capnp.Untyped as U

-- | Return a canonicalized message with a copy of the given struct as its
-- root. returns a (message, segment) pair, where the segment is the first
-- and only segment of the returned message.
--
-- In addition to the usual reasons for failure when reading a message (traversal limit,
-- malformed messages), this can fail if the message does not fit in a single segment,
-- as the canonical form requires single-segment messages.
canonicalize :: (U.RWCtx m s, M.Message m msgIn) => U.Struct msgIn -> m (M.MutMsg s, M.Segment (M.MutMsg s))
canonicalize rootStructIn = do
    let msgIn = U.message rootStructIn
    -- Note [Allocation strategy]
    words <- totalWords msgIn
    msgOut <- M.newMessage $ Just words
    rootStructOut <- cloneCanonicalStruct rootStructIn msgOut
    U.setRoot rootStructOut
    segOut <- M.getSegment msgOut 0
    pure (msgOut, segOut)

totalWords :: U.ReadCtx m msg => msg -> m WordCount
totalWords msg = do
    -- Note [Allocation strategy]
    segCount <- M.numSegs msg
    sizes <- for [0..segCount - 1] $ \i -> do
        seg <- M.getSegment msg i
        M.numWords seg
    pure $ sum sizes

cloneCanonicalStruct :: (U.RWCtx m s, M.Message m msgIn) => U.Struct msgIn -> M.MutMsg s -> m (U.Struct (M.MutMsg s))
cloneCanonicalStruct structIn msgOut = do
    (nWords, nPtrs) <- findCanonicalSectionCounts structIn
    structOut <- U.allocStruct msgOut (fromIntegral nWords) (fromIntegral nPtrs)
    copyCanonicalStruct structIn structOut
    pure structOut

copyCanonicalStruct :: (U.RWCtx m s, M.Message m msgIn) => U.Struct msgIn -> U.Struct (M.MutMsg s) -> m ()
copyCanonicalStruct structIn structOut = do
    let nWords = fromIntegral $ U.structWordCount structOut
        nPtrs = fromIntegral $ U.structPtrCount structOut
    for_ [0..nWords - 1] $ \i -> do
        word <- U.getData i structIn
        U.setData word i structOut
    for_ [0..nPtrs - 1] $ \i -> do
        ptrIn <- U.getPtr i structIn
        ptrOut <- cloneCanonicalPtr ptrIn (U.message structOut)
        U.setPtr ptrOut i structOut

findCanonicalSectionCounts :: U.ReadCtx m msg => U.Struct msg -> m (Word16, Word16)
findCanonicalSectionCounts struct = do
    nWords <- canonicalSectionCount (== 0) (`U.getData` struct) (fromIntegral $ U.structWordCount struct)
    nPtrs <- canonicalSectionCount isNothing (`U.getPtr` struct) (fromIntegral $ U.structPtrCount struct)
    pure (nWords, nPtrs)

canonicalSectionCount :: Monad m => (a -> Bool) -> (Int -> m a) -> Int -> m Word16
canonicalSectionCount _ _ 0 = pure 0
canonicalSectionCount isDefault getIndex total = do
    value <- getIndex (total - 1)
    if isDefault value
        then canonicalSectionCount isDefault getIndex (total - 1)
        else pure $ fromIntegral total

cloneCanonicalPtr :: (U.RWCtx m s, M.Message m msgIn) => Maybe (U.Ptr msgIn) -> M.MutMsg s -> m (Maybe (U.Ptr (M.MutMsg s)))
cloneCanonicalPtr ptrIn msgOut =
    case ptrIn of
        Nothing ->
            pure Nothing
        Just (U.PtrCap cap) -> do
            client <- U.getClient cap
            Just . U.PtrCap <$> U.appendCap msgOut client
        Just (U.PtrStruct struct) ->
            Just . U.PtrStruct <$> cloneCanonicalStruct struct msgOut
        Just (U.PtrList list) ->
            Just . U.PtrList <$> cloneCanonicalList list msgOut

cloneCanonicalList :: (U.RWCtx m s, M.Message m msgIn) => U.List msgIn -> M.MutMsg s -> m (U.List (M.MutMsg s))
cloneCanonicalList listIn msgOut =
    case listIn of
        U.List0 l -> U.List0 <$> U.allocList0 msgOut (U.length l)
        U.List1 l -> U.List1 <$> (U.allocList1 msgOut (U.length l) >>= copyCanonicalDataList l)
        U.List8 l -> U.List8 <$> (U.allocList8 msgOut (U.length l) >>= copyCanonicalDataList l)
        U.List16 l -> U.List16 <$> (U.allocList16 msgOut (U.length l) >>= copyCanonicalDataList l)
        U.List32 l -> U.List32 <$> (U.allocList32 msgOut (U.length l) >>= copyCanonicalDataList l)
        U.List64 l -> U.List64 <$> (U.allocList64 msgOut (U.length l) >>= copyCanonicalDataList l)
        U.ListPtr l -> U.ListPtr <$> (U.allocListPtr msgOut (U.length l) >>= copyCanonicalPtrList l)
        U.ListStruct l -> U.ListStruct <$> cloneCanonicalStructList l msgOut

copyCanonicalDataList :: (U.RWCtx m s, M.Message m msgIn) => U.ListOf msgIn a -> U.ListOf (M.MutMsg s) a -> m (U.ListOf (M.MutMsg s) a)
copyCanonicalDataList listIn listOut = do
    for_ [0..U.length listIn - 1] $ \i -> do
        value <- U.index i listIn
        U.setIndex value i listOut
    pure listOut

copyCanonicalPtrList
    :: (U.RWCtx m s, M.Message m msgIn)
    => U.ListOf msgIn (Maybe (U.Ptr msgIn))
    -> U.ListOf (M.MutMsg s) (Maybe (U.Ptr (M.MutMsg s)))
    -> m (U.ListOf (M.MutMsg s) (Maybe (U.Ptr (M.MutMsg s))))
copyCanonicalPtrList listIn listOut = do
    for_ [0..U.length listIn - 1] $ \i -> do
        ptrIn <- U.index i listIn
        ptrOut <- cloneCanonicalPtr ptrIn (U.message listOut)
        U.setIndex ptrOut i listOut
    pure listOut

cloneCanonicalStructList
    :: (U.RWCtx m s, M.Message m msgIn)
    => U.ListOf msgIn (U.Struct msgIn)
    -> M.MutMsg s
    -> m (U.ListOf (M.MutMsg s) (U.Struct (M.MutMsg s)))
cloneCanonicalStructList listIn msgOut = do
    (nWords, nPtrs) <- findCanonicalListSectionCounts listIn
    listOut <- U.allocCompositeList msgOut nWords nPtrs (U.length listIn)
    copyCanonicalStructList listIn listOut
    pure listOut

copyCanonicalStructList
    :: (U.RWCtx m s, M.Message m msgIn)
    => U.ListOf msgIn (U.Struct msgIn)
    -> U.ListOf (M.MutMsg s) (U.Struct (M.MutMsg s))
    -> m ()
copyCanonicalStructList listIn listOut =
    for_ [0..U.length listIn - 1] $ \i -> do
        structIn <- U.index i listIn
        structOut <- U.index i listOut
        copyCanonicalStruct structIn structOut

findCanonicalListSectionCounts :: U.ReadCtx m msg => U.ListOf msg (U.Struct msg) -> m (Word16, Word16)
findCanonicalListSectionCounts list = go 0 0 0 where
    go i !nWords !nPtrs
        | i >= U.length list =
            pure (nWords, nPtrs)
        | otherwise = do
            struct <- U.index i list
            (nWords', nPtrs') <- findCanonicalSectionCounts struct
            go (i+1) (max nWords nWords') (max nPtrs nPtrs')