{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE Safe #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- | -- Module : Data.Tree.AVL.Invariants -- Description : Type level AVL invariants -- Copyright : (c) Nicolás Rodríguez, 2021 -- License : GPL-3 -- Maintainer : Nicolás Rodríguez -- Stability : experimental -- Portability : POSIX -- -- Type level restrictions for the key ordering in type safe AVL trees. module Data.Tree.AVL.Invariants ( BS (Balanced, LeftHeavy, RightHeavy), BalancedState, Height, BalancedHeights, US (LeftUnbalanced, NotUnbalanced, RightUnbalanced), UnbalancedState, ) where import Data.Tree.ITree (Tree (EmptyTree, ForkTree)) import Data.Tree.Node (Node) import Data.Type.Bool (If) import GHC.TypeLits (ErrorMessage (ShowType, Text, (:<>:)), TypeError) import GHC.TypeNats (Nat, type (+), type (-), type (<=?)) import Prelude (Bool (True)) -- | Get the maximum between two type level natural numbers. type family Max (n1 :: Nat) (n2 :: Nat) :: Nat where Max n1 n2 = ( If (n1 <=? n2) n2 n1 ) -- | Get the height of a tree. type family Height (t :: Tree) :: Nat where Height 'EmptyTree = 0 Height ('ForkTree l (Node _n _a) r) = 1 + Max (Height l) (Height r) -- | Check if two type level natural numbers, -- that represent the heights of some left and right sub trees, -- differ at most in one (i.e., the tree is balanced). type family BalancedHeights (h1 :: Nat) (h2 :: Nat) (k :: Nat) :: Bool where BalancedHeights 0 0 _k = 'True BalancedHeights 1 0 _k = 'True BalancedHeights _h1 0 k = TypeError ('Text "The left sub tree at node with key " ':<>: 'ShowType k ':<>: 'Text " has +2 greater height!") BalancedHeights 0 1 _k = 'True BalancedHeights 0 _h2 k = TypeError ('Text "The right sub tree at node with key " ':<>: 'ShowType k ':<>: 'Text " has +2 greater height!") BalancedHeights h1 h2 k = BalancedHeights (h1 - 1) (h2 - 1) k -- | Data type that represents the state of unbalance of the sub trees: -- -- [`LeftUnbalanced`] @height(left sub tree) = height(right sub tree) + 2@ -- -- [`RightUnbalanced`] @height(right sub tree) = height(left sub tree) + 2@ -- -- [`NotUnbalanced`] @tree is not unbalanced@ data US = LeftUnbalanced | RightUnbalanced | NotUnbalanced -- | Check from two type level natural numbers, -- that represent the heights of some left and right sub trees, -- if the tree is balanced or if some of those sub trees is unbalanced. type family UnbalancedState (h1 :: Nat) (h2 :: Nat) :: US where UnbalancedState 0 0 = 'NotUnbalanced UnbalancedState 1 0 = 'NotUnbalanced UnbalancedState 0 1 = 'NotUnbalanced UnbalancedState 2 0 = 'LeftUnbalanced UnbalancedState 0 2 = 'RightUnbalanced UnbalancedState h1 h2 = UnbalancedState (h1 - 1) (h2 - 1) -- | Data type that represents the state of balance of the sub trees in a balanced tree: -- -- [`LeftHeavy`] @height(left sub tree) = height(right sub tree) + 1@ -- -- [`RightHeavy`] @height(right sub tree) = height(left sub tree) + 1@ -- -- [`Balanced`] @height(left sub tree) = height(right sub tree)@ data BS = LeftHeavy | RightHeavy | Balanced -- | Check from two type level natural numbers, -- that represent the heights of some left and right sub trees, -- if some of those sub trees have height larger than the other. type family BalancedState (h1 :: Nat) (h2 :: Nat) :: BS where BalancedState 0 0 = 'Balanced BalancedState 1 0 = 'LeftHeavy BalancedState 0 1 = 'RightHeavy BalancedState h1 h2 = BalancedState (h1 - 1) (h2 - 1)