Cerata
A library to generate structural hardware designs
expression.cc
1 // Copyright 2018-2019 Delft University of Technology
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <utility>
17 
18 #include "cerata/expression.h"
19 #include "cerata/graph.h"
20 
21 namespace cerata {
22 
23 static std::string ToString(Node *n) {
24  std::stringstream ss;
25  ss << n;
26  return ss.str();
27 }
28 
29 static std::string ToString(const std::shared_ptr<Node>& n) {
30  std::stringstream ss;
31  ss << n;
32  return ss.str();
33 }
34 
35 std::shared_ptr<Expression> Expression::Make(Op op, std::shared_ptr<Node> lhs, std::shared_ptr<Node> rhs) {
36  auto e = new Expression(op, std::move(lhs), std::move(rhs));
37  auto result = std::shared_ptr<Expression>(e);
38  if (e->parent()) {
39  e->parent().value()->Add(result);
40  }
41  return result;
42 }
43 
44 // Hash the the node pointers into a short string.
45 static std::string GenerateName(Expression *expr, std::shared_ptr<Node> lhs, std::shared_ptr<Node> rhs) {
46  auto l = ::cerata::ToString(std::move(lhs));
47  auto e = ::cerata::ToString(expr);
48  auto r = ::cerata::ToString(std::move(rhs));
49  std::string result = "Expr_" + l + e + r;
50  return result;
51 }
52 
53 Expression::Expression(Expression::Op op, std::shared_ptr<Node> lhs, std::shared_ptr<Node> rhs)
54  : MultiOutputNode(GenerateName(this, lhs, rhs), NodeID::EXPRESSION, string()),
55  operation_(op),
56  lhs_(std::move(lhs)),
57  rhs_(std::move(rhs)) {
58  if (lhs_->parent() && rhs_->parent()) {
59  auto lp = *lhs_->parent();
60  auto rp = *rhs_->parent();
61  if (lp != rp) {
62  CERATA_LOG(ERROR, "Can only generate expressions between nodes on same parent.");
63  }
64  }
65  if (lhs_->parent()) {
66  auto lp = *lhs_->parent();
67  SetParent(lp);
68  } else if (rhs_->parent()) {
69  auto rp = *rhs_->parent();
70  SetParent(rp);
71  }
72 }
73 
74 std::shared_ptr<Node> Expression::MergeIntLiterals(Expression *exp) {
75  if (exp->lhs_->IsLiteral() && exp->rhs_->IsLiteral()) {
76  auto l = std::dynamic_pointer_cast<const Literal>(exp->lhs_);
77  auto r = std::dynamic_pointer_cast<const Literal>(exp->rhs_);
78  if ((l->storage_type() == Literal::StorageType::INT)
79  && (r->storage_type() == Literal::StorageType::INT)
80  && (l->type() == r->type())) {
81  std::shared_ptr<Node> new_node;
82  switch (exp->operation_) {
83  case Op::ADD: return intl(l->IntValue() + r->IntValue());
84  case Op::SUB: return intl(l->IntValue() - r->IntValue());
85  case Op::MUL: return intl(l->IntValue() * r->IntValue());
86  case Op::DIV: return intl(l->IntValue() / r->IntValue());
87  }
88  }
89  }
90  return exp->shared_from_this();
91 }
92 
93 std::shared_ptr<Node> Expression::EliminateZeroOne(Expression *exp) {
94  switch (exp->operation_) {
95  case Op::ADD: {
96  if (exp->lhs_ == intl(0)) return exp->rhs_;
97  if (exp->rhs_ == intl(0)) return exp->lhs_;
98  break;
99  }
100  case Op::SUB: {
101  if (exp->rhs_ == intl(0)) return exp->lhs_;
102  break;
103  }
104  case Op::MUL: {
105  if (exp->lhs_ == intl(0)) return intl(0);
106  if (exp->rhs_ == intl(0)) return intl(0);
107  if (exp->lhs_ == intl(1)) return exp->rhs_;
108  if (exp->rhs_ == intl(1)) return exp->lhs_;
109  break;
110  }
111  case Op::DIV: {
112  if (exp->lhs_ == intl(0) && exp->rhs_ != intl(0)) return intl(0);
113  if (exp->rhs_ == intl(0)) { CERATA_LOG(FATAL, "Division by 0."); }
114  if (exp->rhs_ == intl(1)) return exp->lhs_;
115  break;
116  }
117  }
118  return exp->shared_from_this();
119 }
120 
121 std::shared_ptr<Node> Expression::Minimize(Node *node) {
122  std::shared_ptr<Node> result = node->shared_from_this();
123 
124  // If this node is an expression, we need to minimize its lhs and rhs first.
125  if (node->IsExpression()) {
126  auto expr = std::dynamic_pointer_cast<Expression>(result);
127  // Attempt to minimize children.
128  auto min_lhs = Minimize(expr->lhs());
129  auto min_rhs = Minimize(expr->rhs());
130 
131  // If minimization took place in either node, create a new expression with the minimized nodes.
132  if ((min_lhs != expr->lhs_) || (min_rhs != expr->rhs_)) {
133  expr = Expression::Make(expr->operation_, min_lhs, min_rhs);
134  }
135 
136  // Apply zero/one elimination on the minimized expression.
137  result = EliminateZeroOne(expr.get());
138 
139  // Integer literal merging
140  if (result->IsExpression()) {
141  expr = std::dynamic_pointer_cast<Expression>(result);
142  result = MergeIntLiterals(expr.get());
143  }
144  // TODO(johanpel): put some more elaborate minimization function/rules etc.. here
145  }
146  return result;
147 }
148 
149 std::string ToString(Expression::Op operation) {
150  switch (operation) {
151  case Expression::Op::ADD:return "+";
152  case Expression::Op::SUB:return "-";
153  case Expression::Op::MUL:return "*";
154  case Expression::Op::DIV:return "/";
155  }
156  return "INVALID OP";
157 }
158 
159 std::string Expression::ToString() const {
160  auto min = Minimize(const_cast<Expression *>(this));
161  if (min->IsExpression()) {
162  auto mine = std::dynamic_pointer_cast<Expression>(min);
163  auto ls = mine->lhs_->ToString();
164  auto op = cerata::ToString(mine->operation_);
165  auto rs = mine->rhs_->ToString();
166  return ls + op + rs;
167  } else {
168  return min->ToString();
169  }
170 }
171 
172 std::shared_ptr<Object> Expression::Copy() const {
174  std::dynamic_pointer_cast<Node>(lhs_->Copy()),
175  std::dynamic_pointer_cast<Node>(rhs_->Copy()));
176 }
177 
178 Node *Expression::CopyOnto(Graph *dst, const std::string &name, NodeMap *rebinding) const {
179  auto new_lhs = this->lhs_;
180  auto new_rhs = this->rhs_;
181  ImplicitlyRebindNodes(dst, {lhs_.get(), rhs_.get()}, rebinding);
182  // Check for both sides if they were already in the rebind map.
183  // If not, make copies onto the graph for those nodes as well.
184  if (rebinding->count(lhs_.get()) > 0) {
185  new_lhs = rebinding->at(lhs_.get())->shared_from_this();
186  } else {
187  new_lhs = lhs_->CopyOnto(dst, lhs_->name(), rebinding)->shared_from_this();
188  }
189  if (rebinding->count(rhs_.get()) > 0) {
190  new_rhs = rebinding->at(rhs_.get())->shared_from_this();
191  } else {
192  new_rhs = rhs_->CopyOnto(dst, rhs_->name(), rebinding)->shared_from_this();
193  }
194  auto result = Expression::Make(operation_, new_lhs, new_rhs);
195  (*rebinding)[this] = result.get();
196  dst->Add(result);
197  return result.get();
198 }
199 
200 void Expression::AppendReferences(std::vector<Object *> *out) const {
201  out->push_back(lhs_.get());
202  lhs_->AppendReferences(out);
203  out->push_back(rhs_.get());
204  rhs_->AppendReferences(out);
205 }
206 
207 } // namespace cerata
cerata::Expression::Minimize
static std::shared_ptr< Node > Minimize(Node *node)
Minimize a node, if it is an expression, otherwise just returns a copy of the input.
Definition: expression.cc:121
cerata::Graph::Add
virtual Graph & Add(const std::shared_ptr< Object > &object)
Add an object to the component.
Definition: graph.cc:32
cerata::Expression::ToString
std::string ToString() const override
Minimize the expression and convert it to a human-readable string.
Definition: expression.cc:159
cerata::Expression::EliminateZeroOne
static std::shared_ptr< Node > EliminateZeroOne(Expression *exp)
Eliminate nodes that have zero or one on either side for specific expressions.
Definition: expression.cc:93
cerata::Expression
A node representing a binary tree of other nodes.
Definition: expression.h:34
cerata::Graph
A graph representing a hardware structure.
Definition: graph.h:37
cerata::Expression::Make
static std::shared_ptr< Expression > Make(Op op, std::shared_ptr< Node > lhs, std::shared_ptr< Node > rhs)
Short-hand to create a smart pointer to an expression.
Definition: expression.cc:35
cerata::Expression::Op
Op
Binary expression operator enum class.
Definition: expression.h:37
cerata::Object::SetParent
virtual void SetParent(Graph *parent)
Set the parent graph of this object.
Definition: object.cc:25
cerata::Expression::lhs_
std::shared_ptr< Node > lhs_
The left hand side node.
Definition: expression.h:77
cerata
Contains every Cerata class, function, etc...
Definition: api.h:41
cerata::Node
A node.
Definition: node.h:42
cerata::Expression::lhs
Node * lhs() const
Return the left-hand side node of the expression.
Definition: expression.h:54
cerata::NodeMap
std::unordered_map< const Node *, Node * > NodeMap
A mapping from one object to another object, used in e.g. type generic rebinding.
Definition: node.h:135
cerata::Node::NodeID
NodeID
Node type IDs with different properties.
Definition: node.h:45
cerata::Expression::rhs_
std::shared_ptr< Node > rhs_
The right hand side node.
Definition: expression.h:79
cerata::string
std::shared_ptr< Type > string()
Return a static string type.
Definition: type.cc:261
cerata::Expression::rhs
Node * rhs() const
Return the right-hand side node of the expression.
Definition: expression.h:56
cerata::intl
std::shared_ptr< Literal > intl(int64_t i)
Obtain a shared pointer to an integer literal from the default node pool.
Definition: pool.h:144
cerata::ToString
std::string ToString()
Return a human-readable string from a type.
Definition: utils.h:185
cerata::Expression::Expression
Expression(Op op, std::shared_ptr< Node > lhs, std::shared_ptr< Node > rhs)
Construct a new expression.
Definition: expression.cc:53
cerata::Expression::AppendReferences
void AppendReferences(std::vector< Object * > *out) const override
Depth-first traverse the expression tree and add any nodes owned.
Definition: expression.cc:200
cerata::Expression::MergeIntLiterals
static std::shared_ptr< Node > MergeIntLiterals(Expression *exp)
Merge expressions of integer literals into their resulting integer literal.
Definition: expression.cc:74
cerata::Expression::operation_
Op operation_
The binary operator of this expression.
Definition: expression.h:75
cerata::Expression::Copy
std::shared_ptr< Object > Copy() const override
Copy this expression.
Definition: expression.cc:172
cerata::Expression::CopyOnto
Node * CopyOnto(Graph *dst, const std::string &name, NodeMap *rebinding) const override
Copy this expression onto a graph and rebind anything in the expression tree.
Definition: expression.cc:178
cerata::ImplicitlyRebindNodes
void ImplicitlyRebindNodes(Graph *dst, const std::vector< Node * > &nodes, NodeMap *rebinding)
Make sure that the NodeMap contains all nodes to be rebound onto the destination graph.
Definition: node.cc:39
cerata::MultiOutputNode
A no-input, multiple-outputs node.
Definition: node.h:140