1 //===- Interchange.cpp - Linalg interchange transformation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg interchange transformation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/ADT/ScopeExit.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-interchange" 32 33 using namespace mlir; 34 using namespace mlir::linalg; 35 36 static LogicalResult 37 interchangeGenericOpPrecondition(GenericOp genericOp, 38 ArrayRef<unsigned> interchangeVector) { 39 // Interchange vector must be non-empty and match the number of loops. 40 if (interchangeVector.empty() || 41 genericOp.getNumLoops() != interchangeVector.size()) 42 return failure(); 43 // Permutation map must be invertible. 44 if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector, 45 genericOp.getContext()))) 46 return failure(); 47 return success(); 48 } 49 50 FailureOr<GenericOp> 51 mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, 52 ArrayRef<unsigned> interchangeVector) { 53 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 54 return rewriter.notifyMatchFailure(genericOp, "preconditions not met"); 55 56 // 1. Compute the inverse permutation map, it must be non-null since the 57 // preconditions are satisfied. 58 MLIRContext *context = genericOp.getContext(); 59 AffineMap permutationMap = inversePermutation( 60 AffineMap::getPermutationMap(interchangeVector, context)); 61 assert(permutationMap && "unexpected null map"); 62 63 // Start a guarded inplace update. 64 rewriter.startRootUpdate(genericOp); 65 auto guard = 66 llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); }); 67 68 // 2. Compute the interchanged indexing maps. 69 SmallVector<AffineMap> newIndexingMaps; 70 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 71 AffineMap m = genericOp.getTiedIndexingMap(opOperand); 72 if (!permutationMap.isEmpty()) 73 m = m.compose(permutationMap); 74 newIndexingMaps.push_back(m); 75 } 76 genericOp->setAttr(getIndexingMapsAttrName(), 77 rewriter.getAffineMapArrayAttr(newIndexingMaps)); 78 79 // 3. Compute the interchanged iterator types. 80 ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue(); 81 SmallVector<Attribute> itTypesVector; 82 llvm::append_range(itTypesVector, itTypes); 83 SmallVector<int64_t> permutation(interchangeVector.begin(), 84 interchangeVector.end()); 85 applyPermutationToVector(itTypesVector, permutation); 86 genericOp->setAttr(getIteratorTypesAttrName(), 87 ArrayAttr::get(context, itTypesVector)); 88 89 // 4. Transform the index operations by applying the permutation map. 90 if (genericOp.hasIndexSemantics()) { 91 OpBuilder::InsertionGuard guard(rewriter); 92 for (IndexOp indexOp : 93 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) { 94 rewriter.setInsertionPoint(indexOp); 95 SmallVector<Value> allIndices; 96 allIndices.reserve(genericOp.getNumLoops()); 97 llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()), 98 std::back_inserter(allIndices), [&](uint64_t dim) { 99 return rewriter.create<IndexOp>(indexOp->getLoc(), dim); 100 }); 101 rewriter.replaceOpWithNewOp<AffineApplyOp>( 102 indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); 103 } 104 } 105 106 return genericOp; 107 } 108