{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.ProtoLens.Compiler.Definitions
( Env
, Definition(..)
, MessageInfo(..)
, ServiceInfo(..)
, MethodInfo(..)
, FieldInfo(..)
, OneofInfo(..)
, OneofCase(..)
, FieldName(..)
, Symbol
, nameFromSymbol
, promoteSymbol
, EnumInfo(..)
, EnumValueInfo(..)
, qualifyEnv
, unqualifyEnv
, collectDefinitions
, collectServices
, definedFieldType
, definedType
, camelCase
) where
import Data.Char (isUpper, toUpper)
import Data.Int (Int32)
import Data.List (mapAccumL)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
import qualified Data.Semigroup as Semigroup
import qualified Data.Set as Set
import Data.String (IsString(..))
import Data.Text (Text, cons, splitOn, toLower, uncons, unpack)
import qualified Data.Text as T
import Lens.Family2 ((^.), (^..), toListOf)
import Proto.Google.Protobuf.Descriptor
( DescriptorProto
, EnumDescriptorProto
, EnumValueDescriptorProto
, FieldDescriptorProto
, FileDescriptorProto
, MethodDescriptorProto
, ServiceDescriptorProto
)
import Proto.Google.Protobuf.Descriptor_Fields
( clientStreaming
, enumType
, field
, inputType
, maybe'oneofIndex
, messageType
, method
, name
, nestedType
, number
, oneofDecl
, outputType
, package
, serverStreaming
, service
, typeName
, value
)
import Data.ProtoLens.Compiler.Combinators
( Name
, QName
, ModuleName
, Type
, qual
, tyPromotedString
, unQual
)
type Env n = Map.Map Text (Definition n)
data Definition n = Message (MessageInfo n) | Enum (EnumInfo n)
deriving Functor
data MessageInfo n = MessageInfo
{ messageName :: n
, messageDescriptor :: DescriptorProto
, messageFields :: [FieldInfo]
, messageOneofFields :: [OneofInfo]
, messageUnknownFields :: Name
} deriving Functor
data ServiceInfo = ServiceInfo
{ serviceName :: Text
, servicePackage :: Text
, serviceMethods :: [MethodInfo]
}
data MethodInfo = MethodInfo
{ methodName :: Text
, methodIdent :: Text
, methodInput :: Text
, methodOutput :: Text
, methodClientStreaming :: Bool
, methodServerStreaming :: Bool
}
data FieldInfo = FieldInfo
{ fieldDescriptor :: FieldDescriptorProto
, plainFieldName :: FieldName
}
data OneofInfo = OneofInfo
{ oneofFieldName :: FieldName
, oneofTypeName :: Name
, oneofCases :: [OneofCase]
}
data OneofCase = OneofCase
{ caseField :: FieldInfo
, caseConstructorName :: Name
, casePrismName :: Name
}
data FieldName = FieldName
{ overloadedName :: Symbol
, haskellRecordFieldName :: Name
}
newtype Symbol = Symbol String
deriving (Eq, Ord, IsString, Semigroup.Semigroup, Monoid)
nameFromSymbol :: Symbol -> Name
nameFromSymbol (Symbol s) = fromString s
promoteSymbol :: Symbol -> Type
promoteSymbol (Symbol s) = tyPromotedString s
data EnumInfo n = EnumInfo
{ enumName :: n
, enumUnrecognizedName :: n
, enumUnrecognizedValueName :: n
, enumDescriptor :: EnumDescriptorProto
, enumValues :: [EnumValueInfo n]
} deriving Functor
data EnumValueInfo n = EnumValueInfo
{ enumValueName :: n
, enumValueDescriptor :: EnumValueDescriptorProto
, enumAliasOf :: Maybe Name
} deriving Functor
mapEnv :: (n -> n') -> Env n -> Env n'
mapEnv f = fmap $ fmap f
qualifyEnv :: ModuleName -> Env Name -> Env QName
qualifyEnv m = mapEnv (qual m)
unqualifyEnv :: Env Name -> Env QName
unqualifyEnv = mapEnv unQual
definedFieldType :: FieldDescriptorProto -> Env QName -> Definition QName
definedFieldType fd env = fromMaybe err $ Map.lookup (fd ^. typeName) env
where
err = error $ "definedFieldType: Field type " ++ unpack (fd ^. typeName)
++ " not found in environment."
definedType :: Text -> Env QName -> Definition QName
definedType ty = fromMaybe err . Map.lookup ty
where
err = error $ "definedType: Type " ++ unpack ty
++ " not found in environment."
collectDefinitions :: FileDescriptorProto -> Env Name
collectDefinitions fd = let
protoPrefix = case fd ^. package of
"" -> "."
p -> "." <> p <> "."
hsPrefix = ""
in Map.fromList $ messageAndEnumDefs protoPrefix hsPrefix
(fd ^. messageType) (fd ^. enumType)
collectServices :: FileDescriptorProto -> [ServiceInfo]
collectServices fd = fmap (toServiceInfo $ fd ^. package) $ fd ^. service
where
toServiceInfo :: Text -> ServiceDescriptorProto -> ServiceInfo
toServiceInfo pkg sd =
ServiceInfo
{ serviceName = sd ^. name
, servicePackage = pkg
, serviceMethods = fmap toMethodInfo $ sd ^. method
}
toMethodInfo :: MethodDescriptorProto -> MethodInfo
toMethodInfo md =
MethodInfo
{ methodName = md ^. name
, methodIdent = camelCase $ md ^. name
, methodInput = fromString . T.unpack $ md ^. inputType
, methodOutput = fromString . T.unpack $ md ^. outputType
, methodClientStreaming = md ^. clientStreaming
, methodServerStreaming = md ^. serverStreaming
}
messageAndEnumDefs :: Text -> String -> [DescriptorProto]
-> [EnumDescriptorProto] -> [(Text, Definition Name)]
messageAndEnumDefs protoPrefix hsPrefix messages enums
= concatMap (messageDefs protoPrefix hsPrefix) messages
++ map (enumDef protoPrefix hsPrefix) enums
messageDefs :: Text -> String -> DescriptorProto
-> [(Text, Definition Name)]
messageDefs protoPrefix hsPrefix d
= (protoName, thisDef)
: messageAndEnumDefs
(protoName <> ".")
hsPrefix'
(d ^. nestedType)
(d ^. enumType)
where
protoName = protoPrefix <> d ^. name
hsPrefix' = hsPrefix ++ hsName (d ^. name) ++ "'"
allFields = groupFieldsByOneofIndex (d ^. field)
thisDef =
Message MessageInfo
{ messageName = fromString $ hsPrefix ++ hsName (d ^. name)
, messageDescriptor = d
, messageFields =
map (fieldInfo hsPrefix')
$ Map.findWithDefault [] Nothing allFields
, messageOneofFields = collectOneofFields hsPrefix' d allFields
, messageUnknownFields =
fromString $ "_" ++ hsPrefix' ++ "_unknownFields"
}
fieldInfo :: String -> FieldDescriptorProto -> FieldInfo
fieldInfo hsPrefix f = FieldInfo f $ mkFieldName hsPrefix $ f ^. name
collectOneofFields
:: String -> DescriptorProto -> Map.Map (Maybe Int32) [FieldDescriptorProto]
-> [OneofInfo]
collectOneofFields hsPrefix d allFields
= zipWith oneofInfo [0..] $ d ^.. oneofDecl . traverse . name
where
oneofInfo idx n = OneofInfo
{ oneofFieldName = mkFieldName hsPrefix n
, oneofTypeName = fromString $ hsPrefix ++ hsNameUnique subdefTypes n
, oneofCases = map oneofCase
$ Map.findWithDefault [] (Just idx)
allFields
}
oneofCase f =
let consName = hsPrefix ++ hsNameUnique subdefCons (f ^. name)
in OneofCase
{ caseField = fieldInfo hsPrefix f
, caseConstructorName =
fromString consName
, casePrismName =
fromString $ "_" ++ consName
}
hsNameUnique ns n
| n' `elem` ns = n' ++ "'"
| otherwise = n'
where
n' = hsName $ camelCase n
subdefTypes = Set.fromList $ map hsName
$ toListOf (nestedType . traverse . name) d
++ toListOf (enumType . traverse . name) d
subdefCons = Set.fromList $ map hsName
$ toListOf (nestedType . traverse . name) d
++ toListOf (enumType . traverse . value . traverse . name) d
groupFieldsByOneofIndex
:: [FieldDescriptorProto] -> Map.Map (Maybe Int32) [FieldDescriptorProto]
groupFieldsByOneofIndex =
fmap reverse
. Map.fromListWith (++)
. fmap (\f -> (f ^. maybe'oneofIndex, [f]))
hsName :: Text -> String
hsName = unpack . capitalize
mkFieldName :: String -> Text -> FieldName
mkFieldName hsPrefix n = FieldName
{ overloadedName = fromString n'
, haskellRecordFieldName = fromString $ "_" ++ hsPrefix ++ n'
}
where
n' = fieldName n
fieldName :: Text -> String
fieldName = unpack . disambiguate . camelCase
where
disambiguate s
| s `Set.member` reservedKeywords = s <> "'"
| otherwise = s
camelCase :: Text -> Text
camelCase s =
let (underlines, rest) = T.span (== '_') s
in case splitOn "_" rest of
[] -> error $ "camelCase: splitOn returned empty list: "
++ show rest
[""] -> error $ "camelCase: name consists only of underscores: "
++ show s
s':ss -> T.concat $ underlines : lowerInitialChars s' : map capitalize ss
lowerInitialChars :: Text -> Text
lowerInitialChars s = toLower pre <> post
where (pre, post) = T.span isUpper s
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++
[ "mdo"
, "rec"
, "pattern"
, "proc"
]
enumDef :: Text -> String -> EnumDescriptorProto
-> (Text, Definition Name)
enumDef protoPrefix hsPrefix d = let
mkText n = protoPrefix <> n
mkHsName n = fromString $ hsPrefix ++ unpack n
in (mkText (d ^. name)
, Enum EnumInfo
{ enumName = mkHsName (d ^. name)
, enumUnrecognizedName = mkHsName (d ^. name <> "'Unrecognized")
, enumUnrecognizedValueName = mkHsName (d ^. name <> "'UnrecognizedValue")
, enumDescriptor = d
, enumValues = collectEnumValues mkHsName $ d ^. value
})
collectEnumValues :: (Text -> Name) -> [EnumValueDescriptorProto]
-> [EnumValueInfo Name]
collectEnumValues mkHsName = snd . mapAccumL helper Map.empty
where
helper :: Map.Map Int32 Name -> EnumValueDescriptorProto
-> (Map.Map Int32 Name, EnumValueInfo Name)
helper seenNames v
| Just n' <- Map.lookup k seenNames = (seenNames, mkValue (Just n'))
| otherwise = (Map.insert k n seenNames, mkValue Nothing)
where
mkValue = EnumValueInfo n v
n = mkHsName (v ^. name)
k = v ^. number
capitalize :: Text -> Text
capitalize s
| Just (c, s') <- uncons s = cons (toUpper c) s'
| otherwise = s