{-# OPTIONS_GHC -Wno-orphans #-}

module Data.Function.FastMemo.Char () where

import Data.Bits (complement, countLeadingZeros)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.UTF8 as UTF8
import Data.Function.FastMemo.Class (Memoizable (..))
import Data.Function.FastMemo.Util (memoizeFixedLen)
import Data.Function.FastMemo.Word ()
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NonEmpty
import Data.Word (Word8)

-- We want ASCII Chars to require only a single Vector lookup, so let's encode as UTF-8
instance Memoizable Char where
  memoize :: (Char -> b) -> Char -> b
memoize Char -> b
f = (CodePoint -> b) -> CodePoint -> b
forall a b. Memoizable a => (a -> b) -> a -> b
memoize (Char -> b
f (Char -> b) -> (CodePoint -> Char) -> CodePoint -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodePoint -> Char
codePointToChar) (CodePoint -> b) -> (Char -> CodePoint) -> Char -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> CodePoint
charToCodePoint

newtype CodePoint = CodePoint {CodePoint -> NonEmpty Word8
getCodePoint :: NonEmpty Word8}

-- In UTF-8, the first byte of a codepoint tells us how many more bytes that codepoint contains.
-- We can use this fact to reduce lookups.
instance Memoizable CodePoint where
  memoize :: (CodePoint -> b) -> CodePoint -> b
memoize CodePoint -> b
f =
    let f' :: Word8 -> [Word8] -> b
f' = (Word8 -> [Word8] -> b) -> Word8 -> [Word8] -> b
forall a b. Memoizable a => (a -> b) -> a -> b
memoize (\Word8
w -> Int -> ([Word8] -> b) -> [Word8] -> b
forall a b.
(HasCallStack, Memoizable a) =>
Int -> ([a] -> b) -> [a] -> b
memoizeFixedLen (Word8 -> Int
extraBytes Word8
w) (CodePoint -> b
f (CodePoint -> b) -> ([Word8] -> CodePoint) -> [Word8] -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty Word8 -> CodePoint
CodePoint (NonEmpty Word8 -> CodePoint)
-> ([Word8] -> NonEmpty Word8) -> [Word8] -> CodePoint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8
w Word8 -> [Word8] -> NonEmpty Word8
forall a. a -> [a] -> NonEmpty a
:|)))
     in \(CodePoint (Word8
w :| [Word8]
ws)) -> Word8 -> [Word8] -> b
f' Word8
w [Word8]
ws

extraBytes :: Word8 -> Int
extraBytes :: Word8 -> Int
extraBytes Word8
x = case Word8 -> Int
countLeadingOnes Word8
x of
  Int
0 -> Int
0
  Int
n -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

countLeadingOnes :: Word8 -> Int
countLeadingOnes :: Word8 -> Int
countLeadingOnes = Word8 -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros (Word8 -> Int) -> (Word8 -> Word8) -> Word8 -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8
forall a. Bits a => a -> a
complement

charToCodePoint :: Char -> CodePoint
charToCodePoint :: Char -> CodePoint
charToCodePoint = NonEmpty Word8 -> CodePoint
CodePoint (NonEmpty Word8 -> CodePoint)
-> (Char -> NonEmpty Word8) -> Char -> CodePoint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> NonEmpty Word8
forall a. [a] -> NonEmpty a
NonEmpty.fromList ([Word8] -> NonEmpty Word8)
-> (Char -> [Word8]) -> Char -> NonEmpty Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
ByteString.unpack (ByteString -> [Word8]) -> (Char -> ByteString) -> Char -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
UTF8.fromString (String -> ByteString) -> (Char -> String) -> Char -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> String -> String
forall a. a -> [a] -> [a]
: [])

codePointToChar :: CodePoint -> Char
codePointToChar :: CodePoint -> Char
codePointToChar = String -> Char
forall a. [a] -> a
head (String -> Char) -> (CodePoint -> String) -> CodePoint -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
UTF8.toString (ByteString -> String)
-> (CodePoint -> ByteString) -> CodePoint -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
ByteString.pack ([Word8] -> ByteString)
-> (CodePoint -> [Word8]) -> CodePoint -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty Word8 -> [Word8]
forall a. NonEmpty a -> [a]
NonEmpty.toList (NonEmpty Word8 -> [Word8])
-> (CodePoint -> NonEmpty Word8) -> CodePoint -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodePoint -> NonEmpty Word8
getCodePoint