1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGOps.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/CodeGen/CodeGen.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROps.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/Support/FIRContext.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/Support/Debug.h" 23 24 //===----------------------------------------------------------------------===// 25 // Codegen rewrite: rewriting of subgraphs of ops 26 //===----------------------------------------------------------------------===// 27 28 using namespace fir; 29 using namespace mlir; 30 31 #define DEBUG_TYPE "flang-codegen-rewrite" 32 33 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 34 ShapeOp shape) { 35 vec.append(shape.getExtents().begin(), shape.getExtents().end()); 36 } 37 38 // Operands of fir.shape_shift split into two vectors. 39 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 40 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 41 ShapeShiftOp shift) { 42 auto endIter = shift.getPairs().end(); 43 for (auto i = shift.getPairs().begin(); i != endIter;) { 44 shiftVec.push_back(*i++); 45 shapeVec.push_back(*i++); 46 } 47 } 48 49 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 50 ShiftOp shift) { 51 vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); 52 } 53 54 namespace { 55 56 /// Convert fir.embox to the extended form where necessary. 57 /// 58 /// The embox operation can take arguments that specify multidimensional array 59 /// properties at runtime. These properties may be shared between distinct 60 /// objects that have the same properties. Before we lower these small DAGs to 61 /// LLVM-IR, we gather all the information into a single extended operation. For 62 /// example, 63 /// ``` 64 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 65 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 66 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, 67 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 68 /// ``` 69 /// can be rewritten as 70 /// ``` 71 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : 72 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> 73 /// !fir.box<!fir.array<?xi32>> 74 /// ``` 75 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> { 76 public: 77 using OpRewritePattern::OpRewritePattern; 78 79 mlir::LogicalResult 80 matchAndRewrite(EmboxOp embox, 81 mlir::PatternRewriter &rewriter) const override { 82 auto shapeVal = embox.getShape(); 83 // If the embox does not include a shape, then do not convert it 84 if (shapeVal) 85 return rewriteDynamicShape(embox, rewriter, shapeVal); 86 if (auto boxTy = embox.getType().dyn_cast<BoxType>()) 87 if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>()) 88 if (seqTy.hasConstantShape()) 89 return rewriteStaticShape(embox, rewriter, seqTy); 90 return mlir::failure(); 91 } 92 93 mlir::LogicalResult rewriteStaticShape(EmboxOp embox, 94 mlir::PatternRewriter &rewriter, 95 SequenceType seqTy) const { 96 auto loc = embox.getLoc(); 97 llvm::SmallVector<mlir::Value> shapeOpers; 98 auto idxTy = rewriter.getIndexType(); 99 for (auto ext : seqTy.getShape()) { 100 auto iAttr = rewriter.getIndexAttr(ext); 101 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); 102 shapeOpers.push_back(extVal); 103 } 104 auto xbox = rewriter.create<cg::XEmboxOp>( 105 loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None, 106 llvm::None, llvm::None, llvm::None, embox.getTypeparams()); 107 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 108 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 109 return mlir::success(); 110 } 111 112 mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, 113 mlir::PatternRewriter &rewriter, 114 mlir::Value shapeVal) const { 115 auto loc = embox.getLoc(); 116 auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()); 117 llvm::SmallVector<mlir::Value> shapeOpers; 118 llvm::SmallVector<mlir::Value> shiftOpers; 119 if (shapeOp) { 120 populateShape(shapeOpers, shapeOp); 121 } else { 122 auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()); 123 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 124 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 125 } 126 llvm::SmallVector<mlir::Value> sliceOpers; 127 llvm::SmallVector<mlir::Value> subcompOpers; 128 llvm::SmallVector<mlir::Value> substrOpers; 129 if (auto s = embox.getSlice()) 130 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 131 sliceOpers.assign(sliceOp.getTriples().begin(), 132 sliceOp.getTriples().end()); 133 subcompOpers.assign(sliceOp.getFields().begin(), 134 sliceOp.getFields().end()); 135 substrOpers.assign(sliceOp.getSubstr().begin(), 136 sliceOp.getSubstr().end()); 137 } 138 auto xbox = rewriter.create<cg::XEmboxOp>( 139 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, 140 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams()); 141 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 142 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 143 return mlir::success(); 144 } 145 }; 146 147 /// Convert fir.rebox to the extended form where necessary. 148 /// 149 /// For example, 150 /// ``` 151 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> 152 /// !fir.box<!fir.array<?xi32>> 153 /// ``` 154 /// converted to 155 /// ``` 156 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, 157 /// index, index) -> !fir.box<!fir.array<?xi32>> 158 /// ``` 159 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> { 160 public: 161 using OpRewritePattern::OpRewritePattern; 162 163 mlir::LogicalResult 164 matchAndRewrite(ReboxOp rebox, 165 mlir::PatternRewriter &rewriter) const override { 166 auto loc = rebox.getLoc(); 167 llvm::SmallVector<mlir::Value> shapeOpers; 168 llvm::SmallVector<mlir::Value> shiftOpers; 169 if (auto shapeVal = rebox.getShape()) { 170 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 171 populateShape(shapeOpers, shapeOp); 172 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 173 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 174 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 175 populateShift(shiftOpers, shiftOp); 176 else 177 return mlir::failure(); 178 } 179 llvm::SmallVector<mlir::Value> sliceOpers; 180 llvm::SmallVector<mlir::Value> subcompOpers; 181 llvm::SmallVector<mlir::Value> substrOpers; 182 if (auto s = rebox.getSlice()) 183 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 184 sliceOpers.append(sliceOp.getTriples().begin(), 185 sliceOp.getTriples().end()); 186 subcompOpers.append(sliceOp.getFields().begin(), 187 sliceOp.getFields().end()); 188 substrOpers.append(sliceOp.getSubstr().begin(), 189 sliceOp.getSubstr().end()); 190 } 191 192 auto xRebox = rewriter.create<cg::XReboxOp>( 193 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, 194 sliceOpers, subcompOpers, substrOpers); 195 LLVM_DEBUG(llvm::dbgs() 196 << "rewriting " << rebox << " to " << xRebox << '\n'); 197 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 198 return mlir::success(); 199 } 200 }; 201 202 /// Convert all fir.array_coor to the extended form. 203 /// 204 /// For example, 205 /// ``` 206 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, 207 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 208 /// ``` 209 /// converted to 210 /// ``` 211 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : 212 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> 213 /// !fir.ref<i32> 214 /// ``` 215 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> { 216 public: 217 using OpRewritePattern::OpRewritePattern; 218 219 mlir::LogicalResult 220 matchAndRewrite(ArrayCoorOp arrCoor, 221 mlir::PatternRewriter &rewriter) const override { 222 auto loc = arrCoor.getLoc(); 223 llvm::SmallVector<mlir::Value> shapeOpers; 224 llvm::SmallVector<mlir::Value> shiftOpers; 225 if (auto shapeVal = arrCoor.getShape()) { 226 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 227 populateShape(shapeOpers, shapeOp); 228 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 229 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 230 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 231 populateShift(shiftOpers, shiftOp); 232 else 233 return mlir::failure(); 234 } 235 llvm::SmallVector<mlir::Value> sliceOpers; 236 llvm::SmallVector<mlir::Value> subcompOpers; 237 if (auto s = arrCoor.getSlice()) 238 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 239 sliceOpers.append(sliceOp.getTriples().begin(), 240 sliceOp.getTriples().end()); 241 subcompOpers.append(sliceOp.getFields().begin(), 242 sliceOp.getFields().end()); 243 assert(sliceOp.getSubstr().empty() && 244 "Don't allow substring operations on array_coor. This " 245 "restriction may be lifted in the future."); 246 } 247 auto xArrCoor = rewriter.create<cg::XArrayCoorOp>( 248 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, 249 sliceOpers, subcompOpers, arrCoor.getIndices(), 250 arrCoor.getTypeparams()); 251 LLVM_DEBUG(llvm::dbgs() 252 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 253 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 254 return mlir::success(); 255 } 256 }; 257 258 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> { 259 public: 260 void runOnOperation() override final { 261 auto op = getOperation(); 262 auto &context = getContext(); 263 mlir::OpBuilder rewriter(&context); 264 mlir::ConversionTarget target(context); 265 target.addLegalDialect<mlir::arith::ArithmeticDialect, FIROpsDialect, 266 FIRCodeGenDialect, mlir::func::FuncDialect>(); 267 target.addIllegalOp<ArrayCoorOp>(); 268 target.addIllegalOp<ReboxOp>(); 269 target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) { 270 return !(embox.getShape() || 271 embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>()); 272 }); 273 mlir::RewritePatternSet patterns(&context); 274 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>( 275 &context); 276 if (mlir::failed( 277 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 278 mlir::emitError(mlir::UnknownLoc::get(&context), 279 "error in running the pre-codegen conversions"); 280 signalPassFailure(); 281 } 282 } 283 }; 284 285 } // namespace 286 287 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { 288 return std::make_unique<CodeGenRewrite>(); 289 } 290