1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineExprVisitor.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/LoopUtils.h"
30 #include "llvm/Support/Debug.h"
31 
32 #define DEBUG_TYPE "linalg-utils"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::scf;
37 
38 static bool isZero(Value v) {
39   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
40     return cst.getValue() == 0;
41   return false;
42 }
43 
44 namespace {
45 
46 // Helper visitor to determine whether an AffineExpr is tiled.
47 // This is achieved by traversing every AffineDimExpr with position `pos` and
48 // checking whether the corresponding `tileSizes[pos]` is non-zero.
49 // This also enforces only positive coefficients occur in multiplications.
50 //
51 // Example:
52 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
53 //
54 struct TileCheck : public AffineExprVisitor<TileCheck> {
55   TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
56 
57   void visitDimExpr(AffineDimExpr expr) {
58     isTiled |= !isZero(tileSizes[expr.getPosition()]);
59   }
60   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
61     visit(expr.getLHS());
62     visit(expr.getRHS());
63     if (expr.getKind() == mlir::AffineExprKind::Mul)
64       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
65              "nonpositive multiplying coefficient");
66   }
67   bool isTiled;
68   ValueRange tileSizes;
69 };
70 
71 } // namespace
72 
73 static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
74   if (!expr)
75     return false;
76   TileCheck t(tileSizes);
77   t.visit(expr);
78   return t.isTiled;
79 }
80 
81 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
82 static bool isTiled(AffineMap map, ValueRange tileSizes) {
83   if (!map)
84     return false;
85   for (unsigned r = 0; r < map.getNumResults(); ++r)
86     if (isTiled(map.getResult(r), tileSizes))
87       return true;
88   return false;
89 }
90 
91 Optional<RegionMatcher::BinaryOpKind>
92 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
93   auto &region = op.region();
94   if (!llvm::hasSingleElement(region))
95     return llvm::None;
96 
97   Block &block = region.front();
98   if (block.getNumArguments() != 2 ||
99       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
100       !block.getArgument(1).getType().isSignlessIntOrFloat())
101     return llvm::None;
102 
103   auto &ops = block.getOperations();
104   if (!llvm::hasSingleElement(block.without_terminator()))
105     return llvm::None;
106 
107   using mlir::matchers::m_Val;
108   auto a = m_Val(block.getArgument(0));
109   auto b = m_Val(block.getArgument(1));
110 
111   auto addPattern = m_Op<linalg::YieldOp>(m_Op<AddIOp>(a, b));
112   if (addPattern.match(&ops.back()))
113     return BinaryOpKind::IAdd;
114 
115   return llvm::None;
116 }
117 
118 bool mlir::linalg::isParallelIteratorType(Attribute attr) {
119   if (auto strAttr = attr.dyn_cast<StringAttr>()) {
120     return strAttr.getValue() == getParallelIteratorTypeName();
121   }
122   return false;
123 }
124 
125 bool mlir::linalg::isReductionIteratorType(Attribute attr) {
126   if (auto strAttr = attr.dyn_cast<StringAttr>()) {
127     return strAttr.getValue() == getReductionIteratorTypeName();
128   }
129   return false;
130 }
131 
132 bool mlir::linalg::isWindowIteratorType(Attribute attr) {
133   if (auto strAttr = attr.dyn_cast<StringAttr>()) {
134     return strAttr.getValue() == getWindowIteratorTypeName();
135   }
136   return false;
137 }
138 
139 /// Explicit instantiation of loop nest generator for different loop types.
140 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
141 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
142 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
143 template struct mlir::linalg::GenerateLoopNest<TiledLoopOp>;
144 
145 /// Given a list of subview ranges, extract individual values for lower, upper
146 /// bounds and steps and put them into the corresponding vectors.
147 static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
148                          SmallVectorImpl<Value> &ubs,
149                          SmallVectorImpl<Value> &steps) {
150   for (Range range : ranges) {
151     lbs.emplace_back(range.offset);
152     ubs.emplace_back(range.size);
153     steps.emplace_back(range.stride);
154   }
155 }
156 
157 namespace mlir {
158 namespace linalg {
159 
160 /// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
161 /// is a constant then return a new value set to the smallest such constant.
162 /// Otherwise returngetSmallestBoundingIndex nullptr.
163 IntegerAttr getSmallestBoundingIndex(Value size) {
164   Optional<int64_t> boundingConst = {};
165   if (auto affineMinOp = size.getDefiningOp<AffineMinOp>()) {
166     for (auto e : affineMinOp.getAffineMap().getResults())
167       if (auto cst = e.dyn_cast<AffineConstantExpr>())
168         boundingConst = boundingConst
169                             ? std::min(boundingConst.getValue(), cst.getValue())
170                             : cst.getValue();
171   } else if (auto constIndexOp = size.getDefiningOp<ConstantOp>()) {
172     if (constIndexOp.getType().isa<IndexType>())
173       boundingConst = constIndexOp.value().cast<IntegerAttr>().getInt();
174   } else if (auto affineApplyOp = size.getDefiningOp<AffineApplyOp>()) {
175     if (auto cExpr = affineApplyOp.getAffineMap()
176                          .getResult(0)
177                          .dyn_cast<AffineConstantExpr>())
178       boundingConst = cExpr.getValue();
179   } else if (auto dimOp = size.getDefiningOp<memref::DimOp>()) {
180     auto shape = dimOp.memrefOrTensor().getType().dyn_cast<ShapedType>();
181     if (auto constOp = dimOp.index().getDefiningOp<ConstantOp>()) {
182       if (auto indexAttr = constOp.value().dyn_cast<IntegerAttr>()) {
183         auto dimIndex = indexAttr.getInt();
184         if (!shape.isDynamicDim(dimIndex)) {
185           boundingConst = shape.getShape()[dimIndex];
186         }
187       }
188     }
189   }
190   if (boundingConst && *boundingConst >= 0)
191     return Builder(size.getContext()).getIndexAttr(*boundingConst);
192   return nullptr;
193 }
194 
195 /// Specialization to build an scf "for" nest.
196 template <>
197 void GenerateLoopNest<scf::ForOp>::doit(
198     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
199     ArrayRef<Attribute> iteratorTypes,
200     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
201                                   ValueRange)>
202         bodyBuilderFn,
203     Optional<LinalgLoopDistributionOptions> distributionOptions,
204     ArrayRef<StringRef> distributionTypes) {
205   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
206   // Create procInfo so it dominates loops, if appropriate.
207   SmallVector<ProcInfo, 4> procInfo;
208   SmallVector<DistributionMethod, 0> distributionMethod;
209   if (distributionOptions.hasValue()) {
210     // Collect loop ranges for parallel dimensions.
211     SmallVector<Range, 2> parallelLoopRanges;
212     for (auto iteratorType : enumerate(iteratorTypes))
213       if (isParallelIteratorType(iteratorType.value()))
214         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
215 
216     // Get their distribution schemes.
217     distributionMethod = distributionOptions->distributionMethod;
218     if (distributionMethod.size() < parallelLoopRanges.size())
219       parallelLoopRanges.resize(distributionMethod.size());
220     procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges);
221   }
222 
223   SmallVector<Value, 4> lbs, ubs, steps;
224   unpackRanges(loopRanges, lbs, ubs, steps);
225   LoopNest loopNest = mlir::scf::buildLoopNest(
226       b, loc, lbs, ubs, steps, iterArgInitValues, bodyBuilderFn);
227 
228   if (!distributionOptions || loopNest.loops.empty())
229     return;
230 
231   // Filter out scf.for loops that were created out of parallel dimensions.
232   SmallVector<scf::ForOp, 4> loops;
233   for (auto iteratorType : enumerate(iteratorTypes))
234     if (isParallelIteratorType(iteratorType.value()))
235       loops.push_back(loopNest.loops[iteratorType.index()]);
236 
237   // Distribute - only supports cyclic distribution for now.
238   for (auto it : llvm::zip(loops, procInfo, distributionMethod))
239     if (std::get<2>(it) == DistributionMethod::Cyclic)
240       mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
241                             std::get<1>(it).nprocs);
242 }
243 
244 /// Specialization to build affine "for" nest.
245 template <>
246 void GenerateLoopNest<AffineForOp>::doit(
247     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
248     ArrayRef<Attribute> iteratorTypes,
249     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
250                                   ValueRange)>
251         bodyBuilderFn,
252     Optional<LinalgLoopDistributionOptions>, ArrayRef<StringRef>) {
253   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
254   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
255   SmallVector<Value, 4> lbs, ubs, steps;
256   unpackRanges(loopRanges, lbs, ubs, steps);
257 
258   // Affine loops require constant steps.
259   SmallVector<int64_t, 4> constantSteps;
260   constantSteps.reserve(steps.size());
261   for (Value v : steps) {
262     auto op = v.getDefiningOp<ConstantIndexOp>();
263     assert(op && "Affine loops require constant steps");
264     constantSteps.push_back(op.getValue());
265   }
266 
267   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
268                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
269                               bodyBuilderFn(b, loc, ivs, {});
270                             });
271 }
272 
273 /// Specialization to build an linalg.tiled_loop
274 template <>
275 void GenerateLoopNest<TiledLoopOp>::doit(
276     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
277     ArrayRef<Attribute> iteratorTypes,
278     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
279                                   ValueRange)>
280         bodyBuilderFn,
281     Optional<LinalgLoopDistributionOptions> distributionOptions,
282     ArrayRef<StringRef> distributionTypes) {
283   SmallVector<ProcInfo, 2> procInfo;
284   SmallVector<Value, 4> lbs, ubs, steps;
285   unpackRanges(loopRanges, lbs, ubs, steps);
286 
287   auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
288                               ValueRange ivs, ValueRange inputs,
289                               ValueRange outputs) {
290     SmallVector<Value> outputTensors = linalgOp.getOutputTensorOperands();
291     scf::ValueVector results =
292         bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors);
293     nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
294   };
295 
296   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
297   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
298   auto tiledLoop =
299       b.create<TiledLoopOp>(loc, lbs, ubs, steps, inputOperands, outputOperands,
300                             b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
301   if (!distributionTypes.empty())
302     tiledLoop.setDistributionTypes(b, distributionTypes);
303 
304   // Replace inputs/outputs with the corresponding region args.
305   auto isInsideTiledLoop = [&](OpOperand &operand) {
306     return operand.getOwner()->getBlock() == tiledLoop.getBody();
307   };
308   for (auto it : llvm::zip(inputOperands, tiledLoop.getRegionInputArgs()))
309     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
310   for (auto it : llvm::zip(outputOperands, tiledLoop.getRegionOutputArgs()))
311     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
312 }
313 
314 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
315 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
316                                        Value nprocs, Value &lb, Value &ub,
317                                        Value &step) {
318   AffineExpr d0, d1;
319   bindDims(b.getContext(), d0, d1);
320   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
321   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
322   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
323 }
324 
325 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
326 /// on the `iteratorTypes.` Consecutive parallel loops create a single
327 /// scf.parallel operation; each sequential loop creates a new scf.for
328 /// operation. The body of the innermost loop is populated by
329 /// `bodyBuilderFn` that accepts a range of induction variables for all
330 /// loops. `ivStorage` is used to store the partial list of induction
331 /// variables.
332 // TODO: this function can be made iterative instead. However, it
333 // will have at most as many recursive calls as nested loops, which rarely
334 // exceeds 10.
335 static void generateParallelLoopNest(
336     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
337     ValueRange steps, ArrayRef<Attribute> iteratorTypes,
338     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
339     SmallVectorImpl<Value> &ivStorage,
340     ArrayRef<DistributionMethod> distributionMethod = {}) {
341   assert(lbs.size() == ubs.size());
342   assert(lbs.size() == steps.size());
343   assert(lbs.size() == iteratorTypes.size());
344 
345   // If there are no (more) loops to be generated, generate the body and be
346   // done with it.
347   if (iteratorTypes.empty()) {
348     bodyBuilderFn(b, loc, ivStorage);
349     return;
350   }
351 
352   // Find the outermost parallel loops and drop their types from the list.
353   unsigned nLoops = iteratorTypes.size();
354   unsigned nOuterPar =
355       nLoops - iteratorTypes.drop_while(isParallelIteratorType).size();
356 
357   // If there are no outer parallel loops, generate one sequential loop and
358   // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
359   // in this case.
360   if (nOuterPar == 0) {
361     LoopNest singleLoop = buildLoopNest(
362         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
363         [&](OpBuilder &b, Location loc, ValueRange ivs) {
364           ivStorage.append(ivs.begin(), ivs.end());
365           generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(),
366                                    steps.drop_front(),
367                                    iteratorTypes.drop_front(), bodyBuilderFn,
368                                    ivStorage, distributionMethod);
369         });
370     return;
371   }
372   if (distributionMethod.empty()) {
373     // Generate a single parallel loop-nest operation for all outermost
374     // parallel loops and recurse.
375     b.create<scf::ParallelOp>(
376         loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar),
377         steps.take_front(nOuterPar),
378         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
379           ivStorage.append(localIvs.begin(), localIvs.end());
380           generateParallelLoopNest(
381               nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar),
382               ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar),
383               iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage,
384               (distributionMethod.size() < nOuterPar)
385                   ? ArrayRef<DistributionMethod>()
386                   : distributionMethod.drop_front(nOuterPar));
387         });
388     return;
389   }
390 
391   // Process all consecutive similarly distributed loops simultaneously.
392   DistributionMethod methodToUse = distributionMethod[0];
393   unsigned numProcessed = 1;
394   for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) {
395     if (distributionMethod[i] != methodToUse)
396       break;
397     numProcessed++;
398   }
399 
400   switch (methodToUse) {
401   case DistributionMethod::Cyclic: {
402     // Generate a single parallel loop-nest operation for all outermost
403     // parallel loops and recurse.
404     b.create<scf::ParallelOp>(
405         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
406         steps.take_front(numProcessed),
407         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
408           ivStorage.append(localIvs.begin(), localIvs.end());
409           generateParallelLoopNest(
410               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
411               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
412               iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
413               (distributionMethod.size() < numProcessed)
414                   ? ArrayRef<DistributionMethod>()
415                   : distributionMethod.drop_front(numProcessed));
416         });
417     return;
418   }
419   case DistributionMethod::CyclicNumProcsGeNumIters: {
420     // Check (for the processed loops) that the iteration is in-bounds.
421     ArithBuilder ab(b, loc);
422     Value cond = ab.slt(lbs[0], ubs[0]);
423     for (unsigned i = 1; i < numProcessed; ++i)
424       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
425     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
426     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
427       generateParallelLoopNest(
428           b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
429           steps.drop_front(numProcessed),
430           iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
431           distributionMethod.drop_front(numProcessed));
432       b.create<scf::YieldOp>(loc, ValueRange{});
433     });
434     return;
435   }
436   case DistributionMethod::CyclicNumProcsEqNumIters:
437     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
438     // with inner loop generation.
439     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
440     generateParallelLoopNest(
441         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
442         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
443         bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed));
444     return;
445   }
446 }
447 
448 /// Specialization for generating a mix of parallel and sequential scf loops.
449 template <>
450 void GenerateLoopNest<scf::ParallelOp>::doit(
451     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
452     ArrayRef<Attribute> iteratorTypes,
453     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
454                                   ValueRange)>
455         bodyBuilderFn,
456     Optional<LinalgLoopDistributionOptions> distributionOptions,
457     ArrayRef<StringRef> distributionTypes) {
458   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
459   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
460   // This function may be passed more iterator types than ranges.
461   assert(iteratorTypes.size() >= loopRanges.size() &&
462          "expected iterator type for all ranges");
463   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
464   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
465   unsigned numLoops = iteratorTypes.size();
466   ivs.reserve(numLoops);
467   lbsStorage.reserve(numLoops);
468   ubsStorage.reserve(numLoops);
469   stepsStorage.reserve(numLoops);
470 
471   // Get the loop lb, ub, and step.
472   unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage);
473 
474   // Modify the lb, ub, and step based on the distribution options.
475   SmallVector<DistributionMethod, 0> distributionMethod;
476   if (distributionOptions) {
477     auto &options = distributionOptions.getValue();
478     distributionMethod.assign(distributionOptions->distributionMethod.begin(),
479                               distributionOptions->distributionMethod.end());
480     SmallVector<Range, 2> parallelLoopRanges;
481     for (auto iteratorType : enumerate(iteratorTypes)) {
482       if (isParallelIteratorType(iteratorType.value()))
483         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
484     }
485     if (distributionMethod.size() < parallelLoopRanges.size())
486       parallelLoopRanges.resize(distributionMethod.size());
487     SmallVector<ProcInfo, 2> procInfo =
488         options.procInfo(b, loc, parallelLoopRanges);
489     unsigned index = 0;
490     for (auto iteratorType : enumerate(iteratorTypes)) {
491       if (index >= procInfo.size())
492         break;
493       if (isParallelIteratorType(iteratorType.value())) {
494         unsigned i = iteratorType.index();
495         updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId,
496                                           procInfo[index].nprocs, lbsStorage[i],
497                                           ubsStorage[i], stepsStorage[i]);
498         index++;
499       }
500     }
501   }
502   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
503   generateParallelLoopNest(
504       b, loc, lbs, ubs, steps, iteratorTypes,
505       [&](OpBuilder &b, Location loc, ValueRange ivs) {
506         bodyBuilderFn(b, loc, ivs, {});
507       },
508       ivs, distributionMethod);
509 
510   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
511 }
512 
513 SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
514                                       LinalgOp linalgOp,
515                                       ArrayRef<Value> valuesToTile,
516                                       ValueRange ivs, ValueRange tileSizes,
517                                       ArrayRef<Value> sizeBounds) {
518   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
519                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
520                            [](Value v) { return !isZero(v); })) &&
521          "expected as many ivs as non-zero sizes");
522 
523   // Construct (potentially temporary) mins and maxes on which to apply maps
524   // that define tile subshapes.
525   SmallVector<Value, 8> lbs, subShapeSizes;
526   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
527     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
528     bool isTiled = !isZero(tileSizes[idx]);
529     lbs.push_back(isTiled ? ivs[idxIvs++]
530                           : (Value)b.create<ConstantIndexOp>(loc, 0));
531     // Before composing, we need to make range a closed interval.
532     Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
533     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
534     subShapeSizes.push_back(makeComposedAffineApply(b, loc, d0 - 1, size));
535     LLVM_DEBUG(llvm::dbgs() << "lb: " << lbs.back() << "\n");
536     LLVM_DEBUG(llvm::dbgs() << "size: " << subShapeSizes.back() << "\n");
537   }
538 
539   assert(static_cast<int64_t>(valuesToTile.size()) ==
540              linalgOp.getNumInputsAndOutputs() &&
541          "expected one value to tile for every operand");
542   MLIRContext *context = b.getContext();
543   SmallVector<Value, 4> tiledShapes;
544   tiledShapes.reserve(valuesToTile.size());
545   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
546     Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
547     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
548     int64_t rank = linalgOp.getRank(opOperand);
549     ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
550     AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
551     // If the shape is not tiled, we can use it as is.
552     if (!isTiled(map, tileSizes)) {
553       tiledShapes.push_back(shapedOp);
554       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
555                               << opOperand->get().getType() << "\n");
556       continue;
557     }
558     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
559 
560     // Construct a new subview / extract_slice for the tile.
561     SmallVector<OpFoldResult, 4> offsets, sizes, strides;
562     offsets.reserve(rank);
563     sizes.reserve(rank);
564     strides.reserve(rank);
565     for (unsigned r = 0; r < rank; ++r) {
566       LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for dim#" << r);
567       if (!isTiled(map.getSubMap({r}), tileSizes)) {
568         offsets.push_back(b.getIndexAttr(0));
569         Value dim = b.createOrFold<memref::DimOp>(loc, shapedOp, r);
570         sizes.push_back(dim);
571         strides.push_back(b.getIndexAttr(1));
572         LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
573         continue;
574       }
575       LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
576 
577       // Tiling creates a new slice at the proper index, the slice step is 1
578       // (i.e. the op does not subsample, stepping occurs in the loop).
579       auto m = map.getSubMap({r});
580       LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: submap: " << map << "\n");
581       auto offset = applyMapToValues(b, loc, m, lbs).front();
582       offsets.push_back(offset);
583       auto closedIntSize = applyMapToValues(b, loc, m, subShapeSizes).front();
584       // Resulting size needs to be made half open interval again.
585       AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
586       Value size = makeComposedAffineApply(b, loc, s0 + 1, closedIntSize);
587       LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: raw size: " << size << "\n");
588 
589       // The size of the subview / extract_slice should be trimmed to avoid
590       // out-of-bounds accesses, unless we statically know the subshape size
591       // divides the shape size evenly.
592       int64_t shapeSize = shape[r];
593       auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
594       if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
595           (shapeSize % sizeCst.getValue()) != 0) {
596         LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: shapeSize=" << shapeSize
597                                 << ", size: " << size
598                                 << ": make sure in bound with affine.min\n");
599         AffineExpr dim0, dim1, dim2;
600         bindDims(context, dim0, dim1, dim2);
601         // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
602         AffineMap minMap =
603             AffineMap::inferFromExprList(
604                 ArrayRef<ArrayRef<AffineExpr>>{{dim0, dim1 - dim2}})
605                 .front();
606         Value d = b.create<memref::DimOp>(loc, shapedOp, r);
607         SmallVector<Value, 4> operands{size, d, offset};
608         fullyComposeAffineMapAndOperands(&minMap, &operands);
609         size = b.create<AffineMinOp>(loc, b.getIndexType(), minMap, operands);
610       }
611 
612       sizes.push_back(size);
613       LLVM_DEBUG(llvm::dbgs()
614                  << "makeTiledShapes: new offset: " << offset << "\n");
615       LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: new size: " << size << "\n");
616       strides.push_back(b.getIndexAttr(1));
617     }
618 
619     if (opOperand->get().getType().isa<MemRefType>())
620       tiledShapes.push_back(
621           b.create<memref::SubViewOp>(loc, shapedOp, offsets, sizes, strides));
622     else
623       tiledShapes.push_back(b.create<tensor::ExtractSliceOp>(
624           loc, shapedOp, offsets, sizes, strides));
625   }
626 
627   return tiledShapes;
628 }
629 
630 } // namespace linalg
631 } // namespace mlir
632