1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Parser/Parser.h" 18 #include "llvm/Support/FormatVariadic.h" 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 using namespace mlir::transform; 23 24 /// Extracts a vector of int64_t from an array attribute. Asserts if the 25 /// attribute contains values other than integers. 26 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 27 SmallVector<int64_t> result; 28 result.reserve(attr.size()); 29 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 30 result.push_back(value.getSExtValue()); 31 return result; 32 } 33 34 /// Extracts a vector of unsigned from an array attribute. Asserts if the 35 /// attribute contains values other than intergers. May truncate. 36 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 37 SmallVector<unsigned> result; 38 result.reserve(attr.size()); 39 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 40 result.push_back(value.getZExtValue()); 41 return result; 42 } 43 44 namespace { 45 /// A simple pattern rewriter that implements no special logic. 46 class SimpleRewriter : public PatternRewriter { 47 public: 48 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 49 }; 50 } // namespace 51 52 //===----------------------------------------------------------------------===// 53 // TileOp 54 //===----------------------------------------------------------------------===// 55 56 /// Apply a tiling transformation to all payload ops and store both the 57 /// tiled operation as well as the created tile loops. 58 static LogicalResult 59 applyTilingToAll(Operation *transformOp, Value target, 60 ArrayRef<int64_t> tileSizes, 61 transform::TransformResults &transformResults, 62 transform::TransformState &state, 63 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 64 // Number of loops: Number of tiles sizes that are not zero. 65 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 66 // All payload ops. These should all be LinalgOps for now. 67 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 68 69 SmallVector<Operation *> tiledLinalgOps; 70 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 71 for (unsigned int i = 0; i < numLoops; ++i) 72 loopOps[i].reserve(payloadOps.size()); 73 74 for (Operation *target : payloadOps) { 75 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 76 if (!linalgOp) 77 return transformOp->emitError("only LinalgOps are supported"); 78 79 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 80 if (failed(tiled)) 81 return failure(); 82 83 tiledLinalgOps.push_back(tiled->op); 84 if (tiled->loops.size() != numLoops) 85 // Not enough loops were generated. This usually means that the input size 86 // was smaller than the tiling size. 87 // TODO: LinalgTilingPattern should return failure(). 88 return failure(); 89 for (unsigned int i = 0; i < numLoops; ++i) 90 loopOps[i].push_back(tiled->loops[i]); 91 } 92 93 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 94 for (unsigned int i = 0; i < numLoops; ++i) 95 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 96 return success(); 97 } 98 99 LogicalResult transform::TileOp::apply(TransformResults &transformResults, 100 TransformState &state) { 101 LinalgTilingOptions tilingOptions; 102 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 103 104 if (!tileSizes.empty()) 105 tilingOptions.setTileSizes(tileSizes); 106 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 107 LinalgTilingPattern pattern(getContext(), tilingOptions); 108 109 return applyTilingToAll(getOperation(), getTarget(), tileSizes, 110 transformResults, state, [&](LinalgOp linalgOp) { 111 SimpleRewriter rewriter(linalgOp.getContext()); 112 return pattern.returningMatchAndRewrite(linalgOp, 113 rewriter); 114 }); 115 } 116 117 ParseResult transform::TileOp::parse(OpAsmParser &parser, 118 OperationState &result) { 119 StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); 120 OpAsmParser::UnresolvedOperand targetOperand; 121 SMLoc opLoc; 122 parser.getCurrentLocation(&opLoc); 123 if (parser.parseOperand(targetOperand)) 124 return parser.emitError(opLoc, "expected 'target' operand"); 125 if (parser.parseOptionalAttrDict(result.attributes)) 126 return failure(); 127 Attribute sizesAttr = result.attributes.get(sizesAttrName); 128 if (!sizesAttr) 129 return parser.emitError(opLoc) 130 << "expected '" << sizesAttrName << "' attribute"; 131 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 132 if (!sizesArrayAttr) 133 return parser.emitError(opLoc) 134 << "'" << sizesAttrName << "' attribute must be an array"; 135 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 136 size_t numExpectedLoops = 137 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 138 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 139 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 140 return failure(); 141 return success(); 142 } 143 144 void TileOp::print(OpAsmPrinter &p) { 145 p << ' '; 146 p << getTarget(); 147 p.printOptionalAttrDict((*this)->getAttrs()); 148 } 149 150 void TileOp::getEffects( 151 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 152 &effects) { 153 // `target` arg is consumed and can no longer be used. 154 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 155 TransformMappingResource::get()); 156 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 157 TransformMappingResource::get()); 158 159 for (Value r : getResults()) { 160 effects.emplace_back(MemoryEffects::Write::get(), r, 161 TransformMappingResource::get()); 162 effects.emplace_back(MemoryEffects::Allocate::get(), r, 163 TransformMappingResource::get()); 164 } 165 166 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 167 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 168 } 169 170 //===----------------------------------------------------------------------===// 171 // Transform op registration 172 //===----------------------------------------------------------------------===// 173 174 namespace { 175 /// Registers new ops and declares PDL as dependent dialect since the additional 176 /// ops are using PDL types for operands and results. 177 class LinalgTransformDialectExtension 178 : public transform::TransformDialectExtension< 179 LinalgTransformDialectExtension> { 180 public: 181 LinalgTransformDialectExtension() { 182 declareDependentDialect<pdl::PDLDialect>(); 183 declareDependentDialect<scf::SCFDialect>(); 184 registerTransformOps< 185 #define GET_OP_LIST 186 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 187 >(); 188 } 189 }; 190 } // namespace 191 192 #define GET_OP_CLASSES 193 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 194 195 void mlir::linalg::registerTransformDialectExtension( 196 DialectRegistry ®istry) { 197 registry.addExtensions<LinalgTransformDialectExtension>(); 198 } 199