1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Passes.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 #include <utility>
21
22 using namespace mlir;
23 using namespace mlir::linalg;
24
sourceMaterializationCallback(OpBuilder & builder,Type type,ValueRange inputs,Location loc)25 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
26 ValueRange inputs, Location loc) {
27 assert(inputs.size() == 1);
28 auto inputType = inputs[0].getType();
29 if (inputType.isa<TensorType>())
30 return nullptr;
31
32 // A detensored value is converted back by creating a new tensor from its
33 // element(s).
34 return builder.create<tensor::FromElementsOp>(
35 loc, RankedTensorType::get({}, inputType), inputs[0]);
36 }
37
38 namespace {
39 /// Defines the criteria a TensorType must follow in order to be considered
40 /// "detensorable".
41 ///
42 /// NOTE: For now, only 0-D tensors are supported.
43 ///
44 /// Returns true if tensorType can be detensored.
canBeDetensored(TensorType tensorType)45 bool canBeDetensored(TensorType tensorType) {
46 return tensorType.hasRank() && tensorType.getRank() == 0;
47 }
48
shouldBeDetensored(Operation * op,TypeConverter typeConverter)49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
51 return genericOp &&
52 llvm::all_of(
53 genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
54 return !typeConverter.isLegal(opOperand->get().getType());
55 });
56 }
57
58 /// A conversion patttern for detensoring `linalg.generic` ops.
59 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
60 public:
61 using OpConversionPattern::OpConversionPattern;
62 LogicalResult
matchAndRewrite(GenericOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const63 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
64 ConversionPatternRewriter &rewriter) const override {
65 Block *originalBlock = op->getBlock();
66
67 // Gather some information about the op before inling its region.
68 Block *opEntryBlock = &*op.region().begin();
69 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
70
71 // Split the op's region before the op. This way, we have a clear insertion
72 // point in which the op can be inlined.
73 Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
74 rewriter.inlineRegionBefore(op.region(), newBlock);
75 // Now that op's region is inlined, the operands of its YieldOp are mapped
76 // to the materialized target values. Therefore, we can replace the op's
77 // uses with those of its YielOp's operands.
78 rewriter.replaceOp(op, yieldOp->getOperands());
79
80 // No need for these intermediate blocks, merge them into 1.
81 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
82 rewriter.mergeBlocks(newBlock, originalBlock, {});
83
84 rewriter.eraseOp(&*Block::iterator(yieldOp));
85
86 return success();
87 }
88 };
89
90 /// A conversion pattern for detensoring internal (non-entry) blocks within a
91 /// function.
92 struct FunctionNonEntryBlockConversion
93 : public OpInterfaceConversionPattern<FunctionOpInterface> {
FunctionNonEntryBlockConversion__anon4fe90d0e0111::FunctionNonEntryBlockConversion94 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
95 DenseSet<BlockArgument> blockArgsToDetensor)
96 : OpInterfaceConversionPattern(converter, ctx),
97 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
98
99 LogicalResult
matchAndRewrite__anon4fe90d0e0111::FunctionNonEntryBlockConversion100 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
101 ConversionPatternRewriter &rewriter) const override {
102 rewriter.startRootUpdate(op);
103 Region ®ion = op.getBody();
104 SmallVector<TypeConverter::SignatureConversion, 2> conversions;
105
106 for (Block &block : llvm::drop_begin(region, 1)) {
107 conversions.emplace_back(block.getNumArguments());
108 TypeConverter::SignatureConversion &back = conversions.back();
109
110 for (BlockArgument blockArgument : block.getArguments()) {
111 int idx = blockArgument.getArgNumber();
112
113 if (blockArgsToDetensor.count(blockArgument))
114 back.addInputs(idx, {getTypeConverter()->convertType(
115 block.getArgumentTypes()[idx])});
116 else
117 back.addInputs(idx, {block.getArgumentTypes()[idx]});
118 }
119 }
120
121 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
122 conversions))) {
123 rewriter.cancelRootUpdate(op);
124 return failure();
125 }
126
127 rewriter.finalizeRootUpdate(op);
128 return success();
129 }
130
131 private:
132 const DenseSet<BlockArgument> blockArgsToDetensor;
133 };
134
135 class DetensorizeTypeConverter : public TypeConverter {
136 public:
DetensorizeTypeConverter()137 DetensorizeTypeConverter() {
138 addConversion([](Type type) { return type; });
139
140 // A TensorType that can be detensored, is converted to the underlying
141 // element type.
142 addConversion([](TensorType tensorType) -> Type {
143 if (canBeDetensored(tensorType))
144 return tensorType.getElementType();
145
146 return tensorType;
147 });
148
149 // A tensor value is detensoried by extracting its element(s).
150 addTargetMaterialization([](OpBuilder &builder, Type type,
151 ValueRange inputs, Location loc) -> Value {
152 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
153 });
154
155 addSourceMaterialization(sourceMaterializationCallback);
156 addArgumentMaterialization(sourceMaterializationCallback);
157 }
158 };
159
160 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
161 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
162 LinalgDetensorize() = default;
163
164 class CostModel {
165 public:
166 virtual ~CostModel() = default;
167
168 /// A cost model algorithm computes the following outputs:
169 ///
170 /// - opsToDetensor: the list of linalg ops that should be
171 /// detensored.
172 ///
173 /// - blockArgsToDetensor: since the operands and results of detensored
174 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
175 /// from a BB argument and a linalg op's output can be passed to successor
176 /// BBs), we need to maintain the sub-set of arguments that should be
177 /// detensored (i.e. converted by typeConverter) for each affected BB.
178 ///
179 /// Example:
180 ///
181 /// For the following snippet:
182 /// ...
183 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
184 /// %7 = linalg.init_tensor [] : tensor<i32>
185 /// %8 = linalg.generic #attrs
186 /// ins(%6, %6 : tensor<i32>, tensor<i32>)
187 /// outs(%7 : tensor<i32>) {
188 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
189 /// %9 = arith.addi %arg0, %arg1 : i32
190 /// linalg.yield %9 : i32
191 /// } -> tensor<i32>
192 /// %10 = "some.op"(%9)
193 /// br ^bb2(%8 : tensor<i32>)
194 /// ...
195 ///
196 /// if the cost model decides that the linalg.generic op should be
197 /// detensored, then:
198 /// - opsToDetensor should be = {linalg.generic{add}}.
199 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
200 virtual void compute(FunctionOpInterface func,
201 DetensorizeTypeConverter typeConverter,
202 DenseSet<Operation *> &opsToDetensor,
203 DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
204
205 /// From the blockArgsToDetensor set computed by a CostModel
206 /// implementation, this method computes the corresponding branch op
207 /// detensoring. The result is a map from a branch op to a subset of indices
208 /// of its operands. The indices specify which of the branch op's operands
209 /// should be detensored.
210 ///
211 /// For the previous example, this method would compute: {bb2 -> {0}}.
computeBranchOpDetensoring(const DenseSet<BlockArgument> & blockArgsToDetensor)212 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
213 const DenseSet<BlockArgument> &blockArgsToDetensor) {
214 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
215
216 for (auto blockArgumentElem : blockArgsToDetensor) {
217 Block *block = blockArgumentElem.getOwner();
218
219 for (PredecessorIterator pred = block->pred_begin();
220 pred != block->pred_end(); ++pred) {
221 BranchOpInterface terminator =
222 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
223 auto blockOperands =
224 terminator.getSuccessorOperands(pred.getSuccessorIndex());
225
226 if (blockOperands.empty() ||
227 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
228 continue;
229
230 detensorableBranchOps[terminator].insert(
231 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
232 }
233 }
234
235 return detensorableBranchOps;
236 }
237 };
238
239 /// Detensorize linalg ops involved in control-flow within a function.
240 ///
241 /// This model starts from BranchOps and CondBranchOps within a function. For
242 /// each such branch, the model then walks the use-def chain for the branch's
243 /// condition backwards in order to understand where the condition's value
244 /// comes from. If the condition value is (indirectly) computed by a linalg op
245 /// that can be detensored, the model then continues walking the use-def chain
246 /// in order to understand where the linalg op's operands come from. This
247 /// leads to discovering a "detensoring component". A detensoring component is
248 /// the set of operations + block arguments that are involved in control-flow
249 /// AND can be detensored.
250 class ControlFlowDetectionModel : public CostModel {
251 public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)252 void compute(FunctionOpInterface func,
253 DetensorizeTypeConverter typeConverter,
254 DenseSet<Operation *> &opsToDetensor,
255 DenseSet<BlockArgument> &blockArgsToDetensor) override {
256 SmallVector<Value> workList;
257
258 func->walk([&](cf::CondBranchOp condBr) {
259 llvm::append_range(workList, condBr.getOperands());
260 });
261
262 func->walk([&](cf::BranchOp br) {
263 llvm::append_range(workList, br.getOperands());
264 });
265
266 DenseSet<Value> visitedValues;
267 DenseSet<Operation *> visitedOps;
268
269 // For a (to-be-detesored) value, check if it "escapes" the block by being
270 // passed to terminator. If it does, then workList is updated with the
271 // corresponding argument to the successor block.
272 auto updateWorkListWithSuccessorArguments =
273 [&](Value value, BranchOpInterface terminator) {
274 if (!terminator)
275 return;
276
277 for (auto operandIdx :
278 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
279 Value operand = terminator->getOperand(operandIdx);
280
281 if (operand == value) {
282 auto succBlockArg =
283 terminator.getSuccessorBlockArgument(operandIdx);
284
285 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
286 workList.push_back(*succBlockArg);
287 }
288 }
289 };
290
291 while (!workList.empty()) {
292 Value currentItem = workList.pop_back_val();
293
294 if (!visitedValues.insert(currentItem).second)
295 continue;
296
297 // 1 - Look forward:
298 // 1.1 - If currentItem escapes to one or more successors, add
299 // the corresponding successor arguments to workList.
300 updateWorkListWithSuccessorArguments(
301 currentItem, dyn_cast<BranchOpInterface>(
302 currentItem.getParentBlock()->getTerminator()));
303
304 // 1.2 - For each user of currentItem, add the defined values to
305 // workList. This way, the user ops can be inspected later if they are
306 // detensorable and if so, their operands will be added to workList to
307 // potentially discover other parts of the detensorable component.
308 for (auto *user : currentItem.getUsers())
309 llvm::append_range(workList, user->getResults());
310
311 // 2 - Look backward:
312 // 2.1 - The current item is defined by a block argument. If the owner
313 // block is a non-entry one, then:
314 // * Add the argument to blockArgsToDetensor.
315 // * Walk the use-def chain backwards to add each predecessor's
316 // terminator-operands corresponding to currentItem to workList.
317 if (currentItem.dyn_cast<BlockArgument>()) {
318 BlockArgument currentItemBlockArgument =
319 currentItem.cast<BlockArgument>();
320 Block *ownerBlock = currentItemBlockArgument.getOwner();
321
322 // Function arguments are not detensored/converted.
323 if (&*ownerBlock->getParent()->begin() == ownerBlock)
324 continue;
325
326 // This inner-block argument is involved in control-flow, it should be
327 // detensored.
328 blockArgsToDetensor.insert(currentItemBlockArgument);
329
330 for (PredecessorIterator pred = ownerBlock->pred_begin();
331 pred != ownerBlock->pred_end(); ++pred) {
332 BranchOpInterface predTerminator =
333 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
334
335 // TODO: For now, we give up if any of the control-flow components
336 // in a function is not detensorable. Fix that.
337 if (!predTerminator) {
338 opsToDetensor.clear();
339 blockArgsToDetensor.clear();
340 return;
341 }
342
343 auto ownerBlockOperands =
344 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345
346 if (ownerBlockOperands.empty() ||
347 ownerBlockOperands.isOperandProduced(
348 currentItemBlockArgument.getArgNumber()))
349 continue;
350
351 // For each predecessor, add the value it passes to that argument to
352 // workList to find out how it's computed.
353 workList.push_back(
354 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
355 }
356
357 continue;
358 }
359
360 Operation *currentItemDefiningOp = currentItem.getDefiningOp();
361
362 if (!visitedOps.insert(currentItemDefiningOp).second)
363 continue;
364
365 // 2.2 - The current item is computed by a GenericOp. If the op should
366 // be detensored, then:
367 // * Add it to opsToDetensor.
368 // * Add its operands to workList to discover other parts of the
369 // potentially detensorable component.
370 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371 // The op was encountered already, no need to inspect it again.
372 if (opsToDetensor.count(genericOp))
373 continue;
374
375 // The op should not be detensored, give up on it but continue with
376 // discovering the rest of the control-flow component.
377 if (!shouldBeDetensored(genericOp, typeConverter)) {
378 continue;
379 }
380
381 opsToDetensor.insert(genericOp);
382 llvm::append_range(workList, genericOp.inputs());
383 continue;
384 }
385
386 // 2.3 - The current item is the result of a FromElementsOp, it will be
387 // trivially detensored later as part of canonicalization patterns
388 // applied at the end of detensoring.
389 //
390 // Note: No need to check whether the result type of this op is
391 // detensorable since if it wasn't we wouldn't reach that point in the
392 // work list.
393 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
394 continue;
395
396 // 2.4 - The current item is the result of a scalar op, add all its
397 // operands to the work list.
398 if (llvm::all_of(
399 currentItemDefiningOp->getResultTypes(),
400 [&](Type resultType) { return resultType.isIntOrFloat(); }))
401 llvm::append_range(workList, currentItemDefiningOp->getOperands());
402 }
403
404 // Since the cost model gives up on some ops (see the details of step 2.2
405 // above), block arguments that correspond to the values produced by those
406 // ops should not be detensored as well.
407
408 DenseSet<BlockArgument> blockArgsToRemove;
409
410 for (auto &blockArg : blockArgsToDetensor) {
411 Block *block = blockArg.getParentBlock();
412
413 // For the potentially detensorable block argument, find the
414 // correpsonding operands in predecessor blocks.
415 for (PredecessorIterator pred = block->pred_begin();
416 pred != block->pred_end(); ++pred) {
417 BranchOpInterface terminator =
418 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419 auto blockOperands =
420 terminator.getSuccessorOperands(pred.getSuccessorIndex());
421
422 if (blockOperands.empty() ||
423 blockOperands.isOperandProduced(blockArg.getArgNumber()))
424 continue;
425
426 Operation *definingOp =
427 blockOperands[blockArg.getArgNumber()].getDefiningOp();
428
429 // If the operand is defined by a GenericOp that will not be
430 // detensored, then do not detensor the corresponding block argument.
431 if (isa_and_nonnull<GenericOp>(definingOp) &&
432 opsToDetensor.count(definingOp) == 0) {
433 blockArgsToRemove.insert(blockArg);
434 break;
435 }
436 }
437 }
438
439 for (auto &blockArg : blockArgsToRemove) {
440 blockArgsToDetensor.erase(blockArg);
441 }
442 }
443 };
444
445 /// Detensorize everything that can detensored.
446 class AggressiveDetensoringModel : public CostModel {
447 public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)448 void compute(FunctionOpInterface func,
449 DetensorizeTypeConverter typeConverter,
450 DenseSet<Operation *> &opsToDetensor,
451 DenseSet<BlockArgument> &blockArgsToDetensor) override {
452 func->walk([&](GenericOp genericOp) {
453 if (shouldBeDetensored(genericOp, typeConverter))
454 opsToDetensor.insert(genericOp);
455 });
456
457 for (Block &block : llvm::drop_begin(func.getBody(), 1))
458 for (BlockArgument blockArgument : block.getArguments())
459 blockArgsToDetensor.insert(blockArgument);
460 }
461 };
462
runOnOperation__anon4fe90d0e0111::LinalgDetensorize463 void runOnOperation() override {
464 MLIRContext *context = &getContext();
465 DetensorizeTypeConverter typeConverter;
466 RewritePatternSet patterns(context);
467 ConversionTarget target(*context);
468 DenseSet<Operation *> opsToDetensor;
469 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
470 DenseSet<BlockArgument> blockArgsToDetensor;
471 FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
472
473 if (aggressiveMode.getValue()) {
474 AggressiveDetensoringModel costModel;
475 costModel.compute(funcOp, typeConverter, opsToDetensor,
476 blockArgsToDetensor);
477 } else {
478 ControlFlowDetectionModel costModel;
479 costModel.compute(funcOp, typeConverter, opsToDetensor,
480 blockArgsToDetensor);
481 }
482
483 detensorableBranchOps =
484 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
485
486 target.addDynamicallyLegalOp<GenericOp>(
487 [&](GenericOp op) { return !opsToDetensor.count(op); });
488
489 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
490 // A function is legal if all of its non-entry blocks are legal. We
491 // don't legalize the entry block (i.e. the function's signature)
492 // since detensoring can't happen along external calling convention
493 // boundaries, which we conservatively approximate as all function
494 // signatures.
495 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
496 Region &body = funcOp.getBody();
497 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
498 return !llvm::any_of(
499 blockArgsToDetensor, [&](BlockArgument blockArgument) {
500 return blockArgument.getOwner() == &block &&
501 !typeConverter.isLegal(blockArgument.getType());
502 });
503 });
504 }
505
506 if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
507 isLegalForReturnOpTypeConversionPattern(op, typeConverter,
508 /*returnOpAlwaysLegal*/ true))
509 return true;
510
511 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
512 if (!detensorableBranchOps.count(branchOp))
513 return true;
514
515 for (auto operandIdx : detensorableBranchOps[branchOp])
516 if (!typeConverter.isLegal(
517 branchOp->getOperand(operandIdx).getType()))
518 return false;
519
520 return true;
521 }
522
523 return false;
524 });
525
526 patterns.add<DetensorizeGenericOp>(typeConverter, context);
527 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
528 blockArgsToDetensor);
529 // Since non-entry block arguments get detensorized, we also need to
530 // update the control flow inside the function to reflect the correct
531 // types.
532 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
533 int operandIdx) -> bool {
534 return detensorableBranchOps.count(branchOp) &&
535 detensorableBranchOps[branchOp].count(operandIdx);
536 };
537
538 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
539 shouldConvertBranchOperand);
540
541 if (failed(
542 applyFullConversion(getOperation(), target, std::move(patterns))))
543 signalPassFailure();
544
545 RewritePatternSet canonPatterns(context);
546 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
547 if (failed(applyPatternsAndFoldGreedily(getOperation(),
548 std::move(canonPatterns))))
549 signalPassFailure();
550 }
551 };
552 } // namespace
553
createLinalgDetensorizePass()554 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
555 return std::make_unique<LinalgDetensorize>();
556 }
557