summaryrefslogtreecommitdiff
path: root/impl/MaxStrategy.hpp
blob: 06f61b7b1764468430ea4dfca4a5e70ad103fc46 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#ifndef MAX_EXPRESSION_HPP
#define MAX_EXPRESSION_HPP

#include "Expression.hpp"
#include "EquationSystem.hpp"

template<typename Domain>
struct MaxStrategyExpression : public Expression<Domain> {
  MaxStrategyExpression(const Expression<Domain>& expr, const IdMap<MaxExpression<Domain>,unsigned int>& strategy)
    : _expr(expr), _strategy(strategy) { }

  virtual Domain eval(const VariableAssignment<Domain>& rho) const {
    // relies on implementation details - BAD BAD BAD, maybe
    const OperatorExpression<Domain>* opExpr = dynamic_cast<const OperatorExpression<Domain>*>(&_expr);
    if (opExpr) {
      const MaxExpression<Domain>* maxExpr = dynamic_cast<const MaxExpression<Domain>*>(opExpr);
      const std::vector<Expression<Domain>*> args = opExpr->arguments();
      if (maxExpr) {
        unsigned int i = _strategy[*maxExpr];
        return MaxStrategyExpression(*args[i], _strategy).eval(rho);
      } else {
        std::vector<Domain> argumentValues;
        for (typename std::vector<Expression<Domain>*>::const_iterator it = args.begin();
             it != args.end();
             ++it) {
          argumentValues.push_back(MaxStrategyExpression(**it, _strategy).eval(rho));
        }
        return opExpr->op().eval(argumentValues);
      }
    } else {
      return _expr.eval(rho);
    }
  }

  void print(std::ostream& cout) const {
    cout << _expr;
  }
  private:
  const Expression<Domain>& _expr;
  const IdMap<MaxExpression<Domain>,unsigned int>& _strategy;
};

template<typename Domain>
struct MaxStrategy : public EquationSystem<Domain> {
  MaxStrategy(const ConcreteEquationSystem<Domain>& system)
    : _system(system), _expressions(system.variableCount(), NULL), _strategy(system.maxExpressionCount(), 0) {
  }

  const Expression<Domain>* operator[](const Variable<Domain>& v) const {
    if (_expressions[v] == NULL) {
      Expression<Domain>* expression = new MaxStrategyExpression<Domain>(*_system[v], _strategy);
      _expressions[v] = expression;
      return expression;
    } else {
      return _expressions[v];
    }
  }

  unsigned int get(const MaxExpression<Domain>& e) const {
    return _strategy[e];
  }
  unsigned int set(const MaxExpression<Domain>& e, unsigned int i) {
    _strategy[e] = i;
    return i;
  }

  VariableAssignment<Domain>* eval(const VariableAssignment<Domain>& rho) const {
    StableVariableAssignment<Domain>* result = this->assignment(-infinity<Domain>());
    for (unsigned int i = 0, length = _system.variableCount();
         i < length;
         ++i) {
      const Variable<Domain>& var = _system.variable(i);
      const Expression<Domain>& expr = * (*this)[var];
      (*result)[var] = expr.eval(rho);
    }
    return result;
  }

  unsigned int variableCount() const {
    return _system.variableCount();
  }

  Variable<Domain>& variable(unsigned int i) const {
    return _system.variable(i);
  }

  StableVariableAssignment<Domain>* assignment(const Domain& v) const {
    return _system.assignment(v);
  }

  bool equalAssignments(const VariableAssignment<Domain>& l, const VariableAssignment<Domain>& r) const {
    return _system.equalAssignments(l, r);
  }

  void improve(const VariableAssignment<Domain>& rho) {
    for (unsigned int i = 0, length = _system.maxExpressionCount();
         i < length;
         ++i) {
      MaxExpression<Domain>& expr = _system.maxExpression(i);
      Domain bestValue = MaxStrategyExpression<Domain>(expr, _strategy).eval(rho);
      unsigned int bestIndex = this->get(expr);

      // this relies on the fact that an expression will only be proessed after the expressions
      // it depends on (which should always be true, as they form a DAG)
      const std::vector<Expression<Domain>*> args = expr.arguments();
      for (unsigned int j = 0, length = args.size();
           j < length;
           ++j) {
        const Domain value = MaxStrategyExpression<Domain>(*args[j], _strategy).eval(rho);
        if (bestValue < value) {
          bestValue = value;
          bestIndex = j;
        }
      }
      this->set(expr, bestIndex);
    }
  }

  void print(std::ostream& cout) const {
    cout << _system << std::endl;
  }

  private:
  const ConcreteEquationSystem<Domain>& _system;
  mutable IdMap<Variable<Domain>,Expression<Domain>*> _expressions;
  IdMap<MaxExpression<Domain>,unsigned int> _strategy;
};

#endif