1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Parser/Parser.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/Support/FormatVariadic.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 using namespace mlir::transform; 24 25 /// Extracts a vector of int64_t from an array attribute. Asserts if the 26 /// attribute contains values other than integers. 27 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 28 SmallVector<int64_t> result; 29 result.reserve(attr.size()); 30 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 31 result.push_back(value.getSExtValue()); 32 return result; 33 } 34 35 /// Extracts a vector of unsigned from an array attribute. Asserts if the 36 /// attribute contains values other than intergers. May truncate. 37 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 38 SmallVector<unsigned> result; 39 result.reserve(attr.size()); 40 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 41 result.push_back(value.getZExtValue()); 42 return result; 43 } 44 45 namespace { 46 /// A simple pattern rewriter that implements no special logic. 47 class SimpleRewriter : public PatternRewriter { 48 public: 49 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 50 }; 51 } // namespace 52 53 /// Attempts to apply the pattern specified as template argument to the given 54 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 55 /// function that returns the "main" result or failure. Returns failure if the 56 /// pattern failed to apply. Extra arguments are forwarded to the pattern 57 /// constructor. 58 template <typename PatternTy, typename... Args> 59 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 60 // Check if the given operation has the type expected by the pattern. 61 using OpTy = typename llvm::function_traits< 62 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 63 auto op = dyn_cast<OpTy>(operation); 64 if (!op) 65 return failure(); 66 67 // Apply the pattern directly to the op. 68 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 69 SimpleRewriter rewriter(operation->getContext()); 70 rewriter.setInsertionPoint(operation); 71 auto result = pattern.returningMatchAndRewrite(op, rewriter); 72 if (failed(result)) 73 return failure(); 74 return cast<LinalgOp>(result->getOperation()); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // DecomposeOp 79 //===----------------------------------------------------------------------===// 80 81 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) { 82 FailureOr<LinalgOp> windowed = 83 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 84 if (succeeded(windowed)) 85 return windowed; 86 87 FailureOr<LinalgOp> depthwise = 88 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 89 if (succeeded(depthwise)) 90 return depthwise; 91 92 return reportUnknownTransformError(target); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // FuseOp 97 //===----------------------------------------------------------------------===// 98 99 /// Apply a tiling transformation to all payload ops and store both the 100 /// tiled operation as well as the created tile loops. 101 static LogicalResult 102 applyTilingToAll(Operation *transformOp, Value target, 103 ArrayRef<int64_t> tileSizes, 104 transform::TransformResults &transformResults, 105 transform::TransformState &state, 106 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 107 // Number of loops: Number of tiles sizes that are not zero. 108 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 109 // All payload ops. These should all be LinalgOps for now. 110 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 111 112 SmallVector<Operation *> tiledLinalgOps; 113 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 114 for (unsigned int i = 0; i < numLoops; ++i) 115 loopOps[i].reserve(payloadOps.size()); 116 117 for (Operation *target : payloadOps) { 118 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 119 if (!linalgOp) 120 return transformOp->emitError("only LinalgOps are supported"); 121 122 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 123 if (failed(tiled)) 124 return failure(); 125 126 tiledLinalgOps.push_back(tiled->op); 127 if (tiled->loops.size() != numLoops) 128 // Not enough loops were generated. This usually means that the input size 129 // was smaller than the tiling size. 130 // TODO: LinalgTilingPattern should return failure(). 131 return failure(); 132 for (unsigned int i = 0; i < numLoops; ++i) 133 loopOps[i].push_back(tiled->loops[i]); 134 } 135 136 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 137 for (unsigned int i = 0; i < numLoops; ++i) 138 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 139 return success(); 140 } 141 142 /// Parse a tiling-like operation that returns the tiled op as well as the 143 /// created tile loops. The function counts the non-zero tile sizes to compute 144 /// the number of results. 145 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 146 StringRef sizesAttrName) { 147 OpAsmParser::UnresolvedOperand targetOperand; 148 SMLoc opLoc = parser.getCurrentLocation(); 149 if (parser.parseOperand(targetOperand) || 150 parser.parseOptionalAttrDict(result.attributes)) 151 return failure(); 152 Attribute sizesAttr = result.attributes.get(sizesAttrName); 153 if (!sizesAttr) 154 return parser.emitError(opLoc) 155 << "expected '" << sizesAttrName << "' attribute"; 156 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 157 if (!sizesArrayAttr) 158 return parser.emitError(opLoc) 159 << "'" << sizesAttrName << "' attribute must be an array"; 160 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 161 size_t numExpectedLoops = 162 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 163 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 164 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 165 return failure(); 166 return success(); 167 } 168 169 LogicalResult 170 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 171 mlir::transform::TransformState &state) { 172 LinalgTilingAndFusionOptions fusionOptions; 173 fusionOptions.tileSizes = extractI64Array(getTileSizes()); 174 fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); 175 176 return applyTilingToAll( 177 getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, 178 state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 179 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 180 SimpleRewriter rewriter(getContext()); 181 rewriter.setInsertionPoint(linalgOp); 182 FailureOr<TileLoopNest> tileLoopNest = 183 pattern.returningMatchAndRewrite(linalgOp, rewriter); 184 if (failed(tileLoopNest)) 185 return failure(); 186 187 TiledLinalgOp tiledLinalgOp; 188 tiledLinalgOp.op = tileLoopNest->getRootOp(); 189 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 190 tileLoopNest->getLoopOps().end()}; 191 return tiledLinalgOp; 192 }); 193 } 194 195 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 196 OperationState &result) { 197 return parseTileLikeOp( 198 parser, result, 199 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 200 } 201 202 void transform::FuseOp::print(OpAsmPrinter &p) { 203 p << ' '; 204 p << getTarget(); 205 p.printOptionalAttrDict((*this)->getAttrs()); 206 } 207 208 LogicalResult transform::FuseOp::verify() { 209 SmallVector<int64_t> permutation = extractI64Array(getTileInterchange()); 210 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 211 if (!std::is_permutation(sequence.begin(), sequence.end(), 212 permutation.begin(), permutation.end())) { 213 return emitOpError() << "expects interchange to be a permutation, found " 214 << getTileInterchange(); 215 } 216 return success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // GeneralizeOp 221 //===----------------------------------------------------------------------===// 222 223 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) { 224 // Exit early if no transformation is needed. 225 if (isa<GenericOp>(target)) 226 return target; 227 228 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 229 if (succeeded(generic)) 230 return generic; 231 232 return reportUnknownTransformError(target); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // InterchangeOp 237 //===----------------------------------------------------------------------===// 238 239 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) { 240 SmallVector<unsigned> interchangeVector = 241 extractUIntArray(getIteratorInterchange()); 242 // Exit early if no transformation is needed. 243 if (interchangeVector.empty()) 244 return target; 245 246 auto genericTarget = dyn_cast<GenericOp>(target.getOperation()); 247 if (!genericTarget) { 248 InFlightDiagnostic diag = emitOpError() 249 << "applies to " << GenericOp::getOperationName() 250 << " ops"; 251 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 252 return diag; 253 } 254 255 return tryApply<GenericOpInterchangePattern>(target, interchangeVector); 256 } 257 258 LogicalResult transform::InterchangeOp::verify() { 259 SmallVector<unsigned> permutation = 260 extractUIntArray(getIteratorInterchange()); 261 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 262 if (!std::is_permutation(sequence.begin(), sequence.end(), 263 permutation.begin(), permutation.end())) { 264 return emitOpError() 265 << "expects iterator_interchange to be a permutation, found " 266 << getIteratorInterchange(); 267 } 268 return success(); 269 } 270 271 //===---------------------------------------------------------------------===// 272 // PadOp 273 //===---------------------------------------------------------------------===// 274 275 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) { 276 // Convert the integer packing flags to booleans. 277 SmallVector<bool> packPaddings; 278 for (int64_t packPadding : extractI64Array(getPackPaddings())) 279 packPaddings.push_back(static_cast<bool>(packPadding)); 280 281 // Convert the padding values to attributes. 282 SmallVector<Attribute> paddingValues; 283 for (auto const &it : 284 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 285 Attribute attr = std::get<0>(it); 286 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 287 // Try to parse string attributes to obtain an attribute of element type. 288 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 289 paddingValues.push_back( 290 parseAttribute(attr.cast<StringAttr>(), elementType)); 291 if (!paddingValues.back()) { 292 InFlightDiagnostic diag = emitOpError() 293 << "expects a padding value that parses to " 294 << elementType << ", got " << std::get<0>(it); 295 diag.attachNote(target.getLoc()) << "when applied to this op"; 296 return diag; 297 } 298 continue; 299 } 300 // Otherwise, add the attribute directly. 301 if (attr.getType() != elementType) { 302 InFlightDiagnostic diag = emitOpError() 303 << "expects a padding value of type " 304 << elementType << ", got " << attr; 305 diag.attachNote(target.getLoc()) << "when applied to this op"; 306 return diag; 307 } 308 paddingValues.push_back(attr); 309 } 310 311 // Extract the transpose vectors. 312 SmallVector<SmallVector<int64_t>> transposePaddings; 313 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 314 transposePaddings.push_back( 315 extractI64Array(transposeVector.cast<ArrayAttr>())); 316 317 LinalgPaddingOptions paddingOptions; 318 paddingOptions.setPaddingValues(paddingValues); 319 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 320 paddingOptions.setPackPaddings(packPaddings); 321 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 322 paddingOptions.setTransposePaddings(transposePaddings); 323 324 FailureOr<LinalgOp> result = 325 tryApply<LinalgPaddingPattern>(target, paddingOptions); 326 if (succeeded(result)) 327 return result; 328 329 InFlightDiagnostic diag = emitError() 330 << "failed to apply pattern to target op"; 331 diag.attachNote(target.getLoc()) << "target op"; 332 return diag; 333 } 334 335 LogicalResult transform::PadOp::verify() { 336 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 337 if (any_of(packPaddings, [](int64_t packPadding) { 338 return packPadding != 0 && packPadding != 1; 339 })) { 340 return emitOpError() 341 << "expects pack_paddings to contain booleans (0/1), found " 342 << getPackPaddings(); 343 } 344 345 SmallVector<int64_t> paddingDimensions = 346 extractI64Array(getPaddingDimensions()); 347 if (any_of(paddingDimensions, 348 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 349 return emitOpError() 350 << "expects padding_dimensions to contain positive integers, found " 351 << getPaddingDimensions(); 352 } 353 354 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 355 if (any_of(hoistPaddings, 356 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 357 return emitOpError() 358 << "expects hoist_paddings to contain positive integers, found " 359 << getHoistPaddings(); 360 } 361 362 ArrayAttr transposes = getTransposePaddings(); 363 for (Attribute attr : transposes) { 364 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 365 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 366 if (!std::is_permutation(sequence.begin(), sequence.end(), 367 transpose.begin(), transpose.end())) { 368 return emitOpError() 369 << "expects transpose_paddings to be a permutation, found " 370 << attr; 371 } 372 } 373 return success(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // ScalarizeOp 378 //===----------------------------------------------------------------------===// 379 380 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) { 381 LinalgTilingOptions tilingOptions; 382 tilingOptions.scalarizeDynamicDims(); 383 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 384 // sizes and asserts that it is not already set. 385 SmallVector<int64_t> emptyTileSizes; 386 LinalgTilingPattern pattern(getContext(), tilingOptions); 387 SimpleRewriter rewriter(getContext()); 388 rewriter.setInsertionPoint(target); 389 FailureOr<TiledLinalgOp> result = 390 pattern.returningMatchAndRewrite(target, rewriter); 391 if (failed(result)) 392 return failure(); 393 394 return result->op; 395 } 396 397 //===----------------------------------------------------------------------===// 398 // TileOp 399 //===----------------------------------------------------------------------===// 400 401 LogicalResult transform::TileOp::apply(TransformResults &transformResults, 402 TransformState &state) { 403 LinalgTilingOptions tilingOptions; 404 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 405 406 if (!tileSizes.empty()) 407 tilingOptions.setTileSizes(tileSizes); 408 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 409 LinalgTilingPattern pattern(getContext(), tilingOptions); 410 411 return applyTilingToAll(getOperation(), getTarget(), tileSizes, 412 transformResults, state, [&](LinalgOp linalgOp) { 413 SimpleRewriter rewriter(linalgOp.getContext()); 414 return pattern.returningMatchAndRewrite(linalgOp, 415 rewriter); 416 }); 417 } 418 419 ParseResult transform::TileOp::parse(OpAsmParser &parser, 420 OperationState &result) { 421 return parseTileLikeOp(parser, result, 422 TileOp::getSizesAttrName(result.name).getValue()); 423 } 424 425 void TileOp::print(OpAsmPrinter &p) { 426 p << ' '; 427 p << getTarget(); 428 p.printOptionalAttrDict((*this)->getAttrs()); 429 } 430 431 //===----------------------------------------------------------------------===// 432 // VectorizeOp 433 //===----------------------------------------------------------------------===// 434 435 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) { 436 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 437 InFlightDiagnostic diag = emitOpError() 438 << "applies only to isolated-from-above targets"; 439 diag.attachNote(target->getLoc()) << "non-isolated target"; 440 return diag; 441 } 442 443 MLIRContext *ctx = getContext(); 444 RewritePatternSet patterns(ctx); 445 patterns.add<LinalgVectorizationPattern>(ctx); 446 447 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 448 vector::populateVectorReductionToContractPatterns(patterns); 449 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 450 linalg::LinalgCopyVTWForwardingPattern>(ctx, 451 /*benefit=*/2); 452 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 453 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 454 if (getVectorizePadding()) 455 linalg::populatePadOpVectorizationPatterns(patterns); 456 457 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 458 return reportUnknownTransformError(target); 459 return target; 460 } 461 462 //===----------------------------------------------------------------------===// 463 // Transform op registration 464 //===----------------------------------------------------------------------===// 465 466 namespace { 467 /// Registers new ops and declares PDL as dependent dialect since the additional 468 /// ops are using PDL types for operands and results. 469 class LinalgTransformDialectExtension 470 : public transform::TransformDialectExtension< 471 LinalgTransformDialectExtension> { 472 public: 473 LinalgTransformDialectExtension() { 474 declareDependentDialect<pdl::PDLDialect>(); 475 declareDependentDialect<scf::SCFDialect>(); 476 declareDependentDialect<vector::VectorDialect>(); 477 registerTransformOps< 478 #define GET_OP_LIST 479 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 480 >(); 481 } 482 }; 483 } // namespace 484 485 #define GET_OP_CLASSES 486 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 487 488 void mlir::linalg::registerTransformDialectExtension( 489 DialectRegistry ®istry) { 490 registry.addExtensions<LinalgTransformDialectExtension>(); 491 } 492