module Hermod.ReCon.Presburger.Internal.CooperQE (eliminate, eliminateExists) where

import           Hermod.ReCon.Common.Types (NatValue)
import           Hermod.ReCon.Integer.Polynomial.Term
import qualified Hermod.ReCon.Presburger.Internal.IR.AffineDNF as AffineDNF
import qualified Hermod.ReCon.Presburger.Internal.IR.CompNF as CompNF
import           Hermod.ReCon.Presburger.Internal.IR.DNF (DNF, DNFConjunct (..), quoteDNF)
import qualified Hermod.ReCon.Presburger.Internal.IR.DNF as DNF
import qualified Hermod.ReCon.Presburger.Internal.IR.ForallFree as F
import qualified Hermod.ReCon.Presburger.Internal.IR.NegNF as NegNF
import qualified Hermod.ReCon.Presburger.Internal.IR.NormAffineDNF as NormAffineDNF
import           Hermod.ReCon.Presburger.Internal.IR.QuantifierFree (QuantifierFree)
import qualified Hermod.ReCon.Presburger.Internal.IR.QuantifierFree as Q

import           Prelude hiding (Foldable (..))

import           Data.List (foldl', foldr, null)

-- | Eliminate all existential quantifiers from a ForallFree formula,
-- producing an equivalent QuantifierFree formula.
-- Proceeds bottom-up: inner quantifiers are eliminated before outer ones.
eliminate :: F.ForallFree -> QuantifierFree
eliminate :: ForallFree -> QuantifierFree
eliminate (F.Or ForallFree
a ForallFree
b)          = QuantifierFree -> QuantifierFree -> QuantifierFree
Q.Or (ForallFree -> QuantifierFree
eliminate ForallFree
a) (ForallFree -> QuantifierFree
eliminate ForallFree
b)
eliminate (F.And ForallFree
a ForallFree
b)         = QuantifierFree -> QuantifierFree -> QuantifierFree
Q.And (ForallFree -> QuantifierFree
eliminate ForallFree
a) (ForallFree -> QuantifierFree
eliminate ForallFree
b)
eliminate (F.Not ForallFree
a)           = QuantifierFree -> QuantifierFree
Q.Not (ForallFree -> QuantifierFree
eliminate ForallFree
a)
eliminate (F.Implies ForallFree
a ForallFree
b)     = QuantifierFree -> QuantifierFree -> QuantifierFree
Q.Implies (ForallFree -> QuantifierFree
eliminate ForallFree
a) (ForallFree -> QuantifierFree
eliminate ForallFree
b)
eliminate ForallFree
F.Top               = QuantifierFree
Q.Top
eliminate ForallFree
F.Bottom            = QuantifierFree
Q.Bottom
eliminate (F.IntBinRel BinRel
r IntTerm
i IntTerm
j) = BinRel -> IntTerm -> IntTerm -> QuantifierFree
Q.IntBinRel BinRel
r IntTerm
i IntTerm
j
eliminate (F.IntDiv NatValue
k IntTerm
t)      = NatValue -> IntTerm -> QuantifierFree
Q.IntDiv NatValue
k IntTerm
t
eliminate (F.IntExists VariableIdentifier
x ForallFree
body) =
  -- Recursively eliminate inner quantifiers first, yielding a
  -- quantifier-free body, then apply the full Cooper pipeline.
  VariableIdentifier -> QuantifierFree -> QuantifierFree
eliminateLeaf VariableIdentifier
x (ForallFree -> QuantifierFree
eliminate ForallFree
body)

-- | Run the full Cooper pipeline on ∃x. qf and return a QuantifierFree result.
eliminateLeaf :: VariableIdentifier -> QuantifierFree -> QuantifierFree
eliminateLeaf :: VariableIdentifier -> QuantifierFree -> QuantifierFree
eliminateLeaf VariableIdentifier
x =
    DNF -> QuantifierFree
quoteDNF
  (DNF -> QuantifierFree)
-> (QuantifierFree -> DNF) -> QuantifierFree -> QuantifierFree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormAffineDNF -> DNF
eliminateExists
  (NormAffineDNF -> DNF)
-> (QuantifierFree -> NormAffineDNF) -> QuantifierFree -> DNF
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AffineDNF -> NormAffineDNF
NormAffineDNF.fromAffineDNF
  (AffineDNF -> NormAffineDNF)
-> (QuantifierFree -> AffineDNF) -> QuantifierFree -> NormAffineDNF
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VariableIdentifier -> DNF -> AffineDNF
AffineDNF.fromDNF VariableIdentifier
x
  (DNF -> AffineDNF)
-> (QuantifierFree -> DNF) -> QuantifierFree -> AffineDNF
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompNF -> DNF
DNF.fromCompNF
  (CompNF -> DNF)
-> (QuantifierFree -> CompNF) -> QuantifierFree -> DNF
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NegNF -> CompNF
CompNF.fromNegNF
  (NegNF -> CompNF)
-> (QuantifierFree -> NegNF) -> QuantifierFree -> CompNF
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QuantifierFree -> NegNF
NegNF.fromQuantifierFree

-- | Eliminate ∃x from a formula already in NormAffineDNF form w.r.t. x.
-- Returns a quantifier-free DNF in the remaining variables.
--
-- Algorithm (Cooper's theorem, unified form):
--
--   ∃x. C(x)  ⟺  ∨_{j=1}^{δ} C₋∞(j)  ∨  ∨_{lᵢ ∈ LB} ∨_{j=1}^{δ} C(lᵢ + j)
--
-- where δ = lcm of all divisors in divisibility constraints involving x,
-- and C₋∞ is C with all upper-bound atoms dropped.
eliminateExists :: NormAffineDNF.NormAffineDNF -> DNF
eliminateExists :: NormAffineDNF -> DNF
eliminateExists = (NormAffineDNFDisjunct -> DNF) -> NormAffineDNF -> DNF
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap NormAffineDNFDisjunct -> DNF
elimDisjunct

-- ── per-disjunct ─────────────────────────────────────────────────────────────

data Partitioned = Partitioned
  { Partitioned -> [IntTerm]
lbs     :: [IntTerm]           -- lᵢ from  lᵢ < x
  , Partitioned -> [IntTerm]
ubs     :: [IntTerm]           -- uⱼ from  x < uⱼ
  , Partitioned -> [(NatValue, IntTerm)]
divs    :: [(NatValue, IntTerm)] -- (h, e) from  h | x + e
  , Partitioned -> [(NatValue, IntTerm)]
negDivs :: [(NatValue, IntTerm)] -- (h, e) from  ¬(h | x + e)
  , Partitioned -> [DNFConjunct]
frees   :: [DNFConjunct]       -- x-free atoms
  }

emptyPartitioned :: Partitioned
emptyPartitioned :: Partitioned
emptyPartitioned = [IntTerm]
-> [IntTerm]
-> [(NatValue, IntTerm)]
-> [(NatValue, IntTerm)]
-> [DNFConjunct]
-> Partitioned
Partitioned [] [] [] [] []

classify :: NormAffineDNF.NormAffineDNFConjunct -> Partitioned -> Partitioned
classify :: NormAffineDNFConjunct -> Partitioned -> Partitioned
classify (NormAffineDNF.IntLt (NormAffineDNF.LtR IntTerm
l))        Partitioned
p = Partitioned
p { lbs     = l : lbs p }
classify (NormAffineDNF.IntLt (NormAffineDNF.LtL IntTerm
u))        Partitioned
p = Partitioned
p { ubs     = u : ubs p }
classify (NormAffineDNF.IntLt (NormAffineDNF.LtC IntTerm
i IntTerm
j))      Partitioned
p = Partitioned
p { frees   = DNF.IntLt i j    : frees p }
classify (NormAffineDNF.IntDiv (NormAffineDNF.DivL NatValue
h IntTerm
e))    Partitioned
p = Partitioned
p { divs    = (h, e)            : divs p }
classify (NormAffineDNF.IntDiv (NormAffineDNF.DivC NatValue
h IntTerm
i))    Partitioned
p = Partitioned
p { frees   = DNF.IntDiv h i    : frees p }
classify (NormAffineDNF.IntNegDiv (NormAffineDNF.DivL NatValue
h IntTerm
e)) Partitioned
p = Partitioned
p { negDivs = (h, e)            : negDivs p }
classify (NormAffineDNF.IntNegDiv (NormAffineDNF.DivC NatValue
h IntTerm
i)) Partitioned
p = Partitioned
p { frees   = DNF.IntNegDiv h i : frees p }

elimDisjunct :: NormAffineDNF.NormAffineDNFDisjunct -> DNF
elimDisjunct :: NormAffineDNFDisjunct -> DNF
elimDisjunct NormAffineDNFDisjunct
conjuncts =
  let p :: Partitioned
p     = (NormAffineDNFConjunct -> Partitioned -> Partitioned)
-> Partitioned -> NormAffineDNFDisjunct -> Partitioned
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr NormAffineDNFConjunct -> Partitioned -> Partitioned
classify Partitioned
emptyPartitioned NormAffineDNFDisjunct
conjuncts
      delta :: NatValue
delta = (NatValue -> (NatValue, IntTerm) -> NatValue)
-> NatValue -> [(NatValue, IntTerm)] -> NatValue
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\NatValue
acc (NatValue
d, IntTerm
_) -> NatValue -> NatValue -> NatValue
forall a. Integral a => a -> a -> a
lcm NatValue
acc NatValue
d) NatValue
1 (Partitioned -> [(NatValue, IntTerm)]
divs Partitioned
p)
      -- -∞ test set: only applicable when there are no lower bounds.
      -- When lower bounds exist every substitution x = j makes each atom
      -- l < j potentially true, but at x → -∞ every l < x is FALSE, so
      -- those conjuncts are all False and contribute nothing.
      infBranch :: DNF
infBranch = [ Bool -> Partitioned -> IntTerm -> [DNFConjunct]
substituteAt Bool
True  Partitioned
p (IntValue -> IntTerm
IntConst (NatValue -> IntValue
forall a b. (Integral a, Num b) => a -> b
fromIntegral NatValue
j))
                  | [IntTerm] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Partitioned -> [IntTerm]
lbs Partitioned
p)
                  , NatValue
j <- [NatValue
1 .. NatValue
delta] ]
      -- lower-bound test set: x = lb + j
      lbBranch :: DNF
lbBranch  = [ Bool -> Partitioned -> IntTerm -> [DNFConjunct]
substituteAt Bool
False Partitioned
p (IntTerm -> IntTerm -> IntTerm
IntSum IntTerm
lb (IntValue -> IntTerm
IntConst (NatValue -> IntValue
forall a b. (Integral a, Num b) => a -> b
fromIntegral NatValue
j)))
                  | IntTerm
lb <- Partitioned -> [IntTerm]
lbs Partitioned
p
                  , NatValue
j  <- [NatValue
1 .. NatValue
delta] ]
  in DNF
infBranch DNF -> DNF -> DNF
forall a. [a] -> [a] -> [a]
++ DNF
lbBranch

-- | Build one DNFDisjunct by substituting x = t into all conjuncts.
-- When isInf is True the upper-bound atoms (x < uⱼ) are dropped.
substituteAt :: Bool -> Partitioned -> IntTerm -> [DNFConjunct]
substituteAt :: Bool -> Partitioned -> IntTerm -> [DNFConjunct]
substituteAt Bool
isInf Partitioned
p IntTerm
t =
  Partitioned -> [DNFConjunct]
frees Partitioned
p
  [DNFConjunct] -> [DNFConjunct] -> [DNFConjunct]
forall a. [a] -> [a] -> [a]
++ (if Bool
isInf then [] else [IntTerm -> IntTerm -> DNFConjunct
DNF.IntLt IntTerm
t IntTerm
u       | IntTerm
u      <- Partitioned -> [IntTerm]
ubs     Partitioned
p])
  [DNFConjunct] -> [DNFConjunct] -> [DNFConjunct]
forall a. [a] -> [a] -> [a]
++ [IntTerm -> IntTerm -> DNFConjunct
DNF.IntLt IntTerm
l IntTerm
t                              | IntTerm
l      <- Partitioned -> [IntTerm]
lbs     Partitioned
p]
  [DNFConjunct] -> [DNFConjunct] -> [DNFConjunct]
forall a. [a] -> [a] -> [a]
++ [NatValue -> IntTerm -> DNFConjunct
DNF.IntDiv    NatValue
h (IntTerm -> IntTerm -> IntTerm
IntSum IntTerm
t IntTerm
e)               | (NatValue
h, IntTerm
e) <- Partitioned -> [(NatValue, IntTerm)]
divs    Partitioned
p]
  [DNFConjunct] -> [DNFConjunct] -> [DNFConjunct]
forall a. [a] -> [a] -> [a]
++ [NatValue -> IntTerm -> DNFConjunct
DNF.IntNegDiv NatValue
h (IntTerm -> IntTerm -> IntTerm
IntSum IntTerm
t IntTerm
e)               | (NatValue
h, IntTerm
e) <- Partitioned -> [(NatValue, IntTerm)]
negDivs Partitioned
p]