Fletchgen
The Fletcher Design Generator
basic_types.cc
1 // Copyright 2018-2019 Delft University of Technology
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "fletchgen/basic_types.h"
16 
17 #include <cerata/api.h>
18 #include <cerata/vhdl/vhdl.h>
19 #include <fletcher/common.h>
20 
21 #include <memory>
22 
23 namespace fletchgen {
24 
25 using cerata::vector;
26 using cerata::bit;
27 using cerata::intl;
28 using cerata::integer;
29 using cerata::parameter;
30 using cerata::field;
31 using cerata::record;
32 using cerata::stream;
33 
34 // Validity bit
35 BIT_FACTORY(validity)
36 
37 // Non-nullable fixed-width types.
38 VEC_FACTORY(bool_, 1)
39 VEC_FACTORY(int8, 8)
40 VEC_FACTORY(uint8, 8)
41 VEC_FACTORY(int16, 16)
42 VEC_FACTORY(uint16, 16)
43 VEC_FACTORY(int32, 32)
44 VEC_FACTORY(uint32, 32)
45 VEC_FACTORY(int64, 64)
46 VEC_FACTORY(uint64, 64)
47 VEC_FACTORY(float8, 8)
48 VEC_FACTORY(float16, 16)
49 VEC_FACTORY(float32, 32)
50 VEC_FACTORY(float64, 64)
51 VEC_FACTORY(date32, 32)
52 VEC_FACTORY(date64, 64)
53 VEC_FACTORY(time32, 32)
54 VEC_FACTORY(time64, 64)
55 VEC_FACTORY(timestamp, 64)
56 VEC_FACTORY(decimal128, 128)
57 VEC_FACTORY(utf8c, 8)
58 VEC_FACTORY(byte, 8)
59 VEC_FACTORY(offset, 32)
60 
61 // Create basic clock domains
62 std::shared_ptr<ClockDomain> kernel_cd() {
63  static std::shared_ptr<ClockDomain> result = std::make_shared<ClockDomain>("kcd");
64  return result;
65 }
66 
67 std::shared_ptr<ClockDomain> bus_cd() {
68  static std::shared_ptr<ClockDomain> result = std::make_shared<ClockDomain>("bcd");
69  return result;
70 }
71 
72 // Create basic clock & reset record type
73 std::shared_ptr<Type> cr() {
74  static std::shared_ptr<Type> result = record("cr", {
75  field("clk", bit()),
76  field("reset", bit())});
77  result->meta[cerata::vhdl::meta::NO_INSERT_SIGNAL] = "true";
78  return result;
79 }
80 
81 std::shared_ptr<Type> valid(int width, bool on_primitive) {
82  if (width > 1 || on_primitive) {
83  return vector("valid", width);
84  } else {
85  return bit("valid");
86  }
87 }
88 
89 std::shared_ptr<Type> ready(int width, bool on_primitive) {
90  if (width > 1 || on_primitive) {
91  return vector("ready", width);
92  } else {
93  return bit("ready");
94  }
95 }
96 
97 // Data channel
98 std::shared_ptr<Type> data(int width) {
99  std::shared_ptr<Type> result = vector("data", width);
100  // Mark this type so later we can figure out that it was concatenated onto the data port of an ArrayReader/Writer.
101  result->meta[meta::ARRAY_DATA] = "true";
102  return result;
103 }
104 
105 // Length channel
106 std::shared_ptr<Type> length(int width) {
107  std::shared_ptr<Type> result = vector("length", width);
108  // Mark this type so later we can figure out that it was concatenated onto the data port of an ArrayReader/Writer.
109  result->meta[meta::ARRAY_DATA] = "true";
110  return result;
111 }
112 
113 std::shared_ptr<Type> count(int width) {
114  std::shared_ptr<Type> result = vector(width);
115  // Mark this type so later we can figure out that it was concatenated onto the data port of an ArrayReader/Writer.
116  result->meta[meta::ARRAY_DATA] = "true";
117  result->meta[meta::COUNT] = std::to_string(width);
118  return result;
119 }
120 
121 std::shared_ptr<Type> dvalid(int width, bool on_primitive) {
122  if (width > 1 || on_primitive) {
123  return vector("dvalid", width);
124  } else {
125  return bit("dvalid");
126  }
127 }
128 
129 std::shared_ptr<Type> last(int width, bool on_primitive) {
130  std::shared_ptr<Type> result;
131  if (width > 1 || on_primitive) {
132  result = vector("last", width);
133  } else {
134  result = bit("last");
135  }
136  result->meta[meta::LAST] = "true";
137  return result;
138 }
139 
140 int GetFixedWidthTypeBitWidth(const arrow::DataType &arrow_type) {
141  auto fwt = dynamic_cast<const arrow::FixedWidthType *>(&arrow_type);
142  if (fwt == nullptr) {
143  FLETCHER_LOG(ERROR, "Not a fixed-width Arrow type: " + arrow_type.ToString());
144  }
145  return fwt->bit_width();
146 }
147 
148 std::shared_ptr<Type> ConvertFixedWidthType(const std::shared_ptr<arrow::DataType> &arrow_type, int epc) {
149  if (epc == 1) {
150  // Only need to cover fixed-width data types in this function
151  switch (arrow_type->id()) {
152  case arrow::Type::UINT8: return uint8();
153  case arrow::Type::UINT16: return uint16();
154  case arrow::Type::UINT32: return uint32();
155  case arrow::Type::UINT64: return uint64();
156  case arrow::Type::INT8: return int8();
157  case arrow::Type::INT16: return int16();
158  case arrow::Type::INT32: return int32();
159  case arrow::Type::INT64: return int64();
160  case arrow::Type::HALF_FLOAT: return float16();
161  case arrow::Type::FLOAT: return float32();
162  case arrow::Type::DOUBLE: return float64();
163  case arrow::Type::BOOL: return bool_();
164  case arrow::Type::DATE32: return date32();
165  case arrow::Type::DATE64: return date64();
166  case arrow::Type::TIME32: return time32();
167  case arrow::Type::TIME64: return time64();
168  case arrow::Type::TIMESTAMP: return timestamp();
169  case arrow::Type::DECIMAL: return decimal128();
170  default:throw std::runtime_error("Unsupported Arrow DataType: " + arrow_type->ToString());
171  }
172  } else {
173  auto fwt = std::dynamic_pointer_cast<arrow::FixedWidthType>(arrow_type);
174  if (fwt == nullptr) {
175  FLETCHER_LOG(ERROR, "Not a fixed-width Arrow type: " + arrow_type->ToString());
176  }
177  return cerata::vector(epc * fwt->bit_width());
178  }
179 }
180 
181 std::optional<cerata::Port *> GetClockResetPort(cerata::Graph *graph, const ClockDomain &domain) {
182  for (auto crn : graph->GetNodes()) {
183  if (crn->type()->IsEqual(*cr()) && crn->IsPort()) {
184  // TODO(johanpel): better comparison
185  if (crn->AsPort()->domain().get() == &domain) {
186  return crn->AsPort();
187  }
188  }
189  }
190  return std::nullopt;
191 }
192 
193 } // namespace fletchgen
constexpr char LAST[]
Key to mark the last field in Arrow data streams.
Definition: basic_types.h:32
constexpr char ARRAY_DATA[]
Key for automated type mapping.
Definition: basic_types.h:28
constexpr char COUNT[]
Key to mark the count field in Arrow data streams.
Definition: basic_types.h:30
Contains all classes and functions related to Fletchgen.
Definition: array.cc:29
std::shared_ptr< ClockDomain > kernel_cd()
Fletcher accelerator clock domain.
Definition: basic_types.cc:62
std::shared_ptr< Type > valid(int width, bool on_primitive)
Fletcher valid.
Definition: basic_types.cc:81
std::shared_ptr< Type > cr()
Fletcher clock/reset;.
Definition: basic_types.cc:73
std::shared_ptr< Type > data(int width)
Fletcher data.
Definition: basic_types.cc:98
std::shared_ptr< Type > ConvertFixedWidthType(const std::shared_ptr< arrow::DataType > &arrow_type, int epc)
Convert a fixed-width arrow::DataType to a fixed-width Fletcher Type.
Definition: basic_types.cc:148
std::shared_ptr< Type > ready(int width, bool on_primitive)
Fletcher ready.
Definition: basic_types.cc:89
std::shared_ptr< Type > dvalid(int width, bool on_primitive)
Fletcher dvalid.
Definition: basic_types.cc:121
std::shared_ptr< Type > length(int width)
Fletcher length.
Definition: basic_types.cc:106
std::shared_ptr< Type > count(int width)
Fletcher count.
Definition: basic_types.cc:113
int GetFixedWidthTypeBitWidth(const arrow::DataType &arrow_type)
Returns the bit-width of a fixed-width Arrow type. Throws if it's not a fixed-width type.
Definition: basic_types.cc:140
std::shared_ptr< ClockDomain > bus_cd()
Fletcher bus clock domain.
Definition: basic_types.cc:67
std::shared_ptr< Type > last(int width, bool on_primitive)
Fletcher last.
Definition: basic_types.cc:129
std::optional< cerata::Port * > GetClockResetPort(cerata::Graph *graph, const ClockDomain &domain)
Return the clock/reset port of a graph for a specific clock domain, if it exists.
Definition: basic_types.cc:181