1 //===- Dialect.cpp - Implementation of the linalg dialect and types -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg dialect types and dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/FunctionSupport.h" 19 #include "mlir/Parser.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Transforms/InliningUtils.h" 22 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/Support/raw_ostream.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 //===----------------------------------------------------------------------===// 30 // LinalgDialect Dialect Interfaces 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 35 struct LinalgInlinerInterface : public DialectInlinerInterface { 36 using DialectInlinerInterface::DialectInlinerInterface; 37 38 // We don't have any special restrictions on what can be inlined into 39 // destination regions (e.g. while/conditional bodies). Always allow it. 40 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 41 BlockAndValueMapping &valueMapping) const final { 42 return true; 43 } 44 // Operations in Linalg dialect are always legal to inline. 45 bool isLegalToInline(Operation *, Region *, bool, 46 BlockAndValueMapping &) const final { 47 return true; 48 } 49 // Handle the given inlined terminator by replacing it with a new operation 50 // as necessary. Required when the region has only one block. 51 void handleTerminator(Operation *op, 52 ArrayRef<Value> valuesToRepl) const final {} 53 }; 54 55 } // namespace 56 57 //===----------------------------------------------------------------------===// 58 // LinalgDialect 59 //===----------------------------------------------------------------------===// 60 61 /// Attribute name used to to memoize indexing maps for named ops. 62 constexpr const ::llvm::StringLiteral 63 LinalgDialect::kMemoizedIndexingMapsAttrName; 64 65 /// Attribute name used to mark the bufferization layout for region 66 /// arguments during linalg comprehensive bufferization. 67 constexpr const ::llvm::StringLiteral 68 comprehensive_bufferize::BufferizableOpInterface::kBufferLayoutAttrName; 69 70 /// Attribute name used to mark region arguments that can be bufferized 71 /// in-place during linalg comprehensive bufferization. 72 constexpr const ::llvm::StringLiteral 73 comprehensive_bufferize::BufferizableOpInterface::kInplaceableAttrName; 74 75 /// Trait to check if T provides a `regionBuilder` method. 76 template <typename T, typename... Args> 77 using has_region_builder = decltype(T::regionBuilder); 78 template <typename T> 79 using detect_has_region_builder = llvm::is_detected<has_region_builder, T>; 80 81 /// SFINAE helper for single C++ class without a `regionBuilder` method (e.g. 82 /// an OpInterface). 83 template <typename OpType, typename = std::enable_if_t< 84 !detect_has_region_builder<OpType>::value>> 85 void addNamedOpBuilderImpl( 86 llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) { 87 // Do nothing. 88 } 89 90 template <typename OpType, 91 typename = std::enable_if_t<detect_has_region_builder<OpType>::value>, 92 typename = void> 93 void addNamedOpBuilderImpl( 94 llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) { 95 map.insert(std::make_pair( 96 OpType::getOperationName(), 97 static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder))); 98 } 99 100 template <typename... OpTypes> 101 void addNamedOpBuilders( 102 llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) { 103 (void)std::initializer_list<int>{0, 104 (addNamedOpBuilderImpl<OpTypes>(map), 0)...}; 105 } 106 107 void mlir::linalg::LinalgDialect::initialize() { 108 addOperations< 109 #define GET_OP_LIST 110 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 111 >(); 112 addOperations< 113 #define GET_OP_LIST 114 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 115 >(); 116 117 // Fill the Linalg-specific OpName to RegionBuilder map. 118 addNamedOpBuilders< 119 #define GET_OP_LIST 120 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 121 >(namedStructuredOpRegionBuilders); 122 123 addInterfaces<LinalgInlinerInterface>(); 124 } 125 126 LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, 127 NamedAttribute attr) { 128 using comprehensive_bufferize::BufferizableOpInterface; 129 130 if (attr.getName() == BufferizableOpInterface::kInplaceableAttrName) { 131 if (!attr.getValue().isa<BoolAttr>()) { 132 return op->emitError() 133 << "'" << BufferizableOpInterface::kInplaceableAttrName 134 << "' is expected to be a boolean attribute"; 135 } 136 if (!op->hasTrait<OpTrait::FunctionLike>()) 137 return op->emitError() << "expected " << attr.getName() 138 << " to be used on function-like operations"; 139 return success(); 140 } 141 if (attr.getName() == BufferizableOpInterface::kBufferLayoutAttrName) { 142 if (!attr.getValue().isa<AffineMapAttr>()) { 143 return op->emitError() 144 << "'" << BufferizableOpInterface::kBufferLayoutAttrName 145 << "' is expected to be a affine map attribute"; 146 } 147 if (!op->hasTrait<OpTrait::FunctionLike>()) 148 return op->emitError() << "expected " << attr.getName() 149 << " to be used on function-like operations"; 150 return success(); 151 } 152 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) 153 return success(); 154 return op->emitError() << "attribute '" << attr.getName() 155 << "' not supported by the linalg dialect"; 156 } 157