1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGOps.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/CodeGen/CodeGen.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROps.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/Support/FIRContext.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/Support/Debug.h" 23 24 //===----------------------------------------------------------------------===// 25 // Codegen rewrite: rewriting of subgraphs of ops 26 //===----------------------------------------------------------------------===// 27 28 using namespace fir; 29 30 #define DEBUG_TYPE "flang-codegen-rewrite" 31 32 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 33 ShapeOp shape) { 34 vec.append(shape.extents().begin(), shape.extents().end()); 35 } 36 37 // Operands of fir.shape_shift split into two vectors. 38 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 39 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 40 ShapeShiftOp shift) { 41 auto endIter = shift.pairs().end(); 42 for (auto i = shift.pairs().begin(); i != endIter;) { 43 shiftVec.push_back(*i++); 44 shapeVec.push_back(*i++); 45 } 46 } 47 48 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 49 ShiftOp shift) { 50 vec.append(shift.origins().begin(), shift.origins().end()); 51 } 52 53 namespace { 54 55 /// Convert fir.embox to the extended form where necessary. 56 /// 57 /// The embox operation can take arguments that specify multidimensional array 58 /// properties at runtime. These properties may be shared between distinct 59 /// objects that have the same properties. Before we lower these small DAGs to 60 /// LLVM-IR, we gather all the information into a single extended operation. For 61 /// example, 62 /// ``` 63 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 64 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 65 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, 66 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 67 /// ``` 68 /// can be rewritten as 69 /// ``` 70 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : 71 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> 72 /// !fir.box<!fir.array<?xi32>> 73 /// ``` 74 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> { 75 public: 76 using OpRewritePattern::OpRewritePattern; 77 78 mlir::LogicalResult 79 matchAndRewrite(EmboxOp embox, 80 mlir::PatternRewriter &rewriter) const override { 81 auto shapeVal = embox.getShape(); 82 // If the embox does not include a shape, then do not convert it 83 if (shapeVal) 84 return rewriteDynamicShape(embox, rewriter, shapeVal); 85 if (auto boxTy = embox.getType().dyn_cast<BoxType>()) 86 if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>()) 87 if (seqTy.hasConstantShape()) 88 return rewriteStaticShape(embox, rewriter, seqTy); 89 return mlir::failure(); 90 } 91 92 mlir::LogicalResult rewriteStaticShape(EmboxOp embox, 93 mlir::PatternRewriter &rewriter, 94 SequenceType seqTy) const { 95 auto loc = embox.getLoc(); 96 llvm::SmallVector<mlir::Value> shapeOpers; 97 auto idxTy = rewriter.getIndexType(); 98 for (auto ext : seqTy.getShape()) { 99 auto iAttr = rewriter.getIndexAttr(ext); 100 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); 101 shapeOpers.push_back(extVal); 102 } 103 auto xbox = rewriter.create<cg::XEmboxOp>( 104 loc, embox.getType(), embox.memref(), shapeOpers, llvm::None, 105 llvm::None, llvm::None, llvm::None, embox.typeparams()); 106 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 107 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 108 return mlir::success(); 109 } 110 111 mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, 112 mlir::PatternRewriter &rewriter, 113 mlir::Value shapeVal) const { 114 auto loc = embox.getLoc(); 115 auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()); 116 llvm::SmallVector<mlir::Value> shapeOpers; 117 llvm::SmallVector<mlir::Value> shiftOpers; 118 if (shapeOp) { 119 populateShape(shapeOpers, shapeOp); 120 } else { 121 auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()); 122 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 123 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 124 } 125 llvm::SmallVector<mlir::Value> sliceOpers; 126 llvm::SmallVector<mlir::Value> subcompOpers; 127 llvm::SmallVector<mlir::Value> substrOpers; 128 if (auto s = embox.getSlice()) 129 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 130 sliceOpers.assign(sliceOp.triples().begin(), sliceOp.triples().end()); 131 subcompOpers.assign(sliceOp.fields().begin(), sliceOp.fields().end()); 132 substrOpers.assign(sliceOp.substr().begin(), sliceOp.substr().end()); 133 } 134 auto xbox = rewriter.create<cg::XEmboxOp>( 135 loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers, 136 sliceOpers, subcompOpers, substrOpers, embox.typeparams()); 137 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 138 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 139 return mlir::success(); 140 } 141 }; 142 143 /// Convert fir.rebox to the extended form where necessary. 144 /// 145 /// For example, 146 /// ``` 147 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> 148 /// !fir.box<!fir.array<?xi32>> 149 /// ``` 150 /// converted to 151 /// ``` 152 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, 153 /// index, index) -> !fir.box<!fir.array<?xi32>> 154 /// ``` 155 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> { 156 public: 157 using OpRewritePattern::OpRewritePattern; 158 159 mlir::LogicalResult 160 matchAndRewrite(ReboxOp rebox, 161 mlir::PatternRewriter &rewriter) const override { 162 auto loc = rebox.getLoc(); 163 llvm::SmallVector<mlir::Value> shapeOpers; 164 llvm::SmallVector<mlir::Value> shiftOpers; 165 if (auto shapeVal = rebox.shape()) { 166 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 167 populateShape(shapeOpers, shapeOp); 168 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 169 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 170 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 171 populateShift(shiftOpers, shiftOp); 172 else 173 return mlir::failure(); 174 } 175 llvm::SmallVector<mlir::Value> sliceOpers; 176 llvm::SmallVector<mlir::Value> subcompOpers; 177 llvm::SmallVector<mlir::Value> substrOpers; 178 if (auto s = rebox.slice()) 179 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 180 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 181 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 182 substrOpers.append(sliceOp.substr().begin(), sliceOp.substr().end()); 183 } 184 185 auto xRebox = rewriter.create<cg::XReboxOp>( 186 loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers, 187 subcompOpers, substrOpers); 188 LLVM_DEBUG(llvm::dbgs() 189 << "rewriting " << rebox << " to " << xRebox << '\n'); 190 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 191 return mlir::success(); 192 } 193 }; 194 195 /// Convert all fir.array_coor to the extended form. 196 /// 197 /// For example, 198 /// ``` 199 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, 200 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 201 /// ``` 202 /// converted to 203 /// ``` 204 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : 205 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> 206 /// !fir.ref<i32> 207 /// ``` 208 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> { 209 public: 210 using OpRewritePattern::OpRewritePattern; 211 212 mlir::LogicalResult 213 matchAndRewrite(ArrayCoorOp arrCoor, 214 mlir::PatternRewriter &rewriter) const override { 215 auto loc = arrCoor.getLoc(); 216 llvm::SmallVector<mlir::Value> shapeOpers; 217 llvm::SmallVector<mlir::Value> shiftOpers; 218 if (auto shapeVal = arrCoor.shape()) { 219 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 220 populateShape(shapeOpers, shapeOp); 221 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 222 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 223 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 224 populateShift(shiftOpers, shiftOp); 225 else 226 return mlir::failure(); 227 } 228 llvm::SmallVector<mlir::Value> sliceOpers; 229 llvm::SmallVector<mlir::Value> subcompOpers; 230 if (auto s = arrCoor.slice()) 231 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 232 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); 233 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); 234 assert(sliceOp.substr().empty() && 235 "Don't allow substring operations on array_coor. This " 236 "restriction may be lifted in the future."); 237 } 238 auto xArrCoor = rewriter.create<cg::XArrayCoorOp>( 239 loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers, 240 sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.typeparams()); 241 LLVM_DEBUG(llvm::dbgs() 242 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 243 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 244 return mlir::success(); 245 } 246 }; 247 248 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> { 249 public: 250 void runOnOperation() override final { 251 auto op = getOperation(); 252 auto &context = getContext(); 253 mlir::OpBuilder rewriter(&context); 254 mlir::ConversionTarget target(context); 255 target.addLegalDialect<mlir::arith::ArithmeticDialect, FIROpsDialect, 256 FIRCodeGenDialect, mlir::StandardOpsDialect>(); 257 target.addIllegalOp<ArrayCoorOp>(); 258 target.addIllegalOp<ReboxOp>(); 259 target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) { 260 return !(embox.getShape() || 261 embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>()); 262 }); 263 mlir::OwningRewritePatternList patterns(&context); 264 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>( 265 &context); 266 if (mlir::failed( 267 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 268 mlir::emitError(mlir::UnknownLoc::get(&context), 269 "error in running the pre-codegen conversions"); 270 signalPassFailure(); 271 } 272 } 273 }; 274 275 } // namespace 276 277 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { 278 return std::make_unique<CodeGenRewrite>(); 279 } 280