1 //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/Support/MathExtras.h"
15
16 using namespace mlir;
17
18 //===----------------------------------------------------------------------===//
19 // MemRefDescriptor implementation
20 //===----------------------------------------------------------------------===//
21
22 /// Construct a helper for the given descriptor value.
MemRefDescriptor(Value descriptor)23 MemRefDescriptor::MemRefDescriptor(Value descriptor)
24 : StructBuilder(descriptor) {
25 assert(value != nullptr && "value cannot be null");
26 indexType = value.getType()
27 .cast<LLVM::LLVMStructType>()
28 .getBody()[kOffsetPosInMemRefDescriptor];
29 }
30
31 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)32 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
33 Type descriptorType) {
34
35 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36 return MemRefDescriptor(descriptor);
37 }
38
39 /// Builds IR creating a MemRef descriptor that represents `type` and
40 /// populates it with static shape and stride information extracted from the
41 /// type.
42 MemRefDescriptor
fromStaticShape(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,MemRefType type,Value memory)43 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
44 LLVMTypeConverter &typeConverter,
45 MemRefType type, Value memory) {
46 assert(type.hasStaticShape() && "unexpected dynamic shape");
47
48 // Extract all strides and offsets and verify they are static.
49 int64_t offset;
50 SmallVector<int64_t, 4> strides;
51 auto result = getStridesAndOffset(type, strides, offset);
52 (void)result;
53 assert(succeeded(result) && "unexpected failure in stride computation");
54 assert(!ShapedType::isDynamicStrideOrOffset(offset) &&
55 "expected static offset");
56 assert(!llvm::any_of(strides, ShapedType::isDynamicStrideOrOffset) &&
57 "expected static strides");
58
59 auto convertedType = typeConverter.convertType(type);
60 assert(convertedType && "unexpected failure in memref type conversion");
61
62 auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
63 descr.setAllocatedPtr(builder, loc, memory);
64 descr.setAlignedPtr(builder, loc, memory);
65 descr.setConstantOffset(builder, loc, offset);
66
67 // Fill in sizes and strides
68 for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
69 descr.setConstantSize(builder, loc, i, type.getDimSize(i));
70 descr.setConstantStride(builder, loc, i, strides[i]);
71 }
72 return descr;
73 }
74
75 /// Builds IR extracting the allocated pointer from the descriptor.
allocatedPtr(OpBuilder & builder,Location loc)76 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
77 return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
78 }
79
80 /// Builds IR inserting the allocated pointer into the descriptor.
setAllocatedPtr(OpBuilder & builder,Location loc,Value ptr)81 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
82 Value ptr) {
83 setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
84 }
85
86 /// Builds IR extracting the aligned pointer from the descriptor.
alignedPtr(OpBuilder & builder,Location loc)87 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
88 return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
89 }
90
91 /// Builds IR inserting the aligned pointer into the descriptor.
setAlignedPtr(OpBuilder & builder,Location loc,Value ptr)92 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
93 Value ptr) {
94 setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
95 }
96
97 // Creates a constant Op producing a value of `resultType` from an index-typed
98 // integer attribute.
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)99 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
100 Type resultType, int64_t value) {
101 return builder.create<LLVM::ConstantOp>(
102 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
103 }
104
105 /// Builds IR extracting the offset from the descriptor.
offset(OpBuilder & builder,Location loc)106 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
107 return builder.create<LLVM::ExtractValueOp>(
108 loc, indexType, value,
109 builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
110 }
111
112 /// Builds IR inserting the offset into the descriptor.
setOffset(OpBuilder & builder,Location loc,Value offset)113 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
114 Value offset) {
115 value = builder.create<LLVM::InsertValueOp>(
116 loc, structType, value, offset,
117 builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
118 }
119
120 /// Builds IR inserting the offset into the descriptor.
setConstantOffset(OpBuilder & builder,Location loc,uint64_t offset)121 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
122 uint64_t offset) {
123 setOffset(builder, loc,
124 createIndexAttrConstant(builder, loc, indexType, offset));
125 }
126
127 /// Builds IR extracting the pos-th size from the descriptor.
size(OpBuilder & builder,Location loc,unsigned pos)128 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
129 return builder.create<LLVM::ExtractValueOp>(
130 loc, indexType, value,
131 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
132 }
133
size(OpBuilder & builder,Location loc,Value pos,int64_t rank)134 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
135 int64_t rank) {
136 auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
137 auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
138 auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
139
140 // Copy size values to stack-allocated memory.
141 auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
142 auto one = createIndexAttrConstant(builder, loc, indexType, 1);
143 auto sizes = builder.create<LLVM::ExtractValueOp>(
144 loc, arrayTy, value,
145 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
146 auto sizesPtr =
147 builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
148 builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
149
150 // Load an return size value of interest.
151 auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
152 ValueRange({zero, pos}));
153 return builder.create<LLVM::LoadOp>(loc, resultPtr);
154 }
155
156 /// Builds IR inserting the pos-th size into the descriptor
setSize(OpBuilder & builder,Location loc,unsigned pos,Value size)157 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
158 Value size) {
159 value = builder.create<LLVM::InsertValueOp>(
160 loc, structType, value, size,
161 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
162 }
163
setConstantSize(OpBuilder & builder,Location loc,unsigned pos,uint64_t size)164 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
165 unsigned pos, uint64_t size) {
166 setSize(builder, loc, pos,
167 createIndexAttrConstant(builder, loc, indexType, size));
168 }
169
170 /// Builds IR extracting the pos-th stride from the descriptor.
stride(OpBuilder & builder,Location loc,unsigned pos)171 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
172 return builder.create<LLVM::ExtractValueOp>(
173 loc, indexType, value,
174 builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
175 }
176
177 /// Builds IR inserting the pos-th stride into the descriptor
setStride(OpBuilder & builder,Location loc,unsigned pos,Value stride)178 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
179 Value stride) {
180 value = builder.create<LLVM::InsertValueOp>(
181 loc, structType, value, stride,
182 builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
183 }
184
setConstantStride(OpBuilder & builder,Location loc,unsigned pos,uint64_t stride)185 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
186 unsigned pos, uint64_t stride) {
187 setStride(builder, loc, pos,
188 createIndexAttrConstant(builder, loc, indexType, stride));
189 }
190
getElementPtrType()191 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
192 return value.getType()
193 .cast<LLVM::LLVMStructType>()
194 .getBody()[kAlignedPtrPosInMemRefDescriptor]
195 .cast<LLVM::LLVMPointerType>();
196 }
197
198 /// Creates a MemRef descriptor structure from a list of individual values
199 /// composing that descriptor, in the following order:
200 /// - allocated pointer;
201 /// - aligned pointer;
202 /// - offset;
203 /// - <rank> sizes;
204 /// - <rank> shapes;
205 /// where <rank> is the MemRef rank as provided in `type`.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,MemRefType type,ValueRange values)206 Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
207 LLVMTypeConverter &converter, MemRefType type,
208 ValueRange values) {
209 Type llvmType = converter.convertType(type);
210 auto d = MemRefDescriptor::undef(builder, loc, llvmType);
211
212 d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
213 d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
214 d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
215
216 int64_t rank = type.getRank();
217 for (unsigned i = 0; i < rank; ++i) {
218 d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
219 d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
220 }
221
222 return d;
223 }
224
225 /// Builds IR extracting individual elements of a MemRef descriptor structure
226 /// and returning them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,MemRefType type,SmallVectorImpl<Value> & results)227 void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
228 MemRefType type,
229 SmallVectorImpl<Value> &results) {
230 int64_t rank = type.getRank();
231 results.reserve(results.size() + getNumUnpackedValues(type));
232
233 MemRefDescriptor d(packed);
234 results.push_back(d.allocatedPtr(builder, loc));
235 results.push_back(d.alignedPtr(builder, loc));
236 results.push_back(d.offset(builder, loc));
237 for (int64_t i = 0; i < rank; ++i)
238 results.push_back(d.size(builder, loc, i));
239 for (int64_t i = 0; i < rank; ++i)
240 results.push_back(d.stride(builder, loc, i));
241 }
242
243 /// Returns the number of non-aggregate values that would be produced by
244 /// `unpack`.
getNumUnpackedValues(MemRefType type)245 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
246 // Two pointers, offset, <rank> sizes, <rank> shapes.
247 return 3 + 2 * type.getRank();
248 }
249
250 //===----------------------------------------------------------------------===//
251 // MemRefDescriptorView implementation.
252 //===----------------------------------------------------------------------===//
253
MemRefDescriptorView(ValueRange range)254 MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
255 : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
256
allocatedPtr()257 Value MemRefDescriptorView::allocatedPtr() {
258 return elements[kAllocatedPtrPosInMemRefDescriptor];
259 }
260
alignedPtr()261 Value MemRefDescriptorView::alignedPtr() {
262 return elements[kAlignedPtrPosInMemRefDescriptor];
263 }
264
offset()265 Value MemRefDescriptorView::offset() {
266 return elements[kOffsetPosInMemRefDescriptor];
267 }
268
size(unsigned pos)269 Value MemRefDescriptorView::size(unsigned pos) {
270 return elements[kSizePosInMemRefDescriptor + pos];
271 }
272
stride(unsigned pos)273 Value MemRefDescriptorView::stride(unsigned pos) {
274 return elements[kSizePosInMemRefDescriptor + rank + pos];
275 }
276
277 //===----------------------------------------------------------------------===//
278 // UnrankedMemRefDescriptor implementation
279 //===----------------------------------------------------------------------===//
280
281 /// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor(Value descriptor)282 UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
283 : StructBuilder(descriptor) {}
284
285 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)286 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
287 Location loc,
288 Type descriptorType) {
289 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
290 return UnrankedMemRefDescriptor(descriptor);
291 }
rank(OpBuilder & builder,Location loc)292 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
293 return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
294 }
setRank(OpBuilder & builder,Location loc,Value v)295 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
296 Value v) {
297 setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
298 }
memRefDescPtr(OpBuilder & builder,Location loc)299 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
300 Location loc) {
301 return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
302 }
setMemRefDescPtr(OpBuilder & builder,Location loc,Value v)303 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
304 Location loc, Value v) {
305 setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
306 }
307
308 /// Builds IR populating an unranked MemRef descriptor structure from a list
309 /// of individual constituent values in the following order:
310 /// - rank of the memref;
311 /// - pointer to the memref descriptor.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,UnrankedMemRefType type,ValueRange values)312 Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
313 LLVMTypeConverter &converter,
314 UnrankedMemRefType type,
315 ValueRange values) {
316 Type llvmType = converter.convertType(type);
317 auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
318
319 d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
320 d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
321 return d;
322 }
323
324 /// Builds IR extracting individual elements that compose an unranked memref
325 /// descriptor and returns them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,SmallVectorImpl<Value> & results)326 void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
327 Value packed,
328 SmallVectorImpl<Value> &results) {
329 UnrankedMemRefDescriptor d(packed);
330 results.reserve(results.size() + 2);
331 results.push_back(d.rank(builder, loc));
332 results.push_back(d.memRefDescPtr(builder, loc));
333 }
334
computeSizes(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,ArrayRef<UnrankedMemRefDescriptor> values,SmallVectorImpl<Value> & sizes)335 void UnrankedMemRefDescriptor::computeSizes(
336 OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
337 ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
338 if (values.empty())
339 return;
340
341 // Cache the index type.
342 Type indexType = typeConverter.getIndexType();
343
344 // Initialize shared constants.
345 Value one = createIndexAttrConstant(builder, loc, indexType, 1);
346 Value two = createIndexAttrConstant(builder, loc, indexType, 2);
347 Value pointerSize = createIndexAttrConstant(
348 builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
349 Value indexSize =
350 createIndexAttrConstant(builder, loc, indexType,
351 ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
352
353 sizes.reserve(sizes.size() + values.size());
354 for (UnrankedMemRefDescriptor desc : values) {
355 // Emit IR computing the memory necessary to store the descriptor. This
356 // assumes the descriptor to be
357 // { type*, type*, index, index[rank], index[rank] }
358 // and densely packed, so the total size is
359 // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
360 // TODO: consider including the actual size (including eventual padding due
361 // to data layout) into the unranked descriptor.
362 Value doublePointerSize =
363 builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
364
365 // (1 + 2 * rank) * sizeof(index)
366 Value rank = desc.rank(builder, loc);
367 Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
368 Value doubleRankIncremented =
369 builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
370 Value rankIndexSize = builder.create<LLVM::MulOp>(
371 loc, indexType, doubleRankIncremented, indexSize);
372
373 // Total allocation size.
374 Value allocationSize = builder.create<LLVM::AddOp>(
375 loc, indexType, doublePointerSize, rankIndexSize);
376 sizes.push_back(allocationSize);
377 }
378 }
379
allocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType)380 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
381 Value memRefDescPtr,
382 Type elemPtrPtrType) {
383
384 Value elementPtrPtr =
385 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
386 return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
387 }
388
setAllocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType,Value allocatedPtr)389 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
390 Value memRefDescPtr,
391 Type elemPtrPtrType,
392 Value allocatedPtr) {
393 Value elementPtrPtr =
394 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
395 builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
396 }
397
alignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)398 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
399 LLVMTypeConverter &typeConverter,
400 Value memRefDescPtr,
401 Type elemPtrPtrType) {
402 Value elementPtrPtr =
403 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
404
405 Value one =
406 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
407 Value alignedGep = builder.create<LLVM::GEPOp>(
408 loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
409 return builder.create<LLVM::LoadOp>(loc, alignedGep);
410 }
411
setAlignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value alignedPtr)412 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
413 LLVMTypeConverter &typeConverter,
414 Value memRefDescPtr,
415 Type elemPtrPtrType,
416 Value alignedPtr) {
417 Value elementPtrPtr =
418 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
419
420 Value one =
421 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
422 Value alignedGep = builder.create<LLVM::GEPOp>(
423 loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
424 builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
425 }
426
offset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)427 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
428 LLVMTypeConverter &typeConverter,
429 Value memRefDescPtr,
430 Type elemPtrPtrType) {
431 Value elementPtrPtr =
432 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
433
434 Value two =
435 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
436 Value offsetGep = builder.create<LLVM::GEPOp>(
437 loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
438 offsetGep = builder.create<LLVM::BitcastOp>(
439 loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
440 return builder.create<LLVM::LoadOp>(loc, offsetGep);
441 }
442
setOffset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value offset)443 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
444 LLVMTypeConverter &typeConverter,
445 Value memRefDescPtr,
446 Type elemPtrPtrType, Value offset) {
447 Value elementPtrPtr =
448 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
449
450 Value two =
451 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
452 Value offsetGep = builder.create<LLVM::GEPOp>(
453 loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
454 offsetGep = builder.create<LLVM::BitcastOp>(
455 loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
456 builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
457 }
458
sizeBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMPointerType elemPtrPtrType)459 Value UnrankedMemRefDescriptor::sizeBasePtr(
460 OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
461 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
462 Type elemPtrTy = elemPtrPtrType.getElementType();
463 Type indexTy = typeConverter.getIndexType();
464 Type structPtrTy =
465 LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
466 indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
467 Value structPtr =
468 builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
469
470 Type int32Type = typeConverter.convertType(builder.getI32Type());
471 Value zero =
472 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
473 Value three = builder.create<LLVM::ConstantOp>(loc, int32Type,
474 builder.getI32IntegerAttr(3));
475 return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
476 structPtr, ValueRange({zero, three}));
477 }
478
size(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value index)479 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
480 LLVMTypeConverter &typeConverter,
481 Value sizeBasePtr, Value index) {
482 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
483 Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
484 ValueRange({index}));
485 return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
486 }
487
setSize(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value index,Value size)488 void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
489 LLVMTypeConverter &typeConverter,
490 Value sizeBasePtr, Value index,
491 Value size) {
492 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
493 Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
494 ValueRange({index}));
495 builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
496 }
497
strideBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value rank)498 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
499 LLVMTypeConverter &typeConverter,
500 Value sizeBasePtr, Value rank) {
501 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
502 return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
503 ValueRange({rank}));
504 }
505
stride(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value strideBasePtr,Value index,Value stride)506 Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
507 LLVMTypeConverter &typeConverter,
508 Value strideBasePtr, Value index,
509 Value stride) {
510 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
511 Value strideStoreGep = builder.create<LLVM::GEPOp>(
512 loc, indexPtrTy, strideBasePtr, ValueRange({index}));
513 return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
514 }
515
setStride(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value strideBasePtr,Value index,Value stride)516 void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
517 LLVMTypeConverter &typeConverter,
518 Value strideBasePtr, Value index,
519 Value stride) {
520 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
521 Value strideStoreGep = builder.create<LLVM::GEPOp>(
522 loc, indexPtrTy, strideBasePtr, ValueRange({index}));
523 builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
524 }
525