1 //===- Interchange.cpp - Linalg interchange transformation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg interchange transformation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Support/LLVM.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/raw_ostream.h" 26 #include <type_traits> 27 28 #define DEBUG_TYPE "linalg-interchange" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 LogicalResult mlir::linalg::interchangeGenericOpPrecondition( 34 GenericOp genericOp, ArrayRef<unsigned> interchangeVector) { 35 // Interchange vector must be non-empty and match the number of loops. 36 if (interchangeVector.empty() || 37 genericOp.getNumLoops() != interchangeVector.size()) 38 return failure(); 39 // Permutation map must be invertible. 40 if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector, 41 genericOp.getContext()))) 42 return failure(); 43 return success(); 44 } 45 46 void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter, 47 GenericOp genericOp, 48 ArrayRef<unsigned> interchangeVector) { 49 // 1. Compute the inverse permutation map. 50 MLIRContext *context = genericOp.getContext(); 51 AffineMap permutationMap = inversePermutation( 52 AffineMap::getPermutationMap(interchangeVector, context)); 53 assert(permutationMap && "expected permutation to be invertible"); 54 assert(interchangeVector.size() == genericOp.getNumLoops() && 55 "expected interchange vector to have entry for every loop"); 56 57 // 2. Compute the interchanged indexing maps. 58 SmallVector<Attribute, 4> newIndexingMaps; 59 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 60 AffineMap m = genericOp.getTiedIndexingMap(opOperand); 61 if (!permutationMap.isEmpty()) 62 m = m.compose(permutationMap); 63 newIndexingMaps.push_back(AffineMapAttr::get(m)); 64 } 65 genericOp->setAttr(getIndexingMapsAttrName(), 66 ArrayAttr::get(context, newIndexingMaps)); 67 68 // 3. Compute the interchanged iterator types. 69 ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue(); 70 SmallVector<Attribute, 4> itTypesVector; 71 llvm::append_range(itTypesVector, itTypes); 72 SmallVector<int64_t> permutation(interchangeVector.begin(), 73 interchangeVector.end()); 74 applyPermutationToVector(itTypesVector, permutation); 75 genericOp->setAttr(getIteratorTypesAttrName(), 76 ArrayAttr::get(context, itTypesVector)); 77 78 // 4. Transform the index operations by applying the permutation map. 79 if (genericOp.hasIndexSemantics()) { 80 // TODO: Remove the assertion and add a getBody() method to LinalgOp 81 // interface once every LinalgOp has a body. 82 assert(genericOp->getNumRegions() == 1 && 83 genericOp->getRegion(0).getBlocks().size() == 1 && 84 "expected generic operation to have one block."); 85 Block &block = genericOp->getRegion(0).front(); 86 OpBuilder::InsertionGuard guard(rewriter); 87 for (IndexOp indexOp : 88 llvm::make_early_inc_range(block.getOps<IndexOp>())) { 89 rewriter.setInsertionPoint(indexOp); 90 SmallVector<Value> allIndices; 91 allIndices.reserve(genericOp.getNumLoops()); 92 llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()), 93 std::back_inserter(allIndices), [&](uint64_t dim) { 94 return rewriter.create<IndexOp>(indexOp->getLoc(), dim); 95 }); 96 rewriter.replaceOpWithNewOp<AffineApplyOp>( 97 indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); 98 } 99 } 100 } 101