{-# LANGUAGE DerivingStrategies         #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedLists            #-}
{-# LANGUAGE OverloadedStrings          #-}
module Auth.Biscuit.Symbols
  ( Symbols
  , BlockSymbols
  , ReverseSymbols
  , SymbolRef (..)
  , getSymbol
  , addSymbols
  , addFromBlock
  , addFromBlocks
  , reverseSymbols
  , getSymbolList
  , getSymbolCode
  , newSymbolTable
  ) where

import           Control.Monad      (join)
import           Data.Int           (Int64)
import           Data.Map           (Map, elems, (!?))
import qualified Data.Map           as Map
import           Data.Set           (Set, difference, union)
import qualified Data.Set           as Set
import           Data.Text          (Text)

import           Auth.Biscuit.Utils (maybeToRight)

newtype SymbolRef = SymbolRef { SymbolRef -> Int64
getSymbolRef :: Int64 }
  deriving stock (SymbolRef -> SymbolRef -> Bool
(SymbolRef -> SymbolRef -> Bool)
-> (SymbolRef -> SymbolRef -> Bool) -> Eq SymbolRef
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SymbolRef -> SymbolRef -> Bool
$c/= :: SymbolRef -> SymbolRef -> Bool
== :: SymbolRef -> SymbolRef -> Bool
$c== :: SymbolRef -> SymbolRef -> Bool
Eq)

instance Show SymbolRef where
  show :: SymbolRef -> String
show = (String
"#" String -> ShowS
forall a. Semigroup a => a -> a -> a
<>) ShowS -> (SymbolRef -> String) -> SymbolRef -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> String
forall a. Show a => a -> String
show (Int64 -> String) -> (SymbolRef -> Int64) -> SymbolRef -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolRef -> Int64
getSymbolRef

newtype Symbols = Symbols { Symbols -> Map Int64 Text
getSymbols :: Map Int64 Text }
  deriving stock (Symbols -> Symbols -> Bool
(Symbols -> Symbols -> Bool)
-> (Symbols -> Symbols -> Bool) -> Eq Symbols
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Symbols -> Symbols -> Bool
$c/= :: Symbols -> Symbols -> Bool
== :: Symbols -> Symbols -> Bool
$c== :: Symbols -> Symbols -> Bool
Eq, Int -> Symbols -> ShowS
[Symbols] -> ShowS
Symbols -> String
(Int -> Symbols -> ShowS)
-> (Symbols -> String) -> ([Symbols] -> ShowS) -> Show Symbols
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Symbols] -> ShowS
$cshowList :: [Symbols] -> ShowS
show :: Symbols -> String
$cshow :: Symbols -> String
showsPrec :: Int -> Symbols -> ShowS
$cshowsPrec :: Int -> Symbols -> ShowS
Show)

newtype BlockSymbols = BlockSymbols { BlockSymbols -> Map Int64 Text
getBlockSymbols :: Map Int64 Text }
  deriving stock (BlockSymbols -> BlockSymbols -> Bool
(BlockSymbols -> BlockSymbols -> Bool)
-> (BlockSymbols -> BlockSymbols -> Bool) -> Eq BlockSymbols
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BlockSymbols -> BlockSymbols -> Bool
$c/= :: BlockSymbols -> BlockSymbols -> Bool
== :: BlockSymbols -> BlockSymbols -> Bool
$c== :: BlockSymbols -> BlockSymbols -> Bool
Eq, Int -> BlockSymbols -> ShowS
[BlockSymbols] -> ShowS
BlockSymbols -> String
(Int -> BlockSymbols -> ShowS)
-> (BlockSymbols -> String)
-> ([BlockSymbols] -> ShowS)
-> Show BlockSymbols
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BlockSymbols] -> ShowS
$cshowList :: [BlockSymbols] -> ShowS
show :: BlockSymbols -> String
$cshow :: BlockSymbols -> String
showsPrec :: Int -> BlockSymbols -> ShowS
$cshowsPrec :: Int -> BlockSymbols -> ShowS
Show)
  deriving newtype (b -> BlockSymbols -> BlockSymbols
NonEmpty BlockSymbols -> BlockSymbols
BlockSymbols -> BlockSymbols -> BlockSymbols
(BlockSymbols -> BlockSymbols -> BlockSymbols)
-> (NonEmpty BlockSymbols -> BlockSymbols)
-> (forall b. Integral b => b -> BlockSymbols -> BlockSymbols)
-> Semigroup BlockSymbols
forall b. Integral b => b -> BlockSymbols -> BlockSymbols
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: b -> BlockSymbols -> BlockSymbols
$cstimes :: forall b. Integral b => b -> BlockSymbols -> BlockSymbols
sconcat :: NonEmpty BlockSymbols -> BlockSymbols
$csconcat :: NonEmpty BlockSymbols -> BlockSymbols
<> :: BlockSymbols -> BlockSymbols -> BlockSymbols
$c<> :: BlockSymbols -> BlockSymbols -> BlockSymbols
Semigroup)

newtype ReverseSymbols = ReverseSymbols { ReverseSymbols -> Map Text Int64
getReverseSymbols :: Map Text Int64 }
  deriving stock (ReverseSymbols -> ReverseSymbols -> Bool
(ReverseSymbols -> ReverseSymbols -> Bool)
-> (ReverseSymbols -> ReverseSymbols -> Bool) -> Eq ReverseSymbols
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReverseSymbols -> ReverseSymbols -> Bool
$c/= :: ReverseSymbols -> ReverseSymbols -> Bool
== :: ReverseSymbols -> ReverseSymbols -> Bool
$c== :: ReverseSymbols -> ReverseSymbols -> Bool
Eq, Int -> ReverseSymbols -> ShowS
[ReverseSymbols] -> ShowS
ReverseSymbols -> String
(Int -> ReverseSymbols -> ShowS)
-> (ReverseSymbols -> String)
-> ([ReverseSymbols] -> ShowS)
-> Show ReverseSymbols
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReverseSymbols] -> ShowS
$cshowList :: [ReverseSymbols] -> ShowS
show :: ReverseSymbols -> String
$cshow :: ReverseSymbols -> String
showsPrec :: Int -> ReverseSymbols -> ShowS
$cshowsPrec :: Int -> ReverseSymbols -> ShowS
Show)
  deriving newtype (b -> ReverseSymbols -> ReverseSymbols
NonEmpty ReverseSymbols -> ReverseSymbols
ReverseSymbols -> ReverseSymbols -> ReverseSymbols
(ReverseSymbols -> ReverseSymbols -> ReverseSymbols)
-> (NonEmpty ReverseSymbols -> ReverseSymbols)
-> (forall b. Integral b => b -> ReverseSymbols -> ReverseSymbols)
-> Semigroup ReverseSymbols
forall b. Integral b => b -> ReverseSymbols -> ReverseSymbols
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: b -> ReverseSymbols -> ReverseSymbols
$cstimes :: forall b. Integral b => b -> ReverseSymbols -> ReverseSymbols
sconcat :: NonEmpty ReverseSymbols -> ReverseSymbols
$csconcat :: NonEmpty ReverseSymbols -> ReverseSymbols
<> :: ReverseSymbols -> ReverseSymbols -> ReverseSymbols
$c<> :: ReverseSymbols -> ReverseSymbols -> ReverseSymbols
Semigroup)

getSymbol :: Symbols -> SymbolRef -> Either String Text
getSymbol :: Symbols -> SymbolRef -> Either String Text
getSymbol (Symbols Map Int64 Text
m) (SymbolRef Int64
i) =
  String -> Maybe Text -> Either String Text
forall b a. b -> Maybe a -> Either b a
maybeToRight (String
"Missing symbol at id #" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int64 -> String
forall a. Show a => a -> String
show Int64
i) (Maybe Text -> Either String Text)
-> Maybe Text -> Either String Text
forall a b. (a -> b) -> a -> b
$ Map Int64 Text
m Map Int64 Text -> Int64 -> Maybe Text
forall k a. Ord k => Map k a -> k -> Maybe a
!? Int64
i

-- | Given already existing symbols and a set of symbols used in a block,
-- compute the symbol table carried by this specific block
addSymbols :: Symbols -> Set Text -> BlockSymbols
addSymbols :: Symbols -> Set Text -> BlockSymbols
addSymbols (Symbols Map Int64 Text
m) Set Text
symbols =
  let existingSymbols :: Set Text
existingSymbols = [Text] -> Set Text
forall a. Ord a => [a] -> Set a
Set.fromList (Map Int64 Text -> [Text]
forall k a. Map k a -> [a]
elems Map Int64 Text
commonSymbols) Set Text -> Set Text -> Set Text
forall a. Ord a => Set a -> Set a -> Set a
`union` [Text] -> Set Text
forall a. Ord a => [a] -> Set a
Set.fromList (Map Int64 Text -> [Text]
forall k a. Map k a -> [a]
elems Map Int64 Text
m)
      newSymbols :: [Text]
newSymbols = Set Text -> [Text]
forall a. Set a -> [a]
Set.toList (Set Text -> [Text]) -> Set Text -> [Text]
forall a b. (a -> b) -> a -> b
$ Set Text
symbols Set Text -> Set Text -> Set Text
forall a. Ord a => Set a -> Set a -> Set a
`difference` Set Text
existingSymbols
      starting :: Int64
starting = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> Int -> Int64
forall a b. (a -> b) -> a -> b
$ Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Map Int64 Text -> Int
forall k a. Map k a -> Int
Map.size Map Int64 Text
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Map Int64 Text -> Int
forall k a. Map k a -> Int
Map.size Map Int64 Text
commonSymbols)
   in Map Int64 Text -> BlockSymbols
BlockSymbols (Map Int64 Text -> BlockSymbols) -> Map Int64 Text -> BlockSymbols
forall a b. (a -> b) -> a -> b
$ [(Int64, Text)] -> Map Int64 Text
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([Int64] -> [Text] -> [(Int64, Text)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
Item [Int64]
starting..] [Text]
newSymbols)

getSymbolList :: BlockSymbols -> [Text]
getSymbolList :: BlockSymbols -> [Text]
getSymbolList (BlockSymbols Map Int64 Text
m) = Map Int64 Text -> [Text]
forall k a. Map k a -> [a]
Map.elems Map Int64 Text
m

newSymbolTable :: Symbols
newSymbolTable :: Symbols
newSymbolTable = Map Int64 Text -> Symbols
Symbols Map Int64 Text
commonSymbols

-- | Given the symbol table of a protobuf block, update the provided symbol table
addFromBlock :: Symbols -> BlockSymbols -> Symbols
addFromBlock :: Symbols -> BlockSymbols -> Symbols
addFromBlock (Symbols Map Int64 Text
m) (BlockSymbols Map Int64 Text
bm) =
   Map Int64 Text -> Symbols
Symbols (Map Int64 Text -> Symbols) -> Map Int64 Text -> Symbols
forall a b. (a -> b) -> a -> b
$ Map Int64 Text
m Map Int64 Text -> Map Int64 Text -> Map Int64 Text
forall a. Semigroup a => a -> a -> a
<> Map Int64 Text
bm

-- | Compute a global symbol table from a series of block symbol tables
addFromBlocks :: [[Text]] -> Symbols
addFromBlocks :: [[Text]] -> Symbols
addFromBlocks [[Text]]
blocksTables =
  let allSymbols :: [Text]
allSymbols = [[Text]] -> [Text]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Text]]
blocksTables
   in Map Int64 Text -> Symbols
Symbols (Map Int64 Text -> Symbols) -> Map Int64 Text -> Symbols
forall a b. (a -> b) -> a -> b
$ Map Int64 Text
commonSymbols Map Int64 Text -> Map Int64 Text -> Map Int64 Text
forall a. Semigroup a => a -> a -> a
<> [(Int64, Text)] -> Map Int64 Text
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([Int64] -> [Text] -> [(Int64, Text)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Item [Int64]
1024..] [Text]
allSymbols)

-- | Reverse a symbol table
reverseSymbols :: Symbols -> ReverseSymbols
reverseSymbols :: Symbols -> ReverseSymbols
reverseSymbols =
  let swap :: (b, a) -> (a, b)
swap (b
a,a
b) = (a
b,b
a)
   in Map Text Int64 -> ReverseSymbols
ReverseSymbols (Map Text Int64 -> ReverseSymbols)
-> (Symbols -> Map Text Int64) -> Symbols -> ReverseSymbols
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Text, Int64)] -> Map Text Int64
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Text, Int64)] -> Map Text Int64)
-> (Symbols -> [(Text, Int64)]) -> Symbols -> Map Text Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int64, Text) -> (Text, Int64))
-> [(Int64, Text)] -> [(Text, Int64)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int64, Text) -> (Text, Int64)
forall b a. (b, a) -> (a, b)
swap ([(Int64, Text)] -> [(Text, Int64)])
-> (Symbols -> [(Int64, Text)]) -> Symbols -> [(Text, Int64)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Int64 Text -> [(Int64, Text)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map Int64 Text -> [(Int64, Text)])
-> (Symbols -> Map Int64 Text) -> Symbols -> [(Int64, Text)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symbols -> Map Int64 Text
getSymbols

-- | Given a reverse symbol table (symbol refs indexed by their textual
-- representation), turn textual representations into symbol refs.
-- This function is partial, the reverse table is guaranteed to
-- contain the expected textual symbols.
getSymbolCode :: ReverseSymbols -> Text -> SymbolRef
getSymbolCode :: ReverseSymbols -> Text -> SymbolRef
getSymbolCode (ReverseSymbols Map Text Int64
rm) Text
t = Int64 -> SymbolRef
SymbolRef (Int64 -> SymbolRef) -> Int64 -> SymbolRef
forall a b. (a -> b) -> a -> b
$ Map Text Int64
rm Map Text Int64 -> Text -> Int64
forall k a. Ord k => Map k a -> k -> a
Map.! Text
t

-- | The common symbols defined in the biscuit spec
commonSymbols :: Map Int64 Text
commonSymbols :: Map Int64 Text
commonSymbols = [(Int64, Text)] -> Map Int64 Text
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Int64, Text)] -> Map Int64 Text)
-> [(Int64, Text)] -> Map Int64 Text
forall a b. (a -> b) -> a -> b
$ [Int64] -> [Text] -> [(Int64, Text)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Item [Int64]
0..]
  [ Item [Text]
"read"
  , Item [Text]
"write"
  , Item [Text]
"resource"
  , Item [Text]
"operation"
  , Item [Text]
"right"
  , Item [Text]
"time"
  , Item [Text]
"role"
  , Item [Text]
"owner"
  , Item [Text]
"tenant"
  , Item [Text]
"namespace"
  , Item [Text]
"user"
  , Item [Text]
"team"
  , Item [Text]
"service"
  , Item [Text]
"admin"
  , Item [Text]
"email"
  , Item [Text]
"group"
  , Item [Text]
"member"
  , Item [Text]
"ip_address"
  , Item [Text]
"client"
  , Item [Text]
"client_ip"
  , Item [Text]
"domain"
  , Item [Text]
"path"
  , Item [Text]
"version"
  , Item [Text]
"cluster"
  , Item [Text]
"node"
  , Item [Text]
"hostname"
  , Item [Text]
"nonce"
  , Item [Text]
"query"
  ]