{-# LANGUAGE PatternGuards #-}
module Database.PostgreSQL.Simple.Arrays where
import           Control.Applicative (Applicative(..), Alternative(..), (<$>))
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import           Data.Monoid
import           Data.Attoparsec.ByteString.Char8
arrayFormat :: Char -> Parser ArrayFormat
arrayFormat delim  =  Array  <$> array delim
                  <|> Plain  <$> plain delim
                  <|> Quoted <$> quoted
data ArrayFormat = Array [ArrayFormat]
                 | Plain ByteString
                 | Quoted ByteString
                   deriving (Eq, Show, Ord)
array :: Char -> Parser [ArrayFormat]
array delim = char '{' *> option [] (arrays <|> strings) <* char '}'
  where
    strings = sepBy1 (Quoted <$> quoted <|> Plain <$> plain delim) (char delim)
    arrays  = sepBy1 (Array <$> array delim) (char ',')
    
quoted :: Parser ByteString
quoted  = char '"' *> option "" contents <* char '"'
  where
    esc' = char '\\' *> (char '\\' <|> char '"')
    unQ = takeWhile1 (notInClass "\"\\")
    contents = mconcat <$> many (unQ <|> B.singleton <$> esc')
plain :: Char -> Parser ByteString
plain delim = takeWhile1 (notInClass (delim:"\"{}"))
fmt :: Char -> ArrayFormat -> ByteString
fmt = fmt' False
delimit :: Char -> [ArrayFormat] -> ByteString
delimit _      [] = ""
delimit c     [x] = fmt' True c x
delimit c (x:y:z) = (fmt' True c x `B.snoc` c') `mappend` delimit c (y:z)
  where
    c' | Array _ <- x = ','
       | otherwise    = c
fmt' :: Bool -> Char -> ArrayFormat -> ByteString
fmt' quoting c x =
  case x of
    Array items          -> '{' `B.cons` (delimit c items `B.snoc` '}')
    Plain bytes          -> B.copy bytes
    Quoted q | quoting   -> '"' `B.cons` (esc q `B.snoc` '"')
             | otherwise -> B.copy q
    
esc :: ByteString -> ByteString
esc = B.concatMap f
  where
    f '"'  = "\\\""
    f '\\' = "\\\\"
    f c    = B.singleton c