//===- LinalgOps.cpp - Implementation of the linalg operations ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Linalg operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::linalg; /// Forward declarations. /// Generic entry point to create the block for the region of a LinalgOp. /// This is used by both named structured ops created by ods-gen and by manually /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding /// to the elemental types of `inputTypes` and `outputTypes`. The latter are /// asserted to be of ShapedType. template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, std::function errorHandler = nullptr); /// Generic entry point to create both the region and the block of a LinalgOp. template static void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes); /// Common parsing and printing used for both named structured ops created by /// ods-gen and by manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes); template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op); /// Specific parsing and printing for named structured ops created by ods-gen. template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); /// Helper function to convert a Value into an OpFoldResult, if the Value is /// known to be a constant index value. static SmallVector getAsOpFoldResult(ArrayRef values) { return llvm::to_vector<4>( llvm::map_range(values, [](Value v) -> OpFoldResult { APInt intValue; if (v.getType().isa() && matchPattern(v, m_ConstantInt(&intValue))) { return IntegerAttr::get(v.getType(), intValue.getSExtValue()); } return v; })); } /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. static SmallVector getAsValues(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec) { return llvm::to_vector<4>( llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { if (auto attr = value.dyn_cast()) return b.create(loc, attr.cast().getInt()); return value.get(); })); } /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast(%src)) -> someop(%src) /// ``` /// It folds the source of the memref.cast into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } /// This is a specialization of `foldMemRefCast` used for patterns of the form /// ``` /// tiled_loop(memrefcast(%src)) -> tiled_loop(%src) /// ``` /// It folds the source of the memref.cast into the root operation directly. static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) { bool folded = false; Location loc = op->getLoc(); Block *body = op.getBody(); OpBuilder b = OpBuilder::atBlockBegin(body); // Update `input` and `output` operands and block arguments if necessary. // Operands list: [lbs, ubs, steps, inputs, outputs]. // Block args list: [ivs, inputs, outputs]. for (size_t operandIndex = op.getNumControlOperands(), bbArgIndex = op.getNumLoops(), e = op.getNumOperands(); operandIndex < e; ++operandIndex, ++bbArgIndex) { OpOperand &operand = op->getOpOperand(operandIndex); auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); BlockArgument newBbArg = body->insertArgument(bbArgIndex, castOp.getOperand().getType()); BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1); // Insert memref.cast back to the original type. oldBbArg.replaceAllUsesWith( b.create(loc, oldBbArg.getType(), newBbArg)); body->eraseArgument(oldBbArg.getArgNumber()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. // The public methods on this class are referenced directly from generated code // and bind by name to math functions in the DSL as: // `applyfn__{fnName}` // Examples: // `applyfn__add` // `applyfn__mul` // The naming convention is intentional in order to match snake-cased DSL names. // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. // // Implementations of the math functions must be polymorphic over numeric types, // internally performing necessary casts. If the function application makes no // sense, then the only recourse is to assert and return nullptr. This can be // extended later if it becomes possible to fail construction of the region. The // invariant should be enforced at a higher level. // // TODO: These helpers are currently type polymorphic over the class of integer // and floating point types, but they will not internally cast within bit // widths of a class (mixed precision such as i8->i32) or across classes // (i.e. mixed float and integer). Many such combinations are ambiguous or need // to be handled with care and work is being considered to extend the op // language to make such cases explicit. In the mean-time, violating this will // fail verification, which is deemed acceptable. //===----------------------------------------------------------------------===// namespace { class RegionBuilderHelper { public: RegionBuilderHelper(MLIRContext *context, Block &block) : context(context), block(block) {} // Generates operations to cast the given operand to a specified type. // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand) { OpBuilder builder = getBuilder(); auto loc = operand.getLoc(); if (operand.getType() == toType) return operand; if (auto toIntType = toType.dyn_cast()) { // If operand is floating point, cast directly to the int type. if (operand.getType().isa()) return builder.create(loc, toType, operand); // Cast index operands directly to the int type. if (operand.getType().isIndex()) return builder.create(loc, toType, operand); if (auto fromIntType = operand.getType().dyn_cast()) { // Either sign extend or truncate. if (toIntType.getWidth() > fromIntType.getWidth()) return builder.create(loc, toType, operand); if (toIntType.getWidth() < fromIntType.getWidth()) return builder.create(loc, toType, operand); } } else if (auto toFloatType = toType.dyn_cast()) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. if (operand.getType().isa()) return builder.create(loc, toFloatType, operand); if (auto fromFloatType = operand.getType().dyn_cast()) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); if (toFloatType.getWidth() < fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); } } emitWarning(operand.getLoc()) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand; } Value applyfn__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__sub(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__mul(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } void yieldOutputs(ValueRange values) { assert(!values.empty() && "linalg ops must yield outputs"); if (values.empty()) return; Value first = values.front(); OpBuilder builder = getBuilder(); builder.create(first.getLoc(), values); } Value constant(std::string value) { OpBuilder builder = getBuilder(); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); return builder.create(loc, valueAttr.getType(), valueAttr); } Value index(int64_t dim) { OpBuilder builder = getBuilder(); return builder.create(builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { return IntegerType::get(context, width); } Type getFloat32Type() { return Float32Type::get(context); } Type getFloat64Type() { return Float64Type::get(context); } private: MLIRContext *context; Block █ bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } OpBuilder getBuilder() { OpBuilder builder(context); builder.setInsertionPointToEnd(&block); return builder; } }; } // namespace //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, Value output, AffineMap inputPermutation, AffineMap outputPermutation, ArrayRef namedAttrs) { result.addOperands({input, output}); result.addAttributes(namedAttrs); if (inputPermutation) result.addAttribute("inputPermutation", AffineMapAttr::get(inputPermutation)); if (outputPermutation) result.addAttribute("outputPermutation", AffineMapAttr::get(outputPermutation)); result.addRegion(); fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{input.getType()}, TypeRange{output.getType()}); } ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, Type outputType) { OpBuilder opBuilder(parser.getBuilder().getContext()); fillStructuredOpRegion(opBuilder, r, TypeRange{inputType}, TypeRange{outputType}); return success(); } /// CopyOp region is elided when printing. void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} static LogicalResult verify(CopyOp op) { OpOperand *output = op.getOutputOperand(0); OpOperand *input = op.getInputOperand(0); if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get())) return op.emitOpError("expects views of the same type"); if (op.getRank(input) != op.getRank(output)) return op.emitOpError("expects views of the same rank"); auto rank = op.getNumParallelLoops(); auto inputPermutationMap = op.inputPermutation(); if (inputPermutationMap) { if (inputPermutationMap->getNumInputs() != rank) return op.emitOpError("expects optional input_permutation map of rank ") << rank; if (!inputPermutationMap->isPermutation()) return op.emitOpError( "expects optional input_permutation map to be a permutation"); } auto outputPermutationMap = op.outputPermutation(); if (outputPermutationMap) { if (outputPermutationMap->getNumInputs() != rank) return op.emitOpError("expects optional output_permutation map of rank ") << rank; if (!outputPermutationMap->isPermutation()) return op.emitOpError( "expects optional output_permutation map to be a permutation"); } if (rank == 0 && inputPermutationMap) return op.emitOpError("expected no input permutation when rank == 0"); if (rank == 0 && outputPermutationMap) return op.emitOpError("expected no output permutation when rank == 0"); return success(); } void CopyOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } void FillOp::build(OpBuilder &builder, OperationState &result, Value value, Value output) { build(builder, result, output.getType().dyn_cast(), value, output); fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{value.getType()}, TypeRange{output.getType()}, {}); } ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, Type outputType) { OpBuilder opBuilder(parser.getBuilder().getContext()); fillStructuredOpRegion(opBuilder, r, TypeRange{valueType}, TypeRange{outputType}); return success(); } /// FillOp region is elided when printing. void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} static LogicalResult verify(FillOp op) { OpOperand *output = op.getOutputOperand(0); Type fillType = op.value().getType(); if (getElementTypeOrSelf(output->get()) != fillType) return op.emitOpError("expects fill type to match view elemental type"); if (!op.getNumResults() && !output->get().getType().isa()) { return op.emitOpError( "expected fill op with no result value to use memref type"); } return success(); } void FillOp::getEffects( SmallVectorImpl> &effects) { if (output().getType().isa()) effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), doc.empty() ? StringAttr() : builder.getStringAttr(doc), libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall)); if (!bodyBuild) return; SmallVector blockArgTypes; for (ValueRange container : {inputs, outputs}) for (Value v : container) blockArgTypes.push_back(getElementTypeOrSelf(v)); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); bodyBuild(builder, result.location, bodyBlock->getArguments()); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild); } void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild); } static void print(OpAsmPrinter &p, GenericOp op) { p << op.getOperationName() << " "; // Print extra attributes. auto genericAttrNames = op.linalgTraitAttrNames(); llvm::StringSet<> genericAttrNamesSet; genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; for (auto attr : op->getAttrs()) if (genericAttrNamesSet.count(attr.first.strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); p << genericDictAttr; } // Printing is shared with named ops, except for the region and attributes printCommonStructuredOpParts(p, op); genericAttrNames.push_back("operand_segment_sizes"); genericAttrNamesSet.insert(genericAttrNames.back()); bool hasExtraAttrs = false; for (NamedAttribute n : op->getAttrs()) { if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref()))) break; } if (hasExtraAttrs) { p << " attrs = "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames); } // Print region. if (!op.region().empty()) p.printRegion(op.region()); // Print results. printNamedStructuredOpResults(p, op.result_tensors().getTypes()); } static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. // The name is unimportant as we will overwrite result.attributes. // The core linalg traits must contain the information necessary to pass the // verifier. if (parser.parseAttribute(dictAttr, "_", result.attributes)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); // Parsing is shared with named ops, except for the region. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // Optional attributes may be added. if (succeeded(parser.parseOptionalKeyword("attrs"))) if (failed(parser.parseEqual()) || failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); SmallVector regionOperands; std::unique_ptr region = std::make_unique(); SmallVector operandTypes, regionTypes; if (parser.parseRegion(*region, regionOperands, regionTypes)) return failure(); result.addRegion(std::move(region)); // Generic ops may specify that a subset of its outputs are tensors. Such // outputs are specified in the result type. // TODO: may need to move output parsing before region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); return success(); } static void getGenericEffectsImpl( SmallVectorImpl> &effects, ValueRange results, ValueRange inputBuffers, ValueRange outputs) { for (Value value : results) { effects.emplace_back(MemoryEffects::Allocate::get(), value, SideEffects::DefaultResource::get()); } for (Value value : inputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } for (Value value : outputs) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, SideEffects::DefaultResource::get()); } } void GenericOp::getEffects( SmallVectorImpl> &effects) { SmallVector inputBuffers = getInputBufferOperands(); SmallVector outputBuffers = getOutputBufferOperands(); getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, outputBuffers); } template static LogicalResult verifyGenericOp(GenericOpType op) { return success(); } static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// void InitTensorOp::build(OpBuilder &b, OperationState &result, ArrayRef sizes, Type elementType, ArrayRef attrs) { unsigned rank = sizes.size(); SmallVector dynamicSizes; SmallVector staticSizes; for (unsigned i = 0; i < rank; ++i) { dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, ShapedType::kDynamicSize); } auto resultType = RankedTensorType ::get(staticSizes, elementType); build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); result.addAttributes(attrs); } static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( op.static_sizes().cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), op.static_sizes(), op.sizes(), ShapedType::isDynamic))) return failure(); if (op.static_sizes().size() != static_cast(resultType.getRank())) return op->emitError("expected ") << resultType.getRank() << " sizes values"; Type expectedType = InitTensorOp::inferResultType(staticSizes, resultType.getElementType()); if (resultType != expectedType) { return op.emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } return success(); } Type InitTensorOp::inferResultType(ArrayRef staticSizes, Type elementType) { return RankedTensorType::get(staticSizes, elementType); } namespace { /// Change the type of the result of a `linalg.init_tensor` by making the result /// type statically sized along dimension that in the original operation where /// defined as dynamic, but the size was defined using a `constant` op. For /// example /// /// %c5 = constant 5: index /// %0 = linalg.init_tensor [%arg0, %c5] : tensor /// /// to /// /// %0 = linalg.init_tensor [%arg0, 5] : tensor struct ReplaceStaticShapeDims : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InitTensorOp op, PatternRewriter &rewriter) const override { SmallVector dynamicSizes; SmallVector staticSizes; for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { // If the size is already static, nothing to do. if (!op.isDynamicSize(i)) { staticSizes.push_back(op.getStaticSize(i)); continue; } // If the size is dynamic but defined using a `constant` op, get the // constant value to find the static size to use. unsigned operandNum = op.getIndexOfDynamicSize(i); Value sizeOperand = op.getOperand(operandNum); if (auto constantIndexOp = sizeOperand.getDefiningOp()) { staticSizes.push_back(constantIndexOp.getValue()); continue; } // Fallback case. Keep the size dynamic. dynamicSizes.push_back(sizeOperand); staticSizes.push_back(ShapedType::kDynamicSize); } RankedTensorType newType = RankedTensorType::get(staticSizes, op.getType().getElementType()); if (newType == op.getType()) return failure(); auto newOp = rewriter.create(op.getLoc(), newType, dynamicSizes, rewriter.getI64ArrayAttr(staticSizes)); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } }; } // namespace namespace { /// Since `init_tensor` operation creates a tensor needed only for its shape, a /// slice of this is also needed only for its shape. The result can be /// replaced by a new init_tensor operation of the same size as the extract /// slice op. struct FoldInitTensorWithExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { if (!sliceOp.source().getDefiningOp()) return failure(); rewriter.replaceOpWithNewOp( sliceOp, sliceOp.sizes(), llvm::to_vector<4>(llvm::map_range( sliceOp.static_sizes(), [](Attribute attr) { return attr.cast().getInt(); })), sliceOp.getSourceType().getElementType()); return success(); } }; template struct FoldInitTensorWithTensorReshapeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { if (!reshapeOp.src().template getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); SmallVector, 4> resultShapes; if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); Value initTensor = rewriter.create( loc, getAsOpFoldResult(resultShapes[0]), reshapeOp.getResultType().getElementType()); if (initTensor.getType() != reshapeOp.getResultType()) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), initTensor); } else { rewriter.replaceOp(reshapeOp, initTensor); } return success(); } }; } // namespace void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); } LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim( OpBuilder &builder, SmallVectorImpl> &reifiedReturnShapes) { auto shapes = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { if (isDynamicSize(dim)) return getDynamicSize(dim); return builder.create(getLoc(), getStaticSize(dim)); })); reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } //===----------------------------------------------------------------------===// // PadTensorOp //===----------------------------------------------------------------------===// static LogicalResult verify(PadTensorOp op) { auto sourceType = op.source().getType().cast(); auto resultType = op.result().getType().cast(); auto expectedType = PadTensorOp::inferResultType( sourceType, extractFromI64ArrayAttr(op.static_low()), extractFromI64ArrayAttr(op.static_high())); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; if (expectedType.isDynamicDim(i)) continue; return op.emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } auto ®ion = op.region(); unsigned rank = resultType.getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) return op.emitError("expected the block to have ") << rank << " arguments"; // Note: the number and type of yield values are checked in the YieldOp. for (auto en : llvm::enumerate(block.getArgumentTypes())) { if (!en.value().isIndex()) return op.emitOpError("expected block argument ") << (en.index() + 1) << " to be an index"; } return success(); } RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, ArrayRef staticLow, ArrayRef staticHigh) { unsigned rank = sourceType.getRank(); assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); SmallVector resultShape; for (auto i : llvm::seq(0, rank)) { if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamicSize || staticHigh[i] == ShapedType::kDynamicSize) { resultShape.push_back(ShapedType::kDynamicSize); } else { int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; resultShape.push_back(size); } } return RankedTensorType::get(resultShape, sourceType.getElementType()); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, ValueRange high, ArrayRef attrs) { auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh)); result.addAttributes(attrs); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange low, ValueRange high, ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamicSize); build(b, result, source, staticVector, staticVector, low, high, attrs); } void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, Value source, ArrayRef low, ArrayRef high, ArrayRef attrs) { assert(resultType.isa()); auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; for (unsigned i = 0; i < rank; ++i) { // staticLow and staticHigh have full information of the padding config. // This will grow staticLow and staticHigh with 1 value. If the config is // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 // value as well. dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow, ShapedType::kDynamicSize); dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh, ShapedType::kDynamicSize); } if (!resultType) { resultType = PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); } build(b, result, resultType, source, dynamicLow, dynamicHigh, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh)); } PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, ArrayRef low, ArrayRef high, Location loc, OpBuilder &builder) { auto padTensorOp = builder.create(loc, type, source, low, high); int rank = padTensorOp.getResultType().getRank(); SmallVector blockArgTypes; blockArgTypes.assign(rank, builder.getIndexType()); auto ®ion = padTensorOp.region(); // `builder.createBlock` changes the insertion point within the block. Create // a guard to reset the insertion point of the builder after it is destroyed. OpBuilder::InsertionGuard guard(builder); builder.createBlock(®ion, region.end(), blockArgTypes); builder.create(loc, pad); return padTensorOp; } PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, Location loc, OpBuilder &builder) { SmallVector low, high; auto rankedTensorType = type.cast(); assert(rankedTensorType.hasStaticShape()); int rank = rankedTensorType.getRank(); for (int i = 0; i < rank; ++i) { auto dimOp = builder.createOrFold(loc, source, i); auto resultDimSize = builder.createOrFold( loc, rankedTensorType.getDimSize(i)); auto highValue = builder.createOrFold(loc, resultDimSize, dimOp); high.push_back(highValue); low.push_back(builder.createOrFold(loc, 0)); } return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc, builder); } LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); SmallVector shapes; for (auto dim : llvm::seq(0, getSourceType().getRank())) { // Shape along each dimension is source dim + low pad + high pad. SmallVector mapOperands; mapOperands.push_back(b.createOrFold(loc, source(), dim)); AffineExpr expr = b.getAffineDimExpr(0); unsigned numSymbols = 0; auto addOpFoldResult = [&](OpFoldResult valueOrAttr) { if (Value v = valueOrAttr.dyn_cast()) { expr = expr + b.getAffineSymbolExpr(numSymbols++); mapOperands.push_back(v); return; } int64_t staticValue = valueOrAttr.get().cast().getInt(); expr = expr + staticValue; }; addOpFoldResult(lowPad[dim]); addOpFoldResult(highPad[dim]); shapes.push_back(applyMapToValues( b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); } reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } namespace { // Folds linalg.pad_tensor when padding is static zeros. struct FoldStaticZeroPadding : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PadTensorOp padTensorOp, PatternRewriter &rewriter) const override { if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) return failure(); rewriter.replaceOpWithNewOp( padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); return success(); } }; } // namespace void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } /// Return the padding value of the PadTensorOp if it constant. In this context, /// "constant" means an actual constant or "defined outside of the block". /// /// Values are considered constant in three cases: /// - A ConstantLike value. /// - A basic block argument from a different block. /// - A value defined outside of the block. /// /// If the padding value is not constant, an empty Value is returned. Value PadTensorOp::getConstantPaddingValue() { auto yieldOp = dyn_cast(getRegion().front().getTerminator()); if (!yieldOp || yieldOp.values().size() != 1) return {}; Value padValue = yieldOp.values().front(); // Check if yield value is a constant. if (matchPattern(padValue, m_Constant())) return padValue; // Check if yield value is defined inside the PadTensorOp block. if (padValue.getParentBlock() == &getRegion().front()) return {}; // Else: Yield value defined outside of the PadTensorOp block. return padValue; } OpFoldResult PadTensorOp::fold(ArrayRef) { if (getResultType().hasStaticShape() && getResultType() == getSourceType()) return source(); return {}; } //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// Optional> mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { // Make the sourceType greater rank than the targetType. If they are same // rank, then its an unsupported reshape op. if (sourceType.getRank() == targetType.getRank()) return llvm::None; if (sourceType.getRank() < targetType.getRank()) std::swap(sourceType, targetType); ArrayRef sourceShape = sourceType.getShape(); ArrayRef targetShape = targetType.getShape(); unsigned sourceDim = 0; SmallVector reassociationMap; reassociationMap.reserve(targetType.getRank()); ReassociationIndices currIndices; int64_t prodOfCollapsedDims = 1; while (sourceDim < sourceShape.size()) { unsigned targetDim = reassociationMap.size(); // If all the dimensions of the targetShape are exhausted, then the // remaining dims in the source shape must be all 1s. So for such cases, set // 1 as the target shape. The actual reassociation indices will be handled // later. int64_t currTargetShape = (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); while (sourceShape[sourceDim] != ShapedType::kDynamicSize && prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && sourceDim < sourceShape.size()) { prodOfCollapsedDims *= sourceShape[sourceDim]; currIndices.push_back(sourceDim++); } // If the current expanded dimension is dynamic, then the collapsed // dimensions should also be dynamic and product of all previous unprocessed // dimensions of the expanded shape should be 1. if (sourceShape[sourceDim] == ShapedType::kDynamicSize && (currTargetShape != ShapedType::kDynamicSize || prodOfCollapsedDims != 1)) return llvm::None; // If the collapsed dim is dynamic, the current expanded dim should also // be dynamic. if (currTargetShape == ShapedType::kDynamicSize && sourceShape[sourceDim] != ShapedType::kDynamicSize) return llvm::None; // For static shapes, if the product of dimensions of the expanded shape // should match the collapsed dimension shape. if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) return llvm::None; currIndices.push_back(sourceDim++); // If the reassociation is empty but the currIndices is not, this by // definition is folding unit-dimensions with the result being scalar type. // So only append the `currIndices` if reassociation map is not empty. if (targetDim == targetShape.size()) { if (!reassociationMap.empty() && !currIndices.empty()) reassociationMap.back().append(currIndices.begin(), currIndices.end()); // Break out of the loops. We should be done here. break; } reassociationMap.emplace_back(ReassociationIndices{}); std::swap(reassociationMap.back(), currIndices); prodOfCollapsedDims = 1; } // All the dimensions in the two shapes must have been processed. if (reassociationMap.size() != targetShape.size() || sourceDim != sourceShape.size()) return llvm::None; return reassociationMap; } template static void print(OpAsmPrinter &p, ReshapeLikeOp op) { p << op.getOperationName() << ' ' << op.src() << " ["; llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) { p << '['; auto arrayAttr = attr.template cast(); llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) { p << attr.cast().getInt(); }); p << ']'; }); p << "] "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{op.getReassociationAttrName()}); p << ": " << op.src().getType() << " into " << op.getType(); } static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) { print(p, op); } static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) { print(p, op); } static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) { print(p, op); } static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) { print(p, op); } static constexpr StringRef getReassociationAttrName() { return "reassociation"; } static ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result) { // Parse the operand. OpAsmParser::OperandType src; if (parser.parseOperand(src)) return failure(); // Parse reassociation indices. Builder &b = parser.getBuilder(); SmallVector reassociation; if (parser.parseLSquare()) return failure(); while (true) { if (succeeded(parser.parseOptionalRSquare())) break; if (parser.parseLSquare()) return failure(); SmallVector indices; while (true) { int64_t index; if (parser.parseInteger(index)) return failure(); indices.push_back(index); if (succeeded(parser.parseOptionalComma())) continue; if (failed(parser.parseRSquare())) return failure(); break; } reassociation.push_back(b.getI64ArrayAttr(indices)); if (succeeded(parser.parseOptionalComma())) continue; if (failed(parser.parseRSquare())) return failure(); break; } result.addAttribute(getReassociationAttrName(), b.getArrayAttr(reassociation)); // Parse optional attributes. parser.parseOptionalAttrDict(result.attributes); // Parse types. Type srcType; Type resultType; if (parser.parseColon() || parser.parseType(srcType) || parser.resolveOperand(src, srcType, result.operands) || parser.parseKeyword("into") || parser.parseType(resultType)) return failure(); result.addTypes(resultType); return success(); } /// Collapse reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are /// expanding dimensions. /// /// For example, /// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, /// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] /// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, /// affine_map<(d0, d1, d2) -> (d2)>] /// /// is folded into /// /// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] static Optional> collapseReassociationIndices(ArrayRef mapsProducer, ArrayRef mapsConsumer, MLIRContext *context) { // Make the producer the larger sized vector. If they are of same size, the // resulting reshape is not a supported reshape op. if (mapsProducer.size() == mapsConsumer.size()) return llvm::None; if (mapsProducer.size() < mapsConsumer.size()) std::swap(mapsProducer, mapsConsumer); // Handle the corner case of the result being a rank 0 shaped type. Return an // empty reassociation. if (mapsConsumer.empty()) return SmallVector{}; if (mapsProducer.size() != mapsConsumer[0].getNumDims()) return llvm::None; unsigned currDim = 0; SmallVector reassociationMaps; for (AffineMap rhs : mapsConsumer) { ReassociationIndices reassociations; for (AffineExpr rhsExpr : rhs.getResults()) { AffineDimExpr dimExpr = rhsExpr.cast(); for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); i < e; ++i) reassociations.push_back(currDim++); } reassociationMaps.push_back(std::move(reassociations)); } return reassociationMaps; } namespace { /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. template struct CollapseReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { auto srcReshapeOp = reshapeOp.src().template getDefiningOp(); if (!srcReshapeOp) return failure(); ShapedType resultType = reshapeOp.getResultType(); Optional> reassociationIndices = collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), reshapeOp.getReassociationMaps(), rewriter.getContext()); if (!reassociationIndices) return failure(); rewriter.replaceOpWithNewOp( reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); return success(); } }; /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. template struct CollapseMixedReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { auto srcReshapeOp = reshapeOp.src().template getDefiningOp(); if (!srcReshapeOp) return failure(); ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); ShapedType intermediateType = reshapeOp.getSrcType(); ShapedType resultType = reshapeOp.getResultType(); // If the source reshape can be collapsed/expanded into the target reshape // they can still be folded. This can only be reasoned about statically // for cases where // - either all shapes are static, or // - The number of dynamic dimensions matches in the source of source and // result with all other dimensions being 1. Optional> reassociationIndices = getReassociationIndicesForReshape(srcReshapeSrcType, resultType); if (!reassociationIndices) return failure(); bool originalOpExpands = intermediateType.getRank() > srcReshapeSrcType.getRank(); bool resultingOpExpands = resultType.getRank() > srcReshapeSrcType.getRank(); if (!(resultingOpExpands ^ originalOpExpands)) rewriter.replaceOpWithNewOp( reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); else rewriter.replaceOpWithNewOp( reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); return success(); } }; } // namespace template static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef operands) { // Fold producer-consumer reshape ops that where the operand type of the // producer is same as the return type of the consumer. auto reshapeSrcOp = reshapeOp.src().template getDefiningOp(); if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) return reshapeSrcOp.src(); // Reshape of a constant can be replaced with a new constant. if (auto elements = operands.front().dyn_cast_or_null()) { return elements.reshape( reshapeOp.getResult().getType().template cast()); } return nullptr; } /// Return true if the reassociation specification is valid, false otherwise. /// When false, the `invalidIndex` integer pointer is optionally filled with the /// index of the offending reassociation map. static bool isReassociationValid(ArrayRef reassociation, int *invalidIndex = nullptr) { if (reassociation.empty()) return true; unsigned nDims = reassociation[0].getNumDims(); unsigned nextExpectedDim = 0; for (auto it : llvm::enumerate(reassociation)) { auto m = it.value(); if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { if (invalidIndex) *invalidIndex = it.index(); return false; } for (auto e : m.getResults()) { auto d = e.dyn_cast(); if (!d || d.getPosition() != nextExpectedDim++) { if (invalidIndex) *invalidIndex = it.index(); return false; } } } if (nextExpectedDim != nDims) { if (invalidIndex) *invalidIndex = reassociation.size() - 1; return false; } return true; } /// Detect whether memref dims [dim, dim + extent) can be reshaped without /// copies. static bool isReshapableDimBand(unsigned dim, unsigned extent, ArrayRef sizes, ArrayRef strides) { assert(sizes.size() == strides.size() && "mismatched ranks"); // off by 1 indexing to avoid out of bounds // V for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { // Only bands of static shapes are reshapable. This is due to the fact that // there is no relation between dynamic sizes and dynamic strides: we do not // have enough information to know whether a "-1" size corresponds to the // proper symbol in the AffineExpr of a stride. if (ShapedType::isDynamic(sizes[dim + 1])) return false; // TODO: Refine this by passing the proper nDims and nSymbols so we can // simplify on the fly and catch more reshapable cases. if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) return false; } return true; } /// Compute the MemRefType obtained by applying the `reassociation` (which is /// expected to be valid) to `type`. /// If `type` is Contiguous MemRefType, this always produce a contiguous /// MemRefType. static MemRefType computeReshapeCollapsedType(MemRefType type, ArrayRef reassociation) { auto sizes = type.getShape(); AffineExpr offset; SmallVector strides; auto status = getStridesAndOffset(type, strides, offset); (void)status; assert(succeeded(status) && "expected strided memref"); SmallVector newSizes; newSizes.reserve(reassociation.size()); SmallVector newStrides; newStrides.reserve(reassociation.size()); // Use the fact that reassociation is valid to simplify the logic: only use // each map's rank. assert(isReassociationValid(reassociation) && "invalid reassociation"); unsigned currentDim = 0; for (AffineMap m : reassociation) { unsigned dim = m.getNumResults(); int64_t size = 1; AffineExpr stride = strides[currentDim + dim - 1]; if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { size = ShapedType::kDynamicSize; stride = AffineExpr(); } else { for (unsigned d = 0; d < dim; ++d) size *= sizes[currentDim + d]; } newSizes.push_back(size); newStrides.push_back(stride); currentDim += dim; } // Early-exit: if `type` is contiguous, the result must be contiguous. if (canonicalizeStridedLayout(type).getAffineMaps().empty()) return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); // Convert back to int64_t because we don't have enough information to create // new strided layouts from AffineExpr only. This corresponds to a case where // copies may be necessary. int64_t intOffset = ShapedType::kDynamicStrideOrOffset; if (auto o = offset.dyn_cast()) intOffset = o.getValue(); SmallVector intStrides; intStrides.reserve(strides.size()); for (auto stride : newStrides) { if (auto cst = stride.dyn_cast_or_null()) intStrides.push_back(cst.getValue()); else intStrides.push_back(ShapedType::kDynamicStrideOrOffset); } auto layout = makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); return canonicalizeStridedLayout( MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); } template unsigned getMaxPosOfType(ArrayRef exprArrays) { unsigned pos = 0; for (const auto &exprs : exprArrays) { for (auto expr : exprs) { expr.walk([&pos](AffineExpr e) { if (auto d = e.dyn_cast()) pos = std::max(pos, d.getPosition()); }); } } return pos; } static SmallVector getSymbolLessAffineMaps(ArrayRef reassociation) { unsigned maxDim = getMaxPosOfType(reassociation); assert(getMaxPosOfType(reassociation) == 0 && "Expected symbol-less expressions"); SmallVector maps; maps.reserve(reassociation.size()); for (const auto &exprs : reassociation) { assert(!exprs.empty()); maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); } return maps; } static SmallVector convertReassociationMapsToIndices( OpBuilder &b, ArrayRef reassociationExprs) { SmallVector reassociationIndices; for (const auto &exprs : reassociationExprs) { ReassociationIndices indices; indices.reserve(exprs.size()); for (const auto &expr : exprs) indices.push_back(expr.cast().getPosition()); reassociationIndices.push_back(indices); } return reassociationIndices; } static SmallVector, 2> convertReassociationIndicesToExprs( OpBuilder &b, ArrayRef reassociationIndices) { SmallVector, 2> reassociationMaps; for (const auto &indices : reassociationIndices) { SmallVector reassociationMap; reassociationMap.reserve(indices.size()); for (int64_t index : indices) reassociationMap.push_back(b.getAffineDimExpr(index)); reassociationMaps.push_back(std::move(reassociationMap)); } return reassociationMaps; } SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector CollapseShapeOp::getReassociationExprs() { OpBuilder b(this->getContext()); return convertReassociationIndicesToExprs(b, getReassociationIndices()); } SmallVector ExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector ExpandShapeOp::getReassociationExprs() { OpBuilder b(this->getContext()); return convertReassociationIndicesToExprs(b, getReassociationIndices()); } SmallVector TensorCollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector TensorCollapseShapeOp::getReassociationExprs() { OpBuilder b(this->getContext()); return convertReassociationIndicesToExprs(b, getReassociationIndices()); } SmallVector TensorExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector TensorExpandShapeOp::getReassociationExprs() { OpBuilder b(this->getContext()); return convertReassociationIndicesToExprs(b, getReassociationIndices()); } /// For reshape op compute the shape at dimension `dimIndex` of the output in /// terms of shape of the `src`, when the reshape op is a collapsing /// operation. It is the product of the shape of the collapsed dimensions of the /// `src`. static OpFoldResult getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef reassociationMap) { AffineMap map = reassociationMap[dimIndex]; unsigned startPos = map.getResults().front().cast().getPosition(); unsigned endPos = map.getResults().back().cast().getPosition(); AffineExpr expr; SmallVector dynamicDims; for (auto dim : llvm::seq(startPos, endPos + 1)) { dynamicDims.push_back(builder.createOrFold(loc, src, dim)); AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); expr = (expr ? expr * currExpr : currExpr); } return applyMapToValues(builder, loc, AffineMap::get(0, endPos - startPos + 1, expr), dynamicDims)[0]; } /// Given the `src` of a collapsing reshape op and its reassociation maps, /// compute the shape of the result of the reshape. static SmallVector getCollapsedOutputShapeFromInputShape( OpBuilder &builder, Location loc, Value src, ArrayRef dstStaticShape, ArrayRef reassociation) { return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, reassociation); })); } /// Compute a map that for a given dimension of the expanded type gives the /// dimension in the collapsed type it maps to. Essentially its the inverse of /// the `reassocation` maps. static llvm::DenseMap getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { llvm::DenseMap expandedDimToCollapsedDim; for (auto map : enumerate(reassociation)) { unsigned startPos = map.value().getResults().front().cast().getPosition(); unsigned endPos = map.value().getResults().back().cast().getPosition(); for (auto dim : llvm::seq(startPos, endPos + 1)) { expandedDimToCollapsedDim[dim] = map.index(); } } return expandedDimToCollapsedDim; } /// For an expanding reshape op, compute the value for a dimension of the output /// from the shape of the input. static OpFoldResult getExpandedOutputDimFromInputShape( OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociation, llvm::DenseMap &expandedDimToCollapsedDim) { if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { return builder.getI64IntegerAttr(dstStaticShape[dimIndex]); } unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; unsigned startPos = reassociation[sourceDimPos] .getResults() .front() .cast() .getPosition(); unsigned endPos = reassociation[sourceDimPos] .getResults() .back() .cast() .getPosition(); int64_t linearizedStaticDim = 1; for (auto d : llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { if (d.index() + startPos == static_cast(dimIndex)) continue; assert(!ShapedType::isDynamic(d.value()) && "single dimension cannot be expanded into multiple dynamic " "dimensions"); linearizedStaticDim *= d.value(); } Value sourceDim = builder.create(loc, src, sourceDimPos); return applyMapToValues( builder, loc, AffineMap::get( 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), sourceDim)[0]; } /// Given the `src` of an expanding reshape op, the reassociation maps and the /// result type, compute the shape of the result of the reshape. static SmallVector getExpandedOutputShapeFromInputShape( OpBuilder &builder, Location loc, Value src, ArrayRef dstStaticShape, ArrayRef reassociation) { llvm::DenseMap expandedDimToCollapsedDim = getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { return getExpandedOutputDimFromInputShape(builder, loc, dim, src, dstStaticShape, reassociation, expandedDimToCollapsedDim); })); } static SmallVector getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef dstStaticShape, ArrayRef reassocation) { return dstStaticShape.size() > static_cast(src.getType().cast().getRank()) ? getExpandedOutputShapeFromInputShape( builder, loc, src, dstStaticShape, reassocation) : getCollapsedOutputShapeFromInputShape( builder, loc, src, dstStaticShape, reassocation); } static ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef reassociation) { SmallVector reassociationAttr = llvm::to_vector<4>(llvm::map_range( reassociation, [&](ReassociationIndices indices) -> Attribute { return b.getI64ArrayAttr(indices).cast(); })); return b.getArrayAttr(reassociationAttr); } void mlir::linalg::ExpandShapeOp::build( OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType( memRefType, getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); } void mlir::linalg::CollapseShapeOp::build( OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType( memRefType, getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); } /// Verify that shapes of the reshaped types using following rules /// 1) if a dimension in the collapsed type is static, then the corresponding /// dimensions in the expanded shape should be /// a) static /// b) the product should be same as the collaped shape. /// 2) if a dimension in the collaped type is dynamic, one and only one of the /// corresponding dimensions in the expanded type should be dynamic. This /// rule is only needed with reshape operations that are expanding. template static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, ShapedType expandedType, bool isExpandingReshape) { ArrayRef collapsedShape = collapsedType.getShape(); ArrayRef expandedShape = expandedType.getShape(); unsigned expandedDimStart = 0; for (auto map : llvm::enumerate(op.getReassociationMaps())) { Optional dynamicShape; int64_t linearizedStaticShape = 1; for (auto dim : llvm::enumerate(expandedShape.slice( expandedDimStart, map.value().getNumResults()))) { if (ShapedType::isDynamic(dim.value())) { if (isExpandingReshape && dynamicShape) { return op->emitOpError("invalid to have a single dimension (") << map.index() << ") expanded into multiple dynamic dims (" << expandedDimStart + dynamicShape.getValue() << "," << expandedDimStart + dim.index() << ")"; } dynamicShape = dim.index(); } else { linearizedStaticShape *= dim.value(); } } if (dynamicShape) { if (!ShapedType::isDynamic(collapsedShape[map.index()])) { return op->emitOpError("expected dimension ") << map.index() << " of collapsed type to be dynamic since one or more of the " "corresponding dimensions in the expanded type is dynamic"; } } else { if (collapsedShape[map.index()] != linearizedStaticShape) { return op->emitOpError("expected dimension ") << map.index() << " of collapsed type to be static value of " << linearizedStaticShape << " "; } } expandedDimStart += map.value().getNumResults(); } return success(); } // Common verifier for reshape-like types. Fills `expandedType` and // `collapsedType` with the proper `src` or `result` type. template ::value || std::is_same::value> static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType) { unsigned expandedRank = expandedType.getRank(); unsigned collapsedRank = collapsedType.getRank(); if (expandedRank < collapsedRank) return op.emitOpError("expected the type ") << expandedType << " to have higher rank than the type = " << collapsedType; if (expandedRank == 0) return op.emitOpError("expected non-zero memref ranks"); if (expandedRank == collapsedRank) return op.emitOpError("expected to collapse or expand dims"); if (collapsedRank == 0) { // If collapsed rank is 0, then expanded type must be static shaped and of // sizes 1. if (llvm::any_of(expandedType.getShape(), [](int64_t dim) -> bool { return dim != 1; })) return op.emitOpError("invalid to reshape tensor/memref with non-unit " "extent dimensions to zero-rank tensor/memref"); return success(); } if (collapsedRank != op.reassociation().size()) return op.emitOpError("expected rank of the collapsed type(") << collapsedRank << ") to be the number of reassociation maps(" << op.reassociation().size() << ")"; auto maps = op.getReassociationMaps(); for (auto it : llvm::enumerate(maps)) if (it.value().getNumDims() != expandedRank) return op.emitOpError("expected reassociation map #") << it.index() << " of same rank as expanded memref(" << expandedRank << "), but got " << it.value().getNumDims(); int invalidIdx = 0; if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous"; return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion); } template static LogicalResult verifyReshapeOp(TensorReshapeOp op, MemRefType expandedType, MemRefType collapsedType) { if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) return failure(); auto maps = op.getReassociationMaps(); MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); if (collapsedType != expectedType) return op.emitOpError("expected collapsed type to be ") << expectedType << ", but got " << collapsedType; return success(); } static LogicalResult verify(ExpandShapeOp op) { return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, CollapseMixedReshapeOps>(context); } static LogicalResult verify(CollapseShapeOp op) { return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, CollapseMixedReshapeOps>(context); } //===----------------------------------------------------------------------===// // TensorReshapeOp //===----------------------------------------------------------------------===// /// Compute the RankedTensorType obtained by applying `reassociation` to `type`. static RankedTensorType computeTensorReshapeCollapsedType(RankedTensorType type, ArrayRef reassociation) { auto shape = type.getShape(); SmallVector newShape; newShape.reserve(reassociation.size()); // Use the fact that reassociation is valid to simplify the logic: only use // each map's rank. assert(isReassociationValid(reassociation) && "invalid reassociation"); unsigned currentDim = 0; for (AffineMap m : reassociation) { unsigned dim = m.getNumResults(); auto band = shape.slice(currentDim, dim); int64_t size = 1; if (llvm::is_contained(band, ShapedType::kDynamicSize)) size = ShapedType::kDynamicSize; else for (unsigned d = 0; d < dim; ++d) size *= shape[currentDim + d]; newShape.push_back(size); currentDim += dim; } return RankedTensorType::get(newShape, type.getElementType()); } void mlir::linalg::TensorCollapseShapeOp::build( OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { auto resultType = computeTensorReshapeCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } void mlir::linalg::TensorExpandShapeOp::build( OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { auto resultType = computeTensorReshapeCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b, reassociation))); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } template static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType) { if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) return failure(); auto maps = op.getReassociationMaps(); RankedTensorType expectedType = computeTensorReshapeCollapsedType(expandedType, maps); if (collapsedType != expectedType) return op.emitOpError("expected collapsed type to be ") << expectedType << ", but got " << collapsedType; return success(); } static LogicalResult verify(TensorExpandShapeOp op) { return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); } static LogicalResult verify(TensorCollapseShapeOp op) { return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); } namespace { /// Reshape of a splat constant can be replaced with a constant of the result /// type. template struct FoldReshapeWithConstant : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { DenseElementsAttr attr; if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) return failure(); if (!attr || !attr.isSplat()) return failure(); DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( reshapeOp.getResultType(), attr.getRawData(), true); rewriter.replaceOpWithNewOp(reshapeOp, newAttr); return success(); } }; /// Fold linalg.fill -> linalg.tensor_reshape chain. /// /// For such op chains, we can create new linalg.fill ops with the result /// type of the linalg.tensor_reshape op. template struct FoldFillWithTensorReshape : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { auto oldFill = reshapeOp.src().template getDefiningOp(); if (!oldFill) return failure(); Location loc = oldFill.getLoc(); auto newInit = rewriter.create( loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.reassociation()); rewriter.replaceOpWithNewOp(reshapeOp, oldFill.value(), newInit); return success(); } }; } // namespace void TensorExpandShapeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results .add, CollapseMixedReshapeOps, FoldFillWithTensorReshape, FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>(context); } void TensorCollapseShapeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results .add, CollapseMixedReshapeOps, FoldFillWithTensorReshape, FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>(context); } LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim( OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { auto resultShape = getAsValues(b, getLoc(), getReshapeOutputShapeFromInputShape( b, getLoc(), src(), getResultType().getShape(), getReassociationMaps())); reifiedReturnShapes.emplace_back(std::move(resultShape)); return success(); } LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim( OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { auto resultShape = getAsValues(b, getLoc(), getReshapeOutputShapeFromInputShape( b, getLoc(), src(), getResultType().getShape(), getReassociationMaps())); reifiedReturnShapes.emplace_back(std::move(resultShape)); return success(); } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, linalg::YieldOp op) { p << op.getOperationName(); if (op.getNumOperands() > 0) p << ' ' << op.getOperands(); p.printOptionalAttrDict(op->getAttrs()); if (op.getNumOperands() > 0) p << " : " << op.getOperandTypes(); } static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { SmallVector opInfo; SmallVector types; llvm::SMLoc loc = parser.getCurrentLocation(); return failure(parser.parseOperandList(opInfo) || parser.parseOptionalAttrDict(result.attributes) || (!opInfo.empty() && parser.parseColonTypeList(types)) || parser.resolveOperands(opInfo, types, loc, result.operands)); } // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { if (op.getNumOperands() != linalgOp.getNumOutputs()) return op.emitOpError("expected number of yield values (") << linalgOp.getNumOutputs() << ") to match the number of operands of the enclosing " << "LinalgOp (" << op.getNumOperands() << ")"; for (OpOperand &opOperand : op->getOpOperands()) { OpOperand *outputOperand = linalgOp.getOutputOperand(opOperand.getOperandNumber()); Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); if (opOperand.get().getType() != elementType) return op.emitOpError("type of yield operand ") << (opOperand.getOperandNumber() + 1) << " (" << opOperand.get().getType() << ") doesn't match " << "the element type of the enclosing linalg.generic op (" << elementType << ")"; } return success(); } static LogicalResult verify(linalg::YieldOp op) { auto *parentOp = op->getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); if (auto linalgOp = dyn_cast(parentOp)) return verifyYield(op, cast(parentOp)); if (auto padTensorOp = dyn_cast(parentOp)) { if (op.getNumOperands() != 1) return op.emitOpError("expected single yield operand (got ") << op->getNumOperands() << ")"; if (op.getOperand(0).getType() != padTensorOp.getType().cast().getElementType()) return op.emitOpError("expected yield type to match shape element type"); return success(); } if (auto tiledLoopOp = dyn_cast(parentOp)) { // Check if output args with tensor types match results types. SmallVector tensorOuts; llvm::copy_if( tiledLoopOp.outputs(), std::back_inserter(tensorOuts), [&](Value out) { return out.getType().isa(); }); if (tensorOuts.size() != op.values().size()) return op.emitOpError("expected number of tensor output args = ") << tensorOuts.size() << " to match the number of yield operands = " << op.values().size(); TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); for (auto &item : llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { Type outType, resultType; unsigned index = item.index(); std::tie(outType, resultType) = item.value(); if (outType != resultType) return op.emitOpError("expected yield operand ") << index << " with type = " << resultType << " to match output arg type = " << outType; } return success(); } return op.emitOpError("expected parent op with LinalgOp interface"); } //===----------------------------------------------------------------------===// // TiledLoopOp //===----------------------------------------------------------------------===// void TiledLoopOp::build(OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange inputs, ValueRange outputs, ArrayAttr iteratorTypes, function_ref bodyBuilderFn) { build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs, iteratorTypes, llvm::None, bodyBuilderFn); } void TiledLoopOp::build(OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange inputs, ValueRange outputs, ArrayAttr iteratorTypes, Optional distributionTypes, function_ref bodyBuilderFn) { result.addOperands(lowerBounds); result.addOperands(upperBounds); result.addOperands(steps); result.addOperands(inputs); result.addOperands(outputs); result.addAttribute( TiledLoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({static_cast(lowerBounds.size()), static_cast(upperBounds.size()), static_cast(steps.size()), static_cast(inputs.size()), static_cast(outputs.size())})); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); if (distributionTypes.hasValue()) result.addAttribute(getDistributionTypesAttrName(), distributionTypes.getValue()); // Add output types for `RankedTensorType` output arguments. for (Value output : outputs) { Type outputType = output.getType(); if (outputType.isa()) result.addTypes(outputType); } OpBuilder::InsertionGuard guard(builder); unsigned numIVs = steps.size(); SmallVector argTypes(numIVs, builder.getIndexType()); for (Type type : TypeRange(inputs)) argTypes.push_back(type); for (Type type : TypeRange(outputs)) argTypes.push_back(type); Region *bodyRegion = result.addRegion(); Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); if (bodyBuilderFn) { builder.setInsertionPointToStart(bodyBlock); bodyBuilderFn(builder, result.location, bodyBlock->getArguments().take_front(numIVs), bodyBlock->getArguments().slice(numIVs, inputs.size()), bodyBlock->getArguments().take_back(outputs.size())); TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } } static void print(OpAsmPrinter &p, TiledLoopOp op) { p << op.getOperationName() << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; if (!op.inputs().empty()) { p << " ins ("; llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it) << ": " << std::get<1>(it).getType(); }); p << ")"; } if (!op.outputs().empty()) { p << " outs ("; llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it) << ": " << std::get<1>(it).getType(); }); p << ")"; } if (llvm::any_of(op.iterator_types(), [](Attribute attr) { return attr.cast().getValue() != getParallelIteratorTypeName(); })) p << " iterators" << op.iterator_types() << ""; if (op.distribution_types().hasValue()) p << " distribution" << op.distribution_types().getValue() << ""; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), getIteratorTypesAttrName(), getDistributionTypesAttrName()}); } static ParseResult parseTiledLoopOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren)) return failure(); // Parse loop bounds. SmallVector lower; if (parser.parseEqual() || parser.parseOperandList(lower, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lower, builder.getIndexType(), result.operands)) return failure(); SmallVector upper; if (parser.parseKeyword("to") || parser.parseOperandList(upper, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upper, builder.getIndexType(), result.operands)) return failure(); // Parse step values. SmallVector steps; if (parser.parseKeyword("step") || parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren) || parser.resolveOperands(steps, builder.getIndexType(), result.operands)) return failure(); // Parse input tensors. SmallVector inputs, input_region_args; SmallVector inputTypes; if (succeeded(parser.parseOptionalKeyword("ins"))) { llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseAssignmentListWithTypes(input_region_args, inputs, inputTypes)) return failure(); if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, result.operands)) return failure(); } // Parse output tensors. SmallVector outputs, output_region_args; SmallVector outputTypes; if (succeeded(parser.parseOptionalKeyword("outs"))) { llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseAssignmentListWithTypes(output_region_args, outputs, outputTypes)) return failure(); if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, result.operands)) return failure(); for (Type outputType : outputTypes) if (outputType.isa()) result.addTypes(outputType); } // Parse attributes. SmallVector iterTypes, distributionTypes; auto parseAttr = [&](StringRef keyword, SmallVector *attrs) { if (succeeded(parser.parseOptionalKeyword(keyword))) { StringAttr attr; if (parser.parseLSquare() || parser.parseAttribute(attr)) return failure(); attrs->push_back(attr); for (int i = 1, e = ivs.size(); i < e; ++i) { if (parser.parseComma() || parser.parseAttribute(attr)) return failure(); attrs->push_back(attr); } if (parser.parseRSquare()) return failure(); } return success(); }; if (failed(parseAttr("iterators", &iterTypes)) || failed(parseAttr("distribution", &distributionTypes))) return failure(); // Set all loop iterator types to "parallel" if they are not printed in IR. if (iterTypes.empty()) { auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName()); iterTypes = SmallVector(ivs.size(), parallelIter); } result.addAttribute(getIteratorTypesAttrName(), builder.getArrayAttr(iterTypes)); if (!distributionTypes.empty()) result.addAttribute(getDistributionTypesAttrName(), builder.getArrayAttr(distributionTypes)); result.addAttribute( TiledLoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({static_cast(lower.size()), static_cast(upper.size()), static_cast(steps.size()), static_cast(inputs.size()), static_cast(outputs.size())})); // Parse the body. Region *body = result.addRegion(); SmallVector region_types(ivs.size(), builder.getIndexType()); region_types.append(inputTypes); region_types.append(outputTypes); SmallVector region_args(ivs); region_args.append(input_region_args); region_args.append(output_region_args); if (parser.parseRegion(*body, region_args, region_types)) return failure(); // Parse optional attributes. parser.parseOptionalAttrDict(result.attributes); return success(); } Region &TiledLoopOp::getLoopBody() { return region(); } LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef ops) { for (auto *op : ops) op->moveBefore(*this); return success(); } bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value.getParentRegion()); } static LogicalResult verify(TiledLoopOp op) { // Check if iterator types are provided for every loop dimension. if (op.iterator_types().size() != op.getNumLoops()) return op.emitOpError("expected iterator types array attribute size = ") << op.iterator_types().size() << " to match the number of loops = " << op.getNumLoops(); // Check if types of input arguments match region args types. for (auto &item : llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { Value input, inputRegionArg; unsigned index = item.index(); std::tie(input, inputRegionArg) = item.value(); if (input.getType() != inputRegionArg.getType()) return op.emitOpError("expected input arg ") << index << " with type = " << input.getType() << " to match region arg " << index + op.getNumLoops() << " type = " << inputRegionArg.getType(); } // Check if types of input arguments match region args types. for (auto &item : llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { Value output, outputRegionArg; unsigned index = item.index(); std::tie(output, outputRegionArg) = item.value(); if (output.getType() != outputRegionArg.getType()) return op.emitOpError("expected output arg ") << index << " with type = " << output.getType() << " to match region arg " << index + op.getNumLoops() + op.inputs().size() << " type = " << outputRegionArg.getType(); } return success(); } namespace { static constexpr int64_t kNoMatch = -1; // Folds away TiledLoopOp inputs if they have no uses within the body. // // Example: // // %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>, // %in_buf_ = %in_buf: memref<...>) {...} // Becomes // // linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...} struct TiledLoopInputsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, PatternRewriter &rewriter) const final { SmallVector newInputs, regionInputTensorArgs; // Store ids of the corresponding old and new input operands. SmallVector oldInputIdToNew(tiledLoop.inputs().size(), kNoMatch); for (auto en : llvm::enumerate( llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) { Value in, bbArg; size_t index = en.index(); std::tie(in, bbArg) = en.value(); if (!bbArg.use_empty()) { oldInputIdToNew[index] = newInputs.size(); newInputs.push_back(in); } } if (newInputs.size() == tiledLoop.inputs().size()) return failure(); Location loc = tiledLoop.getLoc(); auto newTiledLoop = rewriter.create( loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(), tiledLoop.distribution_types()); // Clone the region. BlockAndValueMapping bvm; bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); bvm.map(tiledLoop.getRegionOutputArgs(), newTiledLoop.getRegionOutputArgs()); for (const auto &en : llvm::enumerate(oldInputIdToNew)) if (en.value() != kNoMatch) bvm.map(tiledLoop.getRegionInputArgs()[en.index()], newTiledLoop.getRegionInputArgs()[en.value()]); OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : *tiledLoop.getBody()) innerBuilder.clone(op, bvm); rewriter.replaceOp(tiledLoop, newTiledLoop.getResults()); return success(); } }; // Folds away TiledLoopOp output tensors when the following conditions are met: // * result of `linalg.tiled_loop` has no uses // * output tensor is the argument of `linalg.yield` // // Example: // // %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>, // %obuf_ = %out_buf: memref<...>) { // ... // linalg.yield %o_ : tensor ... // } // // Becomes // // linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) { // ... // linalg.yield // } struct TiledLoopResultsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, PatternRewriter &rewriter) const final { if (tiledLoop.getNumResults() == 0) return failure(); Block *block = tiledLoop.getBody(); auto yieldOp = cast(block->getTerminator()); // Match the pattern and collect output buffers that will replace the output // tensors and also the ops that will be ignored when cloning the body. SmallVector newOutputOperands, newYieldArgs; int resultId = 0; // Store ids of the corresponding old and new output operands. SmallVector oldOutputIdToNew(tiledLoop.outputs().size(), kNoMatch); // Store ids of the corresponding old and new results. SmallVector oldResultIdToNew(tiledLoop.getNumResults(), kNoMatch); SmallVector resultReplacement(tiledLoop.getNumResults()); for (auto en : llvm::enumerate( llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { size_t index = en.index(); Value out = std::get<0>(en.value()); Value outRegionArg = std::get<1>(en.value()); if (!out.getType().isa()) { oldOutputIdToNew[index] = newOutputOperands.size(); newOutputOperands.push_back(out); continue; } Value result = tiledLoop.getResult(resultId); Value yieldArg = yieldOp.getOperand(resultId); if (yieldArg != outRegionArg || !result.use_empty()) { oldOutputIdToNew[index] = newOutputOperands.size(); oldResultIdToNew[resultId] = newYieldArgs.size(); resultReplacement[resultId] = out; newOutputOperands.push_back(out); newYieldArgs.push_back(yieldArg); } ++resultId; } if (newOutputOperands.size() == tiledLoop.outputs().size()) return failure(); Location loc = tiledLoop.getLoc(); auto newTiledLoop = rewriter.create( loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(), tiledLoop.distribution_types()); // Clone the region. BlockAndValueMapping bvm; bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs()); for (const auto &en : llvm::enumerate(oldOutputIdToNew)) { if (en.value() != kNoMatch) bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], newTiledLoop.getRegionOutputArgs()[en.value()]); else bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], tiledLoop.outputs()[en.index()]); } OpBuilder innerBuilder = OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : tiledLoop.getBody()->without_terminator()) innerBuilder.clone(op, bvm); innerBuilder.create( loc, llvm::to_vector<2>(llvm::map_range( newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); for (const auto &en : llvm::enumerate(oldResultIdToNew)) if (en.value() != kNoMatch) resultReplacement[en.index()] = newTiledLoop.getResult(en.value()); rewriter.replaceOp(tiledLoop, resultReplacement); return success(); } }; } // namespace void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } LogicalResult TiledLoopOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCastInTiledLoopOp(*this); } //===----------------------------------------------------------------------===// // IndexOp //===----------------------------------------------------------------------===// static LogicalResult verify(IndexOp op) { auto linalgOp = dyn_cast(op->getParentOp()); if (!linalgOp) return op.emitOpError("expected parent op with LinalgOp interface"); if (linalgOp.getNumLoops() <= op.dim()) return op.emitOpError("expected dim (") << op.dim() << ") to be lower than the number of loops (" << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; return success(); } /////// Operations corresponding to library calls defined with Tablegen //////// template static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, ArrayRef attrs, bool isStride) { auto strideOrDilation = isStride ? "stride" : "dilation"; if (attrs.size() != op.getNumWindowLoops()) return op.emitOpError("expects num ") << strideOrDilation << "s equal to number of window dimensions: " << attrs.size() << " vs " << op.getNumWindowLoops(); return success(); } void ConvOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Read::get(), filter(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); } static LogicalResult verify(ConvOp op) { auto oType = op.output().getType().cast(); auto fType = op.filter().getType().cast(); auto iType = op.input().getType().cast(); if (oType.getElementType() != iType.getElementType() || oType.getElementType() != fType.getElementType()) return op.emitOpError("expects memref elemental types to match"); if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) return op.emitOpError("expects memref ranks to match"); if (auto strides = op.strides()) { if (failed(verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) return failure(); } if (auto dilations = op.dilations()) { if (failed(verifyStrideOrDilation(op, dilations->getValue(), /*isStride=*/false))) return failure(); } return success(); } template static LogicalResult verifySingleInputPoolingOp(PoolingOp op) { auto inputType = op.input().getType().template cast(); auto outputType = op.output().getType().template cast(); if (outputType.getElementType() != inputType.getElementType()) return op.emitOpError("expects memref elemental types to match"); auto windowDimsType = op.windowDims().getType().template cast(); if (outputType.getRank() != inputType.getRank() || outputType.getRank() != windowDimsType.getRank()) return op.emitOpError("expects memref ranks to match"); if (auto strides = op.strides()) { if (failed(verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) return failure(); } if (auto dilations = op.dilations()) { if (failed(verifyStrideOrDilation(op, dilations->getValue(), /*isStride=*/false))) return failure(); } return success(); } #define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ &effects) { \ effects.emplace_back(MemoryEffects::Read::get(), input(), \ SideEffects::DefaultResource::get()); \ effects.emplace_back(MemoryEffects::Write::get(), output(), \ SideEffects::DefaultResource::get()); \ } static LogicalResult verify(PoolingMaxOp op) { return verifySingleInputPoolingOp(op); } static LogicalResult verify(PoolingMinOp op) { return verifySingleInputPoolingOp(op); } static LogicalResult verify(PoolingSumOp op) { return verifySingleInputPoolingOp(op); } DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp) DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp) DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp) #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc" #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. /// Assumes `op` is a LinalgOp. void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res) { if (!cast(op).iterator_types()) return; unsigned dim = 0; MLIRContext *ctx = op->getContext(); for (auto tn : cast(op).iterator_types().getAsValueRange()) { if (tn == iteratorTypeName) res.push_back(getAffineDimExpr(dim, ctx)); ++dim; } } AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, unsigned rank, MLIRContext *context) { if (maybeMap) return maybeMap.getValue(); if (rank == 0) return AffineMap::get(context); return AffineMap::getMultiDimIdentityMap(rank, context); } SmallVector mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context) { SmallVector res; res.reserve(num); for (unsigned i = 0; i < num; ++i) res.push_back(getAffineDimExpr(startIdx++, context)); return res; } template SmallVector mlir::linalg::weightedPoolingInputIndex(PoolingOp op, ArrayRef outputDims, ArrayRef windowDims) { assert(outputDims.size() == windowDims.size()); SmallVector res; res.reserve(outputDims.size()); for (unsigned i = 0, e = outputDims.size(); i < e; ++i) { // TODO: add a level of indirection to linalg.generic. auto expr = op.getStride(i) * outputDims[i] + op.getDilation(i) * windowDims[i] - op.getLowPad(i); res.push_back(expr); } return res; } #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \ template SmallVector \ mlir::linalg::weightedPoolingInputIndex( \ OP_TYPE op, ArrayRef outputDims, \ ArrayRef windowDims); INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp) INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp) INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp) INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp) SmallVector mlir::linalg::concat(ArrayRef a, ArrayRef b) { auto rangeA = llvm::make_range(a.begin(), a.end()); auto rangeB = llvm::make_range(b.begin(), b.end()); auto concatRanges = llvm::concat(rangeA, rangeB); return llvm::to_vector<4>(concatRanges); } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { if (auto memref = t.dyn_cast()) { ss << "view"; for (auto size : memref.getShape()) if (size < 0) ss << "sx"; else ss << size << "x"; appendMangledType(ss, memref.getElementType()); } else if (auto vec = t.dyn_cast()) { ss << "vector"; llvm::interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); appendMangledType(ss, vec.getElementType()); } else if (t.isSignlessIntOrIndexOrFloat()) { ss << t; } else { llvm_unreachable("Invalid type for linalg library name mangling"); } } std::string mlir::linalg::generateLibraryCallName(Operation *op) { assert(isa(op)); std::string name(op->getName().getStringRef().str()); name.reserve(128); std::replace(name.begin(), name.end(), '.', '_'); llvm::raw_string_ostream ss(name); ss << "_"; auto types = op->getOperandTypes(); llvm::interleave( types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, [&]() { ss << "_"; }); return ss.str(); } // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of // OpInterfaces where a Linalg "named" op **isa** LinalgOp. OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { if (succeeded(foldMemRefCast(*this))) return getResult(); return foldReshapeOp(*this, operands); } OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { if (succeeded(foldMemRefCast(*this))) return getResult(); return foldReshapeOp(*this, operands); } OpFoldResult TensorExpandShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } OpFoldResult TensorCollapseShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } //===----------------------------------------------------------------------===// // Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// /// Generic entry point to create the block for the region of a LinalgOp. /// This is used by both named structured ops created by ods-gen and by manually /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted /// to be ShapedType. template static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, std::function errorHandler) { assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. SmallVector argTypes; for (auto containers : {inputTypes, outputTypes}) for (auto t : containers) argTypes.push_back(getElementTypeOrSelf(t)); // RAII. OpBuilder::InsertionGuard guard(opBuilder); Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) { if (errorHandler) errorHandler(expected, actual); return; } opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); NamedStructuredOpType::regionBuilder(b, *body); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } /// Generic entry point to create both the region and the block of a LinalgOp. template void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes) { Region ®ion = *result.addRegion(); fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { assert(expected != actual && "incorrect number of arguments"); }); } /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes) { llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; SmallVector inputsOperands, outputsOperands; parser.parseOptionalAttrDict(result.attributes); if (succeeded(parser.parseOptionalKeyword("ins"))) { if (parser.parseLParen()) return failure(); inputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseOperandList(inputsOperands) || parser.parseColonTypeList(inputTypes) || parser.parseRParen()) return failure(); } if (succeeded(parser.parseOptionalKeyword("outs"))) { outputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || parser.parseColonTypeList(outputTypes) || parser.parseRParen()) return failure(); } if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, result.operands) || parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, result.operands)) return failure(); result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr( {static_cast(inputsOperands.size()), static_cast(outputsOperands.size())})); return success(); } template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op) { if (!op.inputs().empty()) p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; if (!op.outputs().empty()) p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; } //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes) { ParseResult res = success(); OpBuilder opBuilder(parser.getBuilder().getContext()); // Resolve `captures` into `capturedValues` at parse time so we can build the // region with captures. SmallVector capturedValues; fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { res = parser.emitError( parser.getCurrentLocation(), llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " "region expects {0} args, got {1}", expected, actual)); region.front().dump(); }); return res; } static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes) { if (parser.parseOptionalArrowTypeList(resultTypes)) return failure(); return success(); } template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result) { // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( parser, *region, inputTypes, outputTypes)) return failure(); result.addRegion(std::move(region)); return success(); } static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes) { if (resultTypes.empty()) return; p.printOptionalArrowTypeList(resultTypes); } template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { p << op.getOperationName(); p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{"operand_segment_sizes", // See generated code in mlir-linalg-yaml-gen.cpp "linalg.memoized_indexing_maps"}); // Printing is shared with generic ops, except for the region and // attributes. printCommonStructuredOpParts(p, op); // Results printing. printNamedStructuredOpResults(p, op.result_tensors().getTypes()); // Region is elided. } template static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// // Canonicalizers and Folders. //===----------------------------------------------------------------------===// namespace { struct EraseDeadLinalgOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { for (OpOperand *opOperand : op.getInputAndOutputOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. auto mt = opOperand->get().getType().dyn_cast(); if (!mt) continue; if (llvm::is_contained(op.getShape(opOperand), 0)) { rewriter.eraseOp(op); return success(); } } return failure(); } }; struct FoldTensorCastOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { if (opOperand->get().isa()) return false; auto castOp = opOperand->get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) return failure(); SmallVector newResultTypes; newResultTypes.reserve(op->getNumResults()); SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. for (OpOperand *opOperand : op.getInputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : opOperand->get()); } // Init tensors may fold, in which case the resultType must also change. for (OpOperand *opOperand : op.getOutputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand->get()); newResultTypes.push_back(newOperands.back().getType()); } // Clone op. Operation *newOp = op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { Value oldResult = std::get<0>(result); Value newResult = std::get<1>(result); if (newResult.getType() != oldResult.getType()) { replacements.push_back(rewriter.create( op->getLoc(), oldResult.getType(), newResult)); } else { replacements.push_back(newResult); } } rewriter.replaceOp(op, replacements); return success(); } }; } // namespace namespace { // Deduplicate redundant args of a linalg op. // An arg is redundant if it has the same Value and indexing map as another. struct DeduplicateInputs : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { // This pattern reduces the number of arguments of an op, which breaks // the invariants of semantically charged named ops. if (!isa(op)) return failure(); // Associate each input to an equivalent "canonical" input that has the same // Value and indexing map. // // In the non-duplicate case, input `i` will have canonical input `i`. But // in the case of duplicated inputs, the canonical input could be some other // input `< i`. That is, a later input will have some earlier input as its // canonical input. llvm::SmallDenseMap, unsigned> canonicalInput; // For later remapping tasks like deduplicating payload block arguments, // having a simple "inputIndex -> canonicalInputIndex" integer mapping is // convenient. SmallVector canonicalInputIndices; for (OpOperand *opOperand : op.getInputOperands()) { AffineMap indexingMap = op.getTiedIndexingMap(opOperand); // STL-like maps have a convenient behavior for our use case here. In the // case of duplicate keys, the insertion is rejected, and the returned // iterator gives access to the value already in the map. auto pair = canonicalInput.insert( {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); canonicalInputIndices.push_back(pair.first->second); } // If there are no duplicate args, then bail out. if (canonicalInput.size() == op.getNumInputs()) return failure(); // The operands for the newly canonicalized op. SmallVector newOperands; for (OpOperand *opOperand : op.getInputOperands()) if (canonicalInputIndices[opOperand->getOperandNumber()] == opOperand->getOperandNumber()) newOperands.push_back(opOperand->get()); SmallVector outputOperands = op.getOutputOperands(); llvm::append_range(newOperands, outputOperands); // Repair the indexing maps by filtering out the ones that have been // eliminated. SmallVector newIndexingMaps; for (OpOperand *opOperand : op.getInputOperands()) if (canonicalInputIndices[opOperand->getOperandNumber()] == opOperand->getOperandNumber()) newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); for (OpOperand *opOperand : op.getOutputOperands()) newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); // Clone the old op with new operands. Operation *newOp = op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands); auto newLinalgOp = cast(newOp); newOp->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(newIndexingMaps)); // Set the number of inputs to the new value. The `clone` call above kept // the value from the original op. newLinalgOp.setNumInputs(canonicalInput.size()); // Repair the payload entry block by RAUW'ing redundant arguments and // erasing them. Block &payload = newOp->getRegion(0).front(); SmallVector inputOperands = op.getInputOperands(); for (OpOperand *opOperand : llvm::reverse(inputOperands)) { // Iterate in reverse, so that we erase later args first, preventing the // argument list from shifting unexpectedly and invalidating all our // indices. unsigned operandNumber = opOperand->getOperandNumber(); if (canonicalInputIndices[operandNumber] == operandNumber) continue; payload.getArgument(operandNumber) .replaceAllUsesWith( payload.getArgument(canonicalInputIndices[operandNumber])); payload.eraseArgument(operandNumber); } rewriter.replaceOp(op, newOp->getResults()); return success(); } }; /// Remove generic operations (on tensors) that are just copying /// the values from inputs to the results. Requirements are /// 1) All iterator types are parallel /// 2) The body contains just a yield operation with the yielded values being /// the arguments corresponding to the operands. struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { if (auto copyOp = dyn_cast(*op)) { assert(copyOp.hasBufferSemantics()); if (copyOp.input() == copyOp.output() && copyOp.inputPermutation() == copyOp.outputPermutation()) { rewriter.eraseOp(op); return success(); } } if (!isa(op)) return failure(); if (!op.hasTensorSemantics()) return failure(); // Check all indexing maps are identity. if (llvm::any_of(op.getIndexingMaps(), [](AffineMap map) { return !map.isIdentity(); })) return failure(); // Check that the body of the linalg operation is just a linalg.yield // operation. Block &body = op->getRegion(0).front(); if (!llvm::hasSingleElement(body)) return failure(); auto yieldOp = dyn_cast(body.getTerminator()); if (!yieldOp) return failure(); // Get the argument number of the returned values. That is the operand // number to use for replacing uses of this operation. SmallVector returnedArgs; for (Value yieldVal : yieldOp.values()) { auto yieldArg = yieldVal.dyn_cast(); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); returnedArgs.push_back(op->getOperand(argumentNumber)); } if (returnedArgs.size() != op.getOperation()->getNumResults()) return failure(); rewriter.replaceOp(op, returnedArgs); return success(); } }; } // namespace #define LINALGOP_FOLDERS(XXX) \ LogicalResult XXX::fold(ArrayRef, \ SmallVectorImpl &) { \ return foldMemRefCast(*this); \ } LINALGOP_FOLDERS(ConvOp) LINALGOP_FOLDERS(PoolingMaxOp) LINALGOP_FOLDERS(PoolingMinOp) LINALGOP_FOLDERS(PoolingSumOp) LINALGOP_FOLDERS(CopyOp) LINALGOP_FOLDERS(FillOp) LINALGOP_FOLDERS(GenericOp) // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); }