1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for generating MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 19 #include "mlir/IR/Builders.h" 20 21 namespace mlir { 22 class Location; 23 class Type; 24 class Value; 25 26 namespace sparse_tensor { 27 28 //===----------------------------------------------------------------------===// 29 // ExecutionEngine/SparseTensorUtils helper functions. 30 //===----------------------------------------------------------------------===// 31 32 /// Converts an overhead storage bitwidth to its internal type-encoding. 33 OverheadType overheadTypeEncoding(unsigned width); 34 35 /// Converts an overhead storage type to its internal type-encoding. 36 OverheadType overheadTypeEncoding(Type tp); 37 38 /// Converts the internal type-encoding for overhead storage to an mlir::Type. 39 Type getOverheadType(Builder &builder, OverheadType ot); 40 41 /// Returns the mlir::Type for pointer overhead storage. 42 Type getPointerOverheadType(Builder &builder, 43 const SparseTensorEncodingAttr &enc); 44 45 /// Returns the mlir::Type for index overhead storage. 46 Type getIndexOverheadType(Builder &builder, 47 const SparseTensorEncodingAttr &enc); 48 49 /// Convert OverheadType to its function-name suffix. 50 StringRef overheadTypeFunctionSuffix(OverheadType ot); 51 52 /// Converts an overhead storage type to its function-name suffix. 53 StringRef overheadTypeFunctionSuffix(Type overheadTp); 54 55 /// Converts a primary storage type to its internal type-encoding. 56 PrimaryType primaryTypeEncoding(Type elemTp); 57 58 /// Convert PrimaryType to its function-name suffix. 59 StringRef primaryTypeFunctionSuffix(PrimaryType pt); 60 61 /// Converts a primary storage type to its function-name suffix. 62 StringRef primaryTypeFunctionSuffix(Type elemTp); 63 64 /// Converts the IR's dimension level type to its internal type-encoding. 65 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); 66 67 //===----------------------------------------------------------------------===// 68 // Misc code generators. 69 // 70 // TODO: both of these should move upstream to their respective classes. 71 // Once RFCs have been created for those changes, list them here. 72 //===----------------------------------------------------------------------===// 73 74 /// Generates a 1-valued attribute of the given type. This supports 75 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, 76 /// for unsupported types we raise `llvm_unreachable` rather than 77 /// returning a null attribute. 78 Attribute getOneAttr(Builder &builder, Type tp); 79 80 /// Generates the comparison `v != 0` where `v` is of numeric type. 81 /// For floating types, we use the "unordered" comparator (i.e., returns 82 /// true if `v` is NaN). 83 Value genIsNonzero(OpBuilder &builder, Location loc, Value v); 84 85 //===----------------------------------------------------------------------===// 86 // Constant generators. 87 // 88 // All these functions are just wrappers to improve code legibility; 89 // therefore, we mark them as `inline` to avoid introducing any additional 90 // overhead due to the legibility. 91 // 92 // TODO: Ideally these should move upstream, so that we don't 93 // develop a design island. However, doing so will involve 94 // substantial design work. For related prior discussion, see 95 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879> 96 //===----------------------------------------------------------------------===// 97 98 /// Generates a 0-valued constant of the given type. In addition to 99 /// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also 100 /// works for `RankedTensorType` and `VectorType` (for which it generates 101 /// a constant `DenseElementsAttr` of zeros). 102 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { 103 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp)); 104 } 105 106 /// Generates a 1-valued constant of the given type. This supports all 107 /// the same types as `constantZero`. 108 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { 109 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp)); 110 } 111 112 /// Generates a constant of `index` type. 113 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { 114 return builder.create<arith::ConstantIndexOp>(loc, i); 115 } 116 117 /// Generates a constant of `i32` type. 118 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { 119 return builder.create<arith::ConstantIntOp>(loc, i, 32); 120 } 121 122 /// Generates a constant of `i16` type. 123 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { 124 return builder.create<arith::ConstantIntOp>(loc, i, 16); 125 } 126 127 /// Generates a constant of `i8` type. 128 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { 129 return builder.create<arith::ConstantIntOp>(loc, i, 8); 130 } 131 132 /// Generates a constant of `i1` type. 133 inline Value constantI1(OpBuilder &builder, Location loc, bool b) { 134 return builder.create<arith::ConstantIntOp>(loc, b, 1); 135 } 136 137 /// Generates a constant of the given `Action`. 138 inline Value constantAction(OpBuilder &builder, Location loc, Action action) { 139 return constantI32(builder, loc, static_cast<uint32_t>(action)); 140 } 141 142 /// Generates a constant of the internal type-encoding for overhead storage. 143 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, 144 unsigned width) { 145 return constantI32(builder, loc, 146 static_cast<uint32_t>(overheadTypeEncoding(width))); 147 } 148 149 /// Generates a constant of the internal type-encoding for pointer 150 /// overhead storage. 151 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc, 152 const SparseTensorEncodingAttr &enc) { 153 return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth()); 154 } 155 156 /// Generates a constant of the internal type-encoding for index overhead 157 /// storage. 158 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc, 159 const SparseTensorEncodingAttr &enc) { 160 return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth()); 161 } 162 163 /// Generates a constant of the internal type-encoding for primary storage. 164 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, 165 Type elemTp) { 166 return constantI32(builder, loc, 167 static_cast<uint32_t>(primaryTypeEncoding(elemTp))); 168 } 169 170 /// Generates a constant of the internal dimension level type encoding. 171 inline Value 172 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, 173 SparseTensorEncodingAttr::DimLevelType dlt) { 174 return constantI8(builder, loc, 175 static_cast<uint8_t>(dimLevelTypeEncoding(dlt))); 176 } 177 178 } // namespace sparse_tensor 179 } // namespace mlir 180 181 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 182