-- Filter.hs: OpenPGP (RFC4880) packet filtering
-- Copyright © 2014  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

module Data.Conduit.OpenPGP.Filter (
   conduitFilter
 , FilterPredicates(..)
 , Expr(..)
 , PKPPredicate(..)
 , PKPVar(..)
 , PKPOp(..)
 , PKPValue(..)
 , SPPredicate(..)
 , SPVar(..)
 , SPOp(..)
 , SPValue(..)
 , OPredicate(..)
 , OVar(..)
 , OOp(..)
 , OValue(..)
) where

import qualified Data.ByteString as B
import Data.Conduit
import qualified Data.Conduit.List as CL

import Codec.Encryption.OpenPGP.Internal (sigType, sigPKA, sigHA)
import Codec.Encryption.OpenPGP.KeyInfo (keySize)
import Codec.Encryption.OpenPGP.Types

data FilterPredicates = FilterPredicates {
    _pubKeyPktPredicate :: Expr PKPPredicate
  , _sigPktPredicate :: Expr SPPredicate
  , _otherPredicate :: Expr OPredicate
}

data Expr a = EAny
            | E a
            | EAnd (Expr a) (Expr a)
            | EOr (Expr a) (Expr a)
            | ENot (Expr a)

eval :: (a -> v -> Bool) -> Expr a -> v -> Bool
eval t e v = ev e
  where
        ev EAny = True
        ev (EAnd e1 e2) = ev e1 && ev e2
        ev (EOr e1 e2) =  ev e1 || ev e2
        ev (ENot e1) = (not . ev) e1
        ev (E e') = t e' v

data PKPOp = PKEquals | PKLessThan | PKGreaterThan

data PKPPredicate = PKPPredicate PKPVar PKPOp PKPValue

data PKPVar = PKPVVersion
            | PKPVPKA
            | PKPVKeysize
            | PKPVTimestamp

data PKPValue = PKPInt Int
              | PKPPKA PubKeyAlgorithm
    deriving Eq

instance Ord PKPValue where
    compare i j = compare (pkvToInt i) (pkvToInt j)

pkvToInt (PKPInt i) = i
pkvToInt (PKPPKA i) = fromIntegral (fromFVal i)

data SPOp = SPEquals | SPLessThan | SPGreaterThan

data SPPredicate = SPPredicate SPVar SPOp SPValue

data SPVar = SPVVersion
           | SPVSigType
           | SPVPKA
           | SPVHA

data SPValue = SPInt Int
             | SPSigType SigType
             | SPPKA PubKeyAlgorithm
             | SPHA HashAlgorithm
    deriving Eq

instance Ord SPValue where
    compare i j = compare (spvToInt i) (spvToInt j)

spvToInt (SPInt i) = i
spvToInt (SPSigType i) = fromIntegral (fromFVal i)
spvToInt (SPPKA i) = fromIntegral (fromFVal i)
spvToInt (SPHA i) = fromIntegral (fromFVal i)

data OOp = OEquals | OLessThan | OGreaterThan

data OPredicate = OPredicate OVar OOp OValue

data OVar = OVTag

data OValue = OInt Int
    deriving Eq

instance Ord OValue where
    compare i j = compare (ovToInt i) (ovToInt j)

ovToInt (OInt i) = i

conduitFilter :: MonadResource m => FilterPredicates -> Conduit Pkt m Pkt
conduitFilter = CL.filter . superPredicate

superPredicate :: FilterPredicates -> Pkt -> Bool
superPredicate fp (PublicKeyPkt pkp) = eval pkpEval (_pubKeyPktPredicate fp) pkp
superPredicate fp (SignaturePkt sp) = eval spEval (_sigPktPredicate fp) sp
superPredicate fp p = eval oEval (_otherPredicate fp) p

pkpEval :: PKPPredicate -> PKPayload -> Bool
pkpEval (PKPPredicate lhs o rhs) pkp = uncurry (opreduce o) (vreduce (lhs,pkp),rhs)
    where
        opreduce PKEquals = (==)
        opreduce PKLessThan = (<)
        opreduce PKGreaterThan = (>)
        vreduce (PKPVVersion, p) = PKPInt (kv (_keyVersion p))
        vreduce (PKPVPKA, p) = PKPPKA (_pkalgo p)
        vreduce (PKPVKeysize, p) = PKPInt (keySize . _pubkey $ p)
        vreduce (PKPVTimestamp, p) = PKPInt (fromIntegral (_timestamp p))
	kv DeprecatedV3 = 3
	kv V4 = 4

spEval :: SPPredicate -> SignaturePayload -> Bool
spEval (SPPredicate lhs o rhs) pkp = case vreduce (lhs, pkp) >>= \x -> return (uncurry (opreduce o) (x,rhs)) of
                                         Just True -> True
                                         _ -> False
    where
        opreduce SPEquals = (==)
        opreduce SPLessThan = (<)
        opreduce SPGreaterThan = (>)
        vreduce (SPVVersion, s) = Just (SPInt (sigVersion s))
        vreduce (SPVSigType, s) = fmap SPSigType (sigType s)
        vreduce (SPVPKA, s) = fmap SPPKA (sigPKA s)
        vreduce (SPVHA, s) = fmap SPHA (sigHA s)
	sigVersion (SigV3 {}) = 3
	sigVersion (SigV4 {}) = 4
	sigVersion (SigVOther v _) = fromIntegral v

oEval :: OPredicate -> Pkt -> Bool
oEval (OPredicate lhs o rhs) pkp = uncurry (opreduce o) (vreduce (lhs,pkp),rhs)
    where
        opreduce OEquals = (==)
        opreduce OLessThan = (<)
        opreduce OGreaterThan = (>)
        vreduce (OVTag, p) = OInt (fromIntegral (pktTag p))
