1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for generating MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 19 #include "mlir/IR/Builders.h" 20 21 namespace mlir { 22 class Location; 23 class Type; 24 class Value; 25 26 namespace sparse_tensor { 27 28 //===----------------------------------------------------------------------===// 29 // ExecutionEngine/SparseTensorUtils helper functions. 30 //===----------------------------------------------------------------------===// 31 32 /// Converts an overhead storage bitwidth to its internal type-encoding. 33 OverheadType overheadTypeEncoding(unsigned width); 34 35 /// Converts the internal type-encoding for overhead storage to an mlir::Type. 36 Type getOverheadType(Builder &builder, OverheadType ot); 37 38 /// Returns the mlir::Type for pointer overhead storage. 39 Type getPointerOverheadType(Builder &builder, 40 const SparseTensorEncodingAttr &enc); 41 42 /// Returns the mlir::Type for index overhead storage. 43 Type getIndexOverheadType(Builder &builder, 44 const SparseTensorEncodingAttr &enc); 45 46 /// Converts a primary storage type to its internal type-encoding. 47 PrimaryType primaryTypeEncoding(Type elemTp); 48 49 /// Converts the IR's dimension level type to its internal type-encoding. 50 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt); 51 52 //===----------------------------------------------------------------------===// 53 // Misc code generators. 54 // 55 // TODO: both of these should move upstream to their respective classes. 56 // Once RFCs have been created for those changes, list them here. 57 //===----------------------------------------------------------------------===// 58 59 /// Generates a 1-valued attribute of the given type. This supports 60 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, 61 /// for unsupported types we raise `llvm_unreachable` rather than 62 /// returning a null attribute. 63 Attribute getOneAttr(Builder &builder, Type tp); 64 65 /// Generates the comparison `v != 0` where `v` is of numeric type. 66 /// For floating types, we use the "unordered" comparator (i.e., returns 67 /// true if `v` is NaN). 68 Value genIsNonzero(OpBuilder &builder, Location loc, Value v); 69 70 //===----------------------------------------------------------------------===// 71 // Constant generators. 72 // 73 // All these functions are just wrappers to improve code legibility; 74 // therefore, we mark them as `inline` to avoid introducing any additional 75 // overhead due to the legibility. 76 // 77 // TODO: Ideally these should move upstream, so that we don't 78 // develop a design island. However, doing so will involve 79 // substantial design work. For related prior discussion, see 80 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879> 81 //===----------------------------------------------------------------------===// 82 83 /// Generates a 0-valued constant of the given type. In addition to 84 /// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also 85 /// works for `RankedTensorType` and `VectorType` (for which it generates 86 /// a constant `DenseElementsAttr` of zeros). 87 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { 88 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp)); 89 } 90 91 /// Generates a 1-valued constant of the given type. This supports all 92 /// the same types as `constantZero`. 93 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { 94 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp)); 95 } 96 97 /// Generates a constant of `index` type. 98 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { 99 return builder.create<arith::ConstantIndexOp>(loc, i); 100 } 101 102 /// Generates a constant of `i32` type. 103 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { 104 return builder.create<arith::ConstantIntOp>(loc, i, 32); 105 } 106 107 /// Generates a constant of `i16` type. 108 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { 109 return builder.create<arith::ConstantIntOp>(loc, i, 16); 110 } 111 112 /// Generates a constant of `i8` type. 113 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { 114 return builder.create<arith::ConstantIntOp>(loc, i, 8); 115 } 116 117 /// Generates a constant of `i1` type. 118 inline Value constantI1(OpBuilder &builder, Location loc, bool b) { 119 return builder.create<arith::ConstantIntOp>(loc, b, 1); 120 } 121 122 /// Generates a constant of the given `Action`. 123 inline Value constantAction(OpBuilder &builder, Location loc, Action action) { 124 return constantI32(builder, loc, static_cast<uint32_t>(action)); 125 } 126 127 /// Generates a constant of the internal type-encoding for overhead storage. 128 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, 129 unsigned width) { 130 return constantI32(builder, loc, 131 static_cast<uint32_t>(overheadTypeEncoding(width))); 132 } 133 134 /// Generates a constant of the internal type-encoding for pointer 135 /// overhead storage. 136 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc, 137 const SparseTensorEncodingAttr &enc) { 138 return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth()); 139 } 140 141 /// Generates a constant of the internal type-encoding for index overhead 142 /// storage. 143 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc, 144 const SparseTensorEncodingAttr &enc) { 145 return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth()); 146 } 147 148 /// Generates a constant of the internal type-encoding for primary storage. 149 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, 150 Type elemTp) { 151 return constantI32(builder, loc, 152 static_cast<uint32_t>(primaryTypeEncoding(elemTp))); 153 } 154 155 /// Generates a constant of the internal dimension level type encoding. 156 inline Value 157 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, 158 SparseTensorEncodingAttr::DimLevelType dlt) { 159 return constantI8(builder, loc, 160 static_cast<uint8_t>(dimLevelTypeEncoding(dlt))); 161 } 162 163 } // namespace sparse_tensor 164 } // namespace mlir 165 166 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ 167