//===- LinalgOps.cpp - Implementation of the linalg operations ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Linalg operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::linalg; #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" /// Forward declarations. /// Generic entry point to create the block for the region of a LinalgOp. /// This is used by both named structured ops created by ods-gen and by manually /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding /// to the elemental types of `inputTypes` and `outputTypes`. The latter are /// asserted to be of ShapedType. template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, llvm::function_ref errorHandler = nullptr); /// Generic entry point to create both the region and the block of a LinalgOp. template static void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes); /// Common parsing and printing used for both named structured ops created by /// ods-gen and by manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes); template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op); /// Specific parsing and printing for named structured ops created by ods-gen. template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast(%src)) -> someop(%src) /// ``` /// It folds the source of the memref.cast into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } /// This is a specialization of `foldMemRefCast` used for patterns of the form /// ``` /// tiled_loop(memrefcast(%src)) -> tiled_loop(%src) /// ``` /// It folds the source of the memref.cast into the root operation directly. static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) { bool folded = false; Location loc = op->getLoc(); Block *body = op.getBody(); OpBuilder b = OpBuilder::atBlockBegin(body); // Update `input` and `output` operands and block arguments if necessary. // Operands list: [lbs, ubs, steps, inputs, outputs]. // Block args list: [ivs, inputs, outputs]. for (size_t operandIndex = op.getNumControlOperands(), bbArgIndex = op.getNumLoops(), e = op.getNumOperands(); operandIndex < e; ++operandIndex, ++bbArgIndex) { OpOperand &operand = op->getOpOperand(operandIndex); auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); BlockArgument newBbArg = body->insertArgument(bbArgIndex, castOp.getOperand().getType()); BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1); // Insert memref.cast back to the original type. oldBbArg.replaceAllUsesWith( b.create(loc, oldBbArg.getType(), newBbArg)); body->eraseArgument(oldBbArg.getArgNumber()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. // The public methods on this class are referenced directly from generated code // and bind by name to math and type conversion functions in the DSL as: // `arithfn__{fnName}` // `typefn__{fnName}` // Examples: // `arithfn__add` // `arithfn__mul` // `typefn__cast` // The naming convention is intentional in order to match snake-cased DSL names. // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. // // Implementations of the math functions must be polymorphic over numeric types, // internally performing necessary casts. If the function application makes no // sense, then the only recourse is to assert and return nullptr. This can be // extended later if it becomes possible to fail construction of the region. The // invariant should be enforced at a higher level. // // TODO: These helpers are currently type polymorphic over the class of integer // and floating point types, but they will not internally cast within bit // widths of a class (mixed precision such as i8->i32) or across classes // (i.e. mixed float and integer). Many such combinations are ambiguous or need // to be handled with care and work is being considered to extend the op // language to make such cases explicit. In the mean-time, violating this will // fail verification, which is deemed acceptable. //===----------------------------------------------------------------------===// namespace { class RegionBuilderHelper { public: RegionBuilderHelper(MLIRContext *context, Block &block) : context(context), block(block) {} // Generates operations to cast the given operand to a specified type. // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand, bool isUnsignedCast) { OpBuilder builder = getBuilder(); auto loc = operand.getLoc(); if (operand.getType() == toType) return operand; if (auto toIntType = toType.dyn_cast()) { // If operand is floating point, cast directly to the int type. if (operand.getType().isa()) { if (isUnsignedCast) return builder.create(loc, toType, operand); return builder.create(loc, toType, operand); } // Cast index operands directly to the int type. if (operand.getType().isIndex()) return builder.create(loc, toType, operand); if (auto fromIntType = operand.getType().dyn_cast()) { // Either extend or truncate. if (toIntType.getWidth() > fromIntType.getWidth()) { if (isUnsignedCast) return builder.create(loc, toType, operand); return builder.create(loc, toType, operand); } if (toIntType.getWidth() < fromIntType.getWidth()) return builder.create(loc, toType, operand); } } else if (auto toFloatType = toType.dyn_cast()) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. if (operand.getType().isa()) { if (isUnsignedCast) return builder.create(loc, toFloatType, operand); return builder.create(loc, toFloatType, operand); } if (auto fromFloatType = operand.getType().dyn_cast()) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); if (toFloatType.getWidth() < fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); } } emitWarning(operand.getLoc()) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand; } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value typefn__cast(Type toType, Value operand) { return cast(toType, operand, false); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value typefn__cast_unsigned(Type toType, Value operand) { return cast(toType, operand, true); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__exp(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__log(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__sub(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__mul(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__max(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__max_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__min(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__min_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } void yieldOutputs(ValueRange values) { assert(!values.empty() && "linalg ops must yield outputs"); if (values.empty()) return; Value first = values.front(); OpBuilder builder = getBuilder(); builder.create(first.getLoc(), values); } Value constant(const std::string &value) { OpBuilder builder = getBuilder(); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); return builder.create(loc, valueAttr.getType(), valueAttr); } Value index(int64_t dim) { OpBuilder builder = getBuilder(); return builder.create(builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { return IntegerType::get(context, width); } Type getFloat32Type() { return Float32Type::get(context); } Type getFloat64Type() { return Float64Type::get(context); } private: MLIRContext *context; Block █ bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } OpBuilder getBuilder() { OpBuilder builder(context); builder.setInsertionPointToEnd(&block); return builder; } }; } // namespace //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, Value output, AffineMap inputPermutation, AffineMap outputPermutation, ArrayRef namedAttrs) { result.addOperands({input, output}); result.addAttributes(namedAttrs); if (inputPermutation) result.addAttribute("inputPermutation", AffineMapAttr::get(inputPermutation)); if (outputPermutation) result.addAttribute("outputPermutation", AffineMapAttr::get(outputPermutation)); result.addRegion(); fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{input.getType()}, TypeRange{output.getType()}); } ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, Type outputType) { OpBuilder opBuilder(parser.getContext()); fillStructuredOpRegion(opBuilder, r, TypeRange{inputType}, TypeRange{outputType}); return success(); } /// CopyOp region is elided when printing. void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} static LogicalResult verify(CopyOp op) { OpOperand *output = op.getOutputOperand(0); OpOperand *input = op.getInputOperand(0); if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get())) return op.emitOpError("expects views of the same type"); if (op.getRank(input) != op.getRank(output)) return op.emitOpError("expects views of the same rank"); auto rank = op.getNumParallelLoops(); auto inputPermutationMap = op.inputPermutation(); if (inputPermutationMap) { if (inputPermutationMap->getNumInputs() != rank) return op.emitOpError("expects optional input_permutation map of rank ") << rank; if (!inputPermutationMap->isPermutation()) return op.emitOpError( "expects optional input_permutation map to be a permutation"); } auto outputPermutationMap = op.outputPermutation(); if (outputPermutationMap) { if (outputPermutationMap->getNumInputs() != rank) return op.emitOpError("expects optional output_permutation map of rank ") << rank; if (!outputPermutationMap->isPermutation()) return op.emitOpError( "expects optional output_permutation map to be a permutation"); } if (rank == 0 && inputPermutationMap) return op.emitOpError("expected no input permutation when rank == 0"); if (rank == 0 && outputPermutationMap) return op.emitOpError("expected no output permutation when rank == 0"); return success(); } void CopyOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); } namespace { /// Remove copy operations that copy data inplace. Requirements are: /// 1) The input and output values are identical. /// 2) The input and output permutation maps are identical. struct EraseIdentityCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CopyOp copyOp, PatternRewriter &rewriter) const override { assert(copyOp.hasBufferSemantics()); if (copyOp.input() == copyOp.output() && copyOp.inputPermutation() == copyOp.outputPermutation()) { rewriter.eraseOp(copyOp); return success(); } return failure(); } }; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } void FillOp::build(OpBuilder &builder, OperationState &result, Value value, Value output) { build(builder, result, output.getType().dyn_cast(), value, output); fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{value.getType()}, TypeRange{output.getType()}, {}); } ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, Type outputType) { OpBuilder opBuilder(parser.getContext()); fillStructuredOpRegion(opBuilder, r, TypeRange{valueType}, TypeRange{outputType}); return success(); } /// FillOp region is elided when printing. void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} static LogicalResult verify(FillOp op) { OpOperand *output = op.getOutputOperand(0); Type fillType = op.value().getType(); if (getElementTypeOrSelf(output->get()) != fillType) return op.emitOpError("expects fill type to match view elemental type"); return success(); } void FillOp::getEffects( SmallVectorImpl> &effects) { if (output().getType().isa()) effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); } namespace { /// Fold linalg.fill -> tensor.expand/collapse_shape chain. /// /// For such op chains, we can create new linalg.fill ops with the result /// type of the tensor.expand/collapse_shape op. template struct FoldFillWithTensorReshape : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { auto oldFill = reshapeOp.src().template getDefiningOp(); if (!oldFill) return failure(); Location loc = oldFill.getLoc(); auto newInit = rewriter.create( loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.reassociation()); rewriter.replaceOpWithNewOp(reshapeOp, oldFill.value(), newInit); return success(); } }; } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, FoldFillWithTensorReshape>(context); } //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), doc.empty() ? StringAttr() : builder.getStringAttr(doc), libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall)); result.addAttributes(attributes); if (!bodyBuild) return; SmallVector blockArgTypes; for (ValueRange container : {inputs, outputs}) for (Value v : container) blockArgTypes.push_back(getElementTypeOrSelf(v)); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); bodyBuild(builder, result.location, bodyBlock->getArguments()); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild, attributes); } static void print(OpAsmPrinter &p, GenericOp op) { p << " "; // Print extra attributes. auto genericAttrNames = op.linalgTraitAttrNames(); llvm::StringSet<> genericAttrNamesSet; genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; for (auto attr : op->getAttrs()) if (genericAttrNamesSet.count(attr.getName().strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); p << genericDictAttr; } // Printing is shared with named ops, except for the region and attributes printCommonStructuredOpParts(p, op); genericAttrNames.push_back("operand_segment_sizes"); genericAttrNamesSet.insert(genericAttrNames.back()); bool hasExtraAttrs = false; for (NamedAttribute n : op->getAttrs()) { if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) break; } if (hasExtraAttrs) { p << " attrs = "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames); } // Print region. if (!op.region().empty()) { p << ' '; p.printRegion(op.region()); } // Print results. printNamedStructuredOpResults(p, op.result_tensors().getTypes()); } static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. // The name is unimportant as we will overwrite result.attributes. // The core linalg traits must contain the information necessary to pass the // verifier. if (parser.parseAttribute(dictAttr, "_", result.attributes)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); // Parsing is shared with named ops, except for the region. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // Optional attributes may be added. if (succeeded(parser.parseOptionalKeyword("attrs"))) if (failed(parser.parseEqual()) || failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); SmallVector regionOperands; std::unique_ptr region = std::make_unique(); SmallVector operandTypes, regionTypes; if (parser.parseRegion(*region, regionOperands, regionTypes)) return failure(); result.addRegion(std::move(region)); // Generic ops may specify that a subset of its outputs are tensors. Such // outputs are specified in the result type. // TODO: may need to move output parsing before region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); return success(); } static void getGenericEffectsImpl( SmallVectorImpl> &effects, ValueRange results, ValueRange inputBuffers, ValueRange outputs) { for (Value value : results) { effects.emplace_back(MemoryEffects::Allocate::get(), value, SideEffects::DefaultResource::get()); } for (Value value : inputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } for (Value value : outputs) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, SideEffects::DefaultResource::get()); } } void GenericOp::getEffects( SmallVectorImpl> &effects) { SmallVector inputBuffers = getInputBufferOperands(); SmallVector outputBuffers = getOutputBufferOperands(); getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, outputBuffers); } template static LogicalResult verifyGenericOp(GenericOpType op) { return success(); } static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } namespace { // Deduplicate redundant args of a linalg generic op. // An arg is redundant if it has the same Value and indexing map as another. struct DeduplicateGenericOpInputs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Associate each input to an equivalent "canonical" input that has the same // Value and indexing map. // // In the non-duplicate case, input `i` will have canonical input `i`. But // in the case of duplicated inputs, the canonical input could be some other // input `< i`. That is, a later input will have some earlier input as its // canonical input. llvm::SmallDenseMap, unsigned> canonicalInput; // For later remapping tasks like deduplicating payload block arguments, // having a simple "inputIndex -> canonicalInputIndex" integer mapping is // convenient. SmallVector canonicalInputIndices; for (OpOperand *opOperand : genericOp.getInputOperands()) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); // STL-like maps have a convenient behavior for our use case here. In the // case of duplicate keys, the insertion is rejected, and the returned // iterator gives access to the value already in the map. auto pair = canonicalInput.insert( {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); canonicalInputIndices.push_back(pair.first->second); } // If there are no duplicate args, then bail out. if (canonicalInput.size() == genericOp.getNumInputs()) return failure(); // The operands for the newly canonicalized op. SmallVector newInputOperands; for (OpOperand *opOperand : genericOp.getInputOperands()) if (canonicalInputIndices[opOperand->getOperandNumber()] == opOperand->getOperandNumber()) newInputOperands.push_back(opOperand->get()); // Repair the indexing maps by filtering out the ones that have been // eliminated. SmallVector newIndexingMaps; for (OpOperand *opOperand : genericOp.getInputOperands()) if (canonicalInputIndices[opOperand->getOperandNumber()] == opOperand->getOperandNumber()) newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); for (OpOperand *opOperand : genericOp.getOutputOperands()) newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); // Clone the old op with new operands. SmallVector outputOperands = genericOp.getOutputOperands(); auto newOp = rewriter.create( genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.iterator_types(), genericOp.docAttr(), genericOp.library_callAttr()); // Copy over unknown attributes. They might be load bearing for some flow. ArrayRef odsAttrs = genericOp.getAttributeNames(); for (NamedAttribute kv : genericOp->getAttrs()) { if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { newOp->setAttr(kv.getName(), kv.getValue()); } } rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), newOp.region().begin()); // Repair the payload entry block by RAUW'ing redundant arguments and // erasing them. Block &payload = newOp.region().front(); SmallVector inputOperands = genericOp.getInputOperands(); for (OpOperand *opOperand : llvm::reverse(inputOperands)) { // Iterate in reverse, so that we erase later args first, preventing the // argument list from shifting unexpectedly and invalidating all our // indices. unsigned operandNumber = opOperand->getOperandNumber(); if (canonicalInputIndices[operandNumber] == operandNumber) continue; payload.getArgument(operandNumber) .replaceAllUsesWith( payload.getArgument(canonicalInputIndices[operandNumber])); payload.eraseArgument(operandNumber); } rewriter.replaceOp(genericOp, newOp->getResults()); return success(); } }; /// Remove generic operations (on tensors) that are just copying /// the values from inputs to the results. Requirements are /// 1) All iterator types are parallel /// 2) The body contains just a yield operation with the yielded values being /// the arguments corresponding to the operands. struct EraseIdentityGenericOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); // Check all indexing maps are identity. if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { return !map.isIdentity(); })) return failure(); // Check that the body of the linalg operation is just a linalg.yield // operation. Block &body = genericOp.region().front(); if (!llvm::hasSingleElement(body)) return failure(); auto yieldOp = dyn_cast(body.getTerminator()); if (!yieldOp) return failure(); // Get the argument number of the returned values. That is the operand // number to use for replacing uses of this operation. SmallVector returnedArgs; for (Value yieldVal : yieldOp.values()) { auto yieldArg = yieldVal.dyn_cast(); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); returnedArgs.push_back(genericOp->getOperand(argumentNumber)); } if (returnedArgs.size() != genericOp->getNumResults()) return failure(); rewriter.replaceOp(genericOp, returnedArgs); return success(); } }; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// void InitTensorOp::build(OpBuilder &b, OperationState &result, ArrayRef sizes, Type elementType, ArrayRef attrs) { SmallVector dynamicSizes; SmallVector staticSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); auto resultType = RankedTensorType ::get(staticSizes, elementType); build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); result.addAttributes(attrs); } static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( op.static_sizes().cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), op.static_sizes(), op.sizes(), ShapedType::isDynamic))) return failure(); if (op.static_sizes().size() != static_cast(resultType.getRank())) return op->emitError("expected ") << resultType.getRank() << " sizes values"; Type expectedType = InitTensorOp::inferResultType( staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { return op.emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } return success(); } Type InitTensorOp::inferResultType(ArrayRef staticSizes, Type elementType, Attribute encoding) { return RankedTensorType::get(staticSizes, elementType, encoding); } namespace { /// Change the type of the result of a `linalg.init_tensor` by making the result /// type statically sized along dimension that in the original operation where /// defined as dynamic, but the size was defined using a `constant` op. For /// example /// /// %c5 = arith.constant 5: index /// %0 = linalg.init_tensor [%arg0, %c5] : tensor /// /// to /// /// %0 = linalg.init_tensor [%arg0, 5] : tensor struct ReplaceStaticShapeDims : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InitTensorOp op, PatternRewriter &rewriter) const override { SmallVector dynamicSizes; SmallVector staticSizes; for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { // If the size is already static, nothing to do. if (!op.isDynamicSize(i)) { staticSizes.push_back(op.getStaticSize(i)); continue; } // If the size is dynamic but defined using a `constant` op, get the // constant value to find the static size to use. unsigned operandNum = op.getIndexOfDynamicSize(i); Value sizeOperand = op.getOperand(operandNum); if (auto constantIndexOp = sizeOperand.getDefiningOp()) { staticSizes.push_back(constantIndexOp.value()); continue; } // Fallback case. Keep the size dynamic. dynamicSizes.push_back(sizeOperand); staticSizes.push_back(ShapedType::kDynamicSize); } RankedTensorType newType = RankedTensorType::get(staticSizes, op.getType().getElementType()); if (newType == op.getType()) return failure(); auto newOp = rewriter.create(op.getLoc(), newType, dynamicSizes, rewriter.getI64ArrayAttr(staticSizes)); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } }; } // namespace namespace { /// Since `init_tensor` operation creates a tensor needed only for its shape, a /// slice of this is also needed only for its shape. The result can be /// replaced by a new init_tensor operation of the same size as the extract /// slice op. struct FoldInitTensorWithExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { if (!sliceOp.source().getDefiningOp()) return failure(); // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved // as well as its result type. rewriter.replaceOpWithNewOp( sliceOp, sliceOp.sizes(), sliceOp.result().getType().cast().getShape(), sliceOp.getSourceType().getElementType()); return success(); } }; template struct FoldInitTensorWithTensorReshapeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { if (!reshapeOp.src().template getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = cast(reshapeOp.getOperation()); if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); Value initTensor = rewriter.create( loc, getAsOpFoldResult(resultShapes[0]), reshapeOp.getResultType().getElementType()); if (initTensor.getType() != reshapeOp.getResultType()) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), initTensor); } else { rewriter.replaceOp(reshapeOp, initTensor); } return success(); } }; struct FoldInitTensorWithDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp dimOp, PatternRewriter &rewriter) const override { Optional maybeConstantIndex = dimOp.getConstantIndex(); auto initTensorOp = dimOp.source().getDefiningOp(); if (!initTensorOp || !maybeConstantIndex) return failure(); if (!initTensorOp.isDynamicSize(*maybeConstantIndex)) return failure(); rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex)); return success(); } }; } // namespace void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); } LogicalResult InitTensorOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { auto shapes = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { if (isDynamicSize(dim)) return getDynamicSize(dim); return builder.create(getLoc(), getStaticSize(dim)); })); reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } //===----------------------------------------------------------------------===// // PadTensorOp //===----------------------------------------------------------------------===// // TODO: Replace custom directive with AllTypesMatch as soon as it // supports optional types. void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom) {} ParseResult parseInferType(OpAsmParser &parser, Optional optOperand, Type &typeToInfer, Type typeToInferFrom) { if (optOperand) typeToInfer = typeToInferFrom; return success(); } static LogicalResult verify(PadTensorOp op) { auto sourceType = op.source().getType().cast(); auto resultType = op.result().getType().cast(); auto expectedType = PadTensorOp::inferResultType( sourceType, extractFromI64ArrayAttr(op.static_low()), extractFromI64ArrayAttr(op.static_high())); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; if (expectedType.isDynamicDim(i)) continue; return op.emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } auto ®ion = op.region(); unsigned rank = resultType.getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) return op.emitError("expected the block to have ") << rank << " arguments"; // Note: the number and type of yield values are checked in the YieldOp. for (const auto &en : llvm::enumerate(block.getArgumentTypes())) { if (!en.value().isIndex()) return op.emitOpError("expected block argument ") << (en.index() + 1) << " to be an index"; } return success(); } RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, ArrayRef staticLow, ArrayRef staticHigh, ArrayRef resultShape) { unsigned rank = sourceType.getRank(); assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); assert((resultShape.empty() || resultShape.size() == rank) && "unexpected resultShape size mismatch"); SmallVector inferredShape; for (auto i : llvm::seq(0, rank)) { if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamicSize || staticHigh[i] == ShapedType::kDynamicSize) { inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize : resultShape[i]); } else { int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; assert((resultShape.empty() || size == resultShape[i] || resultShape[i] == ShapedType::kDynamicSize) && "mismatch between inferred shape and result shape"); inferredShape.push_back(size); } } return RankedTensorType::get(inferredShape, sourceType.getElementType()); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamicSize); build(b, result, source, staticVector, staticVector, low, high, nofold, attrs); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, Value source, ArrayRef low, ArrayRef high, bool nofold, ArrayRef attrs) { assert(resultType.isa()); auto sourceType = source.getType().cast(); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; // staticLow and staticHigh have full information of the padding config. // This will grow staticLow and staticHigh with 1 value. If the config is // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 // value as well. dispatchIndexOpFoldResults(low, dynamicLow, staticLow, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh, ShapedType::kDynamicSize); if (!resultType) { resultType = PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); } build(b, result, resultType, source, dynamicLow, dynamicHigh, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, ArrayRef low, ArrayRef high, bool nofold, Location loc, OpBuilder &builder) { auto padTensorOp = builder.create(loc, type, source, low, high, nofold); int rank = padTensorOp.getResultType().getRank(); SmallVector blockArgTypes; blockArgTypes.assign(rank, builder.getIndexType()); auto ®ion = padTensorOp.region(); // `builder.createBlock` changes the insertion point within the block. Create // a guard to reset the insertion point of the builder after it is destroyed. OpBuilder::InsertionGuard guard(builder); builder.createBlock(®ion, region.end(), blockArgTypes); builder.create(loc, pad); return padTensorOp; } PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, bool nofold, Location loc, OpBuilder &b) { SmallVector low, high; auto rankedTensorType = type.cast(); assert(rankedTensorType.hasStaticShape()); for (const auto &en : enumerate(rankedTensorType.getShape())) { AffineExpr d0; bindDims(b.getContext(), d0); auto dimOp = b.createOrFold(loc, source, en.index()); Value paddingWidth = makeComposedAffineApply(b, loc, en.value() - d0, {dimOp}); high.push_back(paddingWidth); low.push_back(b.createOrFold(loc, 0)); } return PadTensorOp::createPadScalarOp(type, source, pad, low, high, nofold, loc, b); } LogicalResult PadTensorOp::reifyResultShapes( OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); SmallVector shapes; for (auto dim : llvm::seq(0, getSourceType().getRank())) { // Shape along each dimension is source dim + low pad + high pad. SmallVector mapOperands; mapOperands.push_back(b.createOrFold(loc, source(), dim)); AffineExpr expr = b.getAffineDimExpr(0); unsigned numSymbols = 0; auto addOpFoldResult = [&](OpFoldResult valueOrAttr) { if (Value v = valueOrAttr.dyn_cast()) { expr = expr + b.getAffineSymbolExpr(numSymbols++); mapOperands.push_back(v); return; } int64_t staticValue = valueOrAttr.get().cast().getInt(); expr = expr + staticValue; }; addOpFoldResult(lowPad[dim]); addOpFoldResult(highPad[dim]); shapes.push_back(applyMapToValues( b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); } reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } //===----------------------------------------------------------------------===// // Methods related to PadTensor tiling. //===----------------------------------------------------------------------===// SmallVector PadTensorOp::getDestinationOperands(OpBuilder &b) { ReifiedRankedShapedTypeDims reifiedShapes; (void)reifyResultShapes(b, reifiedShapes); SmallVector mixedSizes = getAsOpFoldResult(reifiedShapes[0]); Value initTensor = b.create(getLoc(), mixedSizes, getResultType().getElementType()); return {initTensor}; } SmallVector PadTensorOp::getLoopIteratorTypes() { SmallVector iteratorTypes(getResultType().getRank(), getParallelIteratorTypeName()); return iteratorTypes; } SmallVector PadTensorOp::getIterationDomain(OpBuilder &b) { ReifiedRankedShapedTypeDims reifiedShapes; (void)reifyResultShapes(b, reifiedShapes); Value zero = b.create(getLoc(), 0); Value one = b.create(getLoc(), 1); // Initialize all the ranges to {zero, one, one}. All the `ub`s are // overwritten. SmallVector loopRanges(reifiedShapes[0].size(), {zero, one, one}); for (const auto &ub : enumerate(reifiedShapes[0])) loopRanges[ub.index()].size = ub.value(); return loopRanges; } SmallVector PadTensorOp::getTiledImplementation( OpBuilder &b, ValueRange dest, ArrayRef offsets, ArrayRef sizes, bool /*tileDestOperands*/) { // Only constant padding value supported. Value padValue = getConstantPaddingValue(); if (!padValue) return {}; // Helper variables and functions for various arithmetic operations. These are // used extensively for computing new offset/length and padding values. Location loc = getLoc(); AffineExpr dim0, dim1; bindDims(b.getContext(), dim0, dim1); // Add two integers. auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); auto add = [&](Value v1, Value v2) { return b.createOrFold(loc, addMap, ValueRange{v1, v2}); }; // Subtract two integers. auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); auto sub = [&](Value v1, Value v2) { return b.createOrFold(loc, subMap, ValueRange{v1, v2}); }; // Take the minimum of two integers. auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); auto min = [&](Value v1, Value v2) { return b.createOrFold(loc, idMap, ValueRange{v1, v2}); }; // Take the maximum of two integers. auto max = [&](Value v1, Value v2) { return b.createOrFold(loc, idMap, ValueRange{v1, v2}); }; // Zero index-typed integer. auto zero = b.create(loc, 0); // Helper function for filling static/dynamic low/high padding indices vectors // of PadTensorOp. auto appendIndex = [&](Value val, SmallVector &dynIndices, SmallVector &staticIndices) { if (auto constInt = getConstantIntValue(val)) { staticIndices.push_back(*constInt); } else { staticIndices.push_back(ShapedType::kDynamicSize); dynIndices.push_back(val); } }; // Compute new offsets, lengths, low padding, high padding. SmallVector newOffsets, newLengths, newStrides; SmallVector newLows, newHighs; SmallVector staticNewLows, staticNewHighs; // Set to true if the original data source is not read at all. bool hasZeroLen = false; // Same as hasZeroLen, but for dynamic dimension sizes. This condition // is true if the original data source turns out to be unused at runtime. Value dynHasZeroLenCond; int64_t rank = getSourceType().getRank(); for (unsigned dim = 0; dim < rank; ++dim) { auto low = getValueOrCreateConstantIndexOp(b, loc, getMixedLowPad()[dim]); bool hasLowPad = getConstantIntValue(low) != static_cast(0); auto high = getValueOrCreateConstantIndexOp(b, loc, getMixedHighPad()[dim]); bool hasHighPad = getConstantIntValue(high) != static_cast(0); auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]); auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]); auto srcSize = b.createOrFold(loc, source(), dim); // The new amount of low padding is `low - offset`. Except for the case // where none of the low padding is read. In that case, the new amount of // low padding is zero. // // Optimization: If low = 0, then newLow = 0. Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; appendIndex(newLow, newLows, staticNewLows); // Start reading the data from position `offset - low`. Since the original // read may have started in the low padding zone, this value could be // negative. Therefore, start reading from: // // max(offset - low, 0) // // The original read could also have started in the high padding zone. // In that case, set the offset to the end of source tensor. The new // ExtractSliceOp length will be zero in that case. (Effectively reading no // data from the source.) // // Optimization: If low = 0, then the formula can be simplified. Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) : min(offset, srcSize); newOffsets.push_back(getAsOpFoldResult(newOffset)); // The original ExtractSliceOp was reading until position `offset + length`. // Therefore, the corresponding position within the source tensor is: // // offset + length - low // // In case the original ExtractSliceOp stopped reading within the low // padding zone, this value can be negative. In that case, the end position // of the read should be zero. (Similar to newOffset.) // // The original read could also have stopped in the high padding zone. // In that case, set the end positition of the read should be the end of the // source tensor. (Similar to newOffset.) // // endLoc = min(max(offset - low + length, 0), srcSize) // // The new ExtractSliceOp length is `endLoc - newOffset`. // // Optimization: If low = 0, then the formula can be simplified. Value endLoc = hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) : min(add(offset, length), srcSize); Value newLength = sub(endLoc, newOffset); newLengths.push_back(getAsOpFoldResult(newLength)); // Check if newLength is zero. In that case, no SubTensorOp should be // executed. if (auto newLengthInt = getConstantIntValue(newLength)) { hasZeroLen |= *newLengthInt == 0; } else { Value check = b.create(loc, arith::CmpIPredicate::eq, newLength, zero); dynHasZeroLenCond = dynHasZeroLenCond ? b.create(loc, check, dynHasZeroLenCond) : check; } // The amount of high padding is simply the number of elements remaining, // so that the result has the same length as the original ExtractSliceOp. // As an optimization, if the original high padding is zero, then the new // high padding must also be zero. Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; appendIndex(newHigh, newHighs, staticNewHighs); // Only unit stride supported. newStrides.push_back(b.getIndexAttr(1)); } // The shape of the result can be obtained from the sizes passed in. SmallVector dynDims; SmallVector shape; dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize); RankedTensorType resultType = RankedTensorType::get(shape, getResultType().getElementType()); // Insert cast to ensure that types match. (May be folded away.) auto castResult = [&](Value val) -> Operation * { auto castOp = b.create(loc, resultType, val); return castOp; }; // In cases where the original data source is unused: Emit a GenerateOp and // do not generate a SliceOp. (The result shape of the SliceOp would // have a dimension of size 0, the semantics of which is unclear.) auto createGenerateOp = [&]() { // Create GenerateOp. auto generateOp = b.create( loc, resultType, dynDims, [&](OpBuilder &builder, Location gLoc, ValueRange indices) { builder.create(gLoc, padValue); }); return castResult(generateOp); }; // Emit a SliceOp and a PadTensorOp. Should not be used in cases where // the result shape of the new SliceOp has a zero dimension. auto createPadTensorOfSubTensor = [&]() { // Create pad_tensor(subtensor(x)). auto newSliceOp = b.create( loc, source(), newOffsets, newLengths, newStrides); auto newPadTensorOp = b.create( loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs); // Copy region to new PadTensorOp. BlockAndValueMapping bvm; region().cloneInto(&newPadTensorOp.getRegion(), bvm); // Cast result and return. return castResult(newPadTensorOp); }; // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known // that the original data source x is not used. if (hasZeroLen) { return {createGenerateOp()}; } // If there are dynamic dimensions: Generate an scf.if check to avoid creating // SliceOps with result dimensions of size 0 at runtime. if (dynHasZeroLenCond) { auto result = b.create( loc, resultType, dynHasZeroLenCond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { b.create(loc, createGenerateOp()->getResult(0)); }, /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { b.create(loc, createPadTensorOfSubTensor()->getResult(0)); }); return {result}; } return {createPadTensorOfSubTensor()}; } namespace { // Folds linalg.pad_tensor when padding is static zeros and the attribute // doesn't request otherwise. struct FoldStaticZeroPadding : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PadTensorOp padTensorOp, PatternRewriter &rewriter) const override { if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) return failure(); if (padTensorOp.nofold()) return failure(); rewriter.replaceOpWithNewOp( padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); return success(); } }; // Fold CastOp into PadTensorOp when adding static information. struct FoldSourceTensorCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PadTensorOp padTensorOp, PatternRewriter &rewriter) const override { auto castOp = padTensorOp.source().getDefiningOp(); if (!tensor::canFoldIntoConsumerOp(castOp)) return failure(); auto newResultType = PadTensorOp::inferResultType( castOp.source().getType().cast(), extractFromI64ArrayAttr(padTensorOp.static_low()), extractFromI64ArrayAttr(padTensorOp.static_high()), padTensorOp.getResultType().getShape()); if (newResultType == padTensorOp.getResultType()) { rewriter.updateRootInPlace(padTensorOp, [&]() { padTensorOp.sourceMutable().assign(castOp.source()); }); } else { auto newOp = rewriter.create( padTensorOp->getLoc(), newResultType, padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(), padTensorOp.static_high(), padTensorOp.nofold()); BlockAndValueMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); rewriter.replaceOpWithNewOp( padTensorOp, padTensorOp.getResultType(), newOp); } return success(); } }; // Fold CastOp using the result of PadTensorOp back into the latter if it adds // static information. struct FoldTargetTensorCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PadTensorOp padTensorOp, PatternRewriter &rewriter) const override { if (!padTensorOp.result().hasOneUse()) return failure(); auto tensorCastOp = dyn_cast(*padTensorOp->getUsers().begin()); if (!tensorCastOp) return failure(); if (!tensor::preservesStaticInformation(padTensorOp.result().getType(), tensorCastOp.dest().getType())) return failure(); auto replacementOp = rewriter.create( padTensorOp.getLoc(), tensorCastOp.dest().getType(), padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(), padTensorOp.static_high(), padTensorOp.nofold()); replacementOp.region().takeBody(padTensorOp.region()); rewriter.replaceOp(padTensorOp, replacementOp.result()); rewriter.replaceOp(tensorCastOp, replacementOp.result()); return success(); } }; } // namespace void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); results.add(context); } /// Return the padding value of the PadTensorOp if it constant. In this context, /// "constant" means an actual constant or "defined outside of the block". /// /// Values are considered constant in three cases: /// - A ConstantLike value. /// - A basic block argument from a different block. /// - A value defined outside of the block. /// /// If the padding value is not constant, an empty Value is returned. Value PadTensorOp::getConstantPaddingValue() { auto yieldOp = dyn_cast(getRegion().front().getTerminator()); if (!yieldOp || yieldOp.values().size() != 1) return {}; Value padValue = yieldOp.values().front(); // Check if yield value is a constant. if (matchPattern(padValue, m_Constant())) return padValue; // Check if yield value is defined inside the PadTensorOp block. if (padValue.getParentBlock() == &getRegion().front()) return {}; // Else: Yield value defined outside of the PadTensorOp block. return padValue; } OpFoldResult PadTensorOp::fold(ArrayRef) { if (getResultType().hasStaticShape() && getResultType() == getSourceType() && !nofold()) return source(); return {}; } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, linalg::YieldOp op) { if (op.getNumOperands() > 0) p << ' ' << op.getOperands(); p.printOptionalAttrDict(op->getAttrs()); if (op.getNumOperands() > 0) p << " : " << op.getOperandTypes(); } static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { SmallVector opInfo; SmallVector types; llvm::SMLoc loc = parser.getCurrentLocation(); return failure(parser.parseOperandList(opInfo) || parser.parseOptionalAttrDict(result.attributes) || (!opInfo.empty() && parser.parseColonTypeList(types)) || parser.resolveOperands(opInfo, types, loc, result.operands)); } // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { if (op.getNumOperands() != linalgOp.getNumOutputs()) return op.emitOpError("expected number of yield values (") << linalgOp.getNumOutputs() << ") to match the number of operands of the enclosing " << "LinalgOp (" << op.getNumOperands() << ")"; for (OpOperand &opOperand : op->getOpOperands()) { OpOperand *outputOperand = linalgOp.getOutputOperand(opOperand.getOperandNumber()); Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); if (opOperand.get().getType() != elementType) return op.emitOpError("type of yield operand ") << (opOperand.getOperandNumber() + 1) << " (" << opOperand.get().getType() << ") doesn't match " << "the element type of the enclosing linalg.generic op (" << elementType << ")"; } return success(); } static LogicalResult verify(linalg::YieldOp op) { auto *parentOp = op->getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); if (auto linalgOp = dyn_cast(parentOp)) return verifyYield(op, cast(parentOp)); if (auto padTensorOp = dyn_cast(parentOp)) { if (op.getNumOperands() != 1) return op.emitOpError("expected single yield operand (got ") << op->getNumOperands() << ")"; if (op.getOperand(0).getType() != padTensorOp.getType().cast().getElementType()) return op.emitOpError("expected yield type to match shape element type"); return success(); } if (auto tiledLoopOp = dyn_cast(parentOp)) { // Check if output args with tensor types match results types. SmallVector tensorOuts; llvm::copy_if( tiledLoopOp.outputs(), std::back_inserter(tensorOuts), [&](Value out) { return out.getType().isa(); }); if (tensorOuts.size() != op.values().size()) return op.emitOpError("expected number of tensor output args = ") << tensorOuts.size() << " to match the number of yield operands = " << op.values().size(); TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); for (auto &item : llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { Type outType, resultType; unsigned index = item.index(); std::tie(outType, resultType) = item.value(); if (outType != resultType) return op.emitOpError("expected yield operand ") << index << " with type = " << resultType << " to match output arg type = " << outType; } return success(); } return op.emitOpError("expected parent op with LinalgOp interface"); } //===----------------------------------------------------------------------===// // TiledLoopOp //===----------------------------------------------------------------------===// void TiledLoopOp::build(OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange inputs, ValueRange outputs, ArrayAttr iteratorTypes, function_ref bodyBuilderFn) { build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs, iteratorTypes, llvm::None, bodyBuilderFn); } void TiledLoopOp::build(OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange inputs, ValueRange outputs, ArrayAttr iteratorTypes, Optional distributionTypes, function_ref bodyBuilderFn) { result.addOperands(lowerBounds); result.addOperands(upperBounds); result.addOperands(steps); result.addOperands(inputs); result.addOperands(outputs); result.addAttribute( TiledLoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({static_cast(lowerBounds.size()), static_cast(upperBounds.size()), static_cast(steps.size()), static_cast(inputs.size()), static_cast(outputs.size())})); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); if (distributionTypes.hasValue()) result.addAttribute(getDistributionTypesAttrName(), distributionTypes.getValue()); // Add output types for `RankedTensorType` output arguments. for (Value output : outputs) { Type outputType = output.getType(); if (outputType.isa()) result.addTypes(outputType); } OpBuilder::InsertionGuard guard(builder); unsigned numIVs = steps.size(); SmallVector argTypes(numIVs, builder.getIndexType()); for (Type type : TypeRange(inputs)) argTypes.push_back(type); for (Type type : TypeRange(outputs)) argTypes.push_back(type); Region *bodyRegion = result.addRegion(); Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); if (bodyBuilderFn) { builder.setInsertionPointToStart(bodyBlock); bodyBuilderFn(builder, result.location, bodyBlock->getArguments().take_front(numIVs), bodyBlock->getArguments().slice(numIVs, inputs.size()), bodyBlock->getArguments().take_back(outputs.size())); TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } } static void print(OpAsmPrinter &p, TiledLoopOp op) { p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; if (!op.inputs().empty()) { p << " ins ("; llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it) << ": " << std::get<1>(it).getType(); }); p << ")"; } if (!op.outputs().empty()) { p << " outs ("; llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it) << ": " << std::get<1>(it).getType(); }); p << ")"; } if (llvm::any_of(op.iterator_types(), [](Attribute attr) { return attr.cast().getValue() != getParallelIteratorTypeName(); })) p << " iterators" << op.iterator_types(); if (op.distribution_types().hasValue()) p << " distribution" << op.distribution_types().getValue(); p << ' '; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), getIteratorTypesAttrName(), getDistributionTypesAttrName()}); } static ParseResult parseTiledLoopOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren)) return failure(); // Parse loop bounds. SmallVector lower; if (parser.parseEqual() || parser.parseOperandList(lower, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lower, builder.getIndexType(), result.operands)) return failure(); SmallVector upper; if (parser.parseKeyword("to") || parser.parseOperandList(upper, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upper, builder.getIndexType(), result.operands)) return failure(); // Parse step values. SmallVector steps; if (parser.parseKeyword("step") || parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.resolveOperands(steps, builder.getIndexType(), result.operands)) return failure(); // Parse input tensors. SmallVector inputs, inputRegionArgs; SmallVector inputTypes; if (succeeded(parser.parseOptionalKeyword("ins"))) { llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs, inputTypes)) return failure(); if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, result.operands)) return failure(); } // Parse output tensors. SmallVector outputs, outputRegionArgs; SmallVector outputTypes; if (succeeded(parser.parseOptionalKeyword("outs"))) { llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs, outputTypes)) return failure(); if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, result.operands)) return failure(); for (Type outputType : outputTypes) if (outputType.isa()) result.addTypes(outputType); } // Parse attributes. SmallVector iterTypes, distributionTypes; auto parseAttr = [&](StringRef keyword, SmallVector *attrs) { if (succeeded(parser.parseOptionalKeyword(keyword))) { StringAttr attr; if (parser.parseLSquare() || parser.parseAttribute(attr)) return failure(); attrs->push_back(attr); for (int i = 1, e = ivs.size(); i < e; ++i) { if (parser.parseComma() || parser.parseAttribute(attr)) return failure(); attrs->push_back(attr); } if (parser.parseRSquare()) return failure(); } return success(); }; if (failed(parseAttr("iterators", &iterTypes)) || failed(parseAttr("distribution", &distributionTypes))) return failure(); // Set all loop iterator types to "parallel" if they are not printed in IR. if (iterTypes.empty()) { auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName()); iterTypes = SmallVector(ivs.size(), parallelIter); } result.addAttribute(getIteratorTypesAttrName(), builder.getArrayAttr(iterTypes)); if (!distributionTypes.empty()) result.addAttribute(getDistributionTypesAttrName(), builder.getArrayAttr(distributionTypes)); result.addAttribute( TiledLoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({static_cast(lower.size()), static_cast(upper.size()), static_cast(steps.size()), static_cast(inputs.size()), static_cast(outputs.size())})); // Parse the body. Region *body = result.addRegion(); SmallVector regionTypes(ivs.size(), builder.getIndexType()); regionTypes.append(inputTypes); regionTypes.append(outputTypes); SmallVector regionArgs(ivs); regionArgs.append(inputRegionArgs); regionArgs.append(outputRegionArgs); if (parser.parseRegion(*body, regionArgs, regionTypes)) return failure(); // Parse optional attributes. parser.parseOptionalAttrDict(result.attributes); return success(); } Region &TiledLoopOp::getLoopBody() { return region(); } LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef ops) { for (auto *op : ops) op->moveBefore(*this); return success(); } bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value.getParentRegion()); } static LogicalResult verify(TiledLoopOp op) { // Check if iterator types are provided for every loop dimension. if (op.iterator_types().size() != op.getNumLoops()) return op.emitOpError("expected iterator types array attribute size = ") << op.iterator_types().size() << " to match the number of loops = " << op.getNumLoops(); // Check if types of input arguments match region args types. for (auto &item : llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { Value input, inputRegionArg; unsigned index = item.index(); std::tie(input, inputRegionArg) = item.value(); if (input.getType() != inputRegionArg.getType()) return op.emitOpError("expected input arg ") << index << " with type = " << input.getType() << " to match region arg " << index + op.getNumLoops() << " type = " << inputRegionArg.getType(); } // Check if types of input arguments match region args types. for (auto &item : llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { Value output, outputRegionArg; unsigned index = item.index(); std::tie(output, outputRegionArg) = item.value(); if (output.getType() != outputRegionArg.getType()) return op.emitOpError("expected output arg ") << index << " with type = " << output.getType() << " to match region arg " << index + op.getNumLoops() + op.inputs().size() << " type = " << outputRegionArg.getType(); } return success(); } namespace { static constexpr int64_t kNoMatch = -1; // Folds away TiledLoopOp inputs if they have no uses within the body. // // Example: // // %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>, // %in_buf_ = %in_buf: memref<...>) {...} // Becomes // // linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...} struct TiledLoopInputsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, PatternRewriter &rewriter) const final { SmallVector newInputs, regionInputTensorArgs; // Store ids of the corresponding old and new input operands. SmallVector oldInputIdToNew(tiledLoop.inputs().size(), kNoMatch); for (const auto &en : llvm::enumerate( llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) { Value in, bbArg; size_t index = en.index(); std::tie(in, bbArg) = en.value(); if (!bbArg.use_empty()) { oldInputIdToNew[index] = newInputs.size(); newInputs.push_back(in); } } if (newInputs.size() == tiledLoop.inputs().size()) return failure(); Location loc = tiledLoop.getLoc(); auto newTiledLoop = rewriter.create( loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(), tiledLoop.distribution_types()); // Clone the region. BlockAndValueMapping bvm; bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); bvm.map(tiledLoop.getRegionOutputArgs(), newTiledLoop.getRegionOutputArgs()); for (const auto &en : llvm::enumerate(oldInputIdToNew)) if (en.value() != kNoMatch) bvm.map(tiledLoop.getRegionInputArgs()[en.index()], newTiledLoop.getRegionInputArgs()[en.value()]); OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : *tiledLoop.getBody()) innerBuilder.clone(op, bvm); rewriter.replaceOp(tiledLoop, newTiledLoop.getResults()); return success(); } }; } // namespace /// A simple, conservative analysis to determine if the loop is shape /// conserving. I.e., the type of the arg-th yielded value is the same as the /// type of the corresponding basic block argument of the loop. /// Note: This function handles only simple cases. Expand as needed. static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) { auto yieldOp = cast(loopOp.getLoopBody().front().getTerminator()); if (yieldOp.values().empty()) // Tiled loop either has no outputs or is a "memref-based version". In // either case, the loop is shape conserving. return true; assert(arg < static_cast(yieldOp.values().size()) && "arg is out of bounds"); Value value = yieldOp.values()[arg]; while (value) { if (value == loopOp.getRegionOutputArgs()[arg]) return true; OpResult opResult = value.dyn_cast(); if (!opResult) return false; using tensor::InsertSliceOp; value = llvm::TypeSwitch(opResult.getOwner()) .template Case( [&](InsertSliceOp op) { return op.dest(); }) .template Case([&](TiledLoopOp loopOp) { return isShapePreserving(loopOp, opResult.getResultNumber()) ? loopOp.outputs()[opResult.getResultNumber()] : Value(); }) .Default([&](auto op) { return Value(); }); } return false; } namespace { /// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block /// to dim(y) where `y` is the initial input/output value of the argument. /// /// E.g.: /// %y = ... : tensor<...> /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { /// tensor.dim %x, %c0 : tensor<...> /// } /// /// is folded to: /// %y = ... : tensor<...> /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { /// tensor.dim %y, %c0 : tensor<...> /// } /// /// Note: Dim ops are folded only if it can be proven that the runtime type of /// the yielded value (in case of outputs) does not change with loop iterations. template struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const final { auto src = dimOp.source().template dyn_cast(); if (!src) return failure(); auto loopOp = dyn_cast(src.getOwner()->getParent()->getParentOp()); if (!loopOp) return failure(); unsigned numLoops = loopOp.getNumLoops(); unsigned numInputArgs = loopOp.getRegionInputArgs().size(); if (src.getArgNumber() >= numInputArgs + numLoops && !isShapePreserving(loopOp, src.getArgNumber() - numInputArgs - numLoops)) return failure(); auto inputArgs = loopOp.getRegionInputArgs(); auto it1 = llvm::find(inputArgs, src); if (it1 != inputArgs.end()) { rewriter.updateRootInPlace(dimOp, [&] { dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]); }); return success(); } auto outputArgs = loopOp.getRegionOutputArgs(); auto it2 = llvm::find(outputArgs, src); if (it2 != outputArgs.end()) { rewriter.updateRootInPlace(dimOp, [&] { dimOp.sourceMutable().assign( loopOp.outputs()[it2 - outputArgs.begin()]); }); return success(); } return failure(); } }; /// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y` /// is the initial output value of the loop. /// /// E.g.: /// %y = ... : tensor<...> /// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) { /// ... /// } /// %0 = tensor.dim %r, %c0 : tensor<...> /// /// is folded to: /// %y = ... : tensor<...> /// linalg.tiled_loop ... outs(%i = %y : tensor<...>) { /// ... /// } /// %0 = tensor.dim %y, %c0 : tensor<...> /// /// Note: Dim ops are folded only if it can be proven that the runtime type of /// the yielded value (in case of outputs) does not change with loop iterations. template struct DimOfTiledLoopResultFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const final { auto loopOp = dimOp.source().template getDefiningOp(); if (!loopOp) return failure(); auto opResult = dimOp.source().template cast(); unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(loopOp, resultNumber)) return failure(); rewriter.updateRootInPlace(dimOp, [&]() { dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]); }); return success(); } }; // Folds away TiledLoopOp output tensors when the following conditions are met: // * result of `linalg.tiled_loop` has no uses // * output tensor is the argument of `linalg.yield` // // Example: // // %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>, // %obuf_ = %out_buf: memref<...>) { // ... // linalg.yield %o_ : tensor ... // } // // Becomes // // linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) { // ... // linalg.yield // } struct TiledLoopResultsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, PatternRewriter &rewriter) const final { if (tiledLoop.getNumResults() == 0) return failure(); Block *block = tiledLoop.getBody(); auto yieldOp = cast(block->getTerminator()); // Match the pattern and collect output buffers that will replace the output // tensors and also the ops that will be ignored when cloning the body. SmallVector newOutputOperands, newYieldArgs; int resultId = 0; // Store ids of the corresponding old and new output operands. SmallVector oldOutputIdToNew(tiledLoop.outputs().size(), kNoMatch); // Store ids of the corresponding old and new results. SmallVector oldResultIdToNew(tiledLoop.getNumResults(), kNoMatch); SmallVector resultReplacement(tiledLoop.getNumResults()); for (const auto &en : llvm::enumerate( llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { size_t index = en.index(); Value out = std::get<0>(en.value()); Value outRegionArg = std::get<1>(en.value()); if (!out.getType().isa()) { oldOutputIdToNew[index] = newOutputOperands.size(); newOutputOperands.push_back(out); continue; } Value result = tiledLoop.getResult(resultId); Value yieldArg = yieldOp.getOperand(resultId); if (yieldArg != outRegionArg || !result.use_empty()) { oldOutputIdToNew[index] = newOutputOperands.size(); oldResultIdToNew[resultId] = newYieldArgs.size(); resultReplacement[resultId] = out; newOutputOperands.push_back(out); newYieldArgs.push_back(yieldArg); } ++resultId; } if (newOutputOperands.size() == tiledLoop.outputs().size()) return failure(); Location loc = tiledLoop.getLoc(); auto newTiledLoop = rewriter.create( loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(), tiledLoop.distribution_types()); // Clone the region. BlockAndValueMapping bvm; bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs()); for (const auto &en : llvm::enumerate(oldOutputIdToNew)) { if (en.value() != kNoMatch) bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], newTiledLoop.getRegionOutputArgs()[en.value()]); else bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], tiledLoop.outputs()[en.index()]); } OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : tiledLoop.getBody()->without_terminator()) innerBuilder.clone(op, bvm); innerBuilder.create( loc, llvm::to_vector<2>(llvm::map_range( newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); for (const auto &en : llvm::enumerate(oldResultIdToNew)) if (en.value() != kNoMatch) resultReplacement[en.index()] = newTiledLoop.getResult(en.value()); rewriter.replaceOp(tiledLoop, resultReplacement); return success(); } }; } // namespace void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert, DimOfTiledLoopInsOutsFolder, DimOfTiledLoopResultFolder, DimOfTiledLoopResultFolder>(context); } LogicalResult TiledLoopOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCastInTiledLoopOp(*this); } //===----------------------------------------------------------------------===// // IndexOp //===----------------------------------------------------------------------===// static LogicalResult verify(IndexOp op) { auto linalgOp = dyn_cast(op->getParentOp()); if (!linalgOp) return op.emitOpError("expected parent op with LinalgOp interface"); if (linalgOp.getNumLoops() <= op.dim()) return op.emitOpError("expected dim (") << op.dim() << ") to be lower than the number of loops (" << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; return success(); } /////// Operations corresponding to library calls defined with Tablegen //////// #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. /// Assumes `op` is a LinalgOp. void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res) { if (!cast(op).iterator_types()) return; unsigned dim = 0; for (auto tn : cast(op).iterator_types().getAsValueRange()) { if (tn == iteratorTypeName) res.push_back(dim); ++dim; } } AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, unsigned rank, MLIRContext *context) { if (maybeMap) return maybeMap.getValue(); if (rank == 0) return AffineMap::get(context); return AffineMap::getMultiDimIdentityMap(rank, context); } SmallVector mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context) { SmallVector res; res.reserve(num); for (unsigned i = 0; i < num; ++i) res.push_back(getAffineDimExpr(startIdx++, context)); return res; } SmallVector mlir::linalg::concat(ArrayRef a, ArrayRef b) { auto rangeA = llvm::make_range(a.begin(), a.end()); auto rangeB = llvm::make_range(b.begin(), b.end()); auto concatRanges = llvm::concat(rangeA, rangeB); return llvm::to_vector<4>(concatRanges); } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { if (auto memref = t.dyn_cast()) { ss << "view"; for (auto size : memref.getShape()) if (size < 0) ss << "sx"; else ss << size << "x"; appendMangledType(ss, memref.getElementType()); } else if (auto vec = t.dyn_cast()) { ss << "vector"; llvm::interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); appendMangledType(ss, vec.getElementType()); } else if (t.isSignlessIntOrIndexOrFloat()) { ss << t; } else { llvm_unreachable("Invalid type for linalg library name mangling"); } } std::string mlir::linalg::generateLibraryCallName(Operation *op) { assert(isa(op)); std::string name(op->getName().getStringRef().str()); name.reserve(128); std::replace(name.begin(), name.end(), '.', '_'); llvm::raw_string_ostream ss(name); ss << "_"; auto types = op->getOperandTypes(); llvm::interleave( types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, [&]() { ss << "_"; }); return ss.str(); } //===----------------------------------------------------------------------===// // Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// /// Generic entry point to create the block for the region of a LinalgOp. /// This is used by both named structured ops created by ods-gen and by manually /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted /// to be ShapedType. template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, llvm::function_ref errorHandler) { assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. SmallVector argTypes; for (auto containers : {inputTypes, outputTypes}) for (auto t : containers) argTypes.push_back(getElementTypeOrSelf(t)); // RAII. OpBuilder::InsertionGuard guard(opBuilder); Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) { if (errorHandler) errorHandler(expected, actual); return; } opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); NamedStructuredOpType::regionBuilder(b, *body); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } /// Generic entry point to create both the region and the block of a LinalgOp. template void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes) { Region ®ion = *result.addRegion(); fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { assert(expected != actual && "incorrect number of arguments"); }); } /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes) { llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; SmallVector inputsOperands, outputsOperands; parser.parseOptionalAttrDict(result.attributes); if (succeeded(parser.parseOptionalKeyword("ins"))) { if (parser.parseLParen()) return failure(); inputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseOperandList(inputsOperands) || parser.parseColonTypeList(inputTypes) || parser.parseRParen()) return failure(); } if (succeeded(parser.parseOptionalKeyword("outs"))) { outputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || parser.parseColonTypeList(outputTypes) || parser.parseRParen()) return failure(); } if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, result.operands) || parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, result.operands)) return failure(); result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr( {static_cast(inputsOperands.size()), static_cast(outputsOperands.size())})); return success(); } template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op) { if (!op.inputs().empty()) p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; if (!op.outputs().empty()) p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; } //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes) { ParseResult res = success(); OpBuilder opBuilder(parser.getContext()); // Resolve `captures` into `capturedValues` at parse time so we can build the // region with captures. SmallVector capturedValues; fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { res = parser.emitError( parser.getCurrentLocation(), llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " "region expects {0} args, got {1}", expected, actual)); region.front().dump(); }); return res; } static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes) { if (parser.parseOptionalArrowTypeList(resultTypes)) return failure(); return success(); } template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result) { // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( parser, *region, inputTypes, outputTypes)) return failure(); result.addRegion(std::move(region)); return success(); } static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes) { if (resultTypes.empty()) return; p.printOptionalArrowTypeList(resultTypes); } template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{"operand_segment_sizes", // See generated code in mlir-linalg-yaml-gen.cpp "linalg.memoized_indexing_maps"}); // Printing is shared with generic ops, except for the region and // attributes. printCommonStructuredOpParts(p, op); // Results printing. printNamedStructuredOpResults(p, op.result_tensors().getTypes()); // Region is elided. } template static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// // Canonicalizers and Folders. //===----------------------------------------------------------------------===// namespace { struct EraseDeadLinalgOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { for (OpOperand *opOperand : op.getInputAndOutputOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. auto mt = opOperand->get().getType().dyn_cast(); if (!mt) continue; if (llvm::is_contained(op.getShape(opOperand), 0)) { rewriter.eraseOp(op); return success(); } } return failure(); } }; struct FoldTensorCastOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { if (opOperand->get().isa()) return false; auto castOp = opOperand->get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) return failure(); SmallVector newResultTypes; newResultTypes.reserve(op->getNumResults()); SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. for (OpOperand *opOperand : op.getInputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : opOperand->get()); } // Init tensors may fold, in which case the resultType must also change. for (OpOperand *opOperand : op.getOutputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand->get()); newResultTypes.push_back(newOperands.back().getType()); } // Clone op. Operation *newOp = op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { Value oldResult = std::get<0>(result); Value newResult = std::get<1>(result); if (newResult.getType() != oldResult.getType()) { replacements.push_back(rewriter.create( op->getLoc(), oldResult.getType(), newResult)); } else { replacements.push_back(newResult); } } rewriter.replaceOp(op, replacements); return success(); } }; } // namespace #define LINALGOP_FOLDERS(XXX) \ LogicalResult XXX::fold(ArrayRef, \ SmallVectorImpl &) { \ return foldMemRefCast(*this); \ } LINALGOP_FOLDERS(CopyOp) LINALGOP_FOLDERS(FillOp) LINALGOP_FOLDERS(GenericOp) // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); }