1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGOps.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/CodeGen/CodeGen.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROps.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/Support/FIRContext.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/STLExtras.h" 22 23 //===----------------------------------------------------------------------===// 24 // Codegen rewrite: rewriting of subgraphs of ops 25 //===----------------------------------------------------------------------===// 26 27 using namespace fir; 28 29 #define DEBUG_TYPE "flang-codegen-rewrite" 30 31 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 32 ShapeOp shape) { 33 vec.append(shape.extents().begin(), shape.extents().end()); 34 } 35 36 // Operands of fir.shape_shift split into two vectors. 37 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 38 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 39 ShapeShiftOp shift) { 40 auto endIter = shift.pairs().end(); 41 for (auto i = shift.pairs().begin(); i != endIter;) { 42 shiftVec.push_back(*i++); 43 shapeVec.push_back(*i++); 44 } 45 } 46 47 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 48 ShiftOp shift) { 49 vec.append(shift.origins().begin(), shift.origins().end()); 50 } 51 52 namespace { 53 54 /// Convert fir.embox to the extended form where necessary. 55 /// 56 /// The embox operation can take arguments that specify multidimensional array 57 /// properties at runtime. These properties may be shared between distinct 58 /// objects that have the same properties. Before we lower these small DAGs to 59 /// LLVM-IR, we gather all the information into a single extended operation. For 60 /// example, 61 /// ``` 62 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 63 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 64 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 65 /// ``` 66 /// can be rewritten as 67 /// ``` 68 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>> 69 /// ``` 70 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> { 71 public: 72 using OpRewritePattern::OpRewritePattern; 73 74 mlir::LogicalResult 75 matchAndRewrite(EmboxOp embox, 76 mlir::PatternRewriter &rewriter) const override { 77 auto shapeVal = embox.getShape(); 78 // If the embox does not include a shape, then do not convert it 79 if (shapeVal) 80 return rewriteDynamicShape(embox, rewriter, shapeVal); 81 if (auto boxTy = embox.getType().dyn_cast<BoxType>()) 82 if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>()) 83 if (seqTy.hasConstantShape()) 84 return rewriteStaticShape(embox, rewriter, seqTy); 85 return mlir::failure(); 86 } 87 88 mlir::LogicalResult rewriteStaticShape(EmboxOp embox, 89 mlir::PatternRewriter &rewriter, 90 SequenceType seqTy) const { 91 auto loc = embox.getLoc(); 92 llvm::SmallVector<mlir::Value> shapeOpers; 93 auto idxTy = rewriter.getIndexType(); 94 for (auto ext : seqTy.getShape()) { 95 auto iAttr = rewriter.getIndexAttr(ext); 96 auto extVal = rewriter.create<mlir::ConstantOp>(loc, idxTy, iAttr); 97 shapeOpers.push_back(extVal); 98 } 99 auto xbox = rewriter.create<cg::XEmboxOp>( 100 loc, embox.getType(), embox.memref(), shapeOpers, llvm::None, 101 llvm::None, llvm::None, embox.lenParams()); 102 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 103 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 104 return mlir::success(); 105 } 106 107 mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, 108 mlir::PatternRewriter &rewriter, 109 mlir::Value shapeVal) const { 110 auto loc = embox.getLoc(); 111 auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()); 112 llvm::SmallVector<mlir::Value> shapeOpers; 113 llvm::SmallVector<mlir::Value> shiftOpers; 114 if (shapeOp) { 115 populateShape(shapeOpers, shapeOp); 116 } else { 117 auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()); 118 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 119 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 120 } 121 llvm::SmallVector<mlir::Value> sliceOpers; 122 llvm::SmallVector<mlir::Value> subcompOpers; 123 if (auto s = embox.getSlice()) 124 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 125 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 126 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 127 } 128 auto xbox = rewriter.create<cg::XEmboxOp>( 129 loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers, 130 sliceOpers, subcompOpers, embox.lenParams()); 131 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 132 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 133 return mlir::success(); 134 } 135 }; 136 137 /// Convert fir.rebox to the extended form where necessary. 138 /// 139 /// For example, 140 /// ``` 141 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>> 142 /// ``` 143 /// converted to 144 /// ``` 145 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, index, index) -> !fir.box<!fir.array<?xi32>> 146 /// ``` 147 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> { 148 public: 149 using OpRewritePattern::OpRewritePattern; 150 151 mlir::LogicalResult 152 matchAndRewrite(ReboxOp rebox, 153 mlir::PatternRewriter &rewriter) const override { 154 auto loc = rebox.getLoc(); 155 llvm::SmallVector<mlir::Value> shapeOpers; 156 llvm::SmallVector<mlir::Value> shiftOpers; 157 if (auto shapeVal = rebox.shape()) { 158 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 159 populateShape(shapeOpers, shapeOp); 160 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 161 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 162 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 163 populateShift(shiftOpers, shiftOp); 164 else 165 return mlir::failure(); 166 } 167 llvm::SmallVector<mlir::Value> sliceOpers; 168 llvm::SmallVector<mlir::Value> subcompOpers; 169 if (auto s = rebox.slice()) 170 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 171 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 172 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 173 } 174 175 auto xRebox = rewriter.create<cg::XReboxOp>( 176 loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers, 177 subcompOpers); 178 LLVM_DEBUG(llvm::dbgs() 179 << "rewriting " << rebox << " to " << xRebox << '\n'); 180 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 181 return mlir::success(); 182 } 183 }; 184 185 /// Convert all fir.array_coor to the extended form. 186 /// 187 /// For example, 188 /// ``` 189 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 190 /// ``` 191 /// converted to 192 /// ``` 193 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> !fir.ref<i32> 194 /// ``` 195 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> { 196 public: 197 using OpRewritePattern::OpRewritePattern; 198 199 mlir::LogicalResult 200 matchAndRewrite(ArrayCoorOp arrCoor, 201 mlir::PatternRewriter &rewriter) const override { 202 auto loc = arrCoor.getLoc(); 203 llvm::SmallVector<mlir::Value> shapeOpers; 204 llvm::SmallVector<mlir::Value> shiftOpers; 205 if (auto shapeVal = arrCoor.shape()) { 206 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 207 populateShape(shapeOpers, shapeOp); 208 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 209 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 210 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 211 populateShift(shiftOpers, shiftOp); 212 else 213 return mlir::failure(); 214 } 215 llvm::SmallVector<mlir::Value> sliceOpers; 216 llvm::SmallVector<mlir::Value> subcompOpers; 217 if (auto s = arrCoor.slice()) 218 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 219 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 220 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 221 } 222 auto xArrCoor = rewriter.create<cg::XArrayCoorOp>( 223 loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers, 224 sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams()); 225 LLVM_DEBUG(llvm::dbgs() 226 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 227 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 228 return mlir::success(); 229 } 230 }; 231 232 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> { 233 public: 234 void runOnOperation() override final { 235 auto op = getOperation(); 236 auto &context = getContext(); 237 mlir::OpBuilder rewriter(&context); 238 mlir::ConversionTarget target(context); 239 target.addLegalDialect<FIROpsDialect, FIRCodeGenDialect, 240 mlir::StandardOpsDialect>(); 241 target.addIllegalOp<ArrayCoorOp>(); 242 target.addIllegalOp<ReboxOp>(); 243 target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) { 244 return !(embox.getShape() || 245 embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>()); 246 }); 247 mlir::OwningRewritePatternList patterns(&context); 248 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>( 249 &context); 250 if (mlir::failed( 251 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 252 mlir::emitError(mlir::UnknownLoc::get(&context), 253 "error in running the pre-codegen conversions"); 254 signalPassFailure(); 255 } 256 } 257 }; 258 259 } // namespace 260 261 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { 262 return std::make_unique<CodeGenRewrite>(); 263 } 264