1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/StaticValueUtils.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineExprVisitor.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/OpImplementation.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/LoopUtils.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "linalg-utils"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 using namespace mlir::scf;
38 
39 static bool isZero(Value v) {
40   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
41     return cst.getValue() == 0;
42   return false;
43 }
44 
45 namespace {
46 
47 // Helper visitor to determine whether an AffineExpr is tiled.
48 // This is achieved by traversing every AffineDimExpr with position `pos` and
49 // checking whether the corresponding `tileSizes[pos]` is non-zero.
50 // This also enforces only positive coefficients occur in multiplications.
51 //
52 // Example:
53 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
54 //
55 struct TileCheck : public AffineExprVisitor<TileCheck> {
56   TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
57 
58   void visitDimExpr(AffineDimExpr expr) {
59     isTiled |= !isZero(tileSizes[expr.getPosition()]);
60   }
61   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
62     visit(expr.getLHS());
63     visit(expr.getRHS());
64     if (expr.getKind() == mlir::AffineExprKind::Mul)
65       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
66              "nonpositive multiplying coefficient");
67   }
68   bool isTiled;
69   ValueRange tileSizes;
70 };
71 
72 } // namespace
73 
74 static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
75   if (!expr)
76     return false;
77   TileCheck t(tileSizes);
78   t.visit(expr);
79   return t.isTiled;
80 }
81 
82 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
83 static bool isTiled(AffineMap map, ValueRange tileSizes) {
84   if (!map)
85     return false;
86   for (unsigned r = 0; r < map.getNumResults(); ++r)
87     if (isTiled(map.getResult(r), tileSizes))
88       return true;
89   return false;
90 }
91 
92 Optional<RegionMatcher::BinaryOpKind>
93 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
94   auto &region = op.region();
95   if (!llvm::hasSingleElement(region))
96     return llvm::None;
97 
98   Block &block = region.front();
99   if (block.getNumArguments() != 2 ||
100       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
101       !block.getArgument(1).getType().isSignlessIntOrFloat())
102     return llvm::None;
103 
104   auto &ops = block.getOperations();
105   if (!llvm::hasSingleElement(block.without_terminator()))
106     return llvm::None;
107 
108   using mlir::matchers::m_Val;
109   auto a = m_Val(block.getArgument(0));
110   auto b = m_Val(block.getArgument(1));
111 
112   auto addPattern = m_Op<linalg::YieldOp>(m_Op<AddIOp>(a, b));
113   if (addPattern.match(&ops.back()))
114     return BinaryOpKind::IAdd;
115 
116   return llvm::None;
117 }
118 
119 /// Explicit instantiation of loop nest generator for different loop types.
120 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
121 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
122 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
123 template struct mlir::linalg::GenerateLoopNest<TiledLoopOp>;
124 
125 /// Given a list of subview ranges, extract individual values for lower, upper
126 /// bounds and steps and put them into the corresponding vectors.
127 static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
128                          SmallVectorImpl<Value> &ubs,
129                          SmallVectorImpl<Value> &steps) {
130   for (Range range : ranges) {
131     lbs.emplace_back(range.offset);
132     ubs.emplace_back(range.size);
133     steps.emplace_back(range.stride);
134   }
135 }
136 
137 namespace mlir {
138 namespace linalg {
139 
140 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
141 /// the type of `source`.
142 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
143   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
144     return b.createOrFold<memref::DimOp>(loc, source, dim);
145   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
146     return b.createOrFold<tensor::DimOp>(loc, source, dim);
147   llvm_unreachable("Expected MemRefType or TensorType");
148 }
149 
150 /// Given an operation, retrieves the value of each dynamic dimension through
151 /// constructing the necessary DimOp operators.
152 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) {
153   SmallVector<Value, 4> dynOperands;
154   auto shapedType = val.getType().cast<ShapedType>();
155   for (auto dim : llvm::enumerate(shapedType.getShape())) {
156     if (dim.value() == ShapedType::kDynamicSize)
157       dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index()));
158   }
159   return dynOperands;
160 }
161 
162 /// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
163 /// is a constant then return a new value set to the smallest such constant.
164 /// Otherwise returngetSmallestBoundingIndex nullptr.
165 IntegerAttr getSmallestBoundingIndex(Value size) {
166   Optional<int64_t> boundingConst = {};
167   if (auto affineMinOp = size.getDefiningOp<AffineMinOp>()) {
168     for (auto e : affineMinOp.getAffineMap().getResults())
169       if (auto cst = e.dyn_cast<AffineConstantExpr>())
170         boundingConst = boundingConst
171                             ? std::min(boundingConst.getValue(), cst.getValue())
172                             : cst.getValue();
173   } else if (auto constIndexOp = size.getDefiningOp<ConstantOp>()) {
174     if (constIndexOp.getType().isa<IndexType>())
175       boundingConst = constIndexOp.value().cast<IntegerAttr>().getInt();
176   } else if (auto affineApplyOp = size.getDefiningOp<AffineApplyOp>()) {
177     if (auto cExpr = affineApplyOp.getAffineMap()
178                          .getResult(0)
179                          .dyn_cast<AffineConstantExpr>())
180       boundingConst = cExpr.getValue();
181   } else if (auto dimOp = size.getDefiningOp<tensor::DimOp>()) {
182     auto shape = dimOp.source().getType().dyn_cast<ShapedType>();
183     if (auto constOp = dimOp.index().getDefiningOp<ConstantOp>()) {
184       if (auto indexAttr = constOp.value().dyn_cast<IntegerAttr>()) {
185         auto dimIndex = indexAttr.getInt();
186         if (!shape.isDynamicDim(dimIndex)) {
187           boundingConst = shape.getShape()[dimIndex];
188         }
189       }
190     }
191   }
192   if (boundingConst && *boundingConst >= 0)
193     return Builder(size.getContext()).getIndexAttr(*boundingConst);
194   return nullptr;
195 }
196 
197 /// Specialization to build an scf "for" nest.
198 template <>
199 void GenerateLoopNest<scf::ForOp>::doit(
200     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
201     ArrayRef<Attribute> iteratorTypes,
202     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
203                                   ValueRange)>
204         bodyBuilderFn,
205     Optional<LinalgLoopDistributionOptions> distributionOptions,
206     ArrayRef<StringRef> distributionTypes) {
207   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
208   // Create procInfo so it dominates loops, if appropriate.
209   SmallVector<ProcInfo, 4> procInfo;
210   SmallVector<DistributionMethod, 0> distributionMethod;
211   if (distributionOptions.hasValue()) {
212     // Collect loop ranges for parallel dimensions.
213     SmallVector<Range, 2> parallelLoopRanges;
214     for (auto iteratorType : enumerate(iteratorTypes))
215       if (isParallelIterator(iteratorType.value()))
216         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
217 
218     // Get their distribution schemes.
219     distributionMethod = distributionOptions->distributionMethod;
220     if (distributionMethod.size() < parallelLoopRanges.size())
221       parallelLoopRanges.resize(distributionMethod.size());
222     procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges);
223   }
224 
225   SmallVector<Value, 4> lbs, ubs, steps;
226   unpackRanges(loopRanges, lbs, ubs, steps);
227   LoopNest loopNest = mlir::scf::buildLoopNest(
228       b, loc, lbs, ubs, steps, iterArgInitValues, bodyBuilderFn);
229 
230   if (!distributionOptions || loopNest.loops.empty())
231     return;
232 
233   // Filter out scf.for loops that were created out of parallel dimensions.
234   SmallVector<scf::ForOp, 4> loops;
235   for (auto iteratorType : enumerate(iteratorTypes))
236     if (isParallelIterator(iteratorType.value()))
237       loops.push_back(loopNest.loops[iteratorType.index()]);
238 
239   // Distribute - only supports cyclic distribution for now.
240   for (auto it : llvm::zip(loops, procInfo, distributionMethod))
241     if (std::get<2>(it) == DistributionMethod::Cyclic)
242       mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
243                             std::get<1>(it).nprocs);
244 }
245 
246 /// Specialization to build affine "for" nest.
247 template <>
248 void GenerateLoopNest<AffineForOp>::doit(
249     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
250     ArrayRef<Attribute> iteratorTypes,
251     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
252                                   ValueRange)>
253         bodyBuilderFn,
254     Optional<LinalgLoopDistributionOptions>, ArrayRef<StringRef>) {
255   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
256   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
257   SmallVector<Value, 4> lbs, ubs, steps;
258   unpackRanges(loopRanges, lbs, ubs, steps);
259 
260   // Affine loops require constant steps.
261   SmallVector<int64_t, 4> constantSteps;
262   constantSteps.reserve(steps.size());
263   for (Value v : steps) {
264     auto op = v.getDefiningOp<ConstantIndexOp>();
265     assert(op && "Affine loops require constant steps");
266     constantSteps.push_back(op.getValue());
267   }
268 
269   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
270                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
271                               bodyBuilderFn(b, loc, ivs, {});
272                             });
273 }
274 
275 /// Specialization to build an linalg.tiled_loop
276 template <>
277 void GenerateLoopNest<TiledLoopOp>::doit(
278     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
279     ArrayRef<Attribute> iteratorTypes,
280     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
281                                   ValueRange)>
282         bodyBuilderFn,
283     Optional<LinalgLoopDistributionOptions> distributionOptions,
284     ArrayRef<StringRef> distributionTypes) {
285   SmallVector<ProcInfo, 2> procInfo;
286   SmallVector<Value, 4> lbs, ubs, steps;
287   unpackRanges(loopRanges, lbs, ubs, steps);
288 
289   auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
290                               ValueRange ivs, ValueRange inputs,
291                               ValueRange outputs) {
292     SmallVector<Value> outputTensors = linalgOp.getOutputTensorOperands();
293     scf::ValueVector results =
294         bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors);
295     nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
296   };
297 
298   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
299   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
300   auto tiledLoop =
301       b.create<TiledLoopOp>(loc, lbs, ubs, steps, inputOperands, outputOperands,
302                             b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
303   if (!distributionTypes.empty())
304     tiledLoop.setDistributionTypes(b, distributionTypes);
305 
306   // Replace inputs/outputs with the corresponding region args.
307   auto isInsideTiledLoop = [&](OpOperand &operand) {
308     return operand.getOwner()->getBlock() == tiledLoop.getBody();
309   };
310   for (auto it : llvm::zip(inputOperands, tiledLoop.getRegionInputArgs()))
311     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
312   for (auto it : llvm::zip(outputOperands, tiledLoop.getRegionOutputArgs()))
313     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
314 }
315 
316 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
317 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
318                                        Value nprocs, Value &lb, Value &ub,
319                                        Value &step) {
320   AffineExpr d0, d1;
321   bindDims(b.getContext(), d0, d1);
322   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
323   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
324   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
325 }
326 
327 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
328 /// on the `iteratorTypes.` Consecutive parallel loops create a single
329 /// scf.parallel operation; each sequential loop creates a new scf.for
330 /// operation. The body of the innermost loop is populated by
331 /// `bodyBuilderFn` that accepts a range of induction variables for all
332 /// loops. `ivStorage` is used to store the partial list of induction
333 /// variables.
334 // TODO: this function can be made iterative instead. However, it
335 // will have at most as many recursive calls as nested loops, which rarely
336 // exceeds 10.
337 static void generateParallelLoopNest(
338     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
339     ValueRange steps, ArrayRef<Attribute> iteratorTypes,
340     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
341     SmallVectorImpl<Value> &ivStorage,
342     ArrayRef<DistributionMethod> distributionMethod = {}) {
343   assert(lbs.size() == ubs.size());
344   assert(lbs.size() == steps.size());
345   assert(lbs.size() == iteratorTypes.size());
346 
347   // If there are no (more) loops to be generated, generate the body and be
348   // done with it.
349   if (iteratorTypes.empty()) {
350     bodyBuilderFn(b, loc, ivStorage);
351     return;
352   }
353 
354   // Find the outermost parallel loops and drop their types from the list.
355   unsigned nLoops = iteratorTypes.size();
356   unsigned nOuterPar =
357       nLoops - iteratorTypes.drop_while(isParallelIterator).size();
358 
359   // If there are no outer parallel loops, generate one sequential loop and
360   // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
361   // in this case.
362   if (nOuterPar == 0) {
363     LoopNest singleLoop = buildLoopNest(
364         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
365         [&](OpBuilder &b, Location loc, ValueRange ivs) {
366           ivStorage.append(ivs.begin(), ivs.end());
367           generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(),
368                                    steps.drop_front(),
369                                    iteratorTypes.drop_front(), bodyBuilderFn,
370                                    ivStorage, distributionMethod);
371         });
372     return;
373   }
374   if (distributionMethod.empty()) {
375     // Generate a single parallel loop-nest operation for all outermost
376     // parallel loops and recurse.
377     b.create<scf::ParallelOp>(
378         loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar),
379         steps.take_front(nOuterPar),
380         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
381           ivStorage.append(localIvs.begin(), localIvs.end());
382           generateParallelLoopNest(
383               nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar),
384               ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar),
385               iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage,
386               (distributionMethod.size() < nOuterPar)
387                   ? ArrayRef<DistributionMethod>()
388                   : distributionMethod.drop_front(nOuterPar));
389         });
390     return;
391   }
392 
393   // Process all consecutive similarly distributed loops simultaneously.
394   DistributionMethod methodToUse = distributionMethod[0];
395   unsigned numProcessed = 1;
396   for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) {
397     if (distributionMethod[i] != methodToUse)
398       break;
399     numProcessed++;
400   }
401 
402   switch (methodToUse) {
403   case DistributionMethod::Cyclic: {
404     // Generate a single parallel loop-nest operation for all outermost
405     // parallel loops and recurse.
406     b.create<scf::ParallelOp>(
407         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
408         steps.take_front(numProcessed),
409         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
410           ivStorage.append(localIvs.begin(), localIvs.end());
411           generateParallelLoopNest(
412               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
413               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
414               iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
415               (distributionMethod.size() < numProcessed)
416                   ? ArrayRef<DistributionMethod>()
417                   : distributionMethod.drop_front(numProcessed));
418         });
419     return;
420   }
421   case DistributionMethod::CyclicNumProcsGeNumIters: {
422     // Check (for the processed loops) that the iteration is in-bounds.
423     ArithBuilder ab(b, loc);
424     Value cond = ab.slt(lbs[0], ubs[0]);
425     for (unsigned i = 1; i < numProcessed; ++i)
426       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
427     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
428     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
429       generateParallelLoopNest(
430           b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
431           steps.drop_front(numProcessed),
432           iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
433           distributionMethod.drop_front(numProcessed));
434       b.create<scf::YieldOp>(loc, ValueRange{});
435     });
436     return;
437   }
438   case DistributionMethod::CyclicNumProcsEqNumIters:
439     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
440     // with inner loop generation.
441     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
442     generateParallelLoopNest(
443         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
444         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
445         bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed));
446     return;
447   }
448 }
449 
450 /// Specialization for generating a mix of parallel and sequential scf loops.
451 template <>
452 void GenerateLoopNest<scf::ParallelOp>::doit(
453     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
454     ArrayRef<Attribute> iteratorTypes,
455     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
456                                   ValueRange)>
457         bodyBuilderFn,
458     Optional<LinalgLoopDistributionOptions> distributionOptions,
459     ArrayRef<StringRef> distributionTypes) {
460   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
461   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
462   // This function may be passed more iterator types than ranges.
463   assert(iteratorTypes.size() >= loopRanges.size() &&
464          "expected iterator type for all ranges");
465   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
466   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
467   unsigned numLoops = iteratorTypes.size();
468   ivs.reserve(numLoops);
469   lbsStorage.reserve(numLoops);
470   ubsStorage.reserve(numLoops);
471   stepsStorage.reserve(numLoops);
472 
473   // Get the loop lb, ub, and step.
474   unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage);
475 
476   // Modify the lb, ub, and step based on the distribution options.
477   SmallVector<DistributionMethod, 0> distributionMethod;
478   if (distributionOptions) {
479     auto &options = distributionOptions.getValue();
480     distributionMethod.assign(distributionOptions->distributionMethod.begin(),
481                               distributionOptions->distributionMethod.end());
482     SmallVector<Range, 2> parallelLoopRanges;
483     for (auto iteratorType : enumerate(iteratorTypes)) {
484       if (isParallelIterator(iteratorType.value()))
485         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
486     }
487     if (distributionMethod.size() < parallelLoopRanges.size())
488       parallelLoopRanges.resize(distributionMethod.size());
489     SmallVector<ProcInfo, 2> procInfo =
490         options.procInfo(b, loc, parallelLoopRanges);
491     unsigned index = 0;
492     for (auto iteratorType : enumerate(iteratorTypes)) {
493       if (index >= procInfo.size())
494         break;
495       if (isParallelIterator(iteratorType.value())) {
496         unsigned i = iteratorType.index();
497         updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId,
498                                           procInfo[index].nprocs, lbsStorage[i],
499                                           ubsStorage[i], stepsStorage[i]);
500         index++;
501       }
502     }
503   }
504   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
505   generateParallelLoopNest(
506       b, loc, lbs, ubs, steps, iteratorTypes,
507       [&](OpBuilder &b, Location loc, ValueRange ivs) {
508         bodyBuilderFn(b, loc, ivs, {});
509       },
510       ivs, distributionMethod);
511 
512   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
513 }
514 
515 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
516                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
517                      ValueRange subShapeSizes) {
518   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
519   assert(shapedType && "only shaped types can be tiled");
520   ArrayRef<int64_t> shape = shapedType.getShape();
521   int64_t rank = shapedType.getRank();
522 
523   // Construct a new subview / extract_slice for the tile.
524   SmallVector<OpFoldResult, 4> offsets, sizes, strides;
525   offsets.reserve(rank);
526   sizes.reserve(rank);
527   strides.reserve(rank);
528   for (unsigned r = 0; r < rank; ++r) {
529     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r);
530     if (!isTiled(map.getSubMap({r}), tileSizes)) {
531       offsets.push_back(builder.getIndexAttr(0));
532       Value dim = createOrFoldDimOp(builder, loc, valueToTile, r);
533       sizes.push_back(getAsOpFoldResult(dim));
534       strides.push_back(builder.getIndexAttr(1));
535       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
536       continue;
537     }
538     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
539 
540     // Tiling creates a new slice at the proper index, the slice step is 1
541     // (i.e. the op does not subsample, stepping occurs in the loop).
542     auto m = map.getSubMap({r});
543     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
544     auto offset = applyMapToValues(builder, loc, m, lbs).front();
545     offsets.push_back(offset);
546     auto closedIntSize =
547         applyMapToValues(builder, loc, m, subShapeSizes).front();
548     // Resulting size needs to be made half open interval again.
549     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
550     Value size = makeComposedAffineApply(builder, loc, s0 + 1, closedIntSize);
551     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
552 
553     // The size of the subview / extract_slice should be trimmed to avoid
554     // out-of-bounds accesses, unless we statically know the subshape size
555     // divides the shape size evenly.
556     int64_t shapeSize = shape[r];
557     auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
558     if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
559         (shapeSize % sizeCst.getValue()) != 0) {
560       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
561                               << ", size: " << size
562                               << ": make sure in bound with affine.min\n");
563       AffineExpr dim0, dim1, dim2;
564       bindDims(builder.getContext(), dim0, dim1, dim2);
565       // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
566       AffineMap minMap =
567           AffineMap::inferFromExprList(
568               ArrayRef<ArrayRef<AffineExpr>>{{dim0, dim1 - dim2}})
569               .front();
570       Value d = createOrFoldDimOp(builder, loc, valueToTile, r);
571       SmallVector<Value, 4> operands{size, d, offset};
572       fullyComposeAffineMapAndOperands(&minMap, &operands);
573       size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap,
574                                          operands);
575     }
576 
577     sizes.push_back(size);
578     LLVM_DEBUG(llvm::dbgs()
579                << "makeTiledShape: new offset: " << offset << "\n");
580     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
581     strides.push_back(builder.getIndexAttr(1));
582   }
583 
584   Operation *sliceOp = shapedType.isa<MemRefType>()
585                            ? builder
586                                  .create<memref::SubViewOp>(
587                                      loc, valueToTile, offsets, sizes, strides)
588                                  .getOperation()
589                            : builder
590                                  .create<tensor::ExtractSliceOp>(
591                                      loc, valueToTile, offsets, sizes, strides)
592                                  .getOperation();
593   return sliceOp->getResult(0);
594 }
595 
596 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
597                                       ValueRange ivs, ValueRange tileSizes) {
598   SmallVector<Value> offsets;
599   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
600     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
601     bool isTiled = !isZero(tileSizes[idx]);
602     offsets.push_back(isTiled ? ivs[idxIvs++]
603                               : b.create<ConstantIndexOp>(loc, 0).getResult());
604     LLVM_DEBUG(llvm::dbgs()
605                << "computeTileOffsets: " << offsets.back() << "\n");
606   }
607   return offsets;
608 }
609 
610 SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
611                                     ValueRange tileSizes,
612                                     ArrayRef<Value> sizeBounds) {
613   SmallVector<Value> sizes;
614   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
615     bool isTiled = !isZero(tileSizes[idx]);
616     // Before composing, we need to make range a closed interval.
617     Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
618     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
619     sizes.push_back(makeComposedAffineApply(b, loc, d0 - 1, size));
620     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
621   }
622   return sizes;
623 }
624 
625 SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
626                                       LinalgOp linalgOp,
627                                       ArrayRef<Value> valuesToTile,
628                                       ValueRange ivs, ValueRange tileSizes,
629                                       ArrayRef<Value> sizeBounds) {
630   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
631                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
632                            [](Value v) { return !isZero(v); })) &&
633          "expected as many ivs as non-zero sizes");
634 
635   // Construct (potentially temporary) mins and maxes on which to apply maps
636   // that define tile subshapes.
637   SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
638   SmallVector<Value> subShapeSizes =
639       computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
640 
641   assert(static_cast<int64_t>(valuesToTile.size()) ==
642              linalgOp.getNumInputsAndOutputs() &&
643          "expected one value to tile for every operand");
644   SmallVector<Value, 4> tiledShapes;
645   tiledShapes.reserve(valuesToTile.size());
646   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
647     Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
648     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
649     AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
650     // If the shape is not tiled, we can use it as is.
651     if (!isTiled(map, tileSizes)) {
652       tiledShapes.push_back(shapedOp);
653       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
654                               << opOperand->get().getType() << "\n");
655       continue;
656     }
657     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
658 
659     tiledShapes.push_back(
660         makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs, subShapeSizes));
661   }
662 
663   return tiledShapes;
664 }
665 
666 } // namespace linalg
667 } // namespace mlir
668