summaryrefslogtreecommitdiff
path: root/impl/EquationSystem.hpp
blob: 2fd24bda38a4bf32abdb4bb312021054a8e2181d (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#ifndef EQUATION_SYSTEM_HPP
#define EQUATION_SYSTEM_HPP

#include <vector>
#include <set>
#include <map>
#include "Operator.hpp"
#include "Expression.hpp"
#include "VariableAssignment.hpp"
#include "IdMap.hpp"

template<typename Domain>
struct MaxStrategy;

template<typename Domain>
struct EquationSystem {
  EquationSystem()
    : _expr_to_var(NULL) { }

  virtual ~EquationSystem() {
    for (typename std::set<Expression<Domain>*>::iterator it = _expressions.begin();
         it != _expressions.end();
         ++it) {
      delete *it;
    }
    for (typename std::set<Operator<Domain>*>::iterator it = _operators.begin();
         it != _operators.end();
         ++it) {
      delete *it;
    }
    delete _expr_to_var;
  }

  MaxExpression<Domain>& maxExpression(const std::vector<Expression<Domain>*>& arguments) {
    unsigned int id = _max_expressions.size();
    Maximum<Domain>* max = new Maximum<Domain>();
    MaxExpression<Domain>* expr = new MaxExpression<Domain>(id, *max, arguments);
    _operators.insert(max);
    _max_expressions.push_back(expr);
    _expressions.insert(expr);
    return *expr;
  }
  MaxExpression<Domain>& maxExpression(unsigned int i) const {
    return *_max_expressions[i];
  }
  unsigned int maxExpressionCount() const {
    return _max_expressions.size();
  }

  Expression<Domain>& expression(Operator<Domain>* op, const std::vector<Expression<Domain>*>& arguments) {
    Expression<Domain>* expr = new OperatorExpression<Domain>(*op, arguments);
    _operators.insert(op);
    _expressions.insert(expr);
    return *expr;
  }

  Variable<Domain>& variable(const std::string& name) {
    if (_variable_names.find(name) == _variable_names.end()) {
      // not found - create a new variable and whatnot
      unsigned int id = _variables.size();
      Variable<Domain>* var = new Variable<Domain>(id, name);
      _variables.push_back(var);
      _right_sides.push_back(NULL);
      _expressions.insert(var);
      _variable_names[name] = var;
      return *var;
    } else {
      return *_variable_names[name];
    }
  }
  Variable<Domain>& variable(unsigned int id) const {
    return *_variables[id];
  }
  unsigned int variableCount() const {
    return _variables.size();
  }

  Constant<Domain>& constant(const Domain& value) {
    Constant<Domain>* constant = new Constant<Domain>(value);
    _expressions.insert(constant);
    return *constant;
  }

  MaxExpression<Domain>* operator[](const Variable<Domain>& var) const {
    return _right_sides[var.id()];
  }
  MaxExpression<Domain>*& operator[](const Variable<Domain>& var) {
    return _right_sides[var.id()];
  }

  void indexMaxExpressions() {
    _expr_to_var = new IdMap<MaxExpression<Domain>,Variable<Domain>*>(maxExpressionCount(), NULL); 
    for (unsigned int i = 0, length = _right_sides.size(); i < length; ++i) {
      _right_sides[i]->mapTo(*_variables[i], *_expr_to_var);
    }
  }

  Variable<Domain>* varFromExpr(const Expression<Domain>& expr) const {
    assert(_expr_to_var); // make sure we've indexed
    const MaxExpression<Domain>* maxExpr = expr.toMaxExpression();//dynamic_cast<const MaxExpression<Domain>*>(&expr);
    if (maxExpr) {
      return (*_expr_to_var)[*maxExpr];
    } else {
      return NULL;
    }
  }

  virtual bool equalAssignments(const VariableAssignment<Domain>& l, const VariableAssignment<Domain>& r) const {
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      const Variable<Domain>& var = *_variables[i];
      if (l[var] != r[var])
        return false;
    }
    return true;
  }

  void print(std::ostream& cout) const {
    for (unsigned int i = 0, length = _variables.size();
         i < length;
         ++i) {
      cout << *_variables[i] << " = " << *_right_sides[i] << std::endl;
    }
  }

  private:
  std::set<Operator<Domain>*> _operators;
  std::set<Expression<Domain>*> _expressions;
  std::vector<Variable<Domain>*> _variables;
  std::map<std::string, Variable<Domain>*> _variable_names;
  IdMap<MaxExpression<Domain>, Variable<Domain>*>* _expr_to_var;
  std::vector<MaxExpression<Domain>*> _max_expressions;
  std::vector<MaxExpression<Domain>*> _right_sides;
};

template<typename T>
std::ostream& operator<<(std::ostream& cout, const EquationSystem<T>& system) {
  system.print(cout);
  return cout;
}

#include "MaxStrategy.hpp"

#endif