-- | Maintainer: Arnaud Bailly <arnaud.oqube@gmail.com>
--
-- Properties for configuring firewall (iptables) rules

module Propellor.Property.Firewall (
	rule,
	installed,
	Chain(..),
	Table(..),
	Target(..),
	Proto(..),
	Rules(..),
	ConnectionState(..),
	ICMPTypeMatch(..),
	TCPFlag(..),
	Frequency(..),
	IPWithMask(..),
) where

import qualified Data.Semigroup as Sem
import Data.Char
import Data.List

import Propellor.Base
import qualified Propellor.Property.Apt as Apt
import qualified Propellor.Property.Network as Network

installed :: Property DebianLike
installed :: Property DebianLike
installed = [String] -> Property DebianLike
Apt.installed [String
"iptables"]

rule :: Chain -> Table -> Target -> Rules -> Property Linux
rule :: Chain -> Table -> Target -> Rules -> Property Linux
rule Chain
c Table
tb Target
tg Rules
rs = forall {k} (metatypes :: k).
SingI metatypes =>
String -> Propellor Result -> Property (MetaTypes metatypes)
property (String
"firewall rule: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Rule
r) Propellor Result
addIpTable
  where
	r :: Rule
r = Chain -> Table -> Target -> Rules -> Rule
Rule Chain
c Table
tb Target
tg Rules
rs
	addIpTable :: Propellor Result
addIpTable = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
		let args :: [CommandParam]
args = Rule -> [CommandParam]
toIpTable Rule
r
		Bool
exist <- String -> [CommandParam] -> IO Bool
boolSystem String
"iptables" ([CommandParam] -> [CommandParam]
chk [CommandParam]
args)
		if Bool
exist
			then forall (m :: * -> *) a. Monad m => a -> m a
return Result
NoChange
			else forall t. ToResult t => t -> Result
toResult forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> [CommandParam] -> IO Bool
boolSystem String
"iptables" ([CommandParam] -> [CommandParam]
add [CommandParam]
args)
	add :: [CommandParam] -> [CommandParam]
add [CommandParam]
params = String -> CommandParam
Param String
"-A" forall a. a -> [a] -> [a]
: [CommandParam]
params
	chk :: [CommandParam] -> [CommandParam]
chk [CommandParam]
params = String -> CommandParam
Param String
"-C" forall a. a -> [a] -> [a]
: [CommandParam]
params

toIpTable :: Rule -> [CommandParam]
toIpTable :: Rule -> [CommandParam]
toIpTable Rule
r =  forall a b. (a -> b) -> [a] -> [b]
map String -> CommandParam
Param forall a b. (a -> b) -> a -> b
$
	forall t. ConfigurableValue t => t -> String
val (Rule -> Chain
ruleChain Rule
r) forall a. a -> [a] -> [a]
:
	[String
"-t", forall t. ConfigurableValue t => t -> String
val (Rule -> Table
ruleTable Rule
r), String
"-j", forall t. ConfigurableValue t => t -> String
val (Rule -> Target
ruleTarget Rule
r)] forall a. [a] -> [a] -> [a]
++
	Rules -> [String]
toIpTableArg (Rule -> Rules
ruleRules Rule
r)

toIpTableArg :: Rules -> [String]
toIpTableArg :: Rules -> [String]
toIpTableArg Rules
Everything = []
toIpTableArg (Proto Proto
proto) = [String
"-p", forall a b. (a -> b) -> [a] -> [b]
map Char -> Char
toLower forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Proto
proto]
toIpTableArg (DPort Port
port) = [String
"--dport", forall t. ConfigurableValue t => t -> String
val Port
port]
toIpTableArg (DPortRange (Port
portf, Port
portt)) =
	[String
"--dport", forall t. ConfigurableValue t => t -> String
val Port
portf forall a. [a] -> [a] -> [a]
++ String
":" forall a. [a] -> [a] -> [a]
++ forall t. ConfigurableValue t => t -> String
val Port
portt]
toIpTableArg (InIFace String
iface) = [String
"-i", String
iface]
toIpTableArg (OutIFace String
iface) = [String
"-o", String
iface]
toIpTableArg (Ctstate [ConnectionState]
states) =
	[ String
"-m"
	, String
"conntrack"
	, String
"--ctstate", forall a. [a] -> [[a]] -> [a]
intercalate String
"," (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show [ConnectionState]
states)
	]
toIpTableArg (ICMPType ICMPTypeMatch
i) =
	[ String
"-m"
	, String
"icmp"
	, String
"--icmp-type", forall t. ConfigurableValue t => t -> String
val ICMPTypeMatch
i
	]
toIpTableArg (RateLimit Frequency
f) =
	[ String
"-m"
	, String
"limit"
	, String
"--limit", forall t. ConfigurableValue t => t -> String
val Frequency
f
	]
toIpTableArg (TCPFlags TCPFlagMask
m TCPFlagMask
c) =
	[ String
"-m"
	, String
"tcp"
	, String
"--tcp-flags"
	, forall a. [a] -> [[a]] -> [a]
intercalate String
"," (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show TCPFlagMask
m)
	, forall a. [a] -> [[a]] -> [a]
intercalate String
"," (forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
show TCPFlagMask
c)
	]
toIpTableArg Rules
TCPSyn = [String
"--syn"]
toIpTableArg (GroupOwner (Group String
g)) =
	[ String
"-m"
	, String
"owner"
	, String
"--gid-owner"
	, String
g
	]
toIpTableArg (Source [IPWithMask]
ipwm) =
	[ String
"-s"
	, forall a. [a] -> [[a]] -> [a]
intercalate String
"," (forall a b. (a -> b) -> [a] -> [b]
map forall t. ConfigurableValue t => t -> String
val [IPWithMask]
ipwm)
	]
toIpTableArg (Destination [IPWithMask]
ipwm) =
	[ String
"-d"
	, forall a. [a] -> [[a]] -> [a]
intercalate String
"," (forall a b. (a -> b) -> [a] -> [b]
map forall t. ConfigurableValue t => t -> String
val [IPWithMask]
ipwm)
	]
toIpTableArg (NotDestination [IPWithMask]
ipwm) =
	[ String
"!"
	, String
"-d"
	, forall a. [a] -> [[a]] -> [a]
intercalate String
"," (forall a b. (a -> b) -> [a] -> [b]
map forall t. ConfigurableValue t => t -> String
val [IPWithMask]
ipwm)
	]
toIpTableArg (NatDestination IPAddr
ip Maybe Port
mport) =
	[ String
"--to-destination"
	, forall t. ConfigurableValue t => t -> String
val IPAddr
ip forall a. [a] -> [a] -> [a]
++ forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (\Port
p -> String
":" forall a. [a] -> [a] -> [a]
++ forall t. ConfigurableValue t => t -> String
val Port
p) Maybe Port
mport
	]
toIpTableArg (Rules
r :- Rules
r') = Rules -> [String]
toIpTableArg Rules
r forall a. Semigroup a => a -> a -> a
<> Rules -> [String]
toIpTableArg Rules
r'

data IPWithMask = IPWithNoMask IPAddr | IPWithIPMask IPAddr IPAddr | IPWithNumMask IPAddr Int
	deriving (IPWithMask -> IPWithMask -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IPWithMask -> IPWithMask -> Bool
$c/= :: IPWithMask -> IPWithMask -> Bool
== :: IPWithMask -> IPWithMask -> Bool
$c== :: IPWithMask -> IPWithMask -> Bool
Eq, Int -> IPWithMask -> ShowS
[IPWithMask] -> ShowS
IPWithMask -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IPWithMask] -> ShowS
$cshowList :: [IPWithMask] -> ShowS
show :: IPWithMask -> String
$cshow :: IPWithMask -> String
showsPrec :: Int -> IPWithMask -> ShowS
$cshowsPrec :: Int -> IPWithMask -> ShowS
Show)

instance ConfigurableValue IPWithMask where
	val :: IPWithMask -> String
val (IPWithNoMask IPAddr
ip) = forall t. ConfigurableValue t => t -> String
val IPAddr
ip
	val (IPWithIPMask IPAddr
ip IPAddr
ipm) = forall t. ConfigurableValue t => t -> String
val IPAddr
ip forall a. [a] -> [a] -> [a]
++ String
"/" forall a. [a] -> [a] -> [a]
++ forall t. ConfigurableValue t => t -> String
val IPAddr
ipm
	val (IPWithNumMask IPAddr
ip Int
m) = forall t. ConfigurableValue t => t -> String
val IPAddr
ip forall a. [a] -> [a] -> [a]
++ String
"/" forall a. [a] -> [a] -> [a]
++ forall t. ConfigurableValue t => t -> String
val Int
m

data Rule = Rule
	{ Rule -> Chain
ruleChain  :: Chain
	, Rule -> Table
ruleTable  :: Table
	, Rule -> Target
ruleTarget :: Target
	, Rule -> Rules
ruleRules  :: Rules
	} deriving (Rule -> Rule -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Rule -> Rule -> Bool
$c/= :: Rule -> Rule -> Bool
== :: Rule -> Rule -> Bool
$c== :: Rule -> Rule -> Bool
Eq, Int -> Rule -> ShowS
[Rule] -> ShowS
Rule -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Rule] -> ShowS
$cshowList :: [Rule] -> ShowS
show :: Rule -> String
$cshow :: Rule -> String
showsPrec :: Int -> Rule -> ShowS
$cshowsPrec :: Int -> Rule -> ShowS
Show)

data Table = Filter | Nat | Mangle | Raw | Security
	deriving (Table -> Table -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Table -> Table -> Bool
$c/= :: Table -> Table -> Bool
== :: Table -> Table -> Bool
$c== :: Table -> Table -> Bool
Eq, Int -> Table -> ShowS
[Table] -> ShowS
Table -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Table] -> ShowS
$cshowList :: [Table] -> ShowS
show :: Table -> String
$cshow :: Table -> String
showsPrec :: Int -> Table -> ShowS
$cshowsPrec :: Int -> Table -> ShowS
Show)

instance ConfigurableValue Table where
	val :: Table -> String
val Table
Filter = String
"filter"
	val Table
Nat = String
"nat"
	val Table
Mangle = String
"mangle"
	val Table
Raw = String
"raw"
	val Table
Security = String
"security"

data Target = ACCEPT | REJECT | DROP | LOG | TargetCustom String
	deriving (Target -> Target -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Target -> Target -> Bool
$c/= :: Target -> Target -> Bool
== :: Target -> Target -> Bool
$c== :: Target -> Target -> Bool
Eq, Int -> Target -> ShowS
[Target] -> ShowS
Target -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Target] -> ShowS
$cshowList :: [Target] -> ShowS
show :: Target -> String
$cshow :: Target -> String
showsPrec :: Int -> Target -> ShowS
$cshowsPrec :: Int -> Target -> ShowS
Show)

instance ConfigurableValue Target where
	val :: Target -> String
val Target
ACCEPT = String
"ACCEPT"
	val Target
REJECT = String
"REJECT"
	val Target
DROP = String
"DROP"
	val Target
LOG = String
"LOG"
	val (TargetCustom String
t) = String
t

data Chain = INPUT | OUTPUT | FORWARD | PREROUTING | POSTROUTING | ChainCustom String
	deriving (Chain -> Chain -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Chain -> Chain -> Bool
$c/= :: Chain -> Chain -> Bool
== :: Chain -> Chain -> Bool
$c== :: Chain -> Chain -> Bool
Eq, Int -> Chain -> ShowS
[Chain] -> ShowS
Chain -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Chain] -> ShowS
$cshowList :: [Chain] -> ShowS
show :: Chain -> String
$cshow :: Chain -> String
showsPrec :: Int -> Chain -> ShowS
$cshowsPrec :: Int -> Chain -> ShowS
Show)

instance ConfigurableValue Chain where
	val :: Chain -> String
val Chain
INPUT = String
"INPUT"
	val Chain
OUTPUT = String
"OUTPUT"
	val Chain
FORWARD = String
"FORWARD"
	val Chain
PREROUTING = String
"PREROUTING"
	val Chain
POSTROUTING = String
"POSTROUTING"
	val (ChainCustom String
c) = String
c

data Proto = TCP | UDP | ICMP
	deriving (Proto -> Proto -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Proto -> Proto -> Bool
$c/= :: Proto -> Proto -> Bool
== :: Proto -> Proto -> Bool
$c== :: Proto -> Proto -> Bool
Eq, Int -> Proto -> ShowS
[Proto] -> ShowS
Proto -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Proto] -> ShowS
$cshowList :: [Proto] -> ShowS
show :: Proto -> String
$cshow :: Proto -> String
showsPrec :: Int -> Proto -> ShowS
$cshowsPrec :: Int -> Proto -> ShowS
Show)

data ConnectionState = ESTABLISHED | RELATED | NEW | INVALID
	deriving (ConnectionState -> ConnectionState -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionState -> ConnectionState -> Bool
$c/= :: ConnectionState -> ConnectionState -> Bool
== :: ConnectionState -> ConnectionState -> Bool
$c== :: ConnectionState -> ConnectionState -> Bool
Eq, Int -> ConnectionState -> ShowS
[ConnectionState] -> ShowS
ConnectionState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionState] -> ShowS
$cshowList :: [ConnectionState] -> ShowS
show :: ConnectionState -> String
$cshow :: ConnectionState -> String
showsPrec :: Int -> ConnectionState -> ShowS
$cshowsPrec :: Int -> ConnectionState -> ShowS
Show)

data ICMPTypeMatch = ICMPTypeName String | ICMPTypeCode Int
	deriving (ICMPTypeMatch -> ICMPTypeMatch -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ICMPTypeMatch -> ICMPTypeMatch -> Bool
$c/= :: ICMPTypeMatch -> ICMPTypeMatch -> Bool
== :: ICMPTypeMatch -> ICMPTypeMatch -> Bool
$c== :: ICMPTypeMatch -> ICMPTypeMatch -> Bool
Eq, Int -> ICMPTypeMatch -> ShowS
[ICMPTypeMatch] -> ShowS
ICMPTypeMatch -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ICMPTypeMatch] -> ShowS
$cshowList :: [ICMPTypeMatch] -> ShowS
show :: ICMPTypeMatch -> String
$cshow :: ICMPTypeMatch -> String
showsPrec :: Int -> ICMPTypeMatch -> ShowS
$cshowsPrec :: Int -> ICMPTypeMatch -> ShowS
Show)

instance ConfigurableValue ICMPTypeMatch where
	val :: ICMPTypeMatch -> String
val (ICMPTypeName String
t) = String
t
	val (ICMPTypeCode Int
c) = forall t. ConfigurableValue t => t -> String
val Int
c

data Frequency = NumBySecond Int
	deriving (Frequency -> Frequency -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Frequency -> Frequency -> Bool
$c/= :: Frequency -> Frequency -> Bool
== :: Frequency -> Frequency -> Bool
$c== :: Frequency -> Frequency -> Bool
Eq, Int -> Frequency -> ShowS
[Frequency] -> ShowS
Frequency -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Frequency] -> ShowS
$cshowList :: [Frequency] -> ShowS
show :: Frequency -> String
$cshow :: Frequency -> String
showsPrec :: Int -> Frequency -> ShowS
$cshowsPrec :: Int -> Frequency -> ShowS
Show)

instance ConfigurableValue Frequency where
	val :: Frequency -> String
val (NumBySecond Int
n) = forall t. ConfigurableValue t => t -> String
val Int
n forall a. [a] -> [a] -> [a]
++ String
"/second"

type TCPFlagMask = [TCPFlag]

type TCPFlagComp = [TCPFlag]

data TCPFlag = SYN | ACK | FIN | RST | URG | PSH | ALL | NONE
	deriving (TCPFlag -> TCPFlag -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TCPFlag -> TCPFlag -> Bool
$c/= :: TCPFlag -> TCPFlag -> Bool
== :: TCPFlag -> TCPFlag -> Bool
$c== :: TCPFlag -> TCPFlag -> Bool
Eq, Int -> TCPFlag -> ShowS
TCPFlagMask -> ShowS
TCPFlag -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: TCPFlagMask -> ShowS
$cshowList :: TCPFlagMask -> ShowS
show :: TCPFlag -> String
$cshow :: TCPFlag -> String
showsPrec :: Int -> TCPFlag -> ShowS
$cshowsPrec :: Int -> TCPFlag -> ShowS
Show)

data Rules
	= Everything
	| Proto Proto
	-- ^There is actually some order dependency between proto and port so this should be a specific
	-- data type with proto + ports
	| DPort Port
	| DPortRange (Port, Port)
	| InIFace Network.Interface
	| OutIFace Network.Interface
	| Ctstate [ ConnectionState ]
	| ICMPType ICMPTypeMatch
	| RateLimit Frequency
	| TCPFlags TCPFlagMask TCPFlagComp
	| TCPSyn
	| GroupOwner Group
	| Source [ IPWithMask ]
	| Destination [ IPWithMask ]
	| NotDestination [ IPWithMask ]
	| NatDestination IPAddr (Maybe Port)
	| Rules :- Rules   -- ^Combine two rules
	deriving (Rules -> Rules -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Rules -> Rules -> Bool
$c/= :: Rules -> Rules -> Bool
== :: Rules -> Rules -> Bool
$c== :: Rules -> Rules -> Bool
Eq, Int -> Rules -> ShowS
[Rules] -> ShowS
Rules -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Rules] -> ShowS
$cshowList :: [Rules] -> ShowS
show :: Rules -> String
$cshow :: Rules -> String
showsPrec :: Int -> Rules -> ShowS
$cshowsPrec :: Int -> Rules -> ShowS
Show)

infixl 0 :-

instance Sem.Semigroup Rules where
	<> :: Rules -> Rules -> Rules
(<>) = Rules -> Rules -> Rules
(:-)

instance Monoid Rules where
	mempty :: Rules
mempty  = Rules
Everything
	mappend :: Rules -> Rules -> Rules
mappend = forall a. Semigroup a => a -> a -> a
(Sem.<>)