summaryrefslogtreecommitdiff
path: root/impl/EquationSystem.hpp
blob: 55496741152531772efbf63cfc7b99f70d15c301 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#ifndef EQUATION_SYSTEM_HPP
#define EQUATION_SYSTEM_HPP

#include <vector>
#include <set>
#include <map>
#include "Operator.hpp"
#include "Expression.hpp"
#include "VariableAssignment.hpp"

template<typename Domain>
struct EquationSystem {
  virtual ~EquationSystem() { }
  virtual const Expression<Domain>* operator[](const Variable<Domain>&) const = 0;
  virtual VariableAssignment<Domain>* eval(const VariableAssignment<Domain>& rho) const = 0;

  virtual unsigned int variableCount() const = 0;
  virtual Variable<Domain>& variable(unsigned int) const = 0;

  virtual StableVariableAssignment<Domain>* assignment(const Domain& value) const = 0;
  virtual Variable<Domain>* varFromExpr(const Expression<Domain>& expr) const = 0;
  virtual bool equalAssignments(const VariableAssignment<Domain>&, const VariableAssignment<Domain>&) const = 0;
  virtual void print(std::ostream& cout) const = 0;
};

template<typename Domain>
struct ConcreteEquationSystem : public EquationSystem<Domain> {
  virtual ~ConcreteEquationSystem() {
    for (typename std::set<Expression<Domain>*>::iterator it = _expressions.begin();
         it != _expressions.end();
         ++it) {
      delete *it;
    }
    for (typename std::set<Operator<Domain>*>::iterator it = _operators.begin();
         it != _operators.end();
         ++it) {
      delete *it;
    }
  }

  MaxExpression<Domain>& maxExpression(const std::vector<Expression<Domain>*>& arguments) {
    unsigned int id = _max_expressions.size();
    Maximum<Domain>* max = new Maximum<Domain>();
    MaxExpression<Domain>* expr = new MaxExpression<Domain>(id, *max, arguments);
    _operators.insert(max);
    _max_expressions.push_back(expr);
    _expressions.insert(expr);
    return *expr;
  }
  MaxExpression<Domain>& maxExpression(unsigned int i) const {
    return *_max_expressions[i];
  }
  unsigned int maxExpressionCount() const {
    return _max_expressions.size();
  }

  Expression<Domain>& expression(Operator<Domain>* op, const std::vector<Expression<Domain>*>& arguments) {
    Expression<Domain>* expr = new OperatorExpression<Domain>(*op, arguments);
    _operators.insert(op);
    _expressions.insert(expr);
    return *expr;
  }

  Variable<Domain>& variable(const std::string& name) {
    if (_variable_names.find(name) == _variable_names.end()) {
      // not found - create a new variable and whatnot
      unsigned int id = _variables.size();
      Variable<Domain>* var = new Variable<Domain>(id, name);
      _variables.push_back(var);
      _right_sides.push_back(NULL);
      _expressions.insert(var);
      _variable_names[name] = var;
      return *var;
    } else {
      return *_variable_names[name];
    }
  }
  Variable<Domain>& variable(unsigned int id) const {
    return *_variables[id];
  }
  unsigned int variableCount() const {
    return _variables.size();
  }

  Constant<Domain>& constant(const Domain& value) {
    Constant<Domain>* constant = new Constant<Domain>(value);
    _expressions.insert(constant);
    return *constant;
  }

  MaxExpression<Domain>* operator[](const Variable<Domain>& var) const {
    return _right_sides[var.id()];
  }
  MaxExpression<Domain>*& operator[](const Variable<Domain>& var) {
    return _right_sides[var.id()];
  }

  StableVariableAssignment<Domain>* assignment(const Domain& value) const {
    return new StableVariableAssignment<Domain>(_variables.size(), value);
  }
  VariableAssignment<Domain>* eval(const VariableAssignment<Domain>& rho) const {
    StableVariableAssignment<Domain>* result = this->assignment(-infinity<Domain>());
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      const Variable<Domain>& var = *_variables[i];
      const Expression<Domain>& expr = * (*this)[var];
      (*result)[var] = expr.eval(rho);
    }
    return result;
  }

  Variable<Domain>* varFromExpr(const Expression<Domain>& expr) const {
    for (unsigned int i = 0, length = _right_sides.size(); i < length; ++i) {
      if (_right_sides[i] == &expr)
        return _variables[i];
    }
    return NULL;
  }

  virtual bool equalAssignments(const VariableAssignment<Domain>& l, const VariableAssignment<Domain>& r) const {
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      const Variable<Domain>& var = *_variables[i];
      if (l[var] != r[var])
        return false;
    }
    return true;
  }

  void print(std::ostream& cout) const {
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      cout << *_variables[i] << " = " << *_right_sides[i] << std::endl;
    }
  }

  private:
  std::set<Operator<Domain>*> _operators;
  std::set<Expression<Domain>*> _expressions;
  std::vector<Variable<Domain>*> _variables;
  std::map<std::string, Variable<Domain>*> _variable_names;
  std::vector<MaxExpression<Domain>*> _max_expressions;
  std::vector<MaxExpression<Domain>*> _right_sides;
};

template<typename T>
std::ostream& operator<<(std::ostream& cout, const EquationSystem<T>& system) {
  system.print(cout);
  return cout;
}

#endif