1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for generating MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Complex/IR/Complex.h" 18 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 19 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 20 #include "mlir/IR/Builders.h" 21 22 namespace mlir { 23 class Location; 24 class Type; 25 class Value; 26 27 namespace sparse_tensor { 28 29 //===----------------------------------------------------------------------===// 30 // ExecutionEngine/SparseTensorUtils helper functions. 31 //===----------------------------------------------------------------------===// 32 33 /// Converts an overhead storage bitwidth to its internal type-encoding. 34 OverheadType overheadTypeEncoding(unsigned width); 35 36 /// Converts an overhead storage type to its internal type-encoding. 37 OverheadType overheadTypeEncoding(Type tp); 38 39 /// Converts the internal type-encoding for overhead storage to an mlir::Type. 40 Type getOverheadType(Builder &builder, OverheadType ot); 41 42 /// Returns the OverheadType for pointer overhead storage. 43 OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); 44 45 /// Returns the OverheadType for index overhead storage. 46 OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); 47 48 /// Returns the mlir::Type for pointer overhead storage. 49 Type getPointerOverheadType(Builder &builder, 50 const SparseTensorEncodingAttr &enc); 51 52 /// Returns the mlir::Type for index overhead storage. 53 Type getIndexOverheadType(Builder &builder, 54 const SparseTensorEncodingAttr &enc); 55 56 /// Convert OverheadType to its function-name suffix. 57 StringRef overheadTypeFunctionSuffix(OverheadType ot); 58 59 /// Converts an overhead storage type to its function-name suffix. 60 StringRef overheadTypeFunctionSuffix(Type overheadTp); 61 62 /// Converts a primary storage type to its internal type-encoding. 63 PrimaryType primaryTypeEncoding(Type elemTp); 64 65 /// Convert PrimaryType to its function-name suffix. 66 StringRef primaryTypeFunctionSuffix(PrimaryType pt); 67 68 /// Converts a primary storage type to its function-name suffix. 69 StringRef primaryTypeFunctionSuffix(Type elemTp); 70 71 /// Converts the IR's dimension level type to its internal type-encoding. 72 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); 73 74 //===----------------------------------------------------------------------===// 75 // Misc code generators. 76 // 77 // TODO: both of these should move upstream to their respective classes. 78 // Once RFCs have been created for those changes, list them here. 79 //===----------------------------------------------------------------------===// 80 81 /// Generates a 1-valued attribute of the given type. This supports 82 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, 83 /// for unsupported types we raise `llvm_unreachable` rather than 84 /// returning a null attribute. 85 Attribute getOneAttr(Builder &builder, Type tp); 86 87 /// Generates the comparison `v != 0` where `v` is of numeric type. 88 /// For floating types, we use the "unordered" comparator (i.e., returns 89 /// true if `v` is NaN). 90 Value genIsNonzero(OpBuilder &builder, Location loc, Value v); 91 92 //===----------------------------------------------------------------------===// 93 // Constant generators. 94 // 95 // All these functions are just wrappers to improve code legibility; 96 // therefore, we mark them as `inline` to avoid introducing any additional 97 // overhead due to the legibility. 98 // 99 // TODO: Ideally these should move upstream, so that we don't 100 // develop a design island. However, doing so will involve 101 // substantial design work. For related prior discussion, see 102 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879> 103 //===----------------------------------------------------------------------===// 104 105 /// Generates a 0-valued constant of the given type. In addition to 106 /// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`), 107 /// this also works for `RankedTensorType` and `VectorType` (for which it 108 /// generates a constant `DenseElementsAttr` of zeros). 109 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { 110 if (auto ctp = tp.dyn_cast<ComplexType>()) { 111 auto zeroe = builder.getZeroAttr(ctp.getElementType()); 112 auto zeroa = builder.getArrayAttr({zeroe, zeroe}); 113 return builder.create<complex::ConstantOp>(loc, tp, zeroa); 114 } 115 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp)); 116 } 117 118 /// Generates a 1-valued constant of the given type. This supports all 119 /// the same types as `constantZero`. 120 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { 121 if (auto ctp = tp.dyn_cast<ComplexType>()) { 122 auto zeroe = builder.getZeroAttr(ctp.getElementType()); 123 auto onee = getOneAttr(builder, ctp.getElementType()); 124 auto zeroa = builder.getArrayAttr({onee, zeroe}); 125 return builder.create<complex::ConstantOp>(loc, tp, zeroa); 126 } 127 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp)); 128 } 129 130 /// Generates a constant of `index` type. 131 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { 132 return builder.create<arith::ConstantIndexOp>(loc, i); 133 } 134 135 /// Generates a constant of `i32` type. 136 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { 137 return builder.create<arith::ConstantIntOp>(loc, i, 32); 138 } 139 140 /// Generates a constant of `i16` type. 141 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { 142 return builder.create<arith::ConstantIntOp>(loc, i, 16); 143 } 144 145 /// Generates a constant of `i8` type. 146 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { 147 return builder.create<arith::ConstantIntOp>(loc, i, 8); 148 } 149 150 /// Generates a constant of `i1` type. 151 inline Value constantI1(OpBuilder &builder, Location loc, bool b) { 152 return builder.create<arith::ConstantIntOp>(loc, b, 1); 153 } 154 155 /// Generates a constant of the given `Action`. 156 inline Value constantAction(OpBuilder &builder, Location loc, Action action) { 157 return constantI32(builder, loc, static_cast<uint32_t>(action)); 158 } 159 160 /// Generates a constant of the internal type-encoding for overhead storage. 161 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, 162 unsigned width) { 163 return constantI32(builder, loc, 164 static_cast<uint32_t>(overheadTypeEncoding(width))); 165 } 166 167 /// Generates a constant of the internal type-encoding for pointer 168 /// overhead storage. 169 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc, 170 const SparseTensorEncodingAttr &enc) { 171 return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth()); 172 } 173 174 /// Generates a constant of the internal type-encoding for index overhead 175 /// storage. 176 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc, 177 const SparseTensorEncodingAttr &enc) { 178 return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth()); 179 } 180 181 /// Generates a constant of the internal type-encoding for primary storage. 182 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, 183 Type elemTp) { 184 return constantI32(builder, loc, 185 static_cast<uint32_t>(primaryTypeEncoding(elemTp))); 186 } 187 188 /// Generates a constant of the internal dimension level type encoding. 189 inline Value 190 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, 191 SparseTensorEncodingAttr::DimLevelType dlt) { 192 return constantI8(builder, loc, 193 static_cast<uint8_t>(dimLevelTypeEncoding(dlt))); 194 } 195 196 } // namespace sparse_tensor 197 } // namespace mlir 198 199 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 200