1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for generating MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 19 #include "mlir/IR/Builders.h" 20 21 namespace mlir { 22 class Location; 23 class Type; 24 class Value; 25 26 namespace sparse_tensor { 27 28 //===----------------------------------------------------------------------===// 29 // ExecutionEngine/SparseTensorUtils helper functions. 30 //===----------------------------------------------------------------------===// 31 32 /// Converts an overhead storage bitwidth to its internal type-encoding. 33 OverheadType overheadTypeEncoding(unsigned width); 34 35 /// Converts an overhead storage type to its internal type-encoding. 36 OverheadType overheadTypeEncoding(Type tp); 37 38 /// Converts the internal type-encoding for overhead storage to an mlir::Type. 39 Type getOverheadType(Builder &builder, OverheadType ot); 40 41 /// Returns the OverheadType for pointer overhead storage. 42 OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); 43 44 /// Returns the OverheadType for index overhead storage. 45 OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); 46 47 /// Returns the mlir::Type for pointer overhead storage. 48 Type getPointerOverheadType(Builder &builder, 49 const SparseTensorEncodingAttr &enc); 50 51 /// Returns the mlir::Type for index overhead storage. 52 Type getIndexOverheadType(Builder &builder, 53 const SparseTensorEncodingAttr &enc); 54 55 /// Convert OverheadType to its function-name suffix. 56 StringRef overheadTypeFunctionSuffix(OverheadType ot); 57 58 /// Converts an overhead storage type to its function-name suffix. 59 StringRef overheadTypeFunctionSuffix(Type overheadTp); 60 61 /// Converts a primary storage type to its internal type-encoding. 62 PrimaryType primaryTypeEncoding(Type elemTp); 63 64 /// Convert PrimaryType to its function-name suffix. 65 StringRef primaryTypeFunctionSuffix(PrimaryType pt); 66 67 /// Converts a primary storage type to its function-name suffix. 68 StringRef primaryTypeFunctionSuffix(Type elemTp); 69 70 /// Converts the IR's dimension level type to its internal type-encoding. 71 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); 72 73 //===----------------------------------------------------------------------===// 74 // Misc code generators. 75 // 76 // TODO: both of these should move upstream to their respective classes. 77 // Once RFCs have been created for those changes, list them here. 78 //===----------------------------------------------------------------------===// 79 80 /// Generates a 1-valued attribute of the given type. This supports 81 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, 82 /// for unsupported types we raise `llvm_unreachable` rather than 83 /// returning a null attribute. 84 Attribute getOneAttr(Builder &builder, Type tp); 85 86 /// Generates the comparison `v != 0` where `v` is of numeric type. 87 /// For floating types, we use the "unordered" comparator (i.e., returns 88 /// true if `v` is NaN). 89 Value genIsNonzero(OpBuilder &builder, Location loc, Value v); 90 91 //===----------------------------------------------------------------------===// 92 // Constant generators. 93 // 94 // All these functions are just wrappers to improve code legibility; 95 // therefore, we mark them as `inline` to avoid introducing any additional 96 // overhead due to the legibility. 97 // 98 // TODO: Ideally these should move upstream, so that we don't 99 // develop a design island. However, doing so will involve 100 // substantial design work. For related prior discussion, see 101 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879> 102 //===----------------------------------------------------------------------===// 103 104 /// Generates a 0-valued constant of the given type. In addition to 105 /// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also 106 /// works for `RankedTensorType` and `VectorType` (for which it generates 107 /// a constant `DenseElementsAttr` of zeros). 108 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { 109 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp)); 110 } 111 112 /// Generates a 1-valued constant of the given type. This supports all 113 /// the same types as `constantZero`. 114 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { 115 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp)); 116 } 117 118 /// Generates a constant of `index` type. 119 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { 120 return builder.create<arith::ConstantIndexOp>(loc, i); 121 } 122 123 /// Generates a constant of `i32` type. 124 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { 125 return builder.create<arith::ConstantIntOp>(loc, i, 32); 126 } 127 128 /// Generates a constant of `i16` type. 129 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { 130 return builder.create<arith::ConstantIntOp>(loc, i, 16); 131 } 132 133 /// Generates a constant of `i8` type. 134 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { 135 return builder.create<arith::ConstantIntOp>(loc, i, 8); 136 } 137 138 /// Generates a constant of `i1` type. 139 inline Value constantI1(OpBuilder &builder, Location loc, bool b) { 140 return builder.create<arith::ConstantIntOp>(loc, b, 1); 141 } 142 143 /// Generates a constant of the given `Action`. 144 inline Value constantAction(OpBuilder &builder, Location loc, Action action) { 145 return constantI32(builder, loc, static_cast<uint32_t>(action)); 146 } 147 148 /// Generates a constant of the internal type-encoding for overhead storage. 149 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, 150 unsigned width) { 151 return constantI32(builder, loc, 152 static_cast<uint32_t>(overheadTypeEncoding(width))); 153 } 154 155 /// Generates a constant of the internal type-encoding for pointer 156 /// overhead storage. 157 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc, 158 const SparseTensorEncodingAttr &enc) { 159 return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth()); 160 } 161 162 /// Generates a constant of the internal type-encoding for index overhead 163 /// storage. 164 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc, 165 const SparseTensorEncodingAttr &enc) { 166 return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth()); 167 } 168 169 /// Generates a constant of the internal type-encoding for primary storage. 170 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, 171 Type elemTp) { 172 return constantI32(builder, loc, 173 static_cast<uint32_t>(primaryTypeEncoding(elemTp))); 174 } 175 176 /// Generates a constant of the internal dimension level type encoding. 177 inline Value 178 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, 179 SparseTensorEncodingAttr::DimLevelType dlt) { 180 return constantI8(builder, loc, 181 static_cast<uint8_t>(dimLevelTypeEncoding(dlt))); 182 } 183 184 } // namespace sparse_tensor 185 } // namespace mlir 186 187 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 188