1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/StaticValueUtils.h"
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/AffineExprVisitor.h"
31 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/Matchers.h"
33 #include "mlir/IR/OpImplementation.h"
34 #include "mlir/Pass/Pass.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/Debug.h"
37 
38 #define DEBUG_TYPE "linalg-utils"
39 
40 using namespace mlir;
41 using namespace presburger;
42 using namespace mlir::linalg;
43 using namespace mlir::scf;
44 
isZero(Value v)45 static bool isZero(Value v) {
46   if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
47     return cst.value() == 0;
48   return false;
49 }
50 
51 namespace {
52 
53 // Helper visitor to determine whether an AffineExpr is tiled.
54 // This is achieved by traversing every AffineDimExpr with position `pos` and
55 // checking whether the corresponding `tileSizes[pos]` is non-zero.
56 // This also enforces only positive coefficients occur in multiplications.
57 //
58 // Example:
59 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
60 //
61 struct TileCheck : public AffineExprVisitor<TileCheck> {
TileCheck__anon19d834b50111::TileCheck62   TileCheck(ValueRange tileSizes) : tileSizes(tileSizes) {}
63 
visitDimExpr__anon19d834b50111::TileCheck64   void visitDimExpr(AffineDimExpr expr) {
65     isTiled |= !isZero(tileSizes[expr.getPosition()]);
66   }
visitAffineBinaryOpExpr__anon19d834b50111::TileCheck67   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
68     visit(expr.getLHS());
69     visit(expr.getRHS());
70     if (expr.getKind() == mlir::AffineExprKind::Mul)
71       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
72              "nonpositive multiplying coefficient");
73   }
74   bool isTiled = false;
75   ValueRange tileSizes;
76 };
77 
78 } // namespace
79 
isTiled(AffineExpr expr,ValueRange tileSizes)80 static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
81   if (!expr)
82     return false;
83   TileCheck t(tileSizes);
84   t.visit(expr);
85   return t.isTiled;
86 }
87 
88 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
isTiled(AffineMap map,ValueRange tileSizes)89 static bool isTiled(AffineMap map, ValueRange tileSizes) {
90   if (!map)
91     return false;
92   for (unsigned r = 0; r < map.getNumResults(); ++r)
93     if (isTiled(map.getResult(r), tileSizes))
94       return true;
95   return false;
96 }
97 
98 Optional<RegionMatcher::BinaryOpKind>
matchAsScalarBinaryOp(GenericOp op)99 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
100   auto &region = op.region();
101   if (!llvm::hasSingleElement(region))
102     return llvm::None;
103 
104   Block &block = region.front();
105   if (block.getNumArguments() != 2 ||
106       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
107       !block.getArgument(1).getType().isSignlessIntOrFloat())
108     return llvm::None;
109 
110   auto &ops = block.getOperations();
111   if (!llvm::hasSingleElement(block.without_terminator()))
112     return llvm::None;
113 
114   using mlir::matchers::m_Val;
115   auto a = m_Val(block.getArgument(0));
116   auto b = m_Val(block.getArgument(1));
117 
118   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
119   if (addPattern.match(&ops.back()))
120     return BinaryOpKind::IAdd;
121 
122   return llvm::None;
123 }
124 
125 /// Explicit instantiation of loop nest generator for different loop types.
126 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
127 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
128 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
129 
130 /// Given a list of subview ranges, extract individual values for lower, upper
131 /// bounds and steps and put them into the corresponding vectors.
unpackRanges(ArrayRef<Range> ranges,SmallVectorImpl<Value> & lbs,SmallVectorImpl<Value> & ubs,SmallVectorImpl<Value> & steps)132 static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
133                          SmallVectorImpl<Value> &ubs,
134                          SmallVectorImpl<Value> &steps) {
135   for (Range range : ranges) {
136     lbs.emplace_back(range.offset);
137     ubs.emplace_back(range.size);
138     steps.emplace_back(range.stride);
139   }
140 }
141 
142 namespace mlir {
143 namespace linalg {
144 
allIndexingsAreProjectedPermutation(LinalgOp op)145 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
146   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
147     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
148   });
149 }
150 
hasOnlyScalarElementwiseOp(Region & r)151 bool hasOnlyScalarElementwiseOp(Region &r) {
152   if (!llvm::hasSingleElement(r))
153     return false;
154   for (Operation &op : r.front()) {
155     if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
156               linalg::IndexOp>(op) ||
157           OpTrait::hasElementwiseMappableTraits(&op)) ||
158         llvm::any_of(op.getResultTypes(),
159                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
160       return false;
161   }
162   return true;
163 }
164 
isElementwise(LinalgOp op)165 bool isElementwise(LinalgOp op) {
166   if (op.getNumLoops() != op.getNumParallelLoops())
167     return false;
168 
169   if (!allIndexingsAreProjectedPermutation(op))
170     return false;
171 
172   // TODO: relax the restrictions on indexing map.
173   for (OpOperand *opOperand : op.getOutputOperands()) {
174     if (!op.getTiedIndexingMap(opOperand).isPermutation())
175       return false;
176   }
177   return hasOnlyScalarElementwiseOp(op->getRegion(0));
178 }
179 
isPermutation(ArrayRef<int64_t> permutation)180 bool isPermutation(ArrayRef<int64_t> permutation) {
181   // Count the number of appearances for all indices.
182   SmallVector<int64_t> indexCounts(permutation.size(), 0);
183   for (auto index : permutation) {
184     // Exit if the index is out-of-range.
185     if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
186       return false;
187     indexCounts[index]++;
188   }
189   // Return true if all indices appear once.
190   return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
191 }
192 
193 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
194 /// the type of `source`.
createOrFoldDimOp(OpBuilder & b,Location loc,Value source,int64_t dim)195 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
196   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
197     return b.createOrFold<memref::DimOp>(loc, source, dim);
198   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
199     return b.createOrFold<tensor::DimOp>(loc, source, dim);
200   llvm_unreachable("Expected MemRefType or TensorType");
201 }
202 
203 /// Given an operation, retrieves the value of each dynamic dimension through
204 /// constructing the necessary DimOp operators.
getDynOperands(Location loc,Value val,OpBuilder & b)205 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) {
206   SmallVector<Value, 4> dynOperands;
207   auto shapedType = val.getType().cast<ShapedType>();
208   for (const auto &dim : llvm::enumerate(shapedType.getShape())) {
209     if (dim.value() == ShapedType::kDynamicSize)
210       dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index()));
211   }
212   return dynOperands;
213 }
214 
getUpperBoundForIndex(Value value,AffineMap & boundMap,SmallVectorImpl<Value> & boundOperands,bool constantRequired)215 void getUpperBoundForIndex(Value value, AffineMap &boundMap,
216                            SmallVectorImpl<Value> &boundOperands,
217                            bool constantRequired) {
218   // Initialize `boundMap` and `boundOperands` to the identity returning
219   // `value`. This combination is the default result of the method if no
220   // simplification is possible.
221   assert(value.getType().isIndex() && "expect value to have index type");
222   boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext());
223   boundOperands.assign({value});
224   canonicalizeMapAndOperands(&boundMap, &boundOperands);
225 
226   // Continue only if there is an affine index computation to simplify.
227   Operation *definingOp = value.getDefiningOp();
228   if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp))
229     return;
230 
231   // Get the backward slice containing the affine index computation.
232   SetVector<Operation *> backwardSlice;
233   getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) {
234     return isa<AffineApplyOp, AffineMinOp>(op);
235   });
236   backwardSlice.insert(definingOp);
237 
238   // Setup a system of affine constraints that describe the index computation.
239   FlatAffineValueConstraints constraints;
240 
241   // Helper to find or create an identifier for the given value.
242   auto findOrCreateId = [&](Value value) {
243     if (!constraints.containsVar(value)) {
244       constraints.appendDimVar(value);
245       return true;
246     }
247     unsigned pos;
248     constraints.findVar(value, &pos);
249     return pos < constraints.getNumDimVars();
250   };
251   // Helper to get the position for the given value.
252   auto getPosition = [&](Value value) {
253     unsigned pos;
254     bool exists = constraints.findVar(value, &pos);
255     (void)exists;
256     assert(exists && "expect to find the identifier");
257     return pos;
258   };
259 
260   // Add the affine operations in `backwardSlice` to the constraints.
261   for (Operation *op : llvm::reverse(backwardSlice)) {
262     // Add an identifier for all op results and operands.
263     if (!(llvm::all_of(op->getResults(), findOrCreateId) &&
264           llvm::all_of(op->getOperands(), findOrCreateId)))
265       return;
266 
267     // Add AffineApplyOps to the constraints.
268     if (auto applyOp = dyn_cast<AffineApplyOp>(op)) {
269       AffineMap map = constraints.computeAlignedMap(applyOp.getAffineMap(),
270                                                     applyOp.getOperands());
271       if (failed(constraints.addBound(IntegerPolyhedron::EQ,
272                                       getPosition(applyOp.getResult()), map)))
273         return;
274       continue;
275     }
276     // Add AffineMinOps to the constraints.
277     auto minOp = cast<AffineMinOp>(op);
278     AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(),
279                                                   minOp.getOperands());
280     if (failed(constraints.addBound(IntegerPolyhedron::UB,
281                                     getPosition(minOp.getResult()), map,
282                                     /*isClosedBound=*/true)))
283       return;
284   }
285 
286   // Obtain an upper bound for the affine index computation by projecting out
287   // all temporary results and expressing the upper bound for `value` in terms
288   // of the terminals of the index computation.
289   unsigned pos = getPosition(value);
290   if (constantRequired) {
291     auto ubConst = constraints.getConstantBound(
292         FlatAffineValueConstraints::BoundType::UB, pos);
293     if (!ubConst)
294       return;
295 
296     boundMap = AffineMap::getConstantMap(*ubConst, value.getContext());
297     return;
298   }
299 
300   SmallVector<AffineMap> lowerBounds(1), upperBounds(1);
301   constraints.getSliceBounds(pos, 1, value.getContext(), &lowerBounds,
302                              &upperBounds,
303                              /*getClosedUB=*/true);
304   // Verify `upperBounds[0]` is valid and has at least one result.
305   if (!upperBounds[0] || upperBounds[0].getNumResults() == 0)
306     return;
307 
308   // Set `boundMap` and `boundOperands` to the computed upper bound.
309   boundMap = upperBounds[0];
310   constraints.getAllValues(&boundOperands);
311   erase_value(boundOperands, value);
312   canonicalizeMapAndOperands(&boundMap, &boundOperands);
313 }
314 
getConstantUpperBoundForIndex(Value value)315 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) {
316   // Compute an upper bound for `value`.
317   AffineMap boundMap;
318   SmallVector<Value> boundOperands;
319   getUpperBoundForIndex(value, boundMap, boundOperands,
320                         /*constantRequired=*/true);
321 
322   // Search the results of `boundMap` for constant upper bounds.
323   SmallVector<int64_t> constantBounds;
324   for (AffineExpr result : boundMap.getResults())
325     if (auto constExpr = result.dyn_cast<AffineConstantExpr>())
326       constantBounds.push_back(constExpr.getValue());
327 
328   // Return the minimal upper bound or failure if none is found.
329   if (constantBounds.empty())
330     return failure();
331   return *std::min_element(constantBounds.begin(), constantBounds.end());
332 }
333 
makeComposedExtractSliceOp(OpBuilder & b,Location loc,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)334 tensor::ExtractSliceOp makeComposedExtractSliceOp(
335     OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
336     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
337   assert(source && "expect source to be nonzero");
338 
339   // Do not fold if the producer is not an ExtractSliceOp.
340   auto producerOp = source.getDefiningOp<tensor::ExtractSliceOp>();
341   if (!producerOp)
342     return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
343                                             strides);
344 
345   // Do not fold if the producer is rank reducing or if there are any non-unit
346   // strides. Supporting non-unit strides complicates the offset computation
347   // since the consumer offsets need to be multiplied by the producer strides.
348   // TODO: support non-unit strides once there are use cases.
349   SmallVector<OpFoldResult> allStrides = producerOp.getMixedStrides();
350   allStrides.append(strides.begin(), strides.end());
351   bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) {
352     return getConstantIntValue(ofr) != static_cast<int64_t>(1);
353   });
354   if (hasNonUnitStride ||
355       producerOp.getSourceType().getRank() !=
356           producerOp.getResult().getType().cast<ShapedType>().getRank())
357     return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
358                                             strides);
359 
360   // Fold the producer by adding the offests and extracting the slice directly
361   // from the producer source tensor.
362   SmallVector<OpFoldResult> foldedOffsets(offsets.begin(), offsets.end());
363   AffineExpr dim1, dim2;
364   bindDims(b.getContext(), dim1, dim2);
365   for (const auto &en : enumerate(producerOp.getMixedOffsets())) {
366     SmallVector<Value> offsetValues = {
367         getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]),
368         getValueOrCreateConstantIndexOp(b, loc, en.value())};
369     foldedOffsets[en.index()] =
370         makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult();
371   }
372   return b.create<tensor::ExtractSliceOp>(loc, producerOp.getSource(),
373                                           foldedOffsets, sizes, strides);
374 }
375 
makeComposedPadHighOp(OpBuilder & b,Location loc,RankedTensorType type,Value source,Value pad,bool nofold)376 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
377                             Value source, Value pad, bool nofold) {
378   // Exit if `source` is not defined by an ExtractSliceOp.
379   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
380   if (!sliceOp)
381     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
382 
383   // Search the `source` use-def chain for padded LinalgOps.
384   Value current = sliceOp.getSource();
385   while (current) {
386     auto linalgOp = current.getDefiningOp<LinalgOp>();
387     if (!linalgOp)
388       break;
389     OpResult opResult = current.cast<OpResult>();
390     current = linalgOp.getOutputOperand(opResult.getResultNumber())->get();
391   }
392   auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
393 
394   // Exit if the search fails to match a tensor::PadOp at the end of the matched
395   // LinalgOp sequence.
396   if (!padOp)
397     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
398 
399   // Exit if the padded result type does not match.
400   if (sliceOp.getSource().getType() != type)
401     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
402 
403   // Exit if the LinalgOps are not high padded.
404   if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
405         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
406       }))
407     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
408 
409   // Exit if `padOpSliceOp`, which defines the slice used by
410   // `padOp`, is rank-reducing.
411   auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
412   if (!padOpSliceOp ||
413       sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
414     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
415 
416   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
417   // of the slice padded by `padOp`.
418   if (llvm::any_of(
419           llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
420           [](std::tuple<OpFoldResult, OpFoldResult> it) {
421             return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
422           }))
423     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
424 
425   // Exit if the padding values do not match.
426   Attribute padOpPadAttr, padAttr;
427   Value padOpPad = padOp.getConstantPaddingValue();
428   if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
429       !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
430     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
431 
432   // Return the padded result if the padding values and sizes match.
433   return sliceOp.getSource();
434 }
435 
makeTransposeOp(OpBuilder & b,Location loc,Value inputTensor,Value outputTensor,ArrayRef<int64_t> transposeVector)436 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
437                           Value outputTensor,
438                           ArrayRef<int64_t> transposeVector) {
439   auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
440   Type elementType = resultTensorType.getElementType();
441 
442   assert(isPermutation(transposeVector) &&
443          "expect transpose vector to be a permutation");
444   assert(transposeVector.size() ==
445              static_cast<size_t>(resultTensorType.getRank()) &&
446          "expect transpose vector size to match result tensor rank");
447 
448   // Compute the transpose and the indentity indexing maps.
449   SmallVector<AffineMap> indexingMaps = {
450       inversePermutation(AffineMap::getPermutationMap(
451           SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()),
452           b.getContext())),
453       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
454                                         b.getContext())};
455   SmallVector<llvm::StringRef> iteratorTypes(transposeVector.size(),
456                                              getParallelIteratorTypeName());
457 
458   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
459   auto transposeOp = b.create<GenericOp>(
460       loc, resultTensorType, inputTensor, outputTensor,
461       b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes),
462       /*doc=*/nullptr,
463       /*library_call=*/nullptr);
464   Region &body = transposeOp.getRegion();
465   body.push_back(new Block());
466   body.front().addArguments({elementType, elementType}, {loc, loc});
467 
468   // Create the body of the transpose operation.
469   OpBuilder::InsertionGuard g(b);
470   b.setInsertionPointToEnd(&body.front());
471   b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
472   return transposeOp;
473 }
474 
makeMemRefCopyOp(OpBuilder & b,Location loc,Value from,Value to)475 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
476   auto memrefTypeTo = to.getType().cast<MemRefType>();
477 #ifndef NDEBUG
478   auto memrefTypeFrom = from.getType().cast<MemRefType>();
479   assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
480          "`from` and `to` memref must have the same rank");
481 #endif // NDEBUG
482 
483   AffineMap id =
484       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
485   SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
486                                        getParallelIteratorTypeName());
487   return b.create<linalg::GenericOp>(
488       loc,
489       /*inputs=*/from,
490       /*outputs=*/to,
491       /*indexingMaps=*/llvm::makeArrayRef({id, id}),
492       /*iteratorTypes=*/iteratorTypes,
493       [](OpBuilder &b, Location loc, ValueRange args) {
494         b.create<linalg::YieldOp>(loc, args.front());
495       });
496 }
497 
498 /// Specialization to build an scf "for" nest.
499 template <>
doit(OpBuilder & b,Location loc,ArrayRef<Range> loopRanges,LinalgOp linalgOp,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn,Optional<LinalgLoopDistributionOptions> distributionOptions,ArrayRef<StringRef> distributionTypes)500 void GenerateLoopNest<scf::ForOp>::doit(
501     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
502     ArrayRef<Attribute> iteratorTypes,
503     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
504                                   ValueRange)>
505         bodyBuilderFn,
506     Optional<LinalgLoopDistributionOptions> distributionOptions,
507     ArrayRef<StringRef> distributionTypes) {
508   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
509   // Create procInfo so it dominates loops, if appropriate.
510   SmallVector<ProcInfo, 4> procInfo;
511   SmallVector<DistributionMethod, 0> distributionMethod;
512   if (distributionOptions) {
513     // Collect loop ranges for parallel dimensions.
514     SmallVector<Range, 2> parallelLoopRanges;
515     for (const auto &iteratorType : enumerate(iteratorTypes))
516       if (isParallelIterator(iteratorType.value()))
517         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
518 
519     // Get their distribution schemes.
520     distributionMethod = distributionOptions->distributionMethod;
521     if (distributionMethod.size() < parallelLoopRanges.size())
522       parallelLoopRanges.resize(distributionMethod.size());
523     procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges);
524   }
525 
526   SmallVector<Value, 4> lbs, ubs, steps;
527   unpackRanges(loopRanges, lbs, ubs, steps);
528   LoopNest loopNest = mlir::scf::buildLoopNest(
529       b, loc, lbs, ubs, steps, iterArgInitValues,
530       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
531         assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
532                "expect the number of output tensors and iter args to match");
533         SmallVector<Value> operandValuesToUse =
534             linalgOp.getInputAndOutputOperands();
535         if (!iterArgs.empty()) {
536           operandValuesToUse = linalgOp.getInputOperands();
537           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
538         }
539         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
540       });
541 
542   if (!distributionOptions || loopNest.loops.empty())
543     return;
544 
545   // Filter out scf.for loops that were created out of parallel dimensions.
546   SmallVector<scf::ForOp, 4> loops;
547   for (const auto &iteratorType : enumerate(iteratorTypes))
548     if (isParallelIterator(iteratorType.value()))
549       loops.push_back(loopNest.loops[iteratorType.index()]);
550 
551   // Distribute - only supports cyclic distribution for now.
552   for (auto it : llvm::zip(loops, procInfo, distributionMethod))
553     if (std::get<2>(it) == DistributionMethod::Cyclic)
554       mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
555                             std::get<1>(it).nprocs);
556 }
557 
558 /// Specialization to build affine "for" nest.
559 template <>
doit(OpBuilder & b,Location loc,ArrayRef<Range> loopRanges,LinalgOp linalgOp,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn,Optional<LinalgLoopDistributionOptions>,ArrayRef<StringRef>)560 void GenerateLoopNest<AffineForOp>::doit(
561     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
562     ArrayRef<Attribute> iteratorTypes,
563     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
564                                   ValueRange)>
565         bodyBuilderFn,
566     Optional<LinalgLoopDistributionOptions>, ArrayRef<StringRef>) {
567   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
568   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
569   SmallVector<Value, 4> lbs, ubs, steps;
570   unpackRanges(loopRanges, lbs, ubs, steps);
571 
572   // Affine loops require constant steps.
573   SmallVector<int64_t, 4> constantSteps;
574   constantSteps.reserve(steps.size());
575   for (Value v : steps) {
576     auto op = v.getDefiningOp<arith::ConstantIndexOp>();
577     assert(op && "Affine loops require constant steps");
578     constantSteps.push_back(op.value());
579   }
580 
581   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
582                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
583                               SmallVector<Value> operandValuesToUse =
584                                   linalgOp.getInputAndOutputOperands();
585                               bodyBuilderFn(b, loc, ivs, operandValuesToUse);
586                             });
587 }
588 
589 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
updateBoundsForCyclicDistribution(OpBuilder & b,Location loc,Value procId,Value nprocs,Value & lb,Value & ub,Value & step)590 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
591                                        Value nprocs, Value &lb, Value &ub,
592                                        Value &step) {
593   AffineExpr d0, d1;
594   bindDims(b.getContext(), d0, d1);
595   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
596   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
597   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
598 }
599 
600 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
601 /// on the `iteratorTypes.` Consecutive parallel loops create a single
602 /// scf.parallel operation; each sequential loop creates a new scf.for
603 /// operation. The body of the innermost loop is populated by
604 /// `bodyBuilderFn` that accepts a range of induction variables for all
605 /// loops. `ivStorage` is used to store the partial list of induction
606 /// variables.
607 // TODO: this function can be made iterative instead. However, it
608 // will have at most as many recursive calls as nested loops, which rarely
609 // exceeds 10.
generateParallelLoopNest(OpBuilder & b,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,ArrayRef<Attribute> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,SmallVectorImpl<Value> & ivStorage,ArrayRef<DistributionMethod> distributionMethod={})610 static void generateParallelLoopNest(
611     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
612     ValueRange steps, ArrayRef<Attribute> iteratorTypes,
613     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
614     SmallVectorImpl<Value> &ivStorage,
615     ArrayRef<DistributionMethod> distributionMethod = {}) {
616   assert(lbs.size() == ubs.size());
617   assert(lbs.size() == steps.size());
618   assert(lbs.size() == iteratorTypes.size());
619 
620   // If there are no (more) loops to be generated, generate the body and be
621   // done with it.
622   if (iteratorTypes.empty()) {
623     bodyBuilderFn(b, loc, ivStorage);
624     return;
625   }
626 
627   // Find the outermost parallel loops and drop their types from the list.
628   unsigned nLoops = iteratorTypes.size();
629   unsigned nOuterPar =
630       nLoops - iteratorTypes.drop_while(isParallelIterator).size();
631 
632   // If there are no outer parallel loops, generate one sequential loop and
633   // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
634   // in this case.
635   if (nOuterPar == 0) {
636     LoopNest singleLoop = buildLoopNest(
637         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
__anon19d834b50d02(OpBuilder &b, Location loc, ValueRange ivs) 638         [&](OpBuilder &b, Location loc, ValueRange ivs) {
639           ivStorage.append(ivs.begin(), ivs.end());
640           generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(),
641                                    steps.drop_front(),
642                                    iteratorTypes.drop_front(), bodyBuilderFn,
643                                    ivStorage, distributionMethod);
644         });
645     return;
646   }
647   if (distributionMethod.empty()) {
648     // Generate a single parallel loop-nest operation for all outermost
649     // parallel loops and recurse.
650     b.create<scf::ParallelOp>(
651         loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar),
652         steps.take_front(nOuterPar),
__anon19d834b50e02(OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) 653         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
654           ivStorage.append(localIvs.begin(), localIvs.end());
655           generateParallelLoopNest(
656               nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar),
657               ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar),
658               iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage,
659               (distributionMethod.size() < nOuterPar)
660                   ? ArrayRef<DistributionMethod>()
661                   : distributionMethod.drop_front(nOuterPar));
662         });
663     return;
664   }
665 
666   // Process all consecutive similarly distributed loops simultaneously.
667   DistributionMethod methodToUse = distributionMethod[0];
668   unsigned numProcessed = 1;
669   for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) {
670     if (distributionMethod[i] != methodToUse)
671       break;
672     numProcessed++;
673   }
674 
675   switch (methodToUse) {
676   case DistributionMethod::Cyclic: {
677     // Generate a single parallel loop-nest operation for all outermost
678     // parallel loops and recurse.
679     b.create<scf::ParallelOp>(
680         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
681         steps.take_front(numProcessed),
__anon19d834b50f02(OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) 682         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
683           ivStorage.append(localIvs.begin(), localIvs.end());
684           generateParallelLoopNest(
685               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
686               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
687               iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
688               (distributionMethod.size() < numProcessed)
689                   ? ArrayRef<DistributionMethod>()
690                   : distributionMethod.drop_front(numProcessed));
691         });
692     return;
693   }
694   case DistributionMethod::CyclicNumProcsGeNumIters: {
695     // Check (for the processed loops) that the iteration is in-bounds.
696     ArithBuilder ab(b, loc);
697     Value cond = ab.slt(lbs[0], ubs[0]);
698     for (unsigned i = 1; i < numProcessed; ++i)
699       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
700     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
__anon19d834b51002(OpBuilder &b, Location loc) 701     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
702       generateParallelLoopNest(
703           b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
704           steps.drop_front(numProcessed),
705           iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
706           distributionMethod.drop_front(numProcessed));
707       b.create<scf::YieldOp>(loc, ValueRange{});
708     });
709     return;
710   }
711   case DistributionMethod::CyclicNumProcsEqNumIters:
712     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
713     // with inner loop generation.
714     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
715     generateParallelLoopNest(
716         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
717         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
718         bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed));
719     return;
720   }
721 }
722 
723 /// Specialization for generating a mix of parallel and sequential scf loops.
724 template <>
doit(OpBuilder & b,Location loc,ArrayRef<Range> loopRanges,LinalgOp linalgOp,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn,Optional<LinalgLoopDistributionOptions> distributionOptions,ArrayRef<StringRef> distributionTypes)725 void GenerateLoopNest<scf::ParallelOp>::doit(
726     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
727     ArrayRef<Attribute> iteratorTypes,
728     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
729                                   ValueRange)>
730         bodyBuilderFn,
731     Optional<LinalgLoopDistributionOptions> distributionOptions,
732     ArrayRef<StringRef> distributionTypes) {
733   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
734   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
735   // This function may be passed more iterator types than ranges.
736   assert(iteratorTypes.size() >= loopRanges.size() &&
737          "expected iterator type for all ranges");
738   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
739   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
740   unsigned numLoops = iteratorTypes.size();
741   ivs.reserve(numLoops);
742   lbsStorage.reserve(numLoops);
743   ubsStorage.reserve(numLoops);
744   stepsStorage.reserve(numLoops);
745 
746   // Get the loop lb, ub, and step.
747   unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage);
748 
749   // Modify the lb, ub, and step based on the distribution options.
750   SmallVector<DistributionMethod, 0> distributionMethod;
751   if (distributionOptions) {
752     auto &options = *distributionOptions;
753     distributionMethod.assign(distributionOptions->distributionMethod.begin(),
754                               distributionOptions->distributionMethod.end());
755     SmallVector<Range, 2> parallelLoopRanges;
756     for (const auto &iteratorType : enumerate(iteratorTypes)) {
757       if (isParallelIterator(iteratorType.value()))
758         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
759     }
760     if (distributionMethod.size() < parallelLoopRanges.size())
761       parallelLoopRanges.resize(distributionMethod.size());
762     SmallVector<ProcInfo, 2> procInfo =
763         options.procInfo(b, loc, parallelLoopRanges);
764     unsigned index = 0;
765     for (const auto &iteratorType : enumerate(iteratorTypes)) {
766       if (index >= procInfo.size())
767         break;
768       if (isParallelIterator(iteratorType.value())) {
769         unsigned i = iteratorType.index();
770         updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId,
771                                           procInfo[index].nprocs, lbsStorage[i],
772                                           ubsStorage[i], stepsStorage[i]);
773         index++;
774       }
775     }
776   }
777   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
778   generateParallelLoopNest(
779       b, loc, lbs, ubs, steps, iteratorTypes,
780       [&](OpBuilder &b, Location loc, ValueRange ivs) {
781         SmallVector<Value> operandValuesToUse =
782             linalgOp.getInputAndOutputOperands();
783         bodyBuilderFn(b, loc, ivs, operandValuesToUse);
784       },
785       ivs, distributionMethod);
786 
787   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
788 }
789 
fullyComposeAndAffineApply(OpBuilder & b,Location loc,AffineExpr expr,ValueRange operands)790 static Value fullyComposeAndAffineApply(OpBuilder &b, Location loc,
791                                         AffineExpr expr, ValueRange operands) {
792   AffineMap map = AffineMap::inferFromExprList({expr}).front();
793   SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
794   mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
795   canonicalizeMapAndOperands(&map, &normalizedOperands);
796   return b.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
797 }
798 
makeTiledShape(OpBuilder & builder,Location loc,Value valueToTile,ValueRange tileSizes,AffineMap map,ValueRange lbs,ValueRange ubs,ValueRange subShapeSizes,bool omitPartialTileCheck)799 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
800                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
801                      ValueRange ubs, ValueRange subShapeSizes,
802                      bool omitPartialTileCheck) {
803   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
804   assert(shapedType && "only shaped types can be tiled");
805   ArrayRef<int64_t> shape = shapedType.getShape();
806   int64_t rank = shapedType.getRank();
807 
808   // Construct a new subview / extract_slice for the tile.
809   SmallVector<OpFoldResult, 4> offsets, sizes, strides;
810   offsets.reserve(rank);
811   sizes.reserve(rank);
812   strides.reserve(rank);
813   for (unsigned r = 0; r < rank; ++r) {
814     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r);
815     if (!isTiled(map.getSubMap({r}), tileSizes)) {
816       offsets.push_back(builder.getIndexAttr(0));
817       Value dim = createOrFoldDimOp(builder, loc, valueToTile, r);
818       sizes.push_back(getAsOpFoldResult(dim));
819       strides.push_back(builder.getIndexAttr(1));
820       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
821       continue;
822     }
823     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
824 
825     // Tiling creates a new slice at the proper index, the slice step is 1
826     // (i.e. the op does not subsample, stepping occurs in the loop).
827     auto m = map.getSubMap({r});
828     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
829     auto offset = applyMapToValues(builder, loc, m, lbs).front();
830     offsets.push_back(getAsOpFoldResult(offset));
831     auto closedIntSize =
832         applyMapToValues(builder, loc, m, subShapeSizes).front();
833     // Resulting size needs to be made half open interval again.
834     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
835     Value size =
836         fullyComposeAndAffineApply(builder, loc, s0 + 1, closedIntSize);
837     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
838     LLVM_DEBUG(llvm::dbgs()
839                << "makeTiledShape: new offset: " << offset << "\n");
840     strides.push_back(builder.getIndexAttr(1));
841 
842     if (omitPartialTileCheck) {
843       // We statically know that the partial/boundary tile condition is
844       // unnecessary.
845       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
846       sizes.push_back(getAsOpFoldResult(size));
847       continue;
848     }
849 
850     // The size of the subview / extract_slice should be trimmed to avoid
851     // out-of-bounds accesses, unless:
852     // a. We statically know the subshape size divides the shape size evenly.
853     // b. The subshape size is 1. According to the way the loops are set up,
854     //    tensors with "0" dimensions would never be constructed.
855     int64_t shapeSize = shape[r];
856     auto sizeCst = size.getDefiningOp<arith::ConstantIndexOp>();
857     auto hasTileSizeOne = sizeCst && sizeCst.value() == 1;
858     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
859                          ((shapeSize % sizeCst.value()) == 0);
860     if (!hasTileSizeOne && !dividesEvenly) {
861       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
862                               << ", size: " << size
863                               << ": make sure in bound with affine.min\n");
864 
865       AffineExpr dim0, dim1, dim2;
866       bindDims(builder.getContext(), dim0, dim1, dim2);
867 
868       // Get the dimension size for this dimension. We need to first calculate
869       // the max index and then plus one. This is important because for
870       // convolution ops, we have its input window dimension's affine map of the
871       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
872       // dimension and `s0` is stride. Directly use the dimension size of
873       // output/filer window dimensions will cause incorrect calculation.
874       AffineMap minusOneMap =
875           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
876               .front();
877       AffineMap plusOneMap =
878           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
879               .front();
880       auto maxIndices = llvm::to_vector<8>(llvm::map_range(ubs, [&](Value ub) {
881         return makeComposedAffineApply(builder, loc, minusOneMap, {ub})
882             .getResult();
883       }));
884       Value maxIndex = applyMapToValues(builder, loc, m, maxIndices).front();
885       Value d = makeComposedAffineApply(builder, loc, plusOneMap, {maxIndex});
886 
887       // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
888       AffineMap minMap = AffineMap::inferFromExprList(
889                              {ArrayRef<AffineExpr>{dim1 - dim2, dim0}})
890                              .front();
891       SmallVector<Value, 4> operands{size, d, offset};
892       fullyComposeAffineMapAndOperands(&minMap, &operands);
893       canonicalizeMapAndOperands(&minMap, &operands);
894       size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap,
895                                          operands);
896     }
897     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
898     sizes.push_back(getAsOpFoldResult(size));
899   }
900 
901   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
902                       .Case([&](MemRefType) {
903                         return builder.create<memref::SubViewOp>(
904                             loc, valueToTile, offsets, sizes, strides);
905                       })
906                       .Case([&](RankedTensorType) {
907                         return makeComposedExtractSliceOp(
908                             builder, loc, valueToTile, offsets, sizes, strides);
909                       })
910                       .Default([](ShapedType) -> Operation * {
911                         llvm_unreachable("Unexpected shaped type");
912                       });
913   return sliceOp->getResult(0);
914 }
915 
createSlice(OpBuilder & builder,Location loc,Value value,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)916 Value createSlice(OpBuilder &builder, Location loc, Value value,
917                   ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
918                   ArrayRef<OpFoldResult> strides) {
919   if (value.getType().isa<MemRefType>()) {
920     return builder.create<memref::SubViewOp>(loc, value, offsets, sizes,
921                                              strides);
922   }
923 
924   // This intentionally does not attempt to compose the extractslice operations.
925   assert(value.getType().isa<RankedTensorType>() &&
926          "expected a ranked tensor type");
927   return builder.create<tensor::ExtractSliceOp>(loc, value, offsets, sizes,
928                                                 strides);
929 }
930 
computeTileOffsets(OpBuilder & b,Location loc,ValueRange ivs,ValueRange tileSizes)931 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
932                                       ValueRange ivs, ValueRange tileSizes) {
933   SmallVector<Value> offsets;
934   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
935     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
936     bool isTiled = !isZero(tileSizes[idx]);
937     offsets.push_back(
938         isTiled ? ivs[idxIvs++]
939                 : b.create<arith::ConstantIndexOp>(loc, 0).getResult());
940     LLVM_DEBUG(llvm::dbgs()
941                << "computeTileOffsets: " << offsets.back() << "\n");
942   }
943   return offsets;
944 }
945 
computeTileSizes(OpBuilder & b,Location loc,ValueRange tileSizes,ArrayRef<Value> sizeBounds)946 SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
947                                     ValueRange tileSizes,
948                                     ArrayRef<Value> sizeBounds) {
949   SmallVector<Value> sizes;
950   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
951     bool isTiled = !isZero(tileSizes[idx]);
952     // Before composing, we need to make range a closed interval.
953     Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
954     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
955     sizes.push_back(fullyComposeAndAffineApply(b, loc, d0 - 1, size));
956     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
957   }
958   return sizes;
959 }
960 
getTensorOutputTypes(LinalgOp op,ValueRange operands)961 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
962   // TODO: use an interface/adaptor to avoid leaking position in
963   // `tiledOperands`.
964   return llvm::to_vector(
965       llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) {
966         return operands[opOperand->getOperandNumber()].getType();
967       }));
968 }
969 
insertSlicesBack(OpBuilder & builder,Location loc,LinalgOp op,ValueRange operands,ValueRange results)970 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
971                                     LinalgOp op, ValueRange operands,
972                                     ValueRange results) {
973   SmallVector<Value> tensorResults;
974   tensorResults.reserve(results.size());
975   // Insert a insert_slice for each output tensor.
976   unsigned resultIdx = 0;
977   for (OpOperand *opOperand : op.getOutputTensorOperands()) {
978     // TODO: use an interface/adaptor to avoid leaking position in
979     // `tiledOperands`.
980     Value outputTensor = operands[opOperand->getOperandNumber()];
981     if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
982       Value inserted = builder.create<tensor::InsertSliceOp>(
983           loc, sliceOp.getSource().getType(), results[resultIdx],
984           sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
985           sliceOp.getStrides(), sliceOp.getStaticOffsets(),
986           sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
987       tensorResults.push_back(inserted);
988     } else {
989       tensorResults.push_back(results[resultIdx]);
990     }
991     ++resultIdx;
992   }
993   return tensorResults;
994 }
995 
materializeOpFoldResult(ImplicitLocOpBuilder & builder,OpFoldResult opFoldResult)996 Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
997                               OpFoldResult opFoldResult) {
998   if (auto value = opFoldResult.dyn_cast<Value>())
999     return value;
1000   auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
1001   return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
1002 }
1003 
makeTiledShapes(OpBuilder & b,Location loc,LinalgOp linalgOp,ArrayRef<Value> valuesToTile,ValueRange ivs,ValueRange tileSizes,ArrayRef<Value> sizeBounds,bool omitPartialTileCheck)1004 SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
1005                                       LinalgOp linalgOp,
1006                                       ArrayRef<Value> valuesToTile,
1007                                       ValueRange ivs, ValueRange tileSizes,
1008                                       ArrayRef<Value> sizeBounds,
1009                                       bool omitPartialTileCheck) {
1010   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
1011                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
1012                            [](Value v) { return !isZero(v); })) &&
1013          "expected as many ivs as non-zero sizes");
1014 
1015   // Construct (potentially temporary) mins and maxes on which to apply maps
1016   // that define tile subshapes.
1017   SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
1018   SmallVector<Value> subShapeSizes =
1019       computeTileSizes(b, loc, tileSizes, sizeBounds);
1020 
1021   assert(static_cast<int64_t>(valuesToTile.size()) ==
1022              linalgOp.getNumInputsAndOutputs() &&
1023          "expected one value to tile for every operand");
1024   SmallVector<Value, 4> tiledShapes;
1025   tiledShapes.reserve(valuesToTile.size());
1026   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
1027     Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
1028     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
1029     AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
1030     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
1031     // an extract/insert slice pair for all output tensors simplifies follow up
1032     // transformations such as padding and bufferization since the
1033     // extract/insert slice pairs make the accessed iteration argument
1034     // subdomains explicit.
1035     if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
1036       tiledShapes.push_back(shapedOp);
1037       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
1038                               << opOperand->get().getType() << "\n");
1039       continue;
1040     }
1041     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
1042 
1043     tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
1044                                          sizeBounds, subShapeSizes,
1045                                          omitPartialTileCheck));
1046   }
1047 
1048   return tiledShapes;
1049 }
1050 
offsetIndices(OpBuilder & b,LinalgOp linalgOp,ArrayRef<Value> offsets)1051 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offsets) {
1052   IRRewriter rewriter(b);
1053   offsetIndices(rewriter, linalgOp, offsets);
1054 }
1055 
offsetIndices(RewriterBase & b,LinalgOp linalgOp,ArrayRef<Value> offsets)1056 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
1057                    ArrayRef<Value> offsets) {
1058   if (!linalgOp.hasIndexSemantics())
1059     return;
1060 
1061   for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
1062     if (indexOp.dim() >= offsets.size() || offsets[indexOp.dim()] == nullptr)
1063       continue;
1064     OpBuilder::InsertionGuard guard(b);
1065     b.setInsertionPointAfter(indexOp);
1066     AffineExpr index, offset;
1067     bindDims(b.getContext(), index, offset);
1068     AffineApplyOp applyOp = makeComposedAffineApply(
1069         b, indexOp.getLoc(), index + offset,
1070         ValueRange{indexOp.getResult(), offsets[indexOp.dim()]});
1071     b.replaceOpWithIf(indexOp, applyOp.getResult(), [&](OpOperand &use) {
1072       return use.getOwner() != applyOp;
1073     });
1074   }
1075 }
1076 
1077 /// Get the reassociation maps to fold the result of a extract_slice (or source
1078 /// of a insert_slice) operation with given offsets, and sizes to its
1079 /// rank-reduced version. This is only done for the cases where the size is 1
1080 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
1081 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
1082 /// potentially cannot be handled).
1083 Optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes)1084 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
1085   SmallVector<ReassociationIndices> reassociation;
1086   ReassociationIndices curr;
1087   for (const auto &it : llvm::enumerate(mixedSizes)) {
1088     auto dim = it.index();
1089     auto size = it.value();
1090     curr.push_back(dim);
1091     auto attr = size.dyn_cast<Attribute>();
1092     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
1093       continue;
1094     reassociation.emplace_back(ReassociationIndices{});
1095     std::swap(reassociation.back(), curr);
1096   }
1097   // When the reassociations are not empty, then fold the remaining
1098   // unit-dimensions into the last dimension.  If the reassociations so far is
1099   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
1100   if (!curr.empty() && !reassociation.empty())
1101     reassociation.back().append(curr.begin(), curr.end());
1102   return reassociation;
1103 }
1104 
1105 } // namespace linalg
1106 } // namespace mlir
1107