summaryrefslogtreecommitdiff
path: root/impl/EquationSystem.hpp
blob: c7e5f4ccfd8374608d492cb5eafaa9080523a561 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#ifndef EQUATION_SYSTEM_HPP
#define EQUATION_SYSTEM_HPP

#include <vector>
#include <set>
#include <map>
#include "Operator.hpp"
#include "Expression.hpp"
#include "VariableAssignment.hpp"
#include "IdMap.hpp"

template<typename Domain>
struct MaxStrategy;

template<typename Domain>
struct EquationSystem {
  EquationSystem()
    : _expr_to_var(NULL) { }

  virtual ~EquationSystem() {
    for (typename std::set<Expression<Domain>*>::iterator it = _expressions.begin();
         it != _expressions.end();
         ++it) {
      delete *it;
    }
    for (typename std::set<Operator<Domain>*>::iterator it = _operators.begin();
         it != _operators.end();
         ++it) {
      delete *it;
    }
    if (_expr_to_var)
      delete _expr_to_var;
  }

  MaxExpression<Domain>& maxExpression(const std::vector<Expression<Domain>*>& arguments) {
    unsigned int id = _max_expressions.size();
    Maximum<Domain>* max = new Maximum<Domain>();
    MaxExpression<Domain>* expr = new MaxExpression<Domain>(id, *max, arguments);
    _operators.insert(max);
    _max_expressions.push_back(expr);
    _expressions.insert(expr);
    return *expr;
  }
  MaxExpression<Domain>& maxExpression(unsigned int i) const {
    return *_max_expressions[i];
  }
  unsigned int maxExpressionCount() const {
    return _max_expressions.size();
  }

  Expression<Domain>& expression(Operator<Domain>* op, const std::vector<Expression<Domain>*>& arguments) {
    Expression<Domain>* expr = new OperatorExpression<Domain>(*op, arguments);
    _operators.insert(op);
    _expressions.insert(expr);
    return *expr;
  }

  Variable<Domain>& variable(const std::string& name) {
    if (_variable_names.find(name) == _variable_names.end()) {
      // not found - create a new variable and whatnot
      unsigned int id = _variables.size();
      Variable<Domain>* var = new Variable<Domain>(id, name);
      _variables.push_back(var);
      _right_sides.push_back(NULL);
      _expressions.insert(var);
      _variable_names[name] = var;
      return *var;
    } else {
      return *_variable_names[name];
    }
  }
  Variable<Domain>& variable(unsigned int id) const {
    return *_variables[id];
  }
  unsigned int variableCount() const {
    return _variables.size();
  }

  Constant<Domain>& constant(const Domain& value) {
    Constant<Domain>* constant = new Constant<Domain>(value);
    _expressions.insert(constant);
    return *constant;
  }

  MaxExpression<Domain>* operator[](const Variable<Domain>& var) const {
    return _right_sides[var.id()];
  }
  MaxExpression<Domain>*& operator[](const Variable<Domain>& var) {
    return _right_sides[var.id()];
  }

  StableVariableAssignment<Domain>* assignment(const Domain& value) const {
    return new StableVariableAssignment<Domain>(_variables.size(), value);
  }
  StableVariableAssignment<Domain>* eval(const VariableAssignment<Domain>& rho) const {
    StableVariableAssignment<Domain>* result = this->assignment(-infinity<Domain>());
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      const Variable<Domain>& var = *_variables[i];
      const Expression<Domain>& expr = * (*this)[var];
      (*result)[var] = expr.eval(rho);
    }
    return result;
  }
  StableVariableAssignment<Domain>* eval(const VariableAssignment<Domain>& rho, const MaxStrategy<Domain>& strat) const {
    StableVariableAssignment<Domain>* result = this->assignment(-infinity<Domain>());
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      const Variable<Domain>& var = *_variables[i];
      const Expression<Domain>& expr = * (*this)[var];
      (*result)[var] = expr.eval(rho, strat);
    }
    return result;
  }

  void indexMaxExpressions() {
    _expr_to_var = new IdMap<MaxExpression<Domain>,Variable<Domain>*>(maxExpressionCount(), NULL); 
    for (unsigned int i = 0, length = _right_sides.size(); i < length; ++i) {
      (*_expr_to_var)[*_right_sides[i]] = _variables[i];
    }
  }

  Variable<Domain>* varFromExpr(const Expression<Domain>& expr) const {
    if (_expr_to_var) { // we've indexed:
      auto* maxExpr = dynamic_cast<const MaxExpression<Domain>*>(&expr);
      if (maxExpr) {
        return (*_expr_to_var)[*maxExpr];
      } else {
        return NULL;
      }
    } else {
      for (unsigned int i = 0, length = _right_sides.size(); i < length; ++i) {
        if (_right_sides[i] == &expr)
          return _variables[i];
      }
      return NULL;
    }
  }

  virtual bool equalAssignments(const VariableAssignment<Domain>& l, const VariableAssignment<Domain>& r) const {
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      const Variable<Domain>& var = *_variables[i];
      if (l[var] != r[var])
        return false;
    }
    return true;
  }

  void print(std::ostream& cout) const {
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      cout << *_variables[i] << " = " << *_right_sides[i] << std::endl;
    }
  }

  private:
  std::set<Operator<Domain>*> _operators;
  std::set<Expression<Domain>*> _expressions;
  std::vector<Variable<Domain>*> _variables;
  std::map<std::string, Variable<Domain>*> _variable_names;
  IdMap<MaxExpression<Domain>, Variable<Domain>*>* _expr_to_var;
  std::vector<MaxExpression<Domain>*> _max_expressions;
  std::vector<MaxExpression<Domain>*> _right_sides;
};

template<typename T>
std::ostream& operator<<(std::ostream& cout, const EquationSystem<T>& system) {
  system.print(cout);
  return cout;
}

#include "MaxStrategy.hpp"

#endif