1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Parser/Parser.h" 18 #include "llvm/Support/FormatVariadic.h" 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 using namespace mlir::transform; 23 24 /// Extracts a vector of int64_t from an array attribute. Asserts if the 25 /// attribute contains values other than integers. 26 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 27 SmallVector<int64_t> result; 28 result.reserve(attr.size()); 29 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 30 result.push_back(value.getSExtValue()); 31 return result; 32 } 33 34 /// Extracts a vector of unsigned from an array attribute. Asserts if the 35 /// attribute contains values other than intergers. May truncate. 36 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 37 SmallVector<unsigned> result; 38 result.reserve(attr.size()); 39 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 40 result.push_back(value.getZExtValue()); 41 return result; 42 } 43 44 namespace { 45 /// A simple pattern rewriter that implements no special logic. 46 class SimpleRewriter : public PatternRewriter { 47 public: 48 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 49 }; 50 } // namespace 51 52 //===----------------------------------------------------------------------===// 53 // TileOp 54 //===----------------------------------------------------------------------===// 55 56 /// Apply a tiling transformation to all payload ops and store both the 57 /// tiled operation as well as the created tile loops. 58 static LogicalResult 59 applyTilingToAll(Operation *transformOp, Value target, 60 ArrayRef<int64_t> tileSizes, 61 transform::TransformResults &transformResults, 62 transform::TransformState &state, 63 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 64 // Number of loops: Number of tiles sizes that are not zero. 65 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 66 // All payload ops. These should all be LinalgOps for now. 67 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 68 69 SmallVector<Operation *> tiledLinalgOps; 70 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 71 for (unsigned int i = 0; i < numLoops; ++i) 72 loopOps[i].reserve(payloadOps.size()); 73 74 for (Operation *target : payloadOps) { 75 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 76 if (!linalgOp) 77 return transformOp->emitError("only LinalgOps are supported"); 78 79 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 80 if (failed(tiled)) 81 return failure(); 82 83 tiledLinalgOps.push_back(tiled->op); 84 if (tiled->loops.size() != numLoops) 85 // Not enough loops were generated. This usually means that the input size 86 // was smaller than the tiling size. 87 // TODO: LinalgTilingPattern should return failure(). 88 return failure(); 89 for (unsigned int i = 0; i < numLoops; ++i) 90 loopOps[i].push_back(tiled->loops[i]); 91 } 92 93 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 94 for (unsigned int i = 0; i < numLoops; ++i) 95 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 96 return success(); 97 } 98 99 LogicalResult transform::TileOp::apply(TransformResults &transformResults, 100 TransformState &state) { 101 LinalgTilingOptions tilingOptions; 102 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 103 104 if (!tileSizes.empty()) 105 tilingOptions.setTileSizes(tileSizes); 106 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 107 LinalgTilingPattern pattern(getContext(), tilingOptions); 108 109 return applyTilingToAll(getOperation(), getTarget(), tileSizes, 110 transformResults, state, [&](LinalgOp linalgOp) { 111 SimpleRewriter rewriter(linalgOp.getContext()); 112 return pattern.returningMatchAndRewrite(linalgOp, 113 rewriter); 114 }); 115 } 116 117 ParseResult transform::TileOp::parse(OpAsmParser &parser, 118 OperationState &result) { 119 StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); 120 OpAsmParser::UnresolvedOperand targetOperand; 121 SMLoc opLoc = parser.getCurrentLocation(); 122 if (parser.parseOperand(targetOperand) || 123 parser.parseOptionalAttrDict(result.attributes)) 124 return failure(); 125 Attribute sizesAttr = result.attributes.get(sizesAttrName); 126 if (!sizesAttr) 127 return parser.emitError(opLoc) 128 << "expected '" << sizesAttrName << "' attribute"; 129 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 130 if (!sizesArrayAttr) 131 return parser.emitError(opLoc) 132 << "'" << sizesAttrName << "' attribute must be an array"; 133 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 134 size_t numExpectedLoops = 135 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 136 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 137 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 138 return failure(); 139 return success(); 140 } 141 142 void TileOp::print(OpAsmPrinter &p) { 143 p << ' '; 144 p << getTarget(); 145 p.printOptionalAttrDict((*this)->getAttrs()); 146 } 147 148 void TileOp::getEffects( 149 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 150 &effects) { 151 // `target` arg is consumed and can no longer be used. 152 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 153 TransformMappingResource::get()); 154 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 155 TransformMappingResource::get()); 156 157 for (Value r : getResults()) { 158 effects.emplace_back(MemoryEffects::Write::get(), r, 159 TransformMappingResource::get()); 160 effects.emplace_back(MemoryEffects::Allocate::get(), r, 161 TransformMappingResource::get()); 162 } 163 164 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 165 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 166 } 167 168 //===----------------------------------------------------------------------===// 169 // Transform op registration 170 //===----------------------------------------------------------------------===// 171 172 namespace { 173 /// Registers new ops and declares PDL as dependent dialect since the additional 174 /// ops are using PDL types for operands and results. 175 class LinalgTransformDialectExtension 176 : public transform::TransformDialectExtension< 177 LinalgTransformDialectExtension> { 178 public: 179 LinalgTransformDialectExtension() { 180 declareDependentDialect<pdl::PDLDialect>(); 181 declareDependentDialect<scf::SCFDialect>(); 182 registerTransformOps< 183 #define GET_OP_LIST 184 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 185 >(); 186 } 187 }; 188 } // namespace 189 190 #define GET_OP_CLASSES 191 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 192 193 void mlir::linalg::registerTransformDialectExtension( 194 DialectRegistry ®istry) { 195 registry.addExtensions<LinalgTransformDialectExtension>(); 196 } 197