1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGOps.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/CodeGen/CodeGen.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROps.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/Support/FIRContext.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/Support/Debug.h" 23 24 //===----------------------------------------------------------------------===// 25 // Codegen rewrite: rewriting of subgraphs of ops 26 //===----------------------------------------------------------------------===// 27 28 using namespace fir; 29 30 #define DEBUG_TYPE "flang-codegen-rewrite" 31 32 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 33 ShapeOp shape) { 34 vec.append(shape.extents().begin(), shape.extents().end()); 35 } 36 37 // Operands of fir.shape_shift split into two vectors. 38 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 39 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 40 ShapeShiftOp shift) { 41 auto endIter = shift.pairs().end(); 42 for (auto i = shift.pairs().begin(); i != endIter;) { 43 shiftVec.push_back(*i++); 44 shapeVec.push_back(*i++); 45 } 46 } 47 48 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 49 ShiftOp shift) { 50 vec.append(shift.origins().begin(), shift.origins().end()); 51 } 52 53 namespace { 54 55 /// Convert fir.embox to the extended form where necessary. 56 /// 57 /// The embox operation can take arguments that specify multidimensional array 58 /// properties at runtime. These properties may be shared between distinct 59 /// objects that have the same properties. Before we lower these small DAGs to 60 /// LLVM-IR, we gather all the information into a single extended operation. For 61 /// example, 62 /// ``` 63 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 64 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 65 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 66 /// ``` 67 /// can be rewritten as 68 /// ``` 69 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>> 70 /// ``` 71 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> { 72 public: 73 using OpRewritePattern::OpRewritePattern; 74 75 mlir::LogicalResult 76 matchAndRewrite(EmboxOp embox, 77 mlir::PatternRewriter &rewriter) const override { 78 auto shapeVal = embox.getShape(); 79 // If the embox does not include a shape, then do not convert it 80 if (shapeVal) 81 return rewriteDynamicShape(embox, rewriter, shapeVal); 82 if (auto boxTy = embox.getType().dyn_cast<BoxType>()) 83 if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>()) 84 if (seqTy.hasConstantShape()) 85 return rewriteStaticShape(embox, rewriter, seqTy); 86 return mlir::failure(); 87 } 88 89 mlir::LogicalResult rewriteStaticShape(EmboxOp embox, 90 mlir::PatternRewriter &rewriter, 91 SequenceType seqTy) const { 92 auto loc = embox.getLoc(); 93 llvm::SmallVector<mlir::Value> shapeOpers; 94 auto idxTy = rewriter.getIndexType(); 95 for (auto ext : seqTy.getShape()) { 96 auto iAttr = rewriter.getIndexAttr(ext); 97 auto extVal = rewriter.create<mlir::ConstantOp>(loc, idxTy, iAttr); 98 shapeOpers.push_back(extVal); 99 } 100 auto xbox = rewriter.create<cg::XEmboxOp>( 101 loc, embox.getType(), embox.memref(), shapeOpers, llvm::None, 102 llvm::None, llvm::None, embox.lenParams()); 103 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 104 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 105 return mlir::success(); 106 } 107 108 mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, 109 mlir::PatternRewriter &rewriter, 110 mlir::Value shapeVal) const { 111 auto loc = embox.getLoc(); 112 auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()); 113 llvm::SmallVector<mlir::Value> shapeOpers; 114 llvm::SmallVector<mlir::Value> shiftOpers; 115 if (shapeOp) { 116 populateShape(shapeOpers, shapeOp); 117 } else { 118 auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()); 119 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 120 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 121 } 122 llvm::SmallVector<mlir::Value> sliceOpers; 123 llvm::SmallVector<mlir::Value> subcompOpers; 124 if (auto s = embox.getSlice()) 125 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 126 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 127 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 128 } 129 auto xbox = rewriter.create<cg::XEmboxOp>( 130 loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers, 131 sliceOpers, subcompOpers, embox.lenParams()); 132 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 133 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 134 return mlir::success(); 135 } 136 }; 137 138 /// Convert fir.rebox to the extended form where necessary. 139 /// 140 /// For example, 141 /// ``` 142 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>> 143 /// ``` 144 /// converted to 145 /// ``` 146 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, index, index) -> !fir.box<!fir.array<?xi32>> 147 /// ``` 148 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> { 149 public: 150 using OpRewritePattern::OpRewritePattern; 151 152 mlir::LogicalResult 153 matchAndRewrite(ReboxOp rebox, 154 mlir::PatternRewriter &rewriter) const override { 155 auto loc = rebox.getLoc(); 156 llvm::SmallVector<mlir::Value> shapeOpers; 157 llvm::SmallVector<mlir::Value> shiftOpers; 158 if (auto shapeVal = rebox.shape()) { 159 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 160 populateShape(shapeOpers, shapeOp); 161 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 162 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 163 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 164 populateShift(shiftOpers, shiftOp); 165 else 166 return mlir::failure(); 167 } 168 llvm::SmallVector<mlir::Value> sliceOpers; 169 llvm::SmallVector<mlir::Value> subcompOpers; 170 if (auto s = rebox.slice()) 171 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 172 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 173 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 174 } 175 176 auto xRebox = rewriter.create<cg::XReboxOp>( 177 loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers, 178 subcompOpers); 179 LLVM_DEBUG(llvm::dbgs() 180 << "rewriting " << rebox << " to " << xRebox << '\n'); 181 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 182 return mlir::success(); 183 } 184 }; 185 186 /// Convert all fir.array_coor to the extended form. 187 /// 188 /// For example, 189 /// ``` 190 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 191 /// ``` 192 /// converted to 193 /// ``` 194 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> !fir.ref<i32> 195 /// ``` 196 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> { 197 public: 198 using OpRewritePattern::OpRewritePattern; 199 200 mlir::LogicalResult 201 matchAndRewrite(ArrayCoorOp arrCoor, 202 mlir::PatternRewriter &rewriter) const override { 203 auto loc = arrCoor.getLoc(); 204 llvm::SmallVector<mlir::Value> shapeOpers; 205 llvm::SmallVector<mlir::Value> shiftOpers; 206 if (auto shapeVal = arrCoor.shape()) { 207 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 208 populateShape(shapeOpers, shapeOp); 209 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 210 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 211 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 212 populateShift(shiftOpers, shiftOp); 213 else 214 return mlir::failure(); 215 } 216 llvm::SmallVector<mlir::Value> sliceOpers; 217 llvm::SmallVector<mlir::Value> subcompOpers; 218 if (auto s = arrCoor.slice()) 219 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 220 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 221 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 222 } 223 auto xArrCoor = rewriter.create<cg::XArrayCoorOp>( 224 loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers, 225 sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams()); 226 LLVM_DEBUG(llvm::dbgs() 227 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 228 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 229 return mlir::success(); 230 } 231 }; 232 233 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> { 234 public: 235 void runOnOperation() override final { 236 auto op = getOperation(); 237 auto &context = getContext(); 238 mlir::OpBuilder rewriter(&context); 239 mlir::ConversionTarget target(context); 240 target.addLegalDialect<FIROpsDialect, FIRCodeGenDialect, 241 mlir::StandardOpsDialect>(); 242 target.addIllegalOp<ArrayCoorOp>(); 243 target.addIllegalOp<ReboxOp>(); 244 target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) { 245 return !(embox.getShape() || 246 embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>()); 247 }); 248 mlir::OwningRewritePatternList patterns(&context); 249 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>( 250 &context); 251 if (mlir::failed( 252 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 253 mlir::emitError(mlir::UnknownLoc::get(&context), 254 "error in running the pre-codegen conversions"); 255 signalPassFailure(); 256 } 257 } 258 }; 259 260 } // namespace 261 262 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { 263 return std::make_unique<CodeGenRewrite>(); 264 } 265