1 //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/Support/MathExtras.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // MemRefDescriptor implementation
20 //===----------------------------------------------------------------------===//
21 
22 /// Construct a helper for the given descriptor value.
MemRefDescriptor(Value descriptor)23 MemRefDescriptor::MemRefDescriptor(Value descriptor)
24     : StructBuilder(descriptor) {
25   assert(value != nullptr && "value cannot be null");
26   indexType = value.getType()
27                   .cast<LLVM::LLVMStructType>()
28                   .getBody()[kOffsetPosInMemRefDescriptor];
29 }
30 
31 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)32 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
33                                          Type descriptorType) {
34 
35   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36   return MemRefDescriptor(descriptor);
37 }
38 
39 /// Builds IR creating a MemRef descriptor that represents `type` and
40 /// populates it with static shape and stride information extracted from the
41 /// type.
42 MemRefDescriptor
fromStaticShape(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,MemRefType type,Value memory)43 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
44                                   LLVMTypeConverter &typeConverter,
45                                   MemRefType type, Value memory) {
46   assert(type.hasStaticShape() && "unexpected dynamic shape");
47 
48   // Extract all strides and offsets and verify they are static.
49   int64_t offset;
50   SmallVector<int64_t, 4> strides;
51   auto result = getStridesAndOffset(type, strides, offset);
52   (void)result;
53   assert(succeeded(result) && "unexpected failure in stride computation");
54   assert(!ShapedType::isDynamicStrideOrOffset(offset) &&
55          "expected static offset");
56   assert(!llvm::any_of(strides, ShapedType::isDynamicStrideOrOffset) &&
57          "expected static strides");
58 
59   auto convertedType = typeConverter.convertType(type);
60   assert(convertedType && "unexpected failure in memref type conversion");
61 
62   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
63   descr.setAllocatedPtr(builder, loc, memory);
64   descr.setAlignedPtr(builder, loc, memory);
65   descr.setConstantOffset(builder, loc, offset);
66 
67   // Fill in sizes and strides
68   for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
69     descr.setConstantSize(builder, loc, i, type.getDimSize(i));
70     descr.setConstantStride(builder, loc, i, strides[i]);
71   }
72   return descr;
73 }
74 
75 /// Builds IR extracting the allocated pointer from the descriptor.
allocatedPtr(OpBuilder & builder,Location loc)76 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
77   return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
78 }
79 
80 /// Builds IR inserting the allocated pointer into the descriptor.
setAllocatedPtr(OpBuilder & builder,Location loc,Value ptr)81 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
82                                        Value ptr) {
83   setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
84 }
85 
86 /// Builds IR extracting the aligned pointer from the descriptor.
alignedPtr(OpBuilder & builder,Location loc)87 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
88   return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
89 }
90 
91 /// Builds IR inserting the aligned pointer into the descriptor.
setAlignedPtr(OpBuilder & builder,Location loc,Value ptr)92 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
93                                      Value ptr) {
94   setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
95 }
96 
97 // Creates a constant Op producing a value of `resultType` from an index-typed
98 // integer attribute.
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)99 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
100                                      Type resultType, int64_t value) {
101   return builder.create<LLVM::ConstantOp>(
102       loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
103 }
104 
105 /// Builds IR extracting the offset from the descriptor.
offset(OpBuilder & builder,Location loc)106 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
107   return builder.create<LLVM::ExtractValueOp>(
108       loc, indexType, value,
109       builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
110 }
111 
112 /// Builds IR inserting the offset into the descriptor.
setOffset(OpBuilder & builder,Location loc,Value offset)113 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
114                                  Value offset) {
115   value = builder.create<LLVM::InsertValueOp>(
116       loc, structType, value, offset,
117       builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
118 }
119 
120 /// Builds IR inserting the offset into the descriptor.
setConstantOffset(OpBuilder & builder,Location loc,uint64_t offset)121 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
122                                          uint64_t offset) {
123   setOffset(builder, loc,
124             createIndexAttrConstant(builder, loc, indexType, offset));
125 }
126 
127 /// Builds IR extracting the pos-th size from the descriptor.
size(OpBuilder & builder,Location loc,unsigned pos)128 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
129   return builder.create<LLVM::ExtractValueOp>(
130       loc, indexType, value,
131       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
132 }
133 
size(OpBuilder & builder,Location loc,Value pos,int64_t rank)134 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
135                              int64_t rank) {
136   auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
137   auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
138   auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
139 
140   // Copy size values to stack-allocated memory.
141   auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
142   auto one = createIndexAttrConstant(builder, loc, indexType, 1);
143   auto sizes = builder.create<LLVM::ExtractValueOp>(
144       loc, arrayTy, value,
145       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
146   auto sizesPtr =
147       builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
148   builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
149 
150   // Load an return size value of interest.
151   auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
152                                                ValueRange({zero, pos}));
153   return builder.create<LLVM::LoadOp>(loc, resultPtr);
154 }
155 
156 /// Builds IR inserting the pos-th size into the descriptor
setSize(OpBuilder & builder,Location loc,unsigned pos,Value size)157 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
158                                Value size) {
159   value = builder.create<LLVM::InsertValueOp>(
160       loc, structType, value, size,
161       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
162 }
163 
setConstantSize(OpBuilder & builder,Location loc,unsigned pos,uint64_t size)164 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
165                                        unsigned pos, uint64_t size) {
166   setSize(builder, loc, pos,
167           createIndexAttrConstant(builder, loc, indexType, size));
168 }
169 
170 /// Builds IR extracting the pos-th stride from the descriptor.
stride(OpBuilder & builder,Location loc,unsigned pos)171 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
172   return builder.create<LLVM::ExtractValueOp>(
173       loc, indexType, value,
174       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
175 }
176 
177 /// Builds IR inserting the pos-th stride into the descriptor
setStride(OpBuilder & builder,Location loc,unsigned pos,Value stride)178 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
179                                  Value stride) {
180   value = builder.create<LLVM::InsertValueOp>(
181       loc, structType, value, stride,
182       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
183 }
184 
setConstantStride(OpBuilder & builder,Location loc,unsigned pos,uint64_t stride)185 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
186                                          unsigned pos, uint64_t stride) {
187   setStride(builder, loc, pos,
188             createIndexAttrConstant(builder, loc, indexType, stride));
189 }
190 
getElementPtrType()191 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
192   return value.getType()
193       .cast<LLVM::LLVMStructType>()
194       .getBody()[kAlignedPtrPosInMemRefDescriptor]
195       .cast<LLVM::LLVMPointerType>();
196 }
197 
198 /// Creates a MemRef descriptor structure from a list of individual values
199 /// composing that descriptor, in the following order:
200 /// - allocated pointer;
201 /// - aligned pointer;
202 /// - offset;
203 /// - <rank> sizes;
204 /// - <rank> shapes;
205 /// where <rank> is the MemRef rank as provided in `type`.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,MemRefType type,ValueRange values)206 Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
207                              LLVMTypeConverter &converter, MemRefType type,
208                              ValueRange values) {
209   Type llvmType = converter.convertType(type);
210   auto d = MemRefDescriptor::undef(builder, loc, llvmType);
211 
212   d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
213   d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
214   d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
215 
216   int64_t rank = type.getRank();
217   for (unsigned i = 0; i < rank; ++i) {
218     d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
219     d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
220   }
221 
222   return d;
223 }
224 
225 /// Builds IR extracting individual elements of a MemRef descriptor structure
226 /// and returning them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,MemRefType type,SmallVectorImpl<Value> & results)227 void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
228                               MemRefType type,
229                               SmallVectorImpl<Value> &results) {
230   int64_t rank = type.getRank();
231   results.reserve(results.size() + getNumUnpackedValues(type));
232 
233   MemRefDescriptor d(packed);
234   results.push_back(d.allocatedPtr(builder, loc));
235   results.push_back(d.alignedPtr(builder, loc));
236   results.push_back(d.offset(builder, loc));
237   for (int64_t i = 0; i < rank; ++i)
238     results.push_back(d.size(builder, loc, i));
239   for (int64_t i = 0; i < rank; ++i)
240     results.push_back(d.stride(builder, loc, i));
241 }
242 
243 /// Returns the number of non-aggregate values that would be produced by
244 /// `unpack`.
getNumUnpackedValues(MemRefType type)245 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
246   // Two pointers, offset, <rank> sizes, <rank> shapes.
247   return 3 + 2 * type.getRank();
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // MemRefDescriptorView implementation.
252 //===----------------------------------------------------------------------===//
253 
MemRefDescriptorView(ValueRange range)254 MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
255     : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
256 
allocatedPtr()257 Value MemRefDescriptorView::allocatedPtr() {
258   return elements[kAllocatedPtrPosInMemRefDescriptor];
259 }
260 
alignedPtr()261 Value MemRefDescriptorView::alignedPtr() {
262   return elements[kAlignedPtrPosInMemRefDescriptor];
263 }
264 
offset()265 Value MemRefDescriptorView::offset() {
266   return elements[kOffsetPosInMemRefDescriptor];
267 }
268 
size(unsigned pos)269 Value MemRefDescriptorView::size(unsigned pos) {
270   return elements[kSizePosInMemRefDescriptor + pos];
271 }
272 
stride(unsigned pos)273 Value MemRefDescriptorView::stride(unsigned pos) {
274   return elements[kSizePosInMemRefDescriptor + rank + pos];
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // UnrankedMemRefDescriptor implementation
279 //===----------------------------------------------------------------------===//
280 
281 /// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor(Value descriptor)282 UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
283     : StructBuilder(descriptor) {}
284 
285 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)286 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
287                                                          Location loc,
288                                                          Type descriptorType) {
289   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
290   return UnrankedMemRefDescriptor(descriptor);
291 }
rank(OpBuilder & builder,Location loc)292 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
293   return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
294 }
setRank(OpBuilder & builder,Location loc,Value v)295 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
296                                        Value v) {
297   setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
298 }
memRefDescPtr(OpBuilder & builder,Location loc)299 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
300                                               Location loc) {
301   return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
302 }
setMemRefDescPtr(OpBuilder & builder,Location loc,Value v)303 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
304                                                 Location loc, Value v) {
305   setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
306 }
307 
308 /// Builds IR populating an unranked MemRef descriptor structure from a list
309 /// of individual constituent values in the following order:
310 /// - rank of the memref;
311 /// - pointer to the memref descriptor.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,UnrankedMemRefType type,ValueRange values)312 Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
313                                      LLVMTypeConverter &converter,
314                                      UnrankedMemRefType type,
315                                      ValueRange values) {
316   Type llvmType = converter.convertType(type);
317   auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
318 
319   d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
320   d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
321   return d;
322 }
323 
324 /// Builds IR extracting individual elements that compose an unranked memref
325 /// descriptor and returns them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,SmallVectorImpl<Value> & results)326 void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
327                                       Value packed,
328                                       SmallVectorImpl<Value> &results) {
329   UnrankedMemRefDescriptor d(packed);
330   results.reserve(results.size() + 2);
331   results.push_back(d.rank(builder, loc));
332   results.push_back(d.memRefDescPtr(builder, loc));
333 }
334 
computeSizes(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,ArrayRef<UnrankedMemRefDescriptor> values,SmallVectorImpl<Value> & sizes)335 void UnrankedMemRefDescriptor::computeSizes(
336     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
337     ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
338   if (values.empty())
339     return;
340 
341   // Cache the index type.
342   Type indexType = typeConverter.getIndexType();
343 
344   // Initialize shared constants.
345   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
346   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
347   Value pointerSize = createIndexAttrConstant(
348       builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
349   Value indexSize =
350       createIndexAttrConstant(builder, loc, indexType,
351                               ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
352 
353   sizes.reserve(sizes.size() + values.size());
354   for (UnrankedMemRefDescriptor desc : values) {
355     // Emit IR computing the memory necessary to store the descriptor. This
356     // assumes the descriptor to be
357     //   { type*, type*, index, index[rank], index[rank] }
358     // and densely packed, so the total size is
359     //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
360     // TODO: consider including the actual size (including eventual padding due
361     // to data layout) into the unranked descriptor.
362     Value doublePointerSize =
363         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
364 
365     // (1 + 2 * rank) * sizeof(index)
366     Value rank = desc.rank(builder, loc);
367     Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
368     Value doubleRankIncremented =
369         builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
370     Value rankIndexSize = builder.create<LLVM::MulOp>(
371         loc, indexType, doubleRankIncremented, indexSize);
372 
373     // Total allocation size.
374     Value allocationSize = builder.create<LLVM::AddOp>(
375         loc, indexType, doublePointerSize, rankIndexSize);
376     sizes.push_back(allocationSize);
377   }
378 }
379 
allocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType)380 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
381                                              Value memRefDescPtr,
382                                              Type elemPtrPtrType) {
383 
384   Value elementPtrPtr =
385       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
386   return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
387 }
388 
setAllocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType,Value allocatedPtr)389 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
390                                                Value memRefDescPtr,
391                                                Type elemPtrPtrType,
392                                                Value allocatedPtr) {
393   Value elementPtrPtr =
394       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
395   builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
396 }
397 
alignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)398 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
399                                            LLVMTypeConverter &typeConverter,
400                                            Value memRefDescPtr,
401                                            Type elemPtrPtrType) {
402   Value elementPtrPtr =
403       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
404 
405   Value one =
406       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
407   Value alignedGep = builder.create<LLVM::GEPOp>(
408       loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
409   return builder.create<LLVM::LoadOp>(loc, alignedGep);
410 }
411 
setAlignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value alignedPtr)412 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
413                                              LLVMTypeConverter &typeConverter,
414                                              Value memRefDescPtr,
415                                              Type elemPtrPtrType,
416                                              Value alignedPtr) {
417   Value elementPtrPtr =
418       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
419 
420   Value one =
421       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
422   Value alignedGep = builder.create<LLVM::GEPOp>(
423       loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
424   builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
425 }
426 
offset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)427 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
428                                        LLVMTypeConverter &typeConverter,
429                                        Value memRefDescPtr,
430                                        Type elemPtrPtrType) {
431   Value elementPtrPtr =
432       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
433 
434   Value two =
435       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
436   Value offsetGep = builder.create<LLVM::GEPOp>(
437       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
438   offsetGep = builder.create<LLVM::BitcastOp>(
439       loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
440   return builder.create<LLVM::LoadOp>(loc, offsetGep);
441 }
442 
setOffset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value offset)443 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
444                                          LLVMTypeConverter &typeConverter,
445                                          Value memRefDescPtr,
446                                          Type elemPtrPtrType, Value offset) {
447   Value elementPtrPtr =
448       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
449 
450   Value two =
451       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
452   Value offsetGep = builder.create<LLVM::GEPOp>(
453       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
454   offsetGep = builder.create<LLVM::BitcastOp>(
455       loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
456   builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
457 }
458 
sizeBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMPointerType elemPtrPtrType)459 Value UnrankedMemRefDescriptor::sizeBasePtr(
460     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
461     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
462   Type elemPtrTy = elemPtrPtrType.getElementType();
463   Type indexTy = typeConverter.getIndexType();
464   Type structPtrTy =
465       LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
466           indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
467   Value structPtr =
468       builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
469 
470   Type int32Type = typeConverter.convertType(builder.getI32Type());
471   Value zero =
472       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
473   Value three = builder.create<LLVM::ConstantOp>(loc, int32Type,
474                                                  builder.getI32IntegerAttr(3));
475   return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
476                                      structPtr, ValueRange({zero, three}));
477 }
478 
size(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value index)479 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
480                                      LLVMTypeConverter &typeConverter,
481                                      Value sizeBasePtr, Value index) {
482   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
483   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
484                                                    ValueRange({index}));
485   return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
486 }
487 
setSize(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value index,Value size)488 void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
489                                        LLVMTypeConverter &typeConverter,
490                                        Value sizeBasePtr, Value index,
491                                        Value size) {
492   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
493   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
494                                                    ValueRange({index}));
495   builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
496 }
497 
strideBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value rank)498 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
499                                               LLVMTypeConverter &typeConverter,
500                                               Value sizeBasePtr, Value rank) {
501   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
502   return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
503                                      ValueRange({rank}));
504 }
505 
stride(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value strideBasePtr,Value index,Value stride)506 Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
507                                        LLVMTypeConverter &typeConverter,
508                                        Value strideBasePtr, Value index,
509                                        Value stride) {
510   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
511   Value strideStoreGep = builder.create<LLVM::GEPOp>(
512       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
513   return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
514 }
515 
setStride(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value strideBasePtr,Value index,Value stride)516 void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
517                                          LLVMTypeConverter &typeConverter,
518                                          Value strideBasePtr, Value index,
519                                          Value stride) {
520   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
521   Value strideStoreGep = builder.create<LLVM::GEPOp>(
522       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
523   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
524 }
525