1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineExprVisitor.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/OpImplementation.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Transforms/LoopUtils.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 
37 #define DEBUG_TYPE "linalg-utils"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 using namespace mlir::scf;
42 
43 static bool isZero(Value v) {
44   if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
45     return cst.value() == 0;
46   return false;
47 }
48 
49 namespace {
50 
51 // Helper visitor to determine whether an AffineExpr is tiled.
52 // This is achieved by traversing every AffineDimExpr with position `pos` and
53 // checking whether the corresponding `tileSizes[pos]` is non-zero.
54 // This also enforces only positive coefficients occur in multiplications.
55 //
56 // Example:
57 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
58 //
59 struct TileCheck : public AffineExprVisitor<TileCheck> {
60   TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
61 
62   void visitDimExpr(AffineDimExpr expr) {
63     isTiled |= !isZero(tileSizes[expr.getPosition()]);
64   }
65   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
66     visit(expr.getLHS());
67     visit(expr.getRHS());
68     if (expr.getKind() == mlir::AffineExprKind::Mul)
69       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
70              "nonpositive multiplying coefficient");
71   }
72   bool isTiled;
73   ValueRange tileSizes;
74 };
75 
76 } // namespace
77 
78 static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
79   if (!expr)
80     return false;
81   TileCheck t(tileSizes);
82   t.visit(expr);
83   return t.isTiled;
84 }
85 
86 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
87 static bool isTiled(AffineMap map, ValueRange tileSizes) {
88   if (!map)
89     return false;
90   for (unsigned r = 0; r < map.getNumResults(); ++r)
91     if (isTiled(map.getResult(r), tileSizes))
92       return true;
93   return false;
94 }
95 
96 Optional<RegionMatcher::BinaryOpKind>
97 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
98   auto &region = op.region();
99   if (!llvm::hasSingleElement(region))
100     return llvm::None;
101 
102   Block &block = region.front();
103   if (block.getNumArguments() != 2 ||
104       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
105       !block.getArgument(1).getType().isSignlessIntOrFloat())
106     return llvm::None;
107 
108   auto &ops = block.getOperations();
109   if (!llvm::hasSingleElement(block.without_terminator()))
110     return llvm::None;
111 
112   using mlir::matchers::m_Val;
113   auto a = m_Val(block.getArgument(0));
114   auto b = m_Val(block.getArgument(1));
115 
116   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
117   if (addPattern.match(&ops.back()))
118     return BinaryOpKind::IAdd;
119 
120   return llvm::None;
121 }
122 
123 /// Explicit instantiation of loop nest generator for different loop types.
124 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
125 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
126 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
127 template struct mlir::linalg::GenerateLoopNest<TiledLoopOp>;
128 
129 /// Given a list of subview ranges, extract individual values for lower, upper
130 /// bounds and steps and put them into the corresponding vectors.
131 static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
132                          SmallVectorImpl<Value> &ubs,
133                          SmallVectorImpl<Value> &steps) {
134   for (Range range : ranges) {
135     lbs.emplace_back(range.offset);
136     ubs.emplace_back(range.size);
137     steps.emplace_back(range.stride);
138   }
139 }
140 
141 namespace mlir {
142 namespace linalg {
143 
144 bool isPermutation(ArrayRef<int64_t> permutation) {
145   // Count the number of appearances for all indices.
146   SmallVector<int64_t> indexCounts(permutation.size(), 0);
147   for (auto index : permutation) {
148     // Exit if the index is out-of-range.
149     if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
150       return false;
151     indexCounts[index]++;
152   }
153   // Return true if all indices appear once.
154   return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
155 }
156 
157 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
158 /// the type of `source`.
159 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
160   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
161     return b.createOrFold<memref::DimOp>(loc, source, dim);
162   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
163     return b.createOrFold<tensor::DimOp>(loc, source, dim);
164   llvm_unreachable("Expected MemRefType or TensorType");
165 }
166 
167 /// Given an operation, retrieves the value of each dynamic dimension through
168 /// constructing the necessary DimOp operators.
169 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) {
170   SmallVector<Value, 4> dynOperands;
171   auto shapedType = val.getType().cast<ShapedType>();
172   for (const auto &dim : llvm::enumerate(shapedType.getShape())) {
173     if (dim.value() == ShapedType::kDynamicSize)
174       dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index()));
175   }
176   return dynOperands;
177 }
178 
179 void getUpperBoundForIndex(Value value, AffineMap &boundMap,
180                            SmallVectorImpl<Value> &boundOperands) {
181   // Initialize `boundMap` and `boundOperands` to the identity returning
182   // `value`. This combination is the default result of the method if no
183   // simplification is possible.
184   assert(value.getType().isIndex() && "expect value to have index type");
185   boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext());
186   boundOperands.assign({value});
187   canonicalizeMapAndOperands(&boundMap, &boundOperands);
188 
189   // Continue only if there is an affine index computation to simplify.
190   Operation *definingOp = value.getDefiningOp();
191   if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp))
192     return;
193 
194   // Get the backward slice containing the affine index computation.
195   SetVector<Operation *> backwardSlice;
196   getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) {
197     return isa<AffineApplyOp, AffineMinOp>(op);
198   });
199   backwardSlice.insert(definingOp);
200 
201   // Setup a system of affine constraints that describe the index computation.
202   FlatAffineValueConstraints constraints;
203 
204   // Helper to find or create an identifier for the given value.
205   auto findOrCreateId = [&](Value value) {
206     if (!constraints.containsId(value)) {
207       constraints.appendDimId(value);
208       return true;
209     }
210     unsigned pos;
211     constraints.findId(value, &pos);
212     return pos < constraints.getNumDimIds();
213   };
214   // Helper to get the position for the given value.
215   auto getPosition = [&](Value value) {
216     unsigned pos;
217     bool exists = constraints.findId(value, &pos);
218     (void)exists;
219     assert(exists && "expect to find the identifier");
220     return pos;
221   };
222 
223   // Add the affine operations in `backwardSlice` to the constraints.
224   for (Operation *op : llvm::reverse(backwardSlice)) {
225     // Add an identifier for all op results and operands.
226     if (!(llvm::all_of(op->getResults(), findOrCreateId) &&
227           llvm::all_of(op->getOperands(), findOrCreateId)))
228       return;
229     // Add AffineApplyOps to the constraints.
230     if (auto applyOp = dyn_cast<AffineApplyOp>(op)) {
231       AffineValueMap valueMap(applyOp.getAffineMap(), applyOp.getOperands(),
232                               applyOp.getResult());
233       if (failed(constraints.composeMap(&valueMap)))
234         return;
235       continue;
236     }
237     // Add AffineMinOps to the constraints.
238     auto minOp = cast<AffineMinOp>(op);
239     AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(),
240                                                   minOp.getOperands());
241     if (failed(constraints.addBound(FlatAffineConstraints::UB,
242                                     getPosition(minOp.getResult()), map)))
243       return;
244   }
245 
246   // Obtain an upper bound for the affine index computation by projecting out
247   // all temporary results and expressing the upper bound for `value` in terms
248   // of the terminals of the index computation.
249   SmallVector<AffineMap> lowerBounds(1), upperBounds(1);
250   constraints.getSliceBounds(getPosition(value), 1, value.getContext(),
251                              &lowerBounds, &upperBounds);
252 
253   // Verify `upperBounds[0]` is valid and has at least one result.
254   if (!upperBounds[0] || upperBounds[0].getNumResults() == 0)
255     return;
256 
257   // Set `boundMap` and `boundOperands` to the computed upper bound.
258   boundMap = upperBounds[0];
259   constraints.getAllValues(&boundOperands);
260   erase_value(boundOperands, value);
261   canonicalizeMapAndOperands(&boundMap, &boundOperands);
262 }
263 
264 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) {
265   // Compute an upper bound for `value`.
266   AffineMap boundMap;
267   SmallVector<Value> boundOperands;
268   getUpperBoundForIndex(value, boundMap, boundOperands);
269 
270   // Search the results of `boundMap` for constant upper bounds.
271   SmallVector<int64_t> constantBounds;
272   for (AffineExpr result : boundMap.getResults())
273     if (auto constExpr = result.dyn_cast<AffineConstantExpr>())
274       constantBounds.push_back(constExpr.getValue());
275 
276   // Return the minimal upper bound or failure if none is found.
277   if (constantBounds.empty())
278     return failure();
279   return *std::min_element(constantBounds.begin(), constantBounds.end());
280 }
281 
282 tensor::ExtractSliceOp makeComposedExtractSliceOp(
283     OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
284     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
285   assert(source && "expect source to be nonzero");
286 
287   // Do not fold if the producer is not an ExtractSliceOp.
288   auto producerOp = source.getDefiningOp<tensor::ExtractSliceOp>();
289   if (!producerOp)
290     return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
291                                             strides);
292 
293   // Do not fold if the producer is rank reducing or if there are any non-unit
294   // strides. Supporting non-unit strides complicates the offset computation
295   // since the consumer offsets need to be multiplied by the producer strides.
296   // TODO: support non-unit strides once there are use cases.
297   SmallVector<OpFoldResult> allStrides = producerOp.getMixedStrides();
298   allStrides.append(strides.begin(), strides.end());
299   bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) {
300     return getConstantIntValue(ofr) != static_cast<int64_t>(1);
301   });
302   if (hasNonUnitStride ||
303       producerOp.getSourceType().getRank() !=
304           producerOp.getResult().getType().cast<ShapedType>().getRank())
305     return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
306                                             strides);
307 
308   // Fold the producer by adding the offests and extracting the slice directly
309   // from the producer source tensor.
310   SmallVector<OpFoldResult> foldedOffsets(offsets.begin(), offsets.end());
311   AffineExpr dim1, dim2;
312   bindDims(b.getContext(), dim1, dim2);
313   for (const auto &en : enumerate(producerOp.getMixedOffsets())) {
314     SmallVector<Value> offsetValues = {
315         getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]),
316         getValueOrCreateConstantIndexOp(b, loc, en.value())};
317     foldedOffsets[en.index()] =
318         makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult();
319   }
320   return b.create<tensor::ExtractSliceOp>(loc, producerOp.source(),
321                                           foldedOffsets, sizes, strides);
322 }
323 
324 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
325                             Value source, Value pad, bool nofold) {
326   assert(type.hasStaticShape() && "expect tensor type to have static shape");
327 
328   // Exit if `source` is not defined by an ExtractSliceOp.
329   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
330   if (!sliceOp)
331     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
332 
333   // Search the `source` use-def chain for padded LinalgOps.
334   Value current = sliceOp.source();
335   while (current) {
336     auto linalgOp = current.getDefiningOp<LinalgOp>();
337     if (!linalgOp)
338       break;
339     OpResult opResult = current.cast<OpResult>();
340     current = linalgOp.getOutputOperand(opResult.getResultNumber())->get();
341   }
342   auto padTensorOp = current ? current.getDefiningOp<PadTensorOp>() : nullptr;
343 
344   // Exit if the search fails to match a PadTensorOp at the end of the matched
345   // LinalgOp sequence.
346   if (!padTensorOp)
347     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
348 
349   // Exit if the padded result type does not match.
350   if (sliceOp.source().getType() != type)
351     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
352 
353   // Exit if the LinalgOps are not high padded.
354   if (llvm::any_of(padTensorOp.getMixedLowPad(), [](OpFoldResult ofr) {
355         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
356       }))
357     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
358 
359   // Exit if `padTensorOpSliceOp`, which defines the slice used by
360   // `padTensorOp`, is rank-reducing.
361   auto padTensorOpSliceOp =
362       padTensorOp.source().getDefiningOp<tensor::ExtractSliceOp>();
363   if (!padTensorOpSliceOp || sliceOp.getMixedSizes().size() !=
364                                  padTensorOpSliceOp.getMixedSizes().size())
365     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
366 
367   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
368   // of the slice padded by `padTensorOp`.
369   if (llvm::any_of(llvm::zip(sliceOp.getMixedSizes(),
370                              padTensorOpSliceOp.getMixedSizes()),
371                    [](std::tuple<OpFoldResult, OpFoldResult> it) {
372                      return !isEqualConstantIntOrValue(std::get<0>(it),
373                                                        std::get<1>(it));
374                    }))
375     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
376 
377   // Exit if the padding values do not match.
378   Attribute padTensorOpPadAttr, padAttr;
379   Value padTensorOpPad = padTensorOp.getConstantPaddingValue();
380   if (!padTensorOpPad ||
381       !matchPattern(padTensorOpPad, m_Constant(&padTensorOpPadAttr)) ||
382       !matchPattern(pad, m_Constant(&padAttr)) || padTensorOpPadAttr != padAttr)
383     return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
384 
385   // Return the padded result if the padding values and sizes match.
386   return sliceOp.source();
387 }
388 
389 /// Specialization to build an scf "for" nest.
390 template <>
391 void GenerateLoopNest<scf::ForOp>::doit(
392     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
393     ArrayRef<Attribute> iteratorTypes,
394     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
395                                   ValueRange)>
396         bodyBuilderFn,
397     Optional<LinalgLoopDistributionOptions> distributionOptions,
398     ArrayRef<StringRef> distributionTypes) {
399   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
400   // Create procInfo so it dominates loops, if appropriate.
401   SmallVector<ProcInfo, 4> procInfo;
402   SmallVector<DistributionMethod, 0> distributionMethod;
403   if (distributionOptions.hasValue()) {
404     // Collect loop ranges for parallel dimensions.
405     SmallVector<Range, 2> parallelLoopRanges;
406     for (const auto &iteratorType : enumerate(iteratorTypes))
407       if (isParallelIterator(iteratorType.value()))
408         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
409 
410     // Get their distribution schemes.
411     distributionMethod = distributionOptions->distributionMethod;
412     if (distributionMethod.size() < parallelLoopRanges.size())
413       parallelLoopRanges.resize(distributionMethod.size());
414     procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges);
415   }
416 
417   SmallVector<Value, 4> lbs, ubs, steps;
418   unpackRanges(loopRanges, lbs, ubs, steps);
419   LoopNest loopNest = mlir::scf::buildLoopNest(
420       b, loc, lbs, ubs, steps, iterArgInitValues,
421       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
422         assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
423                "expect the number of output tensors and iter args to match");
424         SmallVector<Value> operandValuesToUse =
425             linalgOp.getInputAndOutputOperands();
426         if (!iterArgs.empty()) {
427           operandValuesToUse = linalgOp.getInputOperands();
428           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
429         }
430         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
431       });
432 
433   if (!distributionOptions || loopNest.loops.empty())
434     return;
435 
436   // Filter out scf.for loops that were created out of parallel dimensions.
437   SmallVector<scf::ForOp, 4> loops;
438   for (const auto &iteratorType : enumerate(iteratorTypes))
439     if (isParallelIterator(iteratorType.value()))
440       loops.push_back(loopNest.loops[iteratorType.index()]);
441 
442   // Distribute - only supports cyclic distribution for now.
443   for (auto it : llvm::zip(loops, procInfo, distributionMethod))
444     if (std::get<2>(it) == DistributionMethod::Cyclic)
445       mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
446                             std::get<1>(it).nprocs);
447 }
448 
449 /// Specialization to build affine "for" nest.
450 template <>
451 void GenerateLoopNest<AffineForOp>::doit(
452     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
453     ArrayRef<Attribute> iteratorTypes,
454     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
455                                   ValueRange)>
456         bodyBuilderFn,
457     Optional<LinalgLoopDistributionOptions>, ArrayRef<StringRef>) {
458   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
459   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
460   SmallVector<Value, 4> lbs, ubs, steps;
461   unpackRanges(loopRanges, lbs, ubs, steps);
462 
463   // Affine loops require constant steps.
464   SmallVector<int64_t, 4> constantSteps;
465   constantSteps.reserve(steps.size());
466   for (Value v : steps) {
467     auto op = v.getDefiningOp<arith::ConstantIndexOp>();
468     assert(op && "Affine loops require constant steps");
469     constantSteps.push_back(op.value());
470   }
471 
472   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
473                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
474                               SmallVector<Value> operandValuesToUse =
475                                   linalgOp.getInputAndOutputOperands();
476                               bodyBuilderFn(b, loc, ivs, operandValuesToUse);
477                             });
478 }
479 
480 /// Specialization to build an linalg.tiled_loop
481 template <>
482 void GenerateLoopNest<TiledLoopOp>::doit(
483     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
484     ArrayRef<Attribute> iteratorTypes,
485     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
486                                   ValueRange)>
487         bodyBuilderFn,
488     Optional<LinalgLoopDistributionOptions> distributionOptions,
489     ArrayRef<StringRef> distributionTypes) {
490   SmallVector<ProcInfo, 2> procInfo;
491   SmallVector<Value, 4> lbs, ubs, steps;
492   unpackRanges(loopRanges, lbs, ubs, steps);
493 
494   auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
495                               ValueRange ivs, ValueRange inputs,
496                               ValueRange outputs) {
497     SmallVector<Value> operandValuesToUse = inputs;
498     operandValuesToUse.append(outputs.begin(), outputs.end());
499     scf::ValueVector results =
500         bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse);
501     nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
502   };
503 
504   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
505   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
506   auto tiledLoop =
507       b.create<TiledLoopOp>(loc, lbs, ubs, steps, inputOperands, outputOperands,
508                             b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
509   if (!distributionTypes.empty())
510     tiledLoop.setDistributionTypes(b, distributionTypes);
511 }
512 
513 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
514 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
515                                        Value nprocs, Value &lb, Value &ub,
516                                        Value &step) {
517   AffineExpr d0, d1;
518   bindDims(b.getContext(), d0, d1);
519   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
520   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
521   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
522 }
523 
524 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
525 /// on the `iteratorTypes.` Consecutive parallel loops create a single
526 /// scf.parallel operation; each sequential loop creates a new scf.for
527 /// operation. The body of the innermost loop is populated by
528 /// `bodyBuilderFn` that accepts a range of induction variables for all
529 /// loops. `ivStorage` is used to store the partial list of induction
530 /// variables.
531 // TODO: this function can be made iterative instead. However, it
532 // will have at most as many recursive calls as nested loops, which rarely
533 // exceeds 10.
534 static void generateParallelLoopNest(
535     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
536     ValueRange steps, ArrayRef<Attribute> iteratorTypes,
537     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
538     SmallVectorImpl<Value> &ivStorage,
539     ArrayRef<DistributionMethod> distributionMethod = {}) {
540   assert(lbs.size() == ubs.size());
541   assert(lbs.size() == steps.size());
542   assert(lbs.size() == iteratorTypes.size());
543 
544   // If there are no (more) loops to be generated, generate the body and be
545   // done with it.
546   if (iteratorTypes.empty()) {
547     bodyBuilderFn(b, loc, ivStorage);
548     return;
549   }
550 
551   // Find the outermost parallel loops and drop their types from the list.
552   unsigned nLoops = iteratorTypes.size();
553   unsigned nOuterPar =
554       nLoops - iteratorTypes.drop_while(isParallelIterator).size();
555 
556   // If there are no outer parallel loops, generate one sequential loop and
557   // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
558   // in this case.
559   if (nOuterPar == 0) {
560     LoopNest singleLoop = buildLoopNest(
561         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
562         [&](OpBuilder &b, Location loc, ValueRange ivs) {
563           ivStorage.append(ivs.begin(), ivs.end());
564           generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(),
565                                    steps.drop_front(),
566                                    iteratorTypes.drop_front(), bodyBuilderFn,
567                                    ivStorage, distributionMethod);
568         });
569     return;
570   }
571   if (distributionMethod.empty()) {
572     // Generate a single parallel loop-nest operation for all outermost
573     // parallel loops and recurse.
574     b.create<scf::ParallelOp>(
575         loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar),
576         steps.take_front(nOuterPar),
577         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
578           ivStorage.append(localIvs.begin(), localIvs.end());
579           generateParallelLoopNest(
580               nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar),
581               ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar),
582               iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage,
583               (distributionMethod.size() < nOuterPar)
584                   ? ArrayRef<DistributionMethod>()
585                   : distributionMethod.drop_front(nOuterPar));
586         });
587     return;
588   }
589 
590   // Process all consecutive similarly distributed loops simultaneously.
591   DistributionMethod methodToUse = distributionMethod[0];
592   unsigned numProcessed = 1;
593   for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) {
594     if (distributionMethod[i] != methodToUse)
595       break;
596     numProcessed++;
597   }
598 
599   switch (methodToUse) {
600   case DistributionMethod::Cyclic: {
601     // Generate a single parallel loop-nest operation for all outermost
602     // parallel loops and recurse.
603     b.create<scf::ParallelOp>(
604         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
605         steps.take_front(numProcessed),
606         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
607           ivStorage.append(localIvs.begin(), localIvs.end());
608           generateParallelLoopNest(
609               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
610               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
611               iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
612               (distributionMethod.size() < numProcessed)
613                   ? ArrayRef<DistributionMethod>()
614                   : distributionMethod.drop_front(numProcessed));
615         });
616     return;
617   }
618   case DistributionMethod::CyclicNumProcsGeNumIters: {
619     // Check (for the processed loops) that the iteration is in-bounds.
620     ArithBuilder ab(b, loc);
621     Value cond = ab.slt(lbs[0], ubs[0]);
622     for (unsigned i = 1; i < numProcessed; ++i)
623       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
624     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
625     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
626       generateParallelLoopNest(
627           b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
628           steps.drop_front(numProcessed),
629           iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
630           distributionMethod.drop_front(numProcessed));
631       b.create<scf::YieldOp>(loc, ValueRange{});
632     });
633     return;
634   }
635   case DistributionMethod::CyclicNumProcsEqNumIters:
636     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
637     // with inner loop generation.
638     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
639     generateParallelLoopNest(
640         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
641         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
642         bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed));
643     return;
644   }
645 }
646 
647 /// Specialization for generating a mix of parallel and sequential scf loops.
648 template <>
649 void GenerateLoopNest<scf::ParallelOp>::doit(
650     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
651     ArrayRef<Attribute> iteratorTypes,
652     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
653                                   ValueRange)>
654         bodyBuilderFn,
655     Optional<LinalgLoopDistributionOptions> distributionOptions,
656     ArrayRef<StringRef> distributionTypes) {
657   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
658   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
659   // This function may be passed more iterator types than ranges.
660   assert(iteratorTypes.size() >= loopRanges.size() &&
661          "expected iterator type for all ranges");
662   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
663   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
664   unsigned numLoops = iteratorTypes.size();
665   ivs.reserve(numLoops);
666   lbsStorage.reserve(numLoops);
667   ubsStorage.reserve(numLoops);
668   stepsStorage.reserve(numLoops);
669 
670   // Get the loop lb, ub, and step.
671   unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage);
672 
673   // Modify the lb, ub, and step based on the distribution options.
674   SmallVector<DistributionMethod, 0> distributionMethod;
675   if (distributionOptions) {
676     auto &options = distributionOptions.getValue();
677     distributionMethod.assign(distributionOptions->distributionMethod.begin(),
678                               distributionOptions->distributionMethod.end());
679     SmallVector<Range, 2> parallelLoopRanges;
680     for (const auto &iteratorType : enumerate(iteratorTypes)) {
681       if (isParallelIterator(iteratorType.value()))
682         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
683     }
684     if (distributionMethod.size() < parallelLoopRanges.size())
685       parallelLoopRanges.resize(distributionMethod.size());
686     SmallVector<ProcInfo, 2> procInfo =
687         options.procInfo(b, loc, parallelLoopRanges);
688     unsigned index = 0;
689     for (const auto &iteratorType : enumerate(iteratorTypes)) {
690       if (index >= procInfo.size())
691         break;
692       if (isParallelIterator(iteratorType.value())) {
693         unsigned i = iteratorType.index();
694         updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId,
695                                           procInfo[index].nprocs, lbsStorage[i],
696                                           ubsStorage[i], stepsStorage[i]);
697         index++;
698       }
699     }
700   }
701   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
702   generateParallelLoopNest(
703       b, loc, lbs, ubs, steps, iteratorTypes,
704       [&](OpBuilder &b, Location loc, ValueRange ivs) {
705         SmallVector<Value> operandValuesToUse =
706             linalgOp.getInputAndOutputOperands();
707         bodyBuilderFn(b, loc, ivs, operandValuesToUse);
708       },
709       ivs, distributionMethod);
710 
711   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
712 }
713 
714 static Value fullyComposeAndAffineApply(OpBuilder &b, Location loc,
715                                         AffineExpr expr, ValueRange operands) {
716   AffineMap map = AffineMap::inferFromExprList({expr}).front();
717   SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
718   mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
719   canonicalizeMapAndOperands(&map, &normalizedOperands);
720   return b.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
721 }
722 
723 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
724                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
725                      ValueRange ubs, ValueRange subShapeSizes) {
726   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
727   assert(shapedType && "only shaped types can be tiled");
728   ArrayRef<int64_t> shape = shapedType.getShape();
729   int64_t rank = shapedType.getRank();
730 
731   // Construct a new subview / extract_slice for the tile.
732   SmallVector<OpFoldResult, 4> offsets, sizes, strides;
733   offsets.reserve(rank);
734   sizes.reserve(rank);
735   strides.reserve(rank);
736   for (unsigned r = 0; r < rank; ++r) {
737     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r);
738     if (!isTiled(map.getSubMap({r}), tileSizes)) {
739       offsets.push_back(builder.getIndexAttr(0));
740       Value dim = createOrFoldDimOp(builder, loc, valueToTile, r);
741       sizes.push_back(getAsOpFoldResult(dim));
742       strides.push_back(builder.getIndexAttr(1));
743       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
744       continue;
745     }
746     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
747 
748     // Tiling creates a new slice at the proper index, the slice step is 1
749     // (i.e. the op does not subsample, stepping occurs in the loop).
750     auto m = map.getSubMap({r});
751     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
752     auto offset = applyMapToValues(builder, loc, m, lbs).front();
753     offsets.push_back(offset);
754     auto closedIntSize =
755         applyMapToValues(builder, loc, m, subShapeSizes).front();
756     // Resulting size needs to be made half open interval again.
757     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
758     Value size =
759         fullyComposeAndAffineApply(builder, loc, s0 + 1, closedIntSize);
760     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
761 
762     // The size of the subview / extract_slice should be trimmed to avoid
763     // out-of-bounds accesses, unless:
764     // a. We statically know the subshape size divides the shape size evenly.
765     // b. The subshape size is 1. According to the way the loops are set up,
766     //    tensors with "0" dimensions would never be constructed.
767     int64_t shapeSize = shape[r];
768     auto sizeCst = size.getDefiningOp<arith::ConstantIndexOp>();
769     auto hasTileSizeOne = sizeCst && sizeCst.value() == 1;
770     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
771                          ((shapeSize % sizeCst.value()) == 0);
772     if (!hasTileSizeOne && !dividesEvenly) {
773       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
774                               << ", size: " << size
775                               << ": make sure in bound with affine.min\n");
776 
777       AffineExpr dim0, dim1, dim2;
778       bindDims(builder.getContext(), dim0, dim1, dim2);
779 
780       // Get the dimension size for this dimension. We need to first calculate
781       // the max index and then plus one. This is important because for
782       // convolution ops, we have its input window dimension's affine map of the
783       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
784       // dimension and `s0` is stride. Directly use the dimension size of
785       // output/filer window dimensions will cause incorrect calculation.
786       AffineMap minusOneMap =
787           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
788               .front();
789       AffineMap plusOneMap =
790           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
791               .front();
792       auto maxIndices = llvm::to_vector<8>(llvm::map_range(ubs, [&](Value ub) {
793         return makeComposedAffineApply(builder, loc, minusOneMap, {ub})
794             .getResult();
795       }));
796       Value maxIndex = applyMapToValues(builder, loc, m, maxIndices).front();
797       Value d = makeComposedAffineApply(builder, loc, plusOneMap, {maxIndex});
798 
799       // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
800       AffineMap minMap = AffineMap::inferFromExprList(
801                              {ArrayRef<AffineExpr>{dim0, dim1 - dim2}})
802                              .front();
803       SmallVector<Value, 4> operands{size, d, offset};
804       fullyComposeAffineMapAndOperands(&minMap, &operands);
805       canonicalizeMapAndOperands(&minMap, &operands);
806       size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap,
807                                          operands);
808     }
809 
810     sizes.push_back(size);
811     LLVM_DEBUG(llvm::dbgs()
812                << "makeTiledShape: new offset: " << offset << "\n");
813     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
814     strides.push_back(builder.getIndexAttr(1));
815   }
816 
817   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
818                       .Case([&](MemRefType) {
819                         return builder.create<memref::SubViewOp>(
820                             loc, valueToTile, offsets, sizes, strides);
821                       })
822                       .Case([&](RankedTensorType) {
823                         return makeComposedExtractSliceOp(
824                             builder, loc, valueToTile, offsets, sizes, strides);
825                       })
826                       .Default([](ShapedType) -> Operation * {
827                         llvm_unreachable("Unexpected shaped type");
828                       });
829   return sliceOp->getResult(0);
830 }
831 
832 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
833                                       ValueRange ivs, ValueRange tileSizes) {
834   SmallVector<Value> offsets;
835   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
836     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
837     bool isTiled = !isZero(tileSizes[idx]);
838     offsets.push_back(
839         isTiled ? ivs[idxIvs++]
840                 : b.create<arith::ConstantIndexOp>(loc, 0).getResult());
841     LLVM_DEBUG(llvm::dbgs()
842                << "computeTileOffsets: " << offsets.back() << "\n");
843   }
844   return offsets;
845 }
846 
847 SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
848                                     ValueRange tileSizes,
849                                     ArrayRef<Value> sizeBounds) {
850   SmallVector<Value> sizes;
851   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
852     bool isTiled = !isZero(tileSizes[idx]);
853     // Before composing, we need to make range a closed interval.
854     Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
855     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
856     sizes.push_back(fullyComposeAndAffineApply(b, loc, d0 - 1, size));
857     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
858   }
859   return sizes;
860 }
861 
862 SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
863                                       LinalgOp linalgOp,
864                                       ArrayRef<Value> valuesToTile,
865                                       ValueRange ivs, ValueRange tileSizes,
866                                       ArrayRef<Value> sizeBounds) {
867   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
868                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
869                            [](Value v) { return !isZero(v); })) &&
870          "expected as many ivs as non-zero sizes");
871 
872   // Construct (potentially temporary) mins and maxes on which to apply maps
873   // that define tile subshapes.
874   SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
875   SmallVector<Value> subShapeSizes =
876       computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
877 
878   assert(static_cast<int64_t>(valuesToTile.size()) ==
879              linalgOp.getNumInputsAndOutputs() &&
880          "expected one value to tile for every operand");
881   SmallVector<Value, 4> tiledShapes;
882   tiledShapes.reserve(valuesToTile.size());
883   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
884     Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
885     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
886     AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
887     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
888     // an extract/insert slice pair for all output tensors simplifies follow up
889     // transformations such as padding and bufferization since the
890     // extract/insert slice pairs make the accessed iteration argument
891     // subdomains explicit.
892     if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
893       tiledShapes.push_back(shapedOp);
894       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
895                               << opOperand->get().getType() << "\n");
896       continue;
897     }
898     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
899 
900     tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
901                                          sizeBounds, subShapeSizes));
902   }
903 
904   return tiledShapes;
905 }
906 
907 void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
908                                     ArrayRef<Value> ivs) {
909   if (tiledOp.hasIndexSemantics()) {
910     for (IndexOp indexOp : tiledOp.getBlock()->getOps<IndexOp>()) {
911       if (ivs[indexOp.dim()] == nullptr)
912         continue;
913       OpBuilder::InsertionGuard guard(b);
914       b.setInsertionPointAfter(indexOp);
915       AffineExpr index, offset;
916       bindDims(b.getContext(), index, offset);
917       AffineApplyOp applyOp = makeComposedAffineApply(
918           b, indexOp.getLoc(), index + offset,
919           ValueRange{indexOp.getResult(), ivs[indexOp.dim()]});
920       indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
921     }
922   }
923 }
924 
925 } // namespace linalg
926 } // namespace mlir
927