{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.Avro.Deriving.NormSchema
where
import Control.Monad.State.Strict
import Data.Avro.Schema
import qualified Data.List as L
import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe)
import Data.Semigroup ((<>))
import qualified Data.Set as S
import Data.Text (Text)
import qualified Data.Text as T
extractDerivables :: Schema -> [Schema]
extractDerivables s = flip evalState state . normSchema . snd <$> rawRecs
where
rawRecs = getTypes s
state = M.fromList rawRecs
getTypes :: Type -> [(TypeName, Type)]
getTypes rec = case rec of
r@Record{name, fields} -> (name,r) : (fields >>= (getTypes . fldType))
Array t -> getTypes t
Union (t1 :| ts) _ -> getTypes t1 <> concatMap getTypes ts
Map t -> getTypes t
e@Enum{name} -> [(name, e)]
f@Fixed{name} -> [(name, f)]
_ -> []
normSchema :: Schema -> State (M.Map TypeName Schema) Schema
normSchema r = case r of
t@(NamedType tn) -> do
resolved <- get
case M.lookup tn resolved of
Just rs ->
modify' (M.insert tn t) >> pure rs
Nothing ->
error $ "Unable to resolve schema: " <> show (typeName t)
Array s -> Array <$> normSchema s
Map s -> Map <$> normSchema s
Union l f -> flip Union f <$> traverse normSchema l
r@Record{name = tn} -> do
modify' (M.insert tn (NamedType tn))
flds <- mapM (\fld -> setType fld <$> normSchema (fldType fld)) (fields r)
pure $ r { fields = flds }
s -> pure s
where
setType fld t = fld { fldType = t}