{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} module Auth.Biscuit.Datalog.ScopedExecutor ( BlockWithRevocationId , runAuthorizer , runAuthorizerWithLimits , runAuthorizerNoTimeout , runFactGeneration , PureExecError (..) , AuthorizationSuccess (..) , getBindings , queryGeneratedFacts , queryAvailableFacts , getVariableValues , getSingleVariableValue , FactGroup (..) , collectWorld ) where import Control.Monad (unless, when) import Control.Monad.State (StateT (..), evalStateT, get, gets, lift, put) import Data.Bifunctor (first) import Data.ByteString (ByteString) import Data.Foldable (sequenceA_) import Data.List (genericLength) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as NE import Data.Map (Map) import qualified Data.Map as Map import Data.Map.Strict ((!?)) import Data.Set (Set) import qualified Data.Set as Set import Data.Text (Text) import Numeric.Natural (Natural) import Validation (Validation (..)) import Auth.Biscuit.Crypto (PublicKey) import Auth.Biscuit.Datalog.AST import Auth.Biscuit.Datalog.Executor (Bindings, ExecutionError (..), FactGroup (..), Limits (..), MatchedQuery (..), ResultError (..), Scoped, checkCheck, checkPolicy, countFacts, defaultLimits, fromScopedFacts, getBindingsForRuleBody, getFactsForRule, keepAuthorized', toScopedFacts) import Auth.Biscuit.Datalog.Parser (fact) import Auth.Biscuit.Timer (timer) import Auth.Biscuit.Utils (foldMapM, mapMaybeM) import Data.Bitraversable (bisequence) type BlockWithRevocationId = (Block, ByteString, Maybe PublicKey) -- | A subset of 'ExecutionError' that can only happen during fact generation data PureExecError = Facts | Iterations | BadRule | BadExpression String deriving (Eq, Show) -- | Proof that a biscuit was authorized successfully. In addition to the matched -- @allow query@, the generated facts are kept around for further querying. -- Since only authority facts can be trusted, they are kept separate. data AuthorizationSuccess = AuthorizationSuccess { matchedAllowQuery :: MatchedQuery -- ^ The allow query that matched , allFacts :: FactGroup -- ^ All the facts that were generated by the biscuit, grouped by their origin , limits :: Limits -- ^ Limits used when running datalog. It is kept around to allow further -- datalog computation when querying facts } deriving (Eq, Show) -- | Get the matched variables from the @allow@ query used to authorize the biscuit. -- This can be used in conjuction with 'getVariableValues' or 'getSingleVariableValue' -- to extract the actual values getBindings :: AuthorizationSuccess -> Set Bindings getBindings AuthorizationSuccess{matchedAllowQuery=MatchedQuery{bindings}} = bindings -- | Given a series of blocks and an authorizer, ensure that all -- the checks and policies match runAuthorizer :: BlockWithRevocationId -- ^ The authority block -> [BlockWithRevocationId] -- ^ The extra blocks -> Authorizer -- ^ A authorizer -> IO (Either ExecutionError AuthorizationSuccess) runAuthorizer = runAuthorizerWithLimits defaultLimits -- | Given a series of blocks and an authorizer, ensure that all -- the checks and policies match, with provided execution -- constraints runAuthorizerWithLimits :: Limits -- ^ custom limits -> BlockWithRevocationId -- ^ The authority block -> [BlockWithRevocationId] -- ^ The extra blocks -> Authorizer -- ^ A authorizer -> IO (Either ExecutionError AuthorizationSuccess) runAuthorizerWithLimits l@Limits{..} authority blocks v = do resultOrTimeout <- timer maxTime $ pure $ runAuthorizerNoTimeout l authority blocks v pure $ case resultOrTimeout of Nothing -> Left Timeout Just r -> r mkRevocationIdFacts :: BlockWithRevocationId -> [BlockWithRevocationId] -> Set Fact mkRevocationIdFacts authority blocks = let allIds :: [(Int, ByteString)] allIds = zip [0..] $ snd' <$> authority : blocks snd' (_,b,_) = b mkFact (index, rid) = [fact|revocation_id({index}, {rid})|] in Set.fromList $ mkFact <$> allIds data ComputeState = ComputeState { sLimits :: Limits -- readonly , sRules :: Map Natural (Set EvalRule) -- readonly , sBlockCount :: Natural -- state , sIterations :: Int -- elapsed iterations , sFacts :: FactGroup -- facts generated so far } deriving (Eq, Show) mkInitState :: Limits -> BlockWithRevocationId -> [BlockWithRevocationId] -> Authorizer -> ComputeState mkInitState limits authority blocks authorizer = let fst' (a,_,_) = a trd' (_,_,c) = c sBlockCount = 1 + genericLength blocks externalKeys = Nothing : (trd' <$> blocks) revocationWorld = (mempty, FactGroup $ Map.singleton (Set.singleton sBlockCount) $ mkRevocationIdFacts authority blocks) firstBlock = fst' authority otherBlocks = fst' <$> blocks allBlocks = zip [0..] (firstBlock : otherBlocks) <> [(sBlockCount, vBlock authorizer)] (sRules, sFacts) = revocationWorld <> foldMap (uncurry collectWorld . fmap (toEvaluation externalKeys)) allBlocks in ComputeState { sLimits = limits , sRules , sBlockCount , sIterations = 0 , sFacts } runAuthorizerNoTimeout :: Limits -> BlockWithRevocationId -> [BlockWithRevocationId] -> Authorizer -> Either ExecutionError AuthorizationSuccess runAuthorizerNoTimeout limits authority blocks authorizer = do let fst' (a,_,_) = a trd' (_,_,c) = c blockCount = 1 + genericLength blocks externalKeys = Nothing : (trd' <$> blocks) (<$$>) = fmap . fmap (<$$$>) = fmap . fmap . fmap initState = mkInitState limits authority blocks authorizer toExecutionError = \case Facts -> TooManyFacts Iterations -> TooManyIterations BadRule -> InvalidRule BadExpression e -> EvaluationError e allFacts <- first toExecutionError $ computeAllFacts initState let checks = bChecks <$$> ( zip [0..] (fst' <$> authority : blocks) <> [(blockCount,vBlock authorizer)] ) policies = vPolicies authorizer checkResults = checkChecks limits blockCount allFacts (checkToEvaluation externalKeys <$$$> checks) policyResults = checkPolicies limits blockCount allFacts (policyToEvaluation externalKeys <$> policies) case bisequence (checkResults, policyResults) of Left e -> Left $ EvaluationError e Right (Success (), Left Nothing) -> Left $ ResultError $ NoPoliciesMatched [] Right (Success (), Left (Just p)) -> Left $ ResultError $ DenyRuleMatched [] p Right (Failure cs, Left Nothing) -> Left $ ResultError $ NoPoliciesMatched (NE.toList cs) Right (Failure cs, Left (Just p)) -> Left $ ResultError $ DenyRuleMatched (NE.toList cs) p Right (Failure cs, Right _) -> Left $ ResultError $ FailedChecks cs Right (Success (), Right p) -> Right $ AuthorizationSuccess { matchedAllowQuery = p , allFacts , limits } runStep :: StateT ComputeState (Either PureExecError) Int runStep = do state@ComputeState{sLimits,sFacts,sRules,sBlockCount,sIterations} <- get let Limits{maxFacts, maxIterations} = sLimits previousCount = countFacts sFacts generatedFacts :: Either PureExecError FactGroup generatedFacts = first BadExpression $ extend sLimits sBlockCount sRules sFacts newFacts <- (sFacts <>) <$> lift generatedFacts let newCount = countFacts newFacts -- counting the facts returned by `extend` is not equivalent to -- comparing complete counts, as `extend` may return facts that -- are already present in `sFacts` addedFactsCount = newCount - previousCount when (newCount >= maxFacts) $ lift $ Left Facts when (sIterations >= maxIterations) $ lift $ Left Iterations put $ state { sIterations = sIterations + 1 , sFacts = newFacts } pure addedFactsCount -- | Check if every variable from the head is present in the body checkRuleHead :: EvalRule -> Bool checkRuleHead Rule{rhead, body} = let headVars = extractVariables [rhead] bodyVars = extractVariables body in headVars `Set.isSubsetOf` bodyVars -- | Repeatedly generate new facts until it converges (no new -- facts are generated) computeAllFacts :: ComputeState -> Either PureExecError FactGroup computeAllFacts initState@ComputeState{sRules} = do let checkRules = all (all checkRuleHead) sRules go = do newFacts <- runStep if newFacts > 0 then go else gets sFacts unless checkRules $ Left BadRule evalStateT go initState -- | Small helper used in tests to directly provide rules and facts without creating -- a biscuit token runFactGeneration :: Limits -> Natural -> Map Natural (Set EvalRule) -> FactGroup -> Either PureExecError FactGroup runFactGeneration sLimits sBlockCount sRules sFacts = let initState = ComputeState{sIterations = 0, ..} in computeAllFacts initState checkChecks :: Limits -> Natural -> FactGroup -> [(Natural, [EvalCheck])] -> Either String (Validation (NonEmpty Check) ()) checkChecks limits blockCount allFacts = fmap sequenceA_ . traverse (uncurry $ checkChecksForGroup limits blockCount allFacts) checkChecksForGroup :: Limits -> Natural -> FactGroup -> Natural -> [EvalCheck] -> Either String (Validation (NonEmpty Check) ()) checkChecksForGroup limits blockCount allFacts checksBlockId = fmap sequenceA_ . traverse (checkCheck limits blockCount checksBlockId allFacts) checkPolicies :: Limits -> Natural -> FactGroup -> [EvalPolicy] -> Either String (Either (Maybe MatchedQuery) MatchedQuery) checkPolicies limits blockCount allFacts policies = do results <- mapMaybeM (checkPolicy limits blockCount allFacts) policies pure $ case results of p : _ -> first Just p [] -> Left Nothing -- | Generate new facts by applying rules on existing facts extend :: Limits -> Natural -> Map Natural (Set EvalRule) -> FactGroup -> Either String FactGroup extend l blockCount rules facts = let buildFacts :: Natural -> Set EvalRule -> FactGroup -> Either String (Set (Scoped Fact)) buildFacts ruleBlockId ruleGroup factGroup = let extendRule :: EvalRule -> Either String (Set (Scoped Fact)) extendRule r@Rule{scope} = getFactsForRule l (toScopedFacts $ keepAuthorized' False blockCount factGroup scope ruleBlockId) r in foldMapM extendRule ruleGroup extendRuleGroup :: Natural -> Set EvalRule -> Either String FactGroup extendRuleGroup ruleBlockId ruleGroup = -- todo pre-filter facts based on the weakest rule scope to avoid passing too many facts -- to buildFacts let authorizedFacts = facts -- test $ keepAuthorized facts $ Set.fromList [0..ruleBlockId] addRuleOrigin = FactGroup . Map.mapKeysWith (<>) (Set.insert ruleBlockId) . getFactGroup in addRuleOrigin . fromScopedFacts <$> buildFacts ruleBlockId ruleGroup authorizedFacts in foldMapM (uncurry extendRuleGroup) $ Map.toList rules collectWorld :: Natural -> EvalBlock -> (Map Natural (Set EvalRule), FactGroup) collectWorld blockId Block{..} = let -- a block can define a default scope for its rule -- which is used unless the rule itself has defined a scope applyScope r@Rule{scope} = r { scope = if null scope then bScope else scope } in ( Map.singleton blockId $ Set.map applyScope $ Set.fromList bRules , FactGroup $ Map.singleton (Set.singleton blockId) $ Set.fromList bFacts ) queryGeneratedFacts :: [Maybe PublicKey] -> AuthorizationSuccess -> Query -> Either String (Set Bindings) queryGeneratedFacts ePks AuthorizationSuccess{allFacts, limits} = queryAvailableFacts ePks allFacts limits queryAvailableFacts :: [Maybe PublicKey] -> FactGroup -> Limits -> Query -> Either String (Set Bindings) queryAvailableFacts ePks allFacts limits q = let blockCount = genericLength ePks getBindingsForQueryItem QueryItem{qBody,qExpressions,qScope} = let facts = toScopedFacts $ keepAuthorized' True blockCount allFacts qScope blockCount in Set.map snd <$> getBindingsForRuleBody limits facts qBody qExpressions in foldMapM (getBindingsForQueryItem . toEvaluation ePks) q -- | Extract a set of values from a matched variable for a specific type. -- Returning @Set Value@ allows to get all values, whatever their type. getVariableValues :: (Ord t, FromValue t) => Set Bindings -> Text -> Set t getVariableValues bindings variableName = let mapMaybeS f = foldMap (foldMap Set.singleton . f) getVar vars = fromValue =<< vars !? variableName in mapMaybeS getVar bindings -- | Extract exactly one value from a matched variable. If the variable has 0 -- matches or more than one match, 'Nothing' will be returned getSingleVariableValue :: (Ord t, FromValue t) => Set Bindings -> Text -> Maybe t getSingleVariableValue bindings variableName = let values = getVariableValues bindings variableName in case Set.toList values of [v] -> Just v _ -> Nothing