summaryrefslogtreecommitdiff
path: root/impl/MaxStrategy.hpp
blob: 2be9f4c4ec4d1e94f8864e6a93007b105c4c1b5a (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#ifndef MAX_EXPRESSION_HPP
#define MAX_EXPRESSION_HPP

#include "Expression.hpp"
#include "EquationSystem.hpp"

template<typename Domain>
struct MaxStrategyExpression : public Expression<Domain> {
  MaxStrategyExpression(const Expression<Domain>& expr, const IdMap<MaxExpression<Domain>,unsigned int>& strategy)
    : _expr(expr), _strategy(strategy) { }

  virtual Domain eval(const VariableAssignment<Domain>& rho) const {
    // relies on implementation details - BAD BAD BAD, maybe
    const OperatorExpression<Domain>* opExpr = dynamic_cast<const OperatorExpression<Domain>*>(&_expr);
    if (opExpr) {
      const MaxExpression<Domain>* maxExpr = dynamic_cast<const MaxExpression<Domain>*>(opExpr);
      const std::vector<Expression<Domain>*> args = opExpr->arguments();
      if (maxExpr) {
        unsigned int i = _strategy[*maxExpr];
        return MaxStrategyExpression(*args[i], _strategy).eval(rho);
      } else {
        std::vector<Domain> argumentValues;
        for (typename std::vector<Expression<Domain>*>::const_iterator it = args.begin();
             it != args.end();
             ++it) {
          argumentValues.push_back(MaxStrategyExpression(**it, _strategy).eval(rho));
        }
        return opExpr->op().eval(argumentValues);
      }
    } else {
      return _expr.eval(rho);
    }
  }

  void print(std::ostream& cout) const {
    cout << _expr;
  }
  private:
  const Expression<Domain>& _expr;
  const IdMap<MaxExpression<Domain>,unsigned int>& _strategy;
};

template<typename Domain>
struct MaxStrategy : public EquationSystem<Domain> {
  MaxStrategy(const ConcreteEquationSystem<Domain>& system)
    : _system(system), _expressions(system.variableCount(), NULL), _strategy(system.maxExpressionCount(), 0) {
  }

  ~MaxStrategy() {
    for (int i = 0, length = _system.variableCount();
         i < length;
         ++i) {
      Expression<Domain>* expr = _expressions[_system.variable(i)];
      if (expr)
        delete expr;
    }
  }

  const Expression<Domain>* operator[](const Variable<Domain>& v) const {
    if (_expressions[v] == NULL) {
      Expression<Domain>* expression = new MaxStrategyExpression<Domain>(*_system[v], _strategy);
      _expressions[v] = expression;
      return expression;
    } else {
      return _expressions[v];
    }
  }

  unsigned int get(const MaxExpression<Domain>& e) const {
    return _strategy[e];
  }
  unsigned int set(const MaxExpression<Domain>& e, unsigned int i) {
    _strategy[e] = i;
    return i;
  }

  VariableAssignment<Domain>* eval(const VariableAssignment<Domain>& rho) const {
    StableVariableAssignment<Domain>* result = this->assignment(-infinity<Domain>());
    for (unsigned int i = 0, length = _system.variableCount();
         i < length;
         ++i) {
      const Variable<Domain>& var = _system.variable(i);
      const Expression<Domain>& expr = * (*this)[var];
      (*result)[var] = expr.eval(rho);
    }
    return result;
  }

  unsigned int variableCount() const {
    return _system.variableCount();
  }

  Variable<Domain>& variable(unsigned int i) const {
    return _system.variable(i);
  }

  StableVariableAssignment<Domain>* assignment(const Domain& v) const {
    return _system.assignment(v);
  }

  bool equalAssignments(const VariableAssignment<Domain>& l, const VariableAssignment<Domain>& r) const {
    return _system.equalAssignments(l, r);
  }


  struct ImprovementOperator {
    virtual ~ImprovementOperator() { }
    virtual bool improve(MaxStrategy<Domain>&, const VariableAssignment<Domain>&) const = 0;
  };
  bool improve(const ImprovementOperator& op, const VariableAssignment<Domain>& rho) {
    return op.improve(*this, rho);
  }

  struct NaiveImprovementOperator : public ImprovementOperator {
    bool improve(MaxStrategy<Domain>& strat, const VariableAssignment<Domain>& rho) const {
      bool changed = false;
      for (unsigned int i = 0, length = strat._system.maxExpressionCount();
           i < length;
           ++i) {
        MaxExpression<Domain>& expr = strat._system.maxExpression(i);
        Domain bestValue = MaxStrategyExpression<Domain>(expr, strat._strategy).eval(rho);
        unsigned int lastIndex= strat.get(expr);
        unsigned int bestIndex = strat.get(expr);

        // this relies on the fact that an expression will only be proessed after the expressions
        // it depends on (which should always be true, as they form a DAG)
        const std::vector<Expression<Domain>*> args = expr.arguments();
        for (unsigned int j = 0, length = args.size();
             j < length;
             ++j) {
          const Domain value = MaxStrategyExpression<Domain>(*args[j], strat._strategy).eval(rho);
          if (bestValue < value) {
            bestValue = value;
            bestIndex = j;
          }
        }
        if (bestIndex != lastIndex) {
          changed = true;
          strat.set(expr, bestIndex);
        }
      }
      return changed;
    }
  };
    
  struct RepeatedImprovementOperator : public ImprovementOperator {
    RepeatedImprovementOperator(const ImprovementOperator& op)
      : _subImprovement(op) { }
    bool improve(MaxStrategy<Domain>& strat, const VariableAssignment<Domain>& rho) const {
      if (_subImprovement.improve(strat, rho)) {
        VariableAssignment<Domain>* rho2 = strat.eval(rho);
        improve(strat, *rho2);
        delete rho2;
        return true;
      }
      return false;
    }
    private:
    const ImprovementOperator& _subImprovement;
  };

  void print(std::ostream& cout) const {
    cout << _system << std::endl;
  }

  private:
  const ConcreteEquationSystem<Domain>& _system;
  mutable IdMap<Variable<Domain>,Expression<Domain>*> _expressions;
  IdMap<MaxExpression<Domain>,unsigned int> _strategy;
};

#endif