1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25
26 using namespace mlir;
27 using namespace mlir::memref;
28
29 namespace {
30 /// Idiomatic saturated operations on offsets, sizes and strides.
31 namespace saturated_arith {
32 struct Wrapper {
stride__anon3c5201a10111::saturated_arith::Wrapper33 static Wrapper stride(int64_t v) {
34 return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
35 : Wrapper{false, v};
36 }
offset__anon3c5201a10111::saturated_arith::Wrapper37 static Wrapper offset(int64_t v) {
38 return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
39 : Wrapper{false, v};
40 }
size__anon3c5201a10111::saturated_arith::Wrapper41 static Wrapper size(int64_t v) {
42 return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
43 }
asOffset__anon3c5201a10111::saturated_arith::Wrapper44 int64_t asOffset() {
45 return saturated ? ShapedType::kDynamicStrideOrOffset : v;
46 }
asSize__anon3c5201a10111::saturated_arith::Wrapper47 int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; }
asStride__anon3c5201a10111::saturated_arith::Wrapper48 int64_t asStride() {
49 return saturated ? ShapedType::kDynamicStrideOrOffset : v;
50 }
operator ==__anon3c5201a10111::saturated_arith::Wrapper51 bool operator==(Wrapper other) {
52 return (saturated && other.saturated) ||
53 (!saturated && !other.saturated && v == other.v);
54 }
operator !=__anon3c5201a10111::saturated_arith::Wrapper55 bool operator!=(Wrapper other) { return !(*this == other); }
operator +__anon3c5201a10111::saturated_arith::Wrapper56 Wrapper operator+(Wrapper other) {
57 if (saturated || other.saturated)
58 return Wrapper{true, 0};
59 return Wrapper{false, other.v + v};
60 }
operator *__anon3c5201a10111::saturated_arith::Wrapper61 Wrapper operator*(Wrapper other) {
62 if (saturated || other.saturated)
63 return Wrapper{true, 0};
64 return Wrapper{false, other.v * v};
65 }
66 bool saturated;
67 int64_t v;
68 };
69 } // namespace saturated_arith
70 } // namespace
71
72 /// Materialize a single constant operation from a given attribute value with
73 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)74 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
75 Attribute value, Type type,
76 Location loc) {
77 if (arith::ConstantOp::isBuildableWith(value, type))
78 return builder.create<arith::ConstantOp>(loc, value, type);
79 return nullptr;
80 }
81
82 //===----------------------------------------------------------------------===//
83 // Common canonicalization pattern support logic
84 //===----------------------------------------------------------------------===//
85
86 /// This is a common class used for patterns of the form
87 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
88 /// into the root operation directly.
foldMemRefCast(Operation * op,Value inner)89 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
90 bool folded = false;
91 for (OpOperand &operand : op->getOpOperands()) {
92 auto cast = operand.get().getDefiningOp<CastOp>();
93 if (cast && operand.get() != inner &&
94 !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
95 operand.set(cast.getOperand());
96 folded = true;
97 }
98 }
99 return success(folded);
100 }
101
102 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
103 /// type.
getTensorTypeFromMemRefType(Type type)104 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
105 if (auto memref = type.dyn_cast<MemRefType>())
106 return RankedTensorType::get(memref.getShape(), memref.getElementType());
107 if (auto memref = type.dyn_cast<UnrankedMemRefType>())
108 return UnrankedTensorType::get(memref.getElementType());
109 return NoneType::get(type.getContext());
110 }
111
112 //===----------------------------------------------------------------------===//
113 // AllocOp / AllocaOp
114 //===----------------------------------------------------------------------===//
115
116 template <typename AllocLikeOp>
verifyAllocLikeOp(AllocLikeOp op)117 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
118 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
119 "applies to only alloc or alloca");
120 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
121 if (!memRefType)
122 return op.emitOpError("result must be a memref");
123
124 if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
125 memRefType.getNumDynamicDims())
126 return op.emitOpError("dimension operand count does not equal memref "
127 "dynamic dimension count");
128
129 unsigned numSymbols = 0;
130 if (!memRefType.getLayout().isIdentity())
131 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
132 if (op.getSymbolOperands().size() != numSymbols)
133 return op.emitOpError("symbol operand count does not equal memref symbol "
134 "count: expected ")
135 << numSymbols << ", got " << op.getSymbolOperands().size();
136
137 return success();
138 }
139
verify()140 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
141
verify()142 LogicalResult AllocaOp::verify() {
143 // An alloca op needs to have an ancestor with an allocation scope trait.
144 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
145 return emitOpError(
146 "requires an ancestor op with AutomaticAllocationScope trait");
147
148 return verifyAllocLikeOp(*this);
149 }
150
151 namespace {
152 /// Fold constant dimensions into an alloc like operation.
153 template <typename AllocLikeOp>
154 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
155 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
156
matchAndRewrite__anon3c5201a10211::SimplifyAllocConst157 LogicalResult matchAndRewrite(AllocLikeOp alloc,
158 PatternRewriter &rewriter) const override {
159 // Check to see if any dimensions operands are constants. If so, we can
160 // substitute and drop them.
161 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
162 return matchPattern(operand, matchConstantIndex());
163 }))
164 return failure();
165
166 auto memrefType = alloc.getType();
167
168 // Ok, we have one or more constant operands. Collect the non-constant ones
169 // and keep track of the resultant memref type to build.
170 SmallVector<int64_t, 4> newShapeConstants;
171 newShapeConstants.reserve(memrefType.getRank());
172 SmallVector<Value, 4> dynamicSizes;
173
174 unsigned dynamicDimPos = 0;
175 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
176 int64_t dimSize = memrefType.getDimSize(dim);
177 // If this is already static dimension, keep it.
178 if (dimSize != -1) {
179 newShapeConstants.push_back(dimSize);
180 continue;
181 }
182 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
183 auto *defOp = dynamicSize.getDefiningOp();
184 if (auto constantIndexOp =
185 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
186 // Dynamic shape dimension will be folded.
187 newShapeConstants.push_back(constantIndexOp.value());
188 } else {
189 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
190 newShapeConstants.push_back(-1);
191 dynamicSizes.push_back(dynamicSize);
192 }
193 dynamicDimPos++;
194 }
195
196 // Create new memref type (which will have fewer dynamic dimensions).
197 MemRefType newMemRefType =
198 MemRefType::Builder(memrefType).setShape(newShapeConstants);
199 assert(static_cast<int64_t>(dynamicSizes.size()) ==
200 newMemRefType.getNumDynamicDims());
201
202 // Create and insert the alloc op for the new memref.
203 auto newAlloc = rewriter.create<AllocLikeOp>(
204 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
205 alloc.getAlignmentAttr());
206 // Insert a cast so we have the same type as the old alloc.
207 auto resultCast =
208 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
209
210 rewriter.replaceOp(alloc, {resultCast});
211 return success();
212 }
213 };
214
215 /// Fold alloc operations with no users or only store and dealloc uses.
216 template <typename T>
217 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
218 using OpRewritePattern<T>::OpRewritePattern;
219
matchAndRewrite__anon3c5201a10211::SimplifyDeadAlloc220 LogicalResult matchAndRewrite(T alloc,
221 PatternRewriter &rewriter) const override {
222 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
223 if (auto storeOp = dyn_cast<StoreOp>(op))
224 return storeOp.getValue() == alloc;
225 return !isa<DeallocOp>(op);
226 }))
227 return failure();
228
229 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
230 rewriter.eraseOp(user);
231
232 rewriter.eraseOp(alloc);
233 return success();
234 }
235 };
236 } // namespace
237
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)238 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
239 MLIRContext *context) {
240 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
241 }
242
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)243 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
244 MLIRContext *context) {
245 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
246 context);
247 }
248
249 //===----------------------------------------------------------------------===//
250 // AllocaScopeOp
251 //===----------------------------------------------------------------------===//
252
print(OpAsmPrinter & p)253 void AllocaScopeOp::print(OpAsmPrinter &p) {
254 bool printBlockTerminators = false;
255
256 p << ' ';
257 if (!getResults().empty()) {
258 p << " -> (" << getResultTypes() << ")";
259 printBlockTerminators = true;
260 }
261 p << ' ';
262 p.printRegion(getBodyRegion(),
263 /*printEntryBlockArgs=*/false,
264 /*printBlockTerminators=*/printBlockTerminators);
265 p.printOptionalAttrDict((*this)->getAttrs());
266 }
267
parse(OpAsmParser & parser,OperationState & result)268 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
269 // Create a region for the body.
270 result.regions.reserve(1);
271 Region *bodyRegion = result.addRegion();
272
273 // Parse optional results type list.
274 if (parser.parseOptionalArrowTypeList(result.types))
275 return failure();
276
277 // Parse the body region.
278 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
279 return failure();
280 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
281 result.location);
282
283 // Parse the optional attribute list.
284 if (parser.parseOptionalAttrDict(result.attributes))
285 return failure();
286
287 return success();
288 }
289
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)290 void AllocaScopeOp::getSuccessorRegions(
291 Optional<unsigned> index, ArrayRef<Attribute> operands,
292 SmallVectorImpl<RegionSuccessor> ®ions) {
293 if (index) {
294 regions.push_back(RegionSuccessor(getResults()));
295 return;
296 }
297
298 regions.push_back(RegionSuccessor(&getBodyRegion()));
299 }
300
301 /// Given an operation, return whether this op is guaranteed to
302 /// allocate an AutomaticAllocationScopeResource
isGuaranteedAutomaticAllocation(Operation * op)303 static bool isGuaranteedAutomaticAllocation(Operation *op) {
304 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
305 if (!interface)
306 return false;
307 for (auto res : op->getResults()) {
308 if (auto effect =
309 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
310 if (isa<SideEffects::AutomaticAllocationScopeResource>(
311 effect->getResource()))
312 return true;
313 }
314 }
315 return false;
316 }
317
318 /// Given an operation, return whether this op itself could
319 /// allocate an AutomaticAllocationScopeResource. Note that
320 /// this will not check whether an operation contained within
321 /// the op can allocate.
isOpItselfPotentialAutomaticAllocation(Operation * op)322 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
323 // This op itself doesn't create a stack allocation,
324 // the inner allocation should be handled separately.
325 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>())
326 return false;
327 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
328 if (!interface)
329 return true;
330 for (auto res : op->getResults()) {
331 if (auto effect =
332 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
333 if (isa<SideEffects::AutomaticAllocationScopeResource>(
334 effect->getResource()))
335 return true;
336 }
337 }
338 return false;
339 }
340
341 /// Return whether this op is the last non terminating op
342 /// in a region. That is to say, it is in a one-block region
343 /// and is only followed by a terminator. This prevents
344 /// extending the lifetime of allocations.
lastNonTerminatorInRegion(Operation * op)345 static bool lastNonTerminatorInRegion(Operation *op) {
346 return op->getNextNode() == op->getBlock()->getTerminator() &&
347 op->getParentRegion()->getBlocks().size() == 1;
348 }
349
350 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
351 /// or it contains no allocation.
352 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
353 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
354
matchAndRewriteAllocaScopeInliner355 LogicalResult matchAndRewrite(AllocaScopeOp op,
356 PatternRewriter &rewriter) const override {
357 bool hasPotentialAlloca =
358 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
359 if (alloc == op)
360 return WalkResult::advance();
361 if (isOpItselfPotentialAutomaticAllocation(alloc))
362 return WalkResult::interrupt();
363 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
364 return WalkResult::skip();
365 return WalkResult::advance();
366 }).wasInterrupted();
367
368 // If this contains no potential allocation, it is always legal to
369 // inline. Otherwise, consider two conditions:
370 if (hasPotentialAlloca) {
371 // If the parent isn't an allocation scope, or we are not the last
372 // non-terminator op in the parent, we will extend the lifetime.
373 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
374 return failure();
375 if (!lastNonTerminatorInRegion(op))
376 return failure();
377 }
378
379 Block *block = &op.getRegion().front();
380 Operation *terminator = block->getTerminator();
381 ValueRange results = terminator->getOperands();
382 rewriter.mergeBlockBefore(block, op);
383 rewriter.replaceOp(op, results);
384 rewriter.eraseOp(terminator);
385 return success();
386 }
387 };
388
389 /// Move allocations into an allocation scope, if it is legal to
390 /// move them (e.g. their operands are available at the location
391 /// the op would be moved to).
392 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
393 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
394
matchAndRewriteAllocaScopeHoister395 LogicalResult matchAndRewrite(AllocaScopeOp op,
396 PatternRewriter &rewriter) const override {
397
398 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
399 return failure();
400
401 Operation *lastParentWithoutScope = op->getParentOp();
402
403 if (!lastParentWithoutScope ||
404 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
405 return failure();
406
407 // Only apply to if this is this last non-terminator
408 // op in the block (lest lifetime be extended) of a one
409 // block region
410 if (!lastNonTerminatorInRegion(op) ||
411 !lastNonTerminatorInRegion(lastParentWithoutScope))
412 return failure();
413
414 while (!lastParentWithoutScope->getParentOp()
415 ->hasTrait<OpTrait::AutomaticAllocationScope>()) {
416 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
417 if (!lastParentWithoutScope ||
418 !lastNonTerminatorInRegion(lastParentWithoutScope))
419 return failure();
420 }
421 assert(lastParentWithoutScope->getParentOp()
422 ->hasTrait<OpTrait::AutomaticAllocationScope>());
423
424 Region *containingRegion = nullptr;
425 for (auto &r : lastParentWithoutScope->getRegions()) {
426 if (r.isAncestor(op->getParentRegion())) {
427 assert(containingRegion == nullptr &&
428 "only one region can contain the op");
429 containingRegion = &r;
430 }
431 }
432 assert(containingRegion && "op must be contained in a region");
433
434 SmallVector<Operation *> toHoist;
435 op->walk([&](Operation *alloc) {
436 if (!isGuaranteedAutomaticAllocation(alloc))
437 return WalkResult::skip();
438
439 // If any operand is not defined before the location of
440 // lastParentWithoutScope (i.e. where we would hoist to), skip.
441 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
442 return containingRegion->isAncestor(v.getParentRegion());
443 }))
444 return WalkResult::skip();
445 toHoist.push_back(alloc);
446 return WalkResult::advance();
447 });
448
449 if (toHoist.empty())
450 return failure();
451 rewriter.setInsertionPoint(lastParentWithoutScope);
452 for (auto *op : toHoist) {
453 auto *cloned = rewriter.clone(*op);
454 rewriter.replaceOp(op, cloned->getResults());
455 }
456 return success();
457 }
458 };
459
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)460 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
461 MLIRContext *context) {
462 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
463 }
464
465 //===----------------------------------------------------------------------===//
466 // AssumeAlignmentOp
467 //===----------------------------------------------------------------------===//
468
verify()469 LogicalResult AssumeAlignmentOp::verify() {
470 if (!llvm::isPowerOf2_32(getAlignment()))
471 return emitOpError("alignment must be power of 2");
472 return success();
473 }
474
475 //===----------------------------------------------------------------------===//
476 // CastOp
477 //===----------------------------------------------------------------------===//
478
479 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
480 /// source memref. This is useful to to fold a memref.cast into a consuming op
481 /// and implement canonicalization patterns for ops in different dialects that
482 /// may consume the results of memref.cast operations. Such foldable memref.cast
483 /// operations are typically inserted as `view` and `subview` ops are
484 /// canonicalized, to preserve the type compatibility of their uses.
485 ///
486 /// Returns true when all conditions are met:
487 /// 1. source and result are ranked memrefs with strided semantics and same
488 /// element type and rank.
489 /// 2. each of the source's size, offset or stride has more static information
490 /// than the corresponding result's size, offset or stride.
491 ///
492 /// Example 1:
493 /// ```mlir
494 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
495 /// %2 = consumer %1 ... : memref<?x?xf32> ...
496 /// ```
497 ///
498 /// may fold into:
499 ///
500 /// ```mlir
501 /// %2 = consumer %0 ... : memref<8x16xf32> ...
502 /// ```
503 ///
504 /// Example 2:
505 /// ```
506 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
507 /// to memref<?x?xf32>
508 /// consumer %1 : memref<?x?xf32> ...
509 /// ```
510 ///
511 /// may fold into:
512 ///
513 /// ```
514 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
515 /// ```
canFoldIntoConsumerOp(CastOp castOp)516 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
517 MemRefType sourceType = castOp.getSource().getType().dyn_cast<MemRefType>();
518 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
519
520 // Requires ranked MemRefType.
521 if (!sourceType || !resultType)
522 return false;
523
524 // Requires same elemental type.
525 if (sourceType.getElementType() != resultType.getElementType())
526 return false;
527
528 // Requires same rank.
529 if (sourceType.getRank() != resultType.getRank())
530 return false;
531
532 // Only fold casts between strided memref forms.
533 int64_t sourceOffset, resultOffset;
534 SmallVector<int64_t, 4> sourceStrides, resultStrides;
535 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
536 failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
537 return false;
538
539 // If cast is towards more static sizes along any dimension, don't fold.
540 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
541 auto ss = std::get<0>(it), st = std::get<1>(it);
542 if (ss != st)
543 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
544 return false;
545 }
546
547 // If cast is towards more static offset along any dimension, don't fold.
548 if (sourceOffset != resultOffset)
549 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
550 !ShapedType::isDynamicStrideOrOffset(resultOffset))
551 return false;
552
553 // If cast is towards more static strides along any dimension, don't fold.
554 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
555 auto ss = std::get<0>(it), st = std::get<1>(it);
556 if (ss != st)
557 if (ShapedType::isDynamicStrideOrOffset(ss) &&
558 !ShapedType::isDynamicStrideOrOffset(st))
559 return false;
560 }
561
562 return true;
563 }
564
areCastCompatible(TypeRange inputs,TypeRange outputs)565 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
566 if (inputs.size() != 1 || outputs.size() != 1)
567 return false;
568 Type a = inputs.front(), b = outputs.front();
569 auto aT = a.dyn_cast<MemRefType>();
570 auto bT = b.dyn_cast<MemRefType>();
571
572 auto uaT = a.dyn_cast<UnrankedMemRefType>();
573 auto ubT = b.dyn_cast<UnrankedMemRefType>();
574
575 if (aT && bT) {
576 if (aT.getElementType() != bT.getElementType())
577 return false;
578 if (aT.getLayout() != bT.getLayout()) {
579 int64_t aOffset, bOffset;
580 SmallVector<int64_t, 4> aStrides, bStrides;
581 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
582 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
583 aStrides.size() != bStrides.size())
584 return false;
585
586 // Strides along a dimension/offset are compatible if the value in the
587 // source memref is static and the value in the target memref is the
588 // same. They are also compatible if either one is dynamic (see
589 // description of MemRefCastOp for details).
590 auto checkCompatible = [](int64_t a, int64_t b) {
591 return (a == MemRefType::getDynamicStrideOrOffset() ||
592 b == MemRefType::getDynamicStrideOrOffset() || a == b);
593 };
594 if (!checkCompatible(aOffset, bOffset))
595 return false;
596 for (const auto &aStride : enumerate(aStrides))
597 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
598 return false;
599 }
600 if (aT.getMemorySpace() != bT.getMemorySpace())
601 return false;
602
603 // They must have the same rank, and any specified dimensions must match.
604 if (aT.getRank() != bT.getRank())
605 return false;
606
607 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
608 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
609 if (aDim != -1 && bDim != -1 && aDim != bDim)
610 return false;
611 }
612 return true;
613 } else {
614 if (!aT && !uaT)
615 return false;
616 if (!bT && !ubT)
617 return false;
618 // Unranked to unranked casting is unsupported
619 if (uaT && ubT)
620 return false;
621
622 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
623 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
624 if (aEltType != bEltType)
625 return false;
626
627 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
628 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
629 return aMemSpace == bMemSpace;
630 }
631
632 return false;
633 }
634
fold(ArrayRef<Attribute> operands)635 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
636 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
637 }
638
639 //===----------------------------------------------------------------------===//
640 // CopyOp
641 //===----------------------------------------------------------------------===//
642
643 namespace {
644 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
645 /// and element type, the cast can be skipped. Such CastOps only cast the layout
646 /// of the type.
647 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
648 using OpRewritePattern<CopyOp>::OpRewritePattern;
649
matchAndRewrite__anon3c5201a10911::FoldCopyOfCast650 LogicalResult matchAndRewrite(CopyOp copyOp,
651 PatternRewriter &rewriter) const override {
652 bool modified = false;
653
654 // Check source.
655 if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
656 auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
657 auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
658
659 if (fromType && toType) {
660 if (fromType.getShape() == toType.getShape() &&
661 fromType.getElementType() == toType.getElementType()) {
662 rewriter.updateRootInPlace(copyOp, [&] {
663 copyOp.getSourceMutable().assign(castOp.getSource());
664 });
665 modified = true;
666 }
667 }
668 }
669
670 // Check target.
671 if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
672 auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
673 auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
674
675 if (fromType && toType) {
676 if (fromType.getShape() == toType.getShape() &&
677 fromType.getElementType() == toType.getElementType()) {
678 rewriter.updateRootInPlace(copyOp, [&] {
679 copyOp.getTargetMutable().assign(castOp.getSource());
680 });
681 modified = true;
682 }
683 }
684 }
685
686 return success(modified);
687 }
688 };
689
690 /// Fold memref.copy(%x, %x).
691 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
692 using OpRewritePattern<CopyOp>::OpRewritePattern;
693
matchAndRewrite__anon3c5201a10911::FoldSelfCopy694 LogicalResult matchAndRewrite(CopyOp copyOp,
695 PatternRewriter &rewriter) const override {
696 if (copyOp.getSource() != copyOp.getTarget())
697 return failure();
698
699 rewriter.eraseOp(copyOp);
700 return success();
701 }
702 };
703 } // namespace
704
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)705 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
706 MLIRContext *context) {
707 results.add<FoldCopyOfCast, FoldSelfCopy>(context);
708 }
709
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)710 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
711 SmallVectorImpl<OpFoldResult> &results) {
712 /// copy(memrefcast) -> copy
713 bool folded = false;
714 Operation *op = *this;
715 for (OpOperand &operand : op->getOpOperands()) {
716 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
717 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
718 operand.set(castOp.getOperand());
719 folded = true;
720 }
721 }
722 return success(folded);
723 }
724
725 //===----------------------------------------------------------------------===//
726 // DeallocOp
727 //===----------------------------------------------------------------------===//
728
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)729 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
730 SmallVectorImpl<OpFoldResult> &results) {
731 /// dealloc(memrefcast) -> dealloc
732 return foldMemRefCast(*this);
733 }
734
735 //===----------------------------------------------------------------------===//
736 // DimOp
737 //===----------------------------------------------------------------------===//
738
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)739 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
740 int64_t index) {
741 auto loc = result.location;
742 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
743 build(builder, result, source, indexValue);
744 }
745
build(OpBuilder & builder,OperationState & result,Value source,Value index)746 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
747 Value index) {
748 auto indexTy = builder.getIndexType();
749 build(builder, result, indexTy, source, index);
750 }
751
getConstantIndex()752 Optional<int64_t> DimOp::getConstantIndex() {
753 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
754 return constantOp.getValue().cast<IntegerAttr>().getInt();
755 return {};
756 }
757
verify()758 LogicalResult DimOp::verify() {
759 // Assume unknown index to be in range.
760 Optional<int64_t> index = getConstantIndex();
761 if (!index)
762 return success();
763
764 // Check that constant index is not knowingly out of range.
765 auto type = getSource().getType();
766 if (auto memrefType = type.dyn_cast<MemRefType>()) {
767 if (*index >= memrefType.getRank())
768 return emitOpError("index is out of range");
769 } else if (type.isa<UnrankedMemRefType>()) {
770 // Assume index to be in range.
771 } else {
772 llvm_unreachable("expected operand with memref type");
773 }
774 return success();
775 }
776
777 /// Return a map with key being elements in `vals` and data being number of
778 /// occurences of it. Use std::map, since the `vals` here are strides and the
779 /// dynamic stride value is the same as the tombstone value for
780 /// `DenseMap<int64_t>`.
getNumOccurences(ArrayRef<int64_t> vals)781 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
782 std::map<int64_t, unsigned> numOccurences;
783 for (auto val : vals)
784 numOccurences[val]++;
785 return numOccurences;
786 }
787
788 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
789 /// to be a subset of `originalType` with some `1` entries erased, return the
790 /// set of indices that specifies which of the entries of `originalShape` are
791 /// dropped to obtain `reducedShape`.
792 /// This accounts for cases where there are multiple unit-dims, but only a
793 /// subset of those are dropped. For MemRefTypes these can be disambiguated
794 /// using the strides. If a dimension is dropped the stride must be dropped too.
795 static llvm::Optional<llvm::SmallBitVector>
computeMemRefRankReductionMask(MemRefType originalType,MemRefType reducedType,ArrayRef<OpFoldResult> sizes)796 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
797 ArrayRef<OpFoldResult> sizes) {
798 llvm::SmallBitVector unusedDims(originalType.getRank());
799 if (originalType.getRank() == reducedType.getRank())
800 return unusedDims;
801
802 for (const auto &dim : llvm::enumerate(sizes))
803 if (auto attr = dim.value().dyn_cast<Attribute>())
804 if (attr.cast<IntegerAttr>().getInt() == 1)
805 unusedDims.set(dim.index());
806
807 // Early exit for the case where the number of unused dims matches the number
808 // of ranks reduced.
809 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
810 originalType.getRank())
811 return unusedDims;
812
813 SmallVector<int64_t> originalStrides, candidateStrides;
814 int64_t originalOffset, candidateOffset;
815 if (failed(
816 getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
817 failed(
818 getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
819 return llvm::None;
820
821 // For memrefs, a dimension is truly dropped if its corresponding stride is
822 // also dropped. This is particularly important when more than one of the dims
823 // is 1. Track the number of occurences of the strides in the original type
824 // and the candidate type. For each unused dim that stride should not be
825 // present in the candidate type. Note that there could be multiple dimensions
826 // that have the same size. We dont need to exactly figure out which dim
827 // corresponds to which stride, we just need to verify that the number of
828 // reptitions of a stride in the original + number of unused dims with that
829 // stride == number of repititions of a stride in the candidate.
830 std::map<int64_t, unsigned> currUnaccountedStrides =
831 getNumOccurences(originalStrides);
832 std::map<int64_t, unsigned> candidateStridesNumOccurences =
833 getNumOccurences(candidateStrides);
834 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
835 if (!unusedDims.test(dim))
836 continue;
837 int64_t originalStride = originalStrides[dim];
838 if (currUnaccountedStrides[originalStride] >
839 candidateStridesNumOccurences[originalStride]) {
840 // This dim can be treated as dropped.
841 currUnaccountedStrides[originalStride]--;
842 continue;
843 }
844 if (currUnaccountedStrides[originalStride] ==
845 candidateStridesNumOccurences[originalStride]) {
846 // The stride for this is not dropped. Keep as is.
847 unusedDims.reset(dim);
848 continue;
849 }
850 if (currUnaccountedStrides[originalStride] <
851 candidateStridesNumOccurences[originalStride]) {
852 // This should never happen. Cant have a stride in the reduced rank type
853 // that wasnt in the original one.
854 return llvm::None;
855 }
856 }
857
858 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
859 originalType.getRank())
860 return llvm::None;
861 return unusedDims;
862 }
863
getDroppedDims()864 llvm::SmallBitVector SubViewOp::getDroppedDims() {
865 MemRefType sourceType = getSourceType();
866 MemRefType resultType = getType();
867 llvm::Optional<llvm::SmallBitVector> unusedDims =
868 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
869 assert(unusedDims && "unable to find unused dims of subview");
870 return *unusedDims;
871 }
872
fold(ArrayRef<Attribute> operands)873 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
874 // All forms of folding require a known index.
875 auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
876 if (!index)
877 return {};
878
879 // Folding for unranked types (UnrankedMemRefType) is not supported.
880 auto memrefType = getSource().getType().dyn_cast<MemRefType>();
881 if (!memrefType)
882 return {};
883
884 // Fold if the shape extent along the given index is known.
885 if (!memrefType.isDynamicDim(index.getInt())) {
886 Builder builder(getContext());
887 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
888 }
889
890 // The size at the given index is now known to be a dynamic size.
891 unsigned unsignedIndex = index.getValue().getZExtValue();
892
893 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
894 Operation *definingOp = getSource().getDefiningOp();
895
896 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
897 return *(alloc.getDynamicSizes().begin() +
898 memrefType.getDynamicDimIndex(unsignedIndex));
899
900 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
901 return *(alloca.getDynamicSizes().begin() +
902 memrefType.getDynamicDimIndex(unsignedIndex));
903
904 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
905 return *(view.getDynamicSizes().begin() +
906 memrefType.getDynamicDimIndex(unsignedIndex));
907
908 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
909 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
910 unsigned resultIndex = 0;
911 unsigned sourceRank = subview.getSourceType().getRank();
912 unsigned sourceIndex = 0;
913 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
914 if (unusedDims.test(i))
915 continue;
916 if (resultIndex == unsignedIndex) {
917 sourceIndex = i;
918 break;
919 }
920 resultIndex++;
921 }
922 assert(subview.isDynamicSize(sourceIndex) &&
923 "expected dynamic subview size");
924 return subview.getDynamicSize(sourceIndex);
925 }
926
927 if (auto sizeInterface =
928 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
929 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
930 "Expected dynamic subview size");
931 return sizeInterface.getDynamicSize(unsignedIndex);
932 }
933
934 // dim(memrefcast) -> dim
935 if (succeeded(foldMemRefCast(*this)))
936 return getResult();
937
938 return {};
939 }
940
941 namespace {
942 /// Fold dim of a memref reshape operation to a load into the reshape's shape
943 /// operand.
944 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
945 using OpRewritePattern<DimOp>::OpRewritePattern;
946
matchAndRewrite__anon3c5201a10c11::DimOfMemRefReshape947 LogicalResult matchAndRewrite(DimOp dim,
948 PatternRewriter &rewriter) const override {
949 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
950
951 if (!reshape)
952 return failure();
953
954 // Place the load directly after the reshape to ensure that the shape memref
955 // was not mutated.
956 rewriter.setInsertionPointAfter(reshape);
957 Location loc = dim.getLoc();
958 Value load =
959 rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
960 if (load.getType() != dim.getType())
961 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
962 rewriter.replaceOp(dim, load);
963 return success();
964 }
965 };
966
967 } // namespace
968
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)969 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
970 MLIRContext *context) {
971 results.add<DimOfMemRefReshape>(context);
972 }
973
974 // ---------------------------------------------------------------------------
975 // DmaStartOp
976 // ---------------------------------------------------------------------------
977
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)978 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
979 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
980 ValueRange destIndices, Value numElements,
981 Value tagMemRef, ValueRange tagIndices, Value stride,
982 Value elementsPerStride) {
983 result.addOperands(srcMemRef);
984 result.addOperands(srcIndices);
985 result.addOperands(destMemRef);
986 result.addOperands(destIndices);
987 result.addOperands({numElements, tagMemRef});
988 result.addOperands(tagIndices);
989 if (stride)
990 result.addOperands({stride, elementsPerStride});
991 }
992
print(OpAsmPrinter & p)993 void DmaStartOp::print(OpAsmPrinter &p) {
994 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
995 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
996 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
997 if (isStrided())
998 p << ", " << getStride() << ", " << getNumElementsPerStride();
999
1000 p.printOptionalAttrDict((*this)->getAttrs());
1001 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1002 << ", " << getTagMemRef().getType();
1003 }
1004
1005 // Parse DmaStartOp.
1006 // Ex:
1007 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1008 // %tag[%index], %stride, %num_elt_per_stride :
1009 // : memref<3076 x f32, 0>,
1010 // memref<1024 x f32, 2>,
1011 // memref<1 x i32>
1012 //
parse(OpAsmParser & parser,OperationState & result)1013 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1014 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1015 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1016 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1017 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1018 OpAsmParser::UnresolvedOperand numElementsInfo;
1019 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1020 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1021 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1022
1023 SmallVector<Type, 3> types;
1024 auto indexType = parser.getBuilder().getIndexType();
1025
1026 // Parse and resolve the following list of operands:
1027 // *) source memref followed by its indices (in square brackets).
1028 // *) destination memref followed by its indices (in square brackets).
1029 // *) dma size in KiB.
1030 if (parser.parseOperand(srcMemRefInfo) ||
1031 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1032 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1033 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1034 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1035 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1036 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1037 return failure();
1038
1039 // Parse optional stride and elements per stride.
1040 if (parser.parseTrailingOperandList(strideInfo))
1041 return failure();
1042
1043 bool isStrided = strideInfo.size() == 2;
1044 if (!strideInfo.empty() && !isStrided) {
1045 return parser.emitError(parser.getNameLoc(),
1046 "expected two stride related operands");
1047 }
1048
1049 if (parser.parseColonTypeList(types))
1050 return failure();
1051 if (types.size() != 3)
1052 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1053
1054 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1055 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1056 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1057 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1058 // size should be an index.
1059 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1060 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1061 // tag indices should be index.
1062 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1063 return failure();
1064
1065 if (isStrided) {
1066 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1067 return failure();
1068 }
1069
1070 return success();
1071 }
1072
verify()1073 LogicalResult DmaStartOp::verify() {
1074 unsigned numOperands = getNumOperands();
1075
1076 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1077 // the number of elements.
1078 if (numOperands < 4)
1079 return emitOpError("expected at least 4 operands");
1080
1081 // Check types of operands. The order of these calls is important: the later
1082 // calls rely on some type properties to compute the operand position.
1083 // 1. Source memref.
1084 if (!getSrcMemRef().getType().isa<MemRefType>())
1085 return emitOpError("expected source to be of memref type");
1086 if (numOperands < getSrcMemRefRank() + 4)
1087 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1088 << " operands";
1089 if (!getSrcIndices().empty() &&
1090 !llvm::all_of(getSrcIndices().getTypes(),
1091 [](Type t) { return t.isIndex(); }))
1092 return emitOpError("expected source indices to be of index type");
1093
1094 // 2. Destination memref.
1095 if (!getDstMemRef().getType().isa<MemRefType>())
1096 return emitOpError("expected destination to be of memref type");
1097 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1098 if (numOperands < numExpectedOperands)
1099 return emitOpError() << "expected at least " << numExpectedOperands
1100 << " operands";
1101 if (!getDstIndices().empty() &&
1102 !llvm::all_of(getDstIndices().getTypes(),
1103 [](Type t) { return t.isIndex(); }))
1104 return emitOpError("expected destination indices to be of index type");
1105
1106 // 3. Number of elements.
1107 if (!getNumElements().getType().isIndex())
1108 return emitOpError("expected num elements to be of index type");
1109
1110 // 4. Tag memref.
1111 if (!getTagMemRef().getType().isa<MemRefType>())
1112 return emitOpError("expected tag to be of memref type");
1113 numExpectedOperands += getTagMemRefRank();
1114 if (numOperands < numExpectedOperands)
1115 return emitOpError() << "expected at least " << numExpectedOperands
1116 << " operands";
1117 if (!getTagIndices().empty() &&
1118 !llvm::all_of(getTagIndices().getTypes(),
1119 [](Type t) { return t.isIndex(); }))
1120 return emitOpError("expected tag indices to be of index type");
1121
1122 // Optional stride-related operands must be either both present or both
1123 // absent.
1124 if (numOperands != numExpectedOperands &&
1125 numOperands != numExpectedOperands + 2)
1126 return emitOpError("incorrect number of operands");
1127
1128 // 5. Strides.
1129 if (isStrided()) {
1130 if (!getStride().getType().isIndex() ||
1131 !getNumElementsPerStride().getType().isIndex())
1132 return emitOpError(
1133 "expected stride and num elements per stride to be of type index");
1134 }
1135
1136 return success();
1137 }
1138
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1139 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1140 SmallVectorImpl<OpFoldResult> &results) {
1141 /// dma_start(memrefcast) -> dma_start
1142 return foldMemRefCast(*this);
1143 }
1144
1145 // ---------------------------------------------------------------------------
1146 // DmaWaitOp
1147 // ---------------------------------------------------------------------------
1148
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1149 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1150 SmallVectorImpl<OpFoldResult> &results) {
1151 /// dma_wait(memrefcast) -> dma_wait
1152 return foldMemRefCast(*this);
1153 }
1154
verify()1155 LogicalResult DmaWaitOp::verify() {
1156 // Check that the number of tag indices matches the tagMemRef rank.
1157 unsigned numTagIndices = getTagIndices().size();
1158 unsigned tagMemRefRank = getTagMemRefRank();
1159 if (numTagIndices != tagMemRefRank)
1160 return emitOpError() << "expected tagIndices to have the same number of "
1161 "elements as the tagMemRef rank, expected "
1162 << tagMemRefRank << ", but got " << numTagIndices;
1163 return success();
1164 }
1165
1166 //===----------------------------------------------------------------------===//
1167 // GenericAtomicRMWOp
1168 //===----------------------------------------------------------------------===//
1169
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)1170 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1171 Value memref, ValueRange ivs) {
1172 result.addOperands(memref);
1173 result.addOperands(ivs);
1174
1175 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
1176 Type elementType = memrefType.getElementType();
1177 result.addTypes(elementType);
1178
1179 Region *bodyRegion = result.addRegion();
1180 bodyRegion->push_back(new Block());
1181 bodyRegion->addArgument(elementType, memref.getLoc());
1182 }
1183 }
1184
verify()1185 LogicalResult GenericAtomicRMWOp::verify() {
1186 auto &body = getRegion();
1187 if (body.getNumArguments() != 1)
1188 return emitOpError("expected single number of entry block arguments");
1189
1190 if (getResult().getType() != body.getArgument(0).getType())
1191 return emitOpError("expected block argument of the same type result type");
1192
1193 bool hasSideEffects =
1194 body.walk([&](Operation *nestedOp) {
1195 if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
1196 return WalkResult::advance();
1197 nestedOp->emitError(
1198 "body of 'memref.generic_atomic_rmw' should contain "
1199 "only operations with no side effects");
1200 return WalkResult::interrupt();
1201 })
1202 .wasInterrupted();
1203 return hasSideEffects ? failure() : success();
1204 }
1205
parse(OpAsmParser & parser,OperationState & result)1206 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1207 OperationState &result) {
1208 OpAsmParser::UnresolvedOperand memref;
1209 Type memrefType;
1210 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1211
1212 Type indexType = parser.getBuilder().getIndexType();
1213 if (parser.parseOperand(memref) ||
1214 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1215 parser.parseColonType(memrefType) ||
1216 parser.resolveOperand(memref, memrefType, result.operands) ||
1217 parser.resolveOperands(ivs, indexType, result.operands))
1218 return failure();
1219
1220 Region *body = result.addRegion();
1221 if (parser.parseRegion(*body, {}) ||
1222 parser.parseOptionalAttrDict(result.attributes))
1223 return failure();
1224 result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1225 return success();
1226 }
1227
print(OpAsmPrinter & p)1228 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1229 p << ' ' << getMemref() << "[" << getIndices()
1230 << "] : " << getMemref().getType() << ' ';
1231 p.printRegion(getRegion());
1232 p.printOptionalAttrDict((*this)->getAttrs());
1233 }
1234
1235 //===----------------------------------------------------------------------===//
1236 // AtomicYieldOp
1237 //===----------------------------------------------------------------------===//
1238
verify()1239 LogicalResult AtomicYieldOp::verify() {
1240 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1241 Type resultType = getResult().getType();
1242 if (parentType != resultType)
1243 return emitOpError() << "types mismatch between yield op: " << resultType
1244 << " and its parent: " << parentType;
1245 return success();
1246 }
1247
1248 //===----------------------------------------------------------------------===//
1249 // GlobalOp
1250 //===----------------------------------------------------------------------===//
1251
printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter & p,GlobalOp op,TypeAttr type,Attribute initialValue)1252 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1253 TypeAttr type,
1254 Attribute initialValue) {
1255 p << type;
1256 if (!op.isExternal()) {
1257 p << " = ";
1258 if (op.isUninitialized())
1259 p << "uninitialized";
1260 else
1261 p.printAttributeWithoutType(initialValue);
1262 }
1263 }
1264
1265 static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & initialValue)1266 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1267 Attribute &initialValue) {
1268 Type type;
1269 if (parser.parseType(type))
1270 return failure();
1271
1272 auto memrefType = type.dyn_cast<MemRefType>();
1273 if (!memrefType || !memrefType.hasStaticShape())
1274 return parser.emitError(parser.getNameLoc())
1275 << "type should be static shaped memref, but got " << type;
1276 typeAttr = TypeAttr::get(type);
1277
1278 if (parser.parseOptionalEqual())
1279 return success();
1280
1281 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1282 initialValue = UnitAttr::get(parser.getContext());
1283 return success();
1284 }
1285
1286 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1287 if (parser.parseAttribute(initialValue, tensorType))
1288 return failure();
1289 if (!initialValue.isa<ElementsAttr>())
1290 return parser.emitError(parser.getNameLoc())
1291 << "initial value should be a unit or elements attribute";
1292 return success();
1293 }
1294
verify()1295 LogicalResult GlobalOp::verify() {
1296 auto memrefType = getType().dyn_cast<MemRefType>();
1297 if (!memrefType || !memrefType.hasStaticShape())
1298 return emitOpError("type should be static shaped memref, but got ")
1299 << getType();
1300
1301 // Verify that the initial value, if present, is either a unit attribute or
1302 // an elements attribute.
1303 if (getInitialValue().has_value()) {
1304 Attribute initValue = getInitialValue().value();
1305 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1306 return emitOpError("initial value should be a unit or elements "
1307 "attribute, but got ")
1308 << initValue;
1309
1310 // Check that the type of the initial value is compatible with the type of
1311 // the global variable.
1312 if (initValue.isa<ElementsAttr>()) {
1313 Type initType = initValue.getType();
1314 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1315 if (initType != tensorType)
1316 return emitOpError("initial value expected to be of type ")
1317 << tensorType << ", but was of type " << initType;
1318 }
1319 }
1320
1321 if (Optional<uint64_t> alignAttr = getAlignment()) {
1322 uint64_t alignment = *alignAttr;
1323
1324 if (!llvm::isPowerOf2_64(alignment))
1325 return emitError() << "alignment attribute value " << alignment
1326 << " is not a power of 2";
1327 }
1328
1329 // TODO: verify visibility for declarations.
1330 return success();
1331 }
1332
getConstantInitValue()1333 ElementsAttr GlobalOp::getConstantInitValue() {
1334 auto initVal = getInitialValue();
1335 if (getConstant() && initVal.has_value())
1336 return initVal.value().cast<ElementsAttr>();
1337 return {};
1338 }
1339
1340 //===----------------------------------------------------------------------===//
1341 // GetGlobalOp
1342 //===----------------------------------------------------------------------===//
1343
1344 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)1345 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1346 // Verify that the result type is same as the type of the referenced
1347 // memref.global op.
1348 auto global =
1349 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1350 if (!global)
1351 return emitOpError("'")
1352 << getName() << "' does not reference a valid global memref";
1353
1354 Type resultType = getResult().getType();
1355 if (global.getType() != resultType)
1356 return emitOpError("result type ")
1357 << resultType << " does not match type " << global.getType()
1358 << " of the global memref @" << getName();
1359 return success();
1360 }
1361
1362 //===----------------------------------------------------------------------===//
1363 // LoadOp
1364 //===----------------------------------------------------------------------===//
1365
verify()1366 LogicalResult LoadOp::verify() {
1367 if (getNumOperands() != 1 + getMemRefType().getRank())
1368 return emitOpError("incorrect number of indices for load");
1369 return success();
1370 }
1371
fold(ArrayRef<Attribute> cstOperands)1372 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1373 /// load(memrefcast) -> load
1374 if (succeeded(foldMemRefCast(*this)))
1375 return getResult();
1376 return OpFoldResult();
1377 }
1378
1379 //===----------------------------------------------------------------------===//
1380 // PrefetchOp
1381 //===----------------------------------------------------------------------===//
1382
print(OpAsmPrinter & p)1383 void PrefetchOp::print(OpAsmPrinter &p) {
1384 p << " " << getMemref() << '[';
1385 p.printOperands(getIndices());
1386 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1387 p << ", locality<" << getLocalityHint();
1388 p << ">, " << (getIsDataCache() ? "data" : "instr");
1389 p.printOptionalAttrDict(
1390 (*this)->getAttrs(),
1391 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1392 p << " : " << getMemRefType();
1393 }
1394
parse(OpAsmParser & parser,OperationState & result)1395 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1396 OpAsmParser::UnresolvedOperand memrefInfo;
1397 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1398 IntegerAttr localityHint;
1399 MemRefType type;
1400 StringRef readOrWrite, cacheType;
1401
1402 auto indexTy = parser.getBuilder().getIndexType();
1403 auto i32Type = parser.getBuilder().getIntegerType(32);
1404 if (parser.parseOperand(memrefInfo) ||
1405 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1406 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1407 parser.parseComma() || parser.parseKeyword("locality") ||
1408 parser.parseLess() ||
1409 parser.parseAttribute(localityHint, i32Type, "localityHint",
1410 result.attributes) ||
1411 parser.parseGreater() || parser.parseComma() ||
1412 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1413 parser.resolveOperand(memrefInfo, type, result.operands) ||
1414 parser.resolveOperands(indexInfo, indexTy, result.operands))
1415 return failure();
1416
1417 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1418 return parser.emitError(parser.getNameLoc(),
1419 "rw specifier has to be 'read' or 'write'");
1420 result.addAttribute(
1421 PrefetchOp::getIsWriteAttrStrName(),
1422 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1423
1424 if (!cacheType.equals("data") && !cacheType.equals("instr"))
1425 return parser.emitError(parser.getNameLoc(),
1426 "cache type has to be 'data' or 'instr'");
1427
1428 result.addAttribute(
1429 PrefetchOp::getIsDataCacheAttrStrName(),
1430 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1431
1432 return success();
1433 }
1434
verify()1435 LogicalResult PrefetchOp::verify() {
1436 if (getNumOperands() != 1 + getMemRefType().getRank())
1437 return emitOpError("too few indices");
1438
1439 return success();
1440 }
1441
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1442 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1443 SmallVectorImpl<OpFoldResult> &results) {
1444 // prefetch(memrefcast) -> prefetch
1445 return foldMemRefCast(*this);
1446 }
1447
1448 //===----------------------------------------------------------------------===//
1449 // RankOp
1450 //===----------------------------------------------------------------------===//
1451
fold(ArrayRef<Attribute> operands)1452 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1453 // Constant fold rank when the rank of the operand is known.
1454 auto type = getOperand().getType();
1455 auto shapedType = type.dyn_cast<ShapedType>();
1456 if (shapedType && shapedType.hasRank())
1457 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1458 return IntegerAttr();
1459 }
1460
1461 //===----------------------------------------------------------------------===//
1462 // ReinterpretCastOp
1463 //===----------------------------------------------------------------------===//
1464
1465 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1466 /// `staticSizes` and `staticStrides` are automatically filled with
1467 /// source-memref-rank sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,OpFoldResult offset,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1468 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1469 MemRefType resultType, Value source,
1470 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1471 ArrayRef<OpFoldResult> strides,
1472 ArrayRef<NamedAttribute> attrs) {
1473 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1474 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1475 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1476 ShapedType::kDynamicStrideOrOffset);
1477 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1478 ShapedType::kDynamicSize);
1479 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1480 ShapedType::kDynamicStrideOrOffset);
1481 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1482 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1483 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1484 result.addAttributes(attrs);
1485 }
1486
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,int64_t offset,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1487 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1488 MemRefType resultType, Value source,
1489 int64_t offset, ArrayRef<int64_t> sizes,
1490 ArrayRef<int64_t> strides,
1491 ArrayRef<NamedAttribute> attrs) {
1492 SmallVector<OpFoldResult> sizeValues =
1493 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1494 return b.getI64IntegerAttr(v);
1495 }));
1496 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1497 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1498 return b.getI64IntegerAttr(v);
1499 }));
1500 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1501 strideValues, attrs);
1502 }
1503
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,Value offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1504 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1505 MemRefType resultType, Value source, Value offset,
1506 ValueRange sizes, ValueRange strides,
1507 ArrayRef<NamedAttribute> attrs) {
1508 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1509 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1510 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1511 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1512 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1513 }
1514
1515 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1516 // completed automatically, like we have for subview and extract_slice.
verify()1517 LogicalResult ReinterpretCastOp::verify() {
1518 // The source and result memrefs should be in the same memory space.
1519 auto srcType = getSource().getType().cast<BaseMemRefType>();
1520 auto resultType = getType().cast<MemRefType>();
1521 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1522 return emitError("different memory spaces specified for source type ")
1523 << srcType << " and result memref type " << resultType;
1524 if (srcType.getElementType() != resultType.getElementType())
1525 return emitError("different element types specified for source type ")
1526 << srcType << " and result memref type " << resultType;
1527
1528 // Match sizes in result memref type and in static_sizes attribute.
1529 for (auto &en : llvm::enumerate(llvm::zip(
1530 resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
1531 int64_t resultSize = std::get<0>(en.value());
1532 int64_t expectedSize = std::get<1>(en.value());
1533 if (!ShapedType::isDynamic(resultSize) &&
1534 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1535 return emitError("expected result type with size = ")
1536 << expectedSize << " instead of " << resultSize
1537 << " in dim = " << en.index();
1538 }
1539
1540 // Match offset and strides in static_offset and static_strides attributes. If
1541 // result memref type has no affine map specified, this will assume an
1542 // identity layout.
1543 int64_t resultOffset;
1544 SmallVector<int64_t, 4> resultStrides;
1545 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1546 return emitError("expected result type to have strided layout but found ")
1547 << resultType;
1548
1549 // Match offset in result memref type and in static_offsets attribute.
1550 int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
1551 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1552 !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1553 resultOffset != expectedOffset)
1554 return emitError("expected result type with offset = ")
1555 << resultOffset << " instead of " << expectedOffset;
1556
1557 // Match strides in result memref type and in static_strides attribute.
1558 for (auto &en : llvm::enumerate(llvm::zip(
1559 resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
1560 int64_t resultStride = std::get<0>(en.value());
1561 int64_t expectedStride = std::get<1>(en.value());
1562 if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1563 !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1564 resultStride != expectedStride)
1565 return emitError("expected result type with stride = ")
1566 << expectedStride << " instead of " << resultStride
1567 << " in dim = " << en.index();
1568 }
1569
1570 return success();
1571 }
1572
fold(ArrayRef<Attribute>)1573 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1574 Value src = getSource();
1575 auto getPrevSrc = [&]() -> Value {
1576 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1577 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1578 return prev.getSource();
1579
1580 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1581 if (auto prev = src.getDefiningOp<CastOp>())
1582 return prev.getSource();
1583
1584 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1585 // are 0.
1586 if (auto prev = src.getDefiningOp<SubViewOp>())
1587 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1588 return isConstantIntValue(val, 0);
1589 }))
1590 return prev.getSource();
1591
1592 return nullptr;
1593 };
1594
1595 if (auto prevSrc = getPrevSrc()) {
1596 getSourceMutable().assign(prevSrc);
1597 return getResult();
1598 }
1599
1600 return nullptr;
1601 }
1602
1603 //===----------------------------------------------------------------------===//
1604 // Reassociative reshape ops
1605 //===----------------------------------------------------------------------===//
1606
1607 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
1608 /// result and operand. Layout maps are verified separately.
1609 ///
1610 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
1611 /// allowed in a reassocation group.
1612 static LogicalResult
verifyCollapsedShape(Operation * op,ArrayRef<int64_t> collapsedShape,ArrayRef<int64_t> expandedShape,ArrayRef<ReassociationIndices> reassociation,bool allowMultipleDynamicDimsPerGroup)1613 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
1614 ArrayRef<int64_t> expandedShape,
1615 ArrayRef<ReassociationIndices> reassociation,
1616 bool allowMultipleDynamicDimsPerGroup) {
1617 // There must be one reassociation group per collapsed dimension.
1618 if (collapsedShape.size() != reassociation.size())
1619 return op->emitOpError("invalid number of reassociation groups: found ")
1620 << reassociation.size() << ", expected " << collapsedShape.size();
1621
1622 // The next expected expanded dimension index (while iterating over
1623 // reassociation indices).
1624 int64_t nextDim = 0;
1625 for (const auto &it : llvm::enumerate(reassociation)) {
1626 ReassociationIndices group = it.value();
1627 int64_t collapsedDim = it.index();
1628
1629 bool foundDynamic = false;
1630 for (int64_t expandedDim : group) {
1631 if (expandedDim != nextDim++)
1632 return op->emitOpError("reassociation indices must be contiguous");
1633
1634 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
1635 return op->emitOpError("reassociation index ")
1636 << expandedDim << " is out of bounds";
1637
1638 // Check if there are multiple dynamic dims in a reassociation group.
1639 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
1640 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
1641 return op->emitOpError(
1642 "at most one dimension in a reassociation group may be dynamic");
1643 foundDynamic = true;
1644 }
1645 }
1646
1647 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
1648 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
1649 return op->emitOpError("collapsed dim (")
1650 << collapsedDim
1651 << ") must be dynamic if and only if reassociation group is "
1652 "dynamic";
1653
1654 // If all dims in the reassociation group are static, the size of the
1655 // collapsed dim can be verified.
1656 if (!foundDynamic) {
1657 int64_t groupSize = 1;
1658 for (int64_t expandedDim : group)
1659 groupSize *= expandedShape[expandedDim];
1660 if (groupSize != collapsedShape[collapsedDim])
1661 return op->emitOpError("collapsed dim size (")
1662 << collapsedShape[collapsedDim]
1663 << ") must equal reassociation group size (" << groupSize << ")";
1664 }
1665 }
1666
1667 if (collapsedShape.empty()) {
1668 // Rank 0: All expanded dimensions must be 1.
1669 for (int64_t d : expandedShape)
1670 if (d != 1)
1671 return op->emitOpError(
1672 "rank 0 memrefs can only be extended/collapsed with/from ones");
1673 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
1674 // Rank >= 1: Number of dimensions among all reassociation groups must match
1675 // the result memref rank.
1676 return op->emitOpError("expanded rank (")
1677 << expandedShape.size()
1678 << ") inconsistent with number of reassociation indices (" << nextDim
1679 << ")";
1680 }
1681
1682 return success();
1683 }
1684
getReassociationMaps()1685 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1686 return getSymbolLessAffineMaps(getReassociationExprs());
1687 }
1688
getReassociationExprs()1689 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1690 return convertReassociationIndicesToExprs(getContext(),
1691 getReassociationIndices());
1692 }
1693
getReassociationMaps()1694 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1695 return getSymbolLessAffineMaps(getReassociationExprs());
1696 }
1697
getReassociationExprs()1698 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1699 return convertReassociationIndicesToExprs(getContext(),
1700 getReassociationIndices());
1701 }
1702
1703 /// Compute the layout map after expanding a given source MemRef type with the
1704 /// specified reassociation indices.
1705 static FailureOr<AffineMap>
computeExpandedLayoutMap(MemRefType srcType,ArrayRef<int64_t> resultShape,ArrayRef<ReassociationIndices> reassociation)1706 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
1707 ArrayRef<ReassociationIndices> reassociation) {
1708 int64_t srcOffset;
1709 SmallVector<int64_t> srcStrides;
1710 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1711 return failure();
1712 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
1713
1714 // 1-1 mapping between srcStrides and reassociation packs.
1715 // Each srcStride starts with the given value and gets expanded according to
1716 // the proper entries in resultShape.
1717 // Example:
1718 // srcStrides = [10000, 1 , 100 ],
1719 // reassociations = [ [0], [1], [2, 3, 4]],
1720 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
1721 // -> For the purpose of stride calculation, the useful sizes are:
1722 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
1723 // resultStrides = [10000, 1, 600, 200, 100]
1724 // Note that a stride does not get expanded along the first entry of each
1725 // shape pack.
1726 SmallVector<int64_t> reverseResultStrides;
1727 reverseResultStrides.reserve(resultShape.size());
1728 unsigned shapeIndex = resultShape.size() - 1;
1729 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
1730 ReassociationIndices reassoc = std::get<0>(it);
1731 int64_t currentStrideToExpand = std::get<1>(it);
1732 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
1733 using saturated_arith::Wrapper;
1734 reverseResultStrides.push_back(currentStrideToExpand);
1735 currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
1736 Wrapper::size(resultShape[shapeIndex--]))
1737 .asStride();
1738 }
1739 }
1740 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
1741 resultStrides.resize(resultShape.size(), 1);
1742 return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1743 srcType.getContext());
1744 }
1745
1746 static FailureOr<MemRefType>
computeExpandedType(MemRefType srcType,ArrayRef<int64_t> resultShape,ArrayRef<ReassociationIndices> reassociation)1747 computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
1748 ArrayRef<ReassociationIndices> reassociation) {
1749 if (srcType.getLayout().isIdentity()) {
1750 // If the source is contiguous (i.e., no layout map specified), so is the
1751 // result.
1752 MemRefLayoutAttrInterface layout;
1753 return MemRefType::get(resultShape, srcType.getElementType(), layout,
1754 srcType.getMemorySpace());
1755 }
1756
1757 // Source may not be contiguous. Compute the layout map.
1758 FailureOr<AffineMap> computedLayout =
1759 computeExpandedLayoutMap(srcType, resultShape, reassociation);
1760 if (failed(computedLayout))
1761 return failure();
1762 auto computedType =
1763 MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1764 srcType.getMemorySpaceAsInt());
1765 return canonicalizeStridedLayout(computedType);
1766 }
1767
build(OpBuilder & builder,OperationState & result,ArrayRef<int64_t> resultShape,Value src,ArrayRef<ReassociationIndices> reassociation)1768 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1769 ArrayRef<int64_t> resultShape, Value src,
1770 ArrayRef<ReassociationIndices> reassociation) {
1771 // Only ranked memref source values are supported.
1772 auto srcType = src.getType().cast<MemRefType>();
1773 FailureOr<MemRefType> resultType =
1774 computeExpandedType(srcType, resultShape, reassociation);
1775 // Failure of this assertion usually indicates a problem with the source
1776 // type, e.g., could not get strides/offset.
1777 assert(succeeded(resultType) && "could not compute layout");
1778 build(builder, result, *resultType, src, reassociation);
1779 }
1780
verify()1781 LogicalResult ExpandShapeOp::verify() {
1782 MemRefType srcType = getSrcType();
1783 MemRefType resultType = getResultType();
1784
1785 // Verify result shape.
1786 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
1787 resultType.getShape(),
1788 getReassociationIndices(),
1789 /*allowMultipleDynamicDimsPerGroup=*/false)))
1790 return failure();
1791
1792 // Compute expected result type (including layout map).
1793 FailureOr<MemRefType> expectedResultType = computeExpandedType(
1794 srcType, resultType.getShape(), getReassociationIndices());
1795 if (failed(expectedResultType))
1796 return emitOpError("invalid source layout map");
1797
1798 // Check actual result type.
1799 auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1800 if (*expectedResultType != canonicalizedResultType)
1801 return emitOpError("expected expanded type to be ")
1802 << *expectedResultType << " but found " << canonicalizedResultType;
1803
1804 return success();
1805 }
1806
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1807 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1808 MLIRContext *context) {
1809 results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
1810 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
1811 context);
1812 }
1813
1814 /// Compute the layout map after collapsing a given source MemRef type with the
1815 /// specified reassociation indices.
1816 ///
1817 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
1818 /// not possible to check this by inspecting a MemRefType in the general case.
1819 /// If non-contiguity cannot be checked statically, the collapse is assumed to
1820 /// be valid (and thus accepted by this function) unless `strict = true`.
1821 static FailureOr<AffineMap>
computeCollapsedLayoutMap(MemRefType srcType,ArrayRef<ReassociationIndices> reassociation,bool strict=false)1822 computeCollapsedLayoutMap(MemRefType srcType,
1823 ArrayRef<ReassociationIndices> reassociation,
1824 bool strict = false) {
1825 int64_t srcOffset;
1826 SmallVector<int64_t> srcStrides;
1827 auto srcShape = srcType.getShape();
1828 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1829 return failure();
1830
1831 // The result stride of a reassociation group is the stride of the last entry
1832 // of the reassociation. (TODO: Should be the minimum stride in the
1833 // reassociation because strides are not necessarily sorted. E.g., when using
1834 // memref.transpose.) Dimensions of size 1 should be skipped, because their
1835 // strides are meaningless and could have any arbitrary value.
1836 SmallVector<int64_t> resultStrides;
1837 resultStrides.reserve(reassociation.size());
1838 for (const ReassociationIndices &reassoc : reassociation) {
1839 ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
1840 while (srcShape[ref.back()] == 1 && ref.size() > 1)
1841 ref = ref.drop_back();
1842 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1843 resultStrides.push_back(srcStrides[ref.back()]);
1844 } else {
1845 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
1846 // the corresponding stride may have to be skipped. (See above comment.)
1847 // Therefore, the result stride cannot be statically determined and must
1848 // be dynamic.
1849 resultStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1850 }
1851 }
1852
1853 // Validate that each reassociation group is contiguous.
1854 unsigned resultStrideIndex = resultStrides.size() - 1;
1855 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
1856 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
1857 using saturated_arith::Wrapper;
1858 auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
1859 for (int64_t idx : llvm::reverse(trailingReassocs)) {
1860 stride = stride * Wrapper::size(srcShape[idx]);
1861
1862 // Both source and result stride must have the same static value. In that
1863 // case, we can be sure, that the dimensions are collapsible (because they
1864 // are contiguous).
1865 //
1866 // One special case is when the srcShape is `1`, in which case it can
1867 // never produce non-contiguity.
1868 if (srcShape[idx] == 1)
1869 continue;
1870
1871 // If `strict = false` (default during op verification), we accept cases
1872 // where one or both strides are dynamic. This is best effort: We reject
1873 // ops where obviously non-contiguous dims are collapsed, but accept ops
1874 // where we cannot be sure statically. Such ops may fail at runtime. See
1875 // the op documentation for details.
1876 auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
1877 if (strict && (stride.saturated || srcStride.saturated))
1878 return failure();
1879
1880 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
1881 return failure();
1882 }
1883 }
1884 return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1885 srcType.getContext());
1886 }
1887
isGuaranteedCollapsible(MemRefType srcType,ArrayRef<ReassociationIndices> reassociation)1888 bool CollapseShapeOp::isGuaranteedCollapsible(
1889 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
1890 // MemRefs with standard layout are always collapsible.
1891 if (srcType.getLayout().isIdentity())
1892 return true;
1893
1894 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
1895 /*strict=*/true));
1896 }
1897
1898 static MemRefType
computeCollapsedType(MemRefType srcType,ArrayRef<ReassociationIndices> reassociation)1899 computeCollapsedType(MemRefType srcType,
1900 ArrayRef<ReassociationIndices> reassociation) {
1901 SmallVector<int64_t> resultShape;
1902 resultShape.reserve(reassociation.size());
1903 for (const ReassociationIndices &group : reassociation) {
1904 using saturated_arith::Wrapper;
1905 auto groupSize = Wrapper::size(1);
1906 for (int64_t srcDim : group)
1907 groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
1908 resultShape.push_back(groupSize.asSize());
1909 }
1910
1911 if (srcType.getLayout().isIdentity()) {
1912 // If the source is contiguous (i.e., no layout map specified), so is the
1913 // result.
1914 MemRefLayoutAttrInterface layout;
1915 return MemRefType::get(resultShape, srcType.getElementType(), layout,
1916 srcType.getMemorySpace());
1917 }
1918
1919 // Source may not be fully contiguous. Compute the layout map.
1920 // Note: Dimensions that are collapsed into a single dim are assumed to be
1921 // contiguous.
1922 FailureOr<AffineMap> computedLayout =
1923 computeCollapsedLayoutMap(srcType, reassociation);
1924 assert(succeeded(computedLayout) &&
1925 "invalid source layout map or collapsing non-contiguous dims");
1926 auto computedType =
1927 MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1928 srcType.getMemorySpaceAsInt());
1929 return canonicalizeStridedLayout(computedType);
1930 }
1931
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)1932 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1933 ArrayRef<ReassociationIndices> reassociation,
1934 ArrayRef<NamedAttribute> attrs) {
1935 auto srcType = src.getType().cast<MemRefType>();
1936 MemRefType resultType = computeCollapsedType(srcType, reassociation);
1937 build(b, result, resultType, src, attrs);
1938 result.addAttribute(::mlir::getReassociationAttrName(),
1939 getReassociationIndicesAttribute(b, reassociation));
1940 }
1941
verify()1942 LogicalResult CollapseShapeOp::verify() {
1943 MemRefType srcType = getSrcType();
1944 MemRefType resultType = getResultType();
1945
1946 // Verify result shape.
1947 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
1948 srcType.getShape(), getReassociationIndices(),
1949 /*allowMultipleDynamicDimsPerGroup=*/true)))
1950 return failure();
1951
1952 // Compute expected result type (including layout map).
1953 MemRefType expectedResultType;
1954 if (srcType.getLayout().isIdentity()) {
1955 // If the source is contiguous (i.e., no layout map specified), so is the
1956 // result.
1957 MemRefLayoutAttrInterface layout;
1958 expectedResultType =
1959 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
1960 srcType.getMemorySpace());
1961 } else {
1962 // Source may not be fully contiguous. Compute the layout map.
1963 // Note: Dimensions that are collapsed into a single dim are assumed to be
1964 // contiguous.
1965 FailureOr<AffineMap> computedLayout =
1966 computeCollapsedLayoutMap(srcType, getReassociationIndices());
1967 if (failed(computedLayout))
1968 return emitOpError(
1969 "invalid source layout map or collapsing non-contiguous dims");
1970 auto computedType =
1971 MemRefType::get(resultType.getShape(), srcType.getElementType(),
1972 *computedLayout, srcType.getMemorySpaceAsInt());
1973 expectedResultType = canonicalizeStridedLayout(computedType);
1974 }
1975
1976 auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1977 if (expectedResultType != canonicalizedResultType)
1978 return emitOpError("expected collapsed type to be ")
1979 << expectedResultType << " but found " << canonicalizedResultType;
1980
1981 return success();
1982 }
1983
1984 struct CollapseShapeOpMemRefCastFolder
1985 : public OpRewritePattern<CollapseShapeOp> {
1986 public:
1987 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1988
matchAndRewriteCollapseShapeOpMemRefCastFolder1989 LogicalResult matchAndRewrite(CollapseShapeOp op,
1990 PatternRewriter &rewriter) const override {
1991 auto cast = op.getOperand().getDefiningOp<CastOp>();
1992 if (!cast)
1993 return failure();
1994
1995 if (!CastOp::canFoldIntoConsumerOp(cast))
1996 return failure();
1997
1998 Type newResultType =
1999 computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(),
2000 op.getReassociationIndices());
2001
2002 if (newResultType == op.getResultType()) {
2003 rewriter.updateRootInPlace(
2004 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2005 } else {
2006 Value newOp = rewriter.create<CollapseShapeOp>(
2007 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2008 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2009 }
2010 return success();
2011 }
2012 };
2013
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2014 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2015 MLIRContext *context) {
2016 results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
2017 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
2018 CollapseShapeOpMemRefCastFolder>(context);
2019 }
2020
fold(ArrayRef<Attribute> operands)2021 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2022 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2023 }
2024
fold(ArrayRef<Attribute> operands)2025 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2026 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2027 }
2028
2029 //===----------------------------------------------------------------------===//
2030 // ReshapeOp
2031 //===----------------------------------------------------------------------===//
2032
verify()2033 LogicalResult ReshapeOp::verify() {
2034 Type operandType = getSource().getType();
2035 Type resultType = getResult().getType();
2036
2037 Type operandElementType = operandType.cast<ShapedType>().getElementType();
2038 Type resultElementType = resultType.cast<ShapedType>().getElementType();
2039 if (operandElementType != resultElementType)
2040 return emitOpError("element types of source and destination memref "
2041 "types should be the same");
2042
2043 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2044 if (!operandMemRefType.getLayout().isIdentity())
2045 return emitOpError("source memref type should have identity affine map");
2046
2047 int64_t shapeSize = getShape().getType().cast<MemRefType>().getDimSize(0);
2048 auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2049 if (resultMemRefType) {
2050 if (!resultMemRefType.getLayout().isIdentity())
2051 return emitOpError("result memref type should have identity affine map");
2052 if (shapeSize == ShapedType::kDynamicSize)
2053 return emitOpError("cannot use shape operand with dynamic length to "
2054 "reshape to statically-ranked memref type");
2055 if (shapeSize != resultMemRefType.getRank())
2056 return emitOpError(
2057 "length of shape operand differs from the result's memref rank");
2058 }
2059 return success();
2060 }
2061
2062 //===----------------------------------------------------------------------===//
2063 // StoreOp
2064 //===----------------------------------------------------------------------===//
2065
verify()2066 LogicalResult StoreOp::verify() {
2067 if (getNumOperands() != 2 + getMemRefType().getRank())
2068 return emitOpError("store index operand count not equal to memref rank");
2069
2070 return success();
2071 }
2072
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2073 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2074 SmallVectorImpl<OpFoldResult> &results) {
2075 /// store(memrefcast) -> store
2076 return foldMemRefCast(*this, getValueToStore());
2077 }
2078
2079 //===----------------------------------------------------------------------===//
2080 // SubViewOp
2081 //===----------------------------------------------------------------------===//
2082
2083 /// A subview result type can be fully inferred from the source type and the
2084 /// static representation of offsets, sizes and strides. Special sentinels
2085 /// encode the dynamic case.
inferResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)2086 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2087 ArrayRef<int64_t> staticOffsets,
2088 ArrayRef<int64_t> staticSizes,
2089 ArrayRef<int64_t> staticStrides) {
2090 unsigned rank = sourceMemRefType.getRank();
2091 (void)rank;
2092 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2093 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2094 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2095
2096 // Extract source offset and strides.
2097 int64_t sourceOffset;
2098 SmallVector<int64_t, 4> sourceStrides;
2099 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2100 assert(succeeded(res) && "SubViewOp expected strided memref type");
2101 (void)res;
2102
2103 // Compute target offset whose value is:
2104 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2105 int64_t targetOffset = sourceOffset;
2106 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2107 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2108 using saturated_arith::Wrapper;
2109 targetOffset =
2110 (Wrapper::offset(targetOffset) +
2111 Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2112 .asOffset();
2113 }
2114
2115 // Compute target stride whose value is:
2116 // `sourceStrides_i * staticStrides_i`.
2117 SmallVector<int64_t, 4> targetStrides;
2118 targetStrides.reserve(staticOffsets.size());
2119 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2120 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2121 using saturated_arith::Wrapper;
2122 targetStrides.push_back(
2123 (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2124 .asStride());
2125 }
2126
2127 // The type is now known.
2128 return MemRefType::get(
2129 staticSizes, sourceMemRefType.getElementType(),
2130 makeStridedLinearLayoutMap(targetStrides, targetOffset,
2131 sourceMemRefType.getContext()),
2132 sourceMemRefType.getMemorySpace());
2133 }
2134
inferResultType(MemRefType sourceMemRefType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)2135 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2136 ArrayRef<OpFoldResult> offsets,
2137 ArrayRef<OpFoldResult> sizes,
2138 ArrayRef<OpFoldResult> strides) {
2139 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2140 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2141 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2142 ShapedType::kDynamicStrideOrOffset);
2143 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2144 ShapedType::kDynamicSize);
2145 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2146 ShapedType::kDynamicStrideOrOffset);
2147 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2148 staticSizes, staticStrides);
2149 }
2150
inferRankReducedResultType(ArrayRef<int64_t> resultShape,MemRefType sourceRankedTensorType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)2151 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2152 MemRefType sourceRankedTensorType,
2153 ArrayRef<int64_t> offsets,
2154 ArrayRef<int64_t> sizes,
2155 ArrayRef<int64_t> strides) {
2156 auto inferredType =
2157 inferResultType(sourceRankedTensorType, offsets, sizes, strides)
2158 .cast<MemRefType>();
2159 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2160 "expected ");
2161 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2162 return inferredType;
2163
2164 // Compute which dimensions are dropped.
2165 Optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2166 computeRankReductionMask(inferredType.getShape(), resultShape);
2167 assert(dimsToProject.has_value() && "invalid rank reduction");
2168 llvm::SmallBitVector dimsToProjectVector(inferredType.getRank());
2169 for (unsigned dim : *dimsToProject)
2170 dimsToProjectVector.set(dim);
2171
2172 // Compute layout map and result type.
2173 AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(),
2174 dimsToProjectVector);
2175 return MemRefType::get(resultShape, inferredType.getElementType(), map,
2176 inferredType.getMemorySpace());
2177 }
2178
inferRankReducedResultType(ArrayRef<int64_t> resultShape,MemRefType sourceRankedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)2179 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2180 MemRefType sourceRankedTensorType,
2181 ArrayRef<OpFoldResult> offsets,
2182 ArrayRef<OpFoldResult> sizes,
2183 ArrayRef<OpFoldResult> strides) {
2184 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2185 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2186 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2187 ShapedType::kDynamicStrideOrOffset);
2188 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2189 ShapedType::kDynamicSize);
2190 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2191 ShapedType::kDynamicStrideOrOffset);
2192 return SubViewOp::inferRankReducedResultType(
2193 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2194 staticStrides);
2195 }
2196
2197 // Build a SubViewOp with mixed static and dynamic entries and custom result
2198 // type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)2199 void SubViewOp::build(OpBuilder &b, OperationState &result,
2200 MemRefType resultType, Value source,
2201 ArrayRef<OpFoldResult> offsets,
2202 ArrayRef<OpFoldResult> sizes,
2203 ArrayRef<OpFoldResult> strides,
2204 ArrayRef<NamedAttribute> attrs) {
2205 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2206 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2207 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2208 ShapedType::kDynamicStrideOrOffset);
2209 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2210 ShapedType::kDynamicSize);
2211 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2212 ShapedType::kDynamicStrideOrOffset);
2213 auto sourceMemRefType = source.getType().cast<MemRefType>();
2214 // Structuring implementation this way avoids duplication between builders.
2215 if (!resultType) {
2216 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2217 staticSizes, staticStrides)
2218 .cast<MemRefType>();
2219 }
2220 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2221 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
2222 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
2223 result.addAttributes(attrs);
2224 }
2225
2226 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2227 // type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)2228 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2229 ArrayRef<OpFoldResult> offsets,
2230 ArrayRef<OpFoldResult> sizes,
2231 ArrayRef<OpFoldResult> strides,
2232 ArrayRef<NamedAttribute> attrs) {
2233 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2234 }
2235
2236 // Build a SubViewOp with static entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)2237 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2238 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2239 ArrayRef<int64_t> strides,
2240 ArrayRef<NamedAttribute> attrs) {
2241 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2242 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2243 return b.getI64IntegerAttr(v);
2244 }));
2245 SmallVector<OpFoldResult> sizeValues =
2246 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2247 return b.getI64IntegerAttr(v);
2248 }));
2249 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2250 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2251 return b.getI64IntegerAttr(v);
2252 }));
2253 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2254 }
2255
2256 // Build a SubViewOp with dynamic entries and custom result type. If the
2257 // type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)2258 void SubViewOp::build(OpBuilder &b, OperationState &result,
2259 MemRefType resultType, Value source,
2260 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2261 ArrayRef<int64_t> strides,
2262 ArrayRef<NamedAttribute> attrs) {
2263 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2264 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2265 return b.getI64IntegerAttr(v);
2266 }));
2267 SmallVector<OpFoldResult> sizeValues =
2268 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2269 return b.getI64IntegerAttr(v);
2270 }));
2271 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2272 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2273 return b.getI64IntegerAttr(v);
2274 }));
2275 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2276 attrs);
2277 }
2278
2279 // Build a SubViewOp with dynamic entries and custom result type. If the type
2280 // passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2281 void SubViewOp::build(OpBuilder &b, OperationState &result,
2282 MemRefType resultType, Value source, ValueRange offsets,
2283 ValueRange sizes, ValueRange strides,
2284 ArrayRef<NamedAttribute> attrs) {
2285 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2286 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2287 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2288 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2289 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2290 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2291 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2292 }
2293
2294 // Build a SubViewOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2295 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2296 ValueRange offsets, ValueRange sizes, ValueRange strides,
2297 ArrayRef<NamedAttribute> attrs) {
2298 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2299 }
2300
2301 /// For ViewLikeOpInterface.
getViewSource()2302 Value SubViewOp::getViewSource() { return getSource(); }
2303
2304 /// Return true if t1 and t2 have equal offsets (both dynamic or of same
2305 /// static value).
haveCompatibleOffsets(MemRefType t1,MemRefType t2)2306 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2307 AffineExpr t1Offset, t2Offset;
2308 SmallVector<AffineExpr> t1Strides, t2Strides;
2309 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2310 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2311 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2312 }
2313
2314 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2315 /// This function is slight variant of `is subsequence` algorithm where
2316 /// not matching dimension must be 1.
2317 static SliceVerificationResult
isRankReducedMemRefType(MemRefType originalType,MemRefType candidateRankReducedType,ArrayRef<OpFoldResult> sizes)2318 isRankReducedMemRefType(MemRefType originalType,
2319 MemRefType candidateRankReducedType,
2320 ArrayRef<OpFoldResult> sizes) {
2321 auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2322 if (partialRes != SliceVerificationResult::Success)
2323 return partialRes;
2324
2325 auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2326 originalType, candidateRankReducedType, sizes);
2327
2328 // Sizes cannot be matched in case empty vector is returned.
2329 if (!optionalUnusedDimsMask)
2330 return SliceVerificationResult::LayoutMismatch;
2331
2332 if (originalType.getMemorySpace() !=
2333 candidateRankReducedType.getMemorySpace())
2334 return SliceVerificationResult::MemSpaceMismatch;
2335
2336 // No amount of stride dropping can reconcile incompatible offsets.
2337 if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2338 return SliceVerificationResult::LayoutMismatch;
2339
2340 return SliceVerificationResult::Success;
2341 }
2342
2343 template <typename OpTy>
produceSubViewErrorMsg(SliceVerificationResult result,OpTy op,Type expectedType)2344 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2345 OpTy op, Type expectedType) {
2346 auto memrefType = expectedType.cast<ShapedType>();
2347 switch (result) {
2348 case SliceVerificationResult::Success:
2349 return success();
2350 case SliceVerificationResult::RankTooLarge:
2351 return op.emitError("expected result rank to be smaller or equal to ")
2352 << "the source rank. ";
2353 case SliceVerificationResult::SizeMismatch:
2354 return op.emitError("expected result type to be ")
2355 << expectedType
2356 << " or a rank-reduced version. (mismatch of result sizes) ";
2357 case SliceVerificationResult::ElemTypeMismatch:
2358 return op.emitError("expected result element type to be ")
2359 << memrefType.getElementType();
2360 case SliceVerificationResult::MemSpaceMismatch:
2361 return op.emitError("expected result and source memory spaces to match.");
2362 case SliceVerificationResult::LayoutMismatch:
2363 return op.emitError("expected result type to be ")
2364 << expectedType
2365 << " or a rank-reduced version. (mismatch of result layout) ";
2366 }
2367 llvm_unreachable("unexpected subview verification result");
2368 }
2369
2370 /// Verifier for SubViewOp.
verify()2371 LogicalResult SubViewOp::verify() {
2372 MemRefType baseType = getSourceType();
2373 MemRefType subViewType = getType();
2374
2375 // The base memref and the view memref should be in the same memory space.
2376 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2377 return emitError("different memory spaces specified for base memref "
2378 "type ")
2379 << baseType << " and subview memref type " << subViewType;
2380
2381 // Verify that the base memref type has a strided layout map.
2382 if (!isStrided(baseType))
2383 return emitError("base type ") << baseType << " is not strided";
2384
2385 // Verify result type against inferred type.
2386 auto expectedType = SubViewOp::inferResultType(
2387 baseType, extractFromI64ArrayAttr(getStaticOffsets()),
2388 extractFromI64ArrayAttr(getStaticSizes()),
2389 extractFromI64ArrayAttr(getStaticStrides()));
2390
2391 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
2392 subViewType, getMixedSizes());
2393 return produceSubViewErrorMsg(result, *this, expectedType);
2394 }
2395
operator <<(raw_ostream & os,const Range & range)2396 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2397 return os << "range " << range.offset << ":" << range.size << ":"
2398 << range.stride;
2399 }
2400
2401 /// Return the list of Range (i.e. offset, size, stride). Each Range
2402 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2403 /// with `b` at location `loc`.
getOrCreateRanges(OffsetSizeAndStrideOpInterface op,OpBuilder & b,Location loc)2404 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2405 OpBuilder &b, Location loc) {
2406 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2407 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2408 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2409 SmallVector<Range, 8> res;
2410 unsigned rank = ranks[0];
2411 res.reserve(rank);
2412 for (unsigned idx = 0; idx < rank; ++idx) {
2413 Value offset =
2414 op.isDynamicOffset(idx)
2415 ? op.getDynamicOffset(idx)
2416 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2417 Value size =
2418 op.isDynamicSize(idx)
2419 ? op.getDynamicSize(idx)
2420 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2421 Value stride =
2422 op.isDynamicStride(idx)
2423 ? op.getDynamicStride(idx)
2424 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2425 res.emplace_back(Range{offset, size, stride});
2426 }
2427 return res;
2428 }
2429
2430 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2431 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2432 /// the rank of the inferred result type if `currentResultType` is lower rank
2433 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2434 /// together with the result type. In this case, it is important to compute
2435 /// the dropped dimensions using `currentSourceType` whose strides align with
2436 /// `currentResultType`.
getCanonicalSubViewResultType(MemRefType currentResultType,MemRefType currentSourceType,MemRefType sourceType,ArrayRef<OpFoldResult> mixedOffsets,ArrayRef<OpFoldResult> mixedSizes,ArrayRef<OpFoldResult> mixedStrides)2437 static MemRefType getCanonicalSubViewResultType(
2438 MemRefType currentResultType, MemRefType currentSourceType,
2439 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2440 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2441 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2442 mixedSizes, mixedStrides)
2443 .cast<MemRefType>();
2444 llvm::Optional<llvm::SmallBitVector> unusedDims =
2445 computeMemRefRankReductionMask(currentSourceType, currentResultType,
2446 mixedSizes);
2447 // Return nullptr as failure mode.
2448 if (!unusedDims)
2449 return nullptr;
2450 SmallVector<int64_t> shape;
2451 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2452 if (unusedDims->test(sizes.index()))
2453 continue;
2454 shape.push_back(sizes.value());
2455 }
2456 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2457 if (!layoutMap.isIdentity())
2458 layoutMap = getProjectedMap(layoutMap, *unusedDims);
2459 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2460 nonRankReducedType.getMemorySpace());
2461 }
2462
2463 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2464 /// to deduce the result type. Additionally, reduce the rank of the inferred
2465 /// result type if `currentResultType` is lower rank than `sourceType`.
getCanonicalSubViewResultType(MemRefType currentResultType,MemRefType sourceType,ArrayRef<OpFoldResult> mixedOffsets,ArrayRef<OpFoldResult> mixedSizes,ArrayRef<OpFoldResult> mixedStrides)2466 static MemRefType getCanonicalSubViewResultType(
2467 MemRefType currentResultType, MemRefType sourceType,
2468 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2469 ArrayRef<OpFoldResult> mixedStrides) {
2470 return getCanonicalSubViewResultType(currentResultType, sourceType,
2471 sourceType, mixedOffsets, mixedSizes,
2472 mixedStrides);
2473 }
2474
2475 /// Helper method to check if a `subview` operation is trivially a no-op. This
2476 /// is the case if the all offsets are zero, all strides are 1, and the source
2477 /// shape is same as the size of the subview. In such cases, the subview can
2478 /// be folded into its source.
isTrivialSubViewOp(SubViewOp subViewOp)2479 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2480 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2481 return false;
2482
2483 auto mixedOffsets = subViewOp.getMixedOffsets();
2484 auto mixedSizes = subViewOp.getMixedSizes();
2485 auto mixedStrides = subViewOp.getMixedStrides();
2486
2487 // Check offsets are zero.
2488 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2489 Optional<int64_t> intValue = getConstantIntValue(ofr);
2490 return !intValue || intValue.value() != 0;
2491 }))
2492 return false;
2493
2494 // Check strides are one.
2495 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2496 Optional<int64_t> intValue = getConstantIntValue(ofr);
2497 return !intValue || intValue.value() != 1;
2498 }))
2499 return false;
2500
2501 // Check all size values are static and matches the (static) source shape.
2502 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2503 for (const auto &size : llvm::enumerate(mixedSizes)) {
2504 Optional<int64_t> intValue = getConstantIntValue(size.value());
2505 if (!intValue || *intValue != sourceShape[size.index()])
2506 return false;
2507 }
2508 // All conditions met. The `SubViewOp` is foldable as a no-op.
2509 return true;
2510 }
2511
2512 namespace {
2513 /// Pattern to rewrite a subview op with MemRefCast arguments.
2514 /// This essentially pushes memref.cast past its consuming subview when
2515 /// `canFoldIntoConsumerOp` is true.
2516 ///
2517 /// Example:
2518 /// ```
2519 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2520 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2521 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2522 /// ```
2523 /// is rewritten into:
2524 /// ```
2525 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2526 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2527 /// memref<3x4xf32, offset:?, strides:[?, 1]>
2528 /// ```
2529 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2530 public:
2531 using OpRewritePattern<SubViewOp>::OpRewritePattern;
2532
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2533 LogicalResult matchAndRewrite(SubViewOp subViewOp,
2534 PatternRewriter &rewriter) const override {
2535 // Any constant operand, just return to let SubViewOpConstantFolder kick
2536 // in.
2537 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2538 return matchPattern(operand, matchConstantIndex());
2539 }))
2540 return failure();
2541
2542 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
2543 if (!castOp)
2544 return failure();
2545
2546 if (!CastOp::canFoldIntoConsumerOp(castOp))
2547 return failure();
2548
2549 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
2550 // the MemRefCastOp source operand type to infer the result type and the
2551 // current SubViewOp source operand type to compute the dropped dimensions
2552 // if the operation is rank-reducing.
2553 auto resultType = getCanonicalSubViewResultType(
2554 subViewOp.getType(), subViewOp.getSourceType(),
2555 castOp.getSource().getType().cast<MemRefType>(),
2556 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2557 subViewOp.getMixedStrides());
2558 if (!resultType)
2559 return failure();
2560
2561 Value newSubView = rewriter.create<SubViewOp>(
2562 subViewOp.getLoc(), resultType, castOp.getSource(),
2563 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
2564 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
2565 subViewOp.getStaticStrides());
2566 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2567 newSubView);
2568 return success();
2569 }
2570 };
2571
2572 /// Canonicalize subview ops that are no-ops. When the source shape is not
2573 /// same as a result shape due to use of `affine_map`.
2574 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2575 public:
2576 using OpRewritePattern<SubViewOp>::OpRewritePattern;
2577
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2578 LogicalResult matchAndRewrite(SubViewOp subViewOp,
2579 PatternRewriter &rewriter) const override {
2580 if (!isTrivialSubViewOp(subViewOp))
2581 return failure();
2582 if (subViewOp.getSourceType() == subViewOp.getType()) {
2583 rewriter.replaceOp(subViewOp, subViewOp.getSource());
2584 return success();
2585 }
2586 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2587 subViewOp.getSource());
2588 return success();
2589 }
2590 };
2591 } // namespace
2592
2593 /// Return the canonical type of the result of a subview.
2594 struct SubViewReturnTypeCanonicalizer {
operator ()SubViewReturnTypeCanonicalizer2595 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2596 ArrayRef<OpFoldResult> mixedSizes,
2597 ArrayRef<OpFoldResult> mixedStrides) {
2598 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2599 mixedOffsets, mixedSizes,
2600 mixedStrides);
2601 }
2602 };
2603
2604 /// A canonicalizer wrapper to replace SubViewOps.
2605 struct SubViewCanonicalizer {
operator ()SubViewCanonicalizer2606 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2607 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2608 }
2609 };
2610
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2611 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2612 MLIRContext *context) {
2613 results
2614 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2615 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2616 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2617 }
2618
fold(ArrayRef<Attribute> operands)2619 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2620 auto resultShapedType = getResult().getType().cast<ShapedType>();
2621 auto sourceShapedType = getSource().getType().cast<ShapedType>();
2622
2623 if (resultShapedType.hasStaticShape() &&
2624 resultShapedType == sourceShapedType) {
2625 return getViewSource();
2626 }
2627
2628 return {};
2629 }
2630
2631 //===----------------------------------------------------------------------===//
2632 // TransposeOp
2633 //===----------------------------------------------------------------------===//
2634
2635 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
inferTransposeResultType(MemRefType memRefType,AffineMap permutationMap)2636 static MemRefType inferTransposeResultType(MemRefType memRefType,
2637 AffineMap permutationMap) {
2638 auto rank = memRefType.getRank();
2639 auto originalSizes = memRefType.getShape();
2640 // Compute permuted sizes.
2641 SmallVector<int64_t, 4> sizes(rank, 0);
2642 for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2643 sizes[en.index()] =
2644 originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2645
2646 // Compute permuted strides.
2647 int64_t offset;
2648 SmallVector<int64_t, 4> strides;
2649 auto res = getStridesAndOffset(memRefType, strides, offset);
2650 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2651 (void)res;
2652 auto map =
2653 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2654 map = permutationMap ? map.compose(permutationMap) : map;
2655 return MemRefType::Builder(memRefType)
2656 .setShape(sizes)
2657 .setLayout(AffineMapAttr::get(map));
2658 }
2659
build(OpBuilder & b,OperationState & result,Value in,AffineMapAttr permutation,ArrayRef<NamedAttribute> attrs)2660 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2661 AffineMapAttr permutation,
2662 ArrayRef<NamedAttribute> attrs) {
2663 auto permutationMap = permutation.getValue();
2664 assert(permutationMap);
2665
2666 auto memRefType = in.getType().cast<MemRefType>();
2667 // Compute result type.
2668 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2669
2670 build(b, result, resultType, in, attrs);
2671 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
2672 }
2673
2674 // transpose $in $permutation attr-dict : type($in) `to` type(results)
print(OpAsmPrinter & p)2675 void TransposeOp::print(OpAsmPrinter &p) {
2676 p << " " << getIn() << " " << getPermutation();
2677 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
2678 p << " : " << getIn().getType() << " to " << getType();
2679 }
2680
parse(OpAsmParser & parser,OperationState & result)2681 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2682 OpAsmParser::UnresolvedOperand in;
2683 AffineMap permutation;
2684 MemRefType srcType, dstType;
2685 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2686 parser.parseOptionalAttrDict(result.attributes) ||
2687 parser.parseColonType(srcType) ||
2688 parser.resolveOperand(in, srcType, result.operands) ||
2689 parser.parseKeywordType("to", dstType) ||
2690 parser.addTypeToList(dstType, result.types))
2691 return failure();
2692
2693 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
2694 AffineMapAttr::get(permutation));
2695 return success();
2696 }
2697
verify()2698 LogicalResult TransposeOp::verify() {
2699 if (!getPermutation().isPermutation())
2700 return emitOpError("expected a permutation map");
2701 if (getPermutation().getNumDims() != getShapedType().getRank())
2702 return emitOpError("expected a permutation map of same rank as the input");
2703
2704 auto srcType = getIn().getType().cast<MemRefType>();
2705 auto dstType = getType().cast<MemRefType>();
2706 auto transposedType = inferTransposeResultType(srcType, getPermutation());
2707 if (dstType != transposedType)
2708 return emitOpError("output type ")
2709 << dstType << " does not match transposed input type " << srcType
2710 << ", " << transposedType;
2711 return success();
2712 }
2713
fold(ArrayRef<Attribute>)2714 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2715 if (succeeded(foldMemRefCast(*this)))
2716 return getResult();
2717 return {};
2718 }
2719
2720 //===----------------------------------------------------------------------===//
2721 // ViewOp
2722 //===----------------------------------------------------------------------===//
2723
verify()2724 LogicalResult ViewOp::verify() {
2725 auto baseType = getOperand(0).getType().cast<MemRefType>();
2726 auto viewType = getType();
2727
2728 // The base memref should have identity layout map (or none).
2729 if (!baseType.getLayout().isIdentity())
2730 return emitError("unsupported map for base memref type ") << baseType;
2731
2732 // The result memref should have identity layout map (or none).
2733 if (!viewType.getLayout().isIdentity())
2734 return emitError("unsupported map for result memref type ") << viewType;
2735
2736 // The base memref and the view memref should be in the same memory space.
2737 if (baseType.getMemorySpace() != viewType.getMemorySpace())
2738 return emitError("different memory spaces specified for base memref "
2739 "type ")
2740 << baseType << " and view memref type " << viewType;
2741
2742 // Verify that we have the correct number of sizes for the result type.
2743 unsigned numDynamicDims = viewType.getNumDynamicDims();
2744 if (getSizes().size() != numDynamicDims)
2745 return emitError("incorrect number of size operands for type ") << viewType;
2746
2747 return success();
2748 }
2749
getViewSource()2750 Value ViewOp::getViewSource() { return getSource(); }
2751
2752 namespace {
2753
2754 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2755 using OpRewritePattern<ViewOp>::OpRewritePattern;
2756
matchAndRewrite__anon3c5201a12511::ViewOpShapeFolder2757 LogicalResult matchAndRewrite(ViewOp viewOp,
2758 PatternRewriter &rewriter) const override {
2759 // Return if none of the operands are constants.
2760 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2761 return matchPattern(operand, matchConstantIndex());
2762 }))
2763 return failure();
2764
2765 // Get result memref type.
2766 auto memrefType = viewOp.getType();
2767
2768 // Get offset from old memref view type 'memRefType'.
2769 int64_t oldOffset;
2770 SmallVector<int64_t, 4> oldStrides;
2771 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2772 return failure();
2773 assert(oldOffset == 0 && "Expected 0 offset");
2774
2775 SmallVector<Value, 4> newOperands;
2776
2777 // Offset cannot be folded into result type.
2778
2779 // Fold any dynamic dim operands which are produced by a constant.
2780 SmallVector<int64_t, 4> newShapeConstants;
2781 newShapeConstants.reserve(memrefType.getRank());
2782
2783 unsigned dynamicDimPos = 0;
2784 unsigned rank = memrefType.getRank();
2785 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2786 int64_t dimSize = memrefType.getDimSize(dim);
2787 // If this is already static dimension, keep it.
2788 if (!ShapedType::isDynamic(dimSize)) {
2789 newShapeConstants.push_back(dimSize);
2790 continue;
2791 }
2792 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
2793 if (auto constantIndexOp =
2794 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2795 // Dynamic shape dimension will be folded.
2796 newShapeConstants.push_back(constantIndexOp.value());
2797 } else {
2798 // Dynamic shape dimension not folded; copy operand from old memref.
2799 newShapeConstants.push_back(dimSize);
2800 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
2801 }
2802 dynamicDimPos++;
2803 }
2804
2805 // Create new memref type with constant folded dims.
2806 MemRefType newMemRefType =
2807 MemRefType::Builder(memrefType).setShape(newShapeConstants);
2808 // Nothing new, don't fold.
2809 if (newMemRefType == memrefType)
2810 return failure();
2811
2812 // Create new ViewOp.
2813 auto newViewOp = rewriter.create<ViewOp>(
2814 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
2815 viewOp.getByteShift(), newOperands);
2816 // Insert a cast so we have the same type as the old memref type.
2817 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2818 return success();
2819 }
2820 };
2821
2822 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2823 using OpRewritePattern<ViewOp>::OpRewritePattern;
2824
matchAndRewrite__anon3c5201a12511::ViewOpMemrefCastFolder2825 LogicalResult matchAndRewrite(ViewOp viewOp,
2826 PatternRewriter &rewriter) const override {
2827 Value memrefOperand = viewOp.getOperand(0);
2828 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2829 if (!memrefCastOp)
2830 return failure();
2831 Value allocOperand = memrefCastOp.getOperand();
2832 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2833 if (!allocOp)
2834 return failure();
2835 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2836 viewOp.getByteShift(),
2837 viewOp.getSizes());
2838 return success();
2839 }
2840 };
2841
2842 } // namespace
2843
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2844 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2845 MLIRContext *context) {
2846 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2847 }
2848
2849 //===----------------------------------------------------------------------===//
2850 // AtomicRMWOp
2851 //===----------------------------------------------------------------------===//
2852
verify()2853 LogicalResult AtomicRMWOp::verify() {
2854 if (getMemRefType().getRank() != getNumOperands() - 2)
2855 return emitOpError(
2856 "expects the number of subscripts to be equal to memref rank");
2857 switch (getKind()) {
2858 case arith::AtomicRMWKind::addf:
2859 case arith::AtomicRMWKind::maxf:
2860 case arith::AtomicRMWKind::minf:
2861 case arith::AtomicRMWKind::mulf:
2862 if (!getValue().getType().isa<FloatType>())
2863 return emitOpError() << "with kind '"
2864 << arith::stringifyAtomicRMWKind(getKind())
2865 << "' expects a floating-point type";
2866 break;
2867 case arith::AtomicRMWKind::addi:
2868 case arith::AtomicRMWKind::maxs:
2869 case arith::AtomicRMWKind::maxu:
2870 case arith::AtomicRMWKind::mins:
2871 case arith::AtomicRMWKind::minu:
2872 case arith::AtomicRMWKind::muli:
2873 case arith::AtomicRMWKind::ori:
2874 case arith::AtomicRMWKind::andi:
2875 if (!getValue().getType().isa<IntegerType>())
2876 return emitOpError() << "with kind '"
2877 << arith::stringifyAtomicRMWKind(getKind())
2878 << "' expects an integer type";
2879 break;
2880 default:
2881 break;
2882 }
2883 return success();
2884 }
2885
fold(ArrayRef<Attribute> operands)2886 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2887 /// atomicrmw(memrefcast) -> atomicrmw
2888 if (succeeded(foldMemRefCast(*this, getValue())))
2889 return getResult();
2890 return OpFoldResult();
2891 }
2892
2893 //===----------------------------------------------------------------------===//
2894 // TableGen'd op method definitions
2895 //===----------------------------------------------------------------------===//
2896
2897 #define GET_OP_CLASSES
2898 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2899