1 //===-- PreCGRewrite.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGOps.h" 14 #include "PassDetail.h" 15 #include "flang/Optimizer/CodeGen/CodeGen.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROps.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/Support/FIRContext.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/Support/Debug.h" 23 24 //===----------------------------------------------------------------------===// 25 // Codegen rewrite: rewriting of subgraphs of ops 26 //===----------------------------------------------------------------------===// 27 28 using namespace fir; 29 30 #define DEBUG_TYPE "flang-codegen-rewrite" 31 32 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, 33 ShapeOp shape) { 34 vec.append(shape.getExtents().begin(), shape.getExtents().end()); 35 } 36 37 // Operands of fir.shape_shift split into two vectors. 38 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, 39 llvm::SmallVectorImpl<mlir::Value> &shiftVec, 40 ShapeShiftOp shift) { 41 auto endIter = shift.getPairs().end(); 42 for (auto i = shift.getPairs().begin(); i != endIter;) { 43 shiftVec.push_back(*i++); 44 shapeVec.push_back(*i++); 45 } 46 } 47 48 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, 49 ShiftOp shift) { 50 vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); 51 } 52 53 namespace { 54 55 /// Convert fir.embox to the extended form where necessary. 56 /// 57 /// The embox operation can take arguments that specify multidimensional array 58 /// properties at runtime. These properties may be shared between distinct 59 /// objects that have the same properties. Before we lower these small DAGs to 60 /// LLVM-IR, we gather all the information into a single extended operation. For 61 /// example, 62 /// ``` 63 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> 64 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> 65 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, 66 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> 67 /// ``` 68 /// can be rewritten as 69 /// ``` 70 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : 71 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> 72 /// !fir.box<!fir.array<?xi32>> 73 /// ``` 74 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> { 75 public: 76 using OpRewritePattern::OpRewritePattern; 77 78 mlir::LogicalResult 79 matchAndRewrite(EmboxOp embox, 80 mlir::PatternRewriter &rewriter) const override { 81 auto shapeVal = embox.getShape(); 82 // If the embox does not include a shape, then do not convert it 83 if (shapeVal) 84 return rewriteDynamicShape(embox, rewriter, shapeVal); 85 if (auto boxTy = embox.getType().dyn_cast<BoxType>()) 86 if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>()) 87 if (seqTy.hasConstantShape()) 88 return rewriteStaticShape(embox, rewriter, seqTy); 89 return mlir::failure(); 90 } 91 92 mlir::LogicalResult rewriteStaticShape(EmboxOp embox, 93 mlir::PatternRewriter &rewriter, 94 SequenceType seqTy) const { 95 auto loc = embox.getLoc(); 96 llvm::SmallVector<mlir::Value> shapeOpers; 97 auto idxTy = rewriter.getIndexType(); 98 for (auto ext : seqTy.getShape()) { 99 auto iAttr = rewriter.getIndexAttr(ext); 100 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); 101 shapeOpers.push_back(extVal); 102 } 103 auto xbox = rewriter.create<cg::XEmboxOp>( 104 loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None, 105 llvm::None, llvm::None, llvm::None, embox.getTypeparams()); 106 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 107 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 108 return mlir::success(); 109 } 110 111 mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, 112 mlir::PatternRewriter &rewriter, 113 mlir::Value shapeVal) const { 114 auto loc = embox.getLoc(); 115 auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()); 116 llvm::SmallVector<mlir::Value> shapeOpers; 117 llvm::SmallVector<mlir::Value> shiftOpers; 118 if (shapeOp) { 119 populateShape(shapeOpers, shapeOp); 120 } else { 121 auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()); 122 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); 123 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 124 } 125 llvm::SmallVector<mlir::Value> sliceOpers; 126 llvm::SmallVector<mlir::Value> subcompOpers; 127 llvm::SmallVector<mlir::Value> substrOpers; 128 if (auto s = embox.getSlice()) 129 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 130 sliceOpers.assign(sliceOp.getTriples().begin(), 131 sliceOp.getTriples().end()); 132 subcompOpers.assign(sliceOp.getFields().begin(), 133 sliceOp.getFields().end()); 134 substrOpers.assign(sliceOp.getSubstr().begin(), 135 sliceOp.getSubstr().end()); 136 } 137 auto xbox = rewriter.create<cg::XEmboxOp>( 138 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, 139 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams()); 140 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); 141 rewriter.replaceOp(embox, xbox.getOperation()->getResults()); 142 return mlir::success(); 143 } 144 }; 145 146 /// Convert fir.rebox to the extended form where necessary. 147 /// 148 /// For example, 149 /// ``` 150 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> 151 /// !fir.box<!fir.array<?xi32>> 152 /// ``` 153 /// converted to 154 /// ``` 155 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, 156 /// index, index) -> !fir.box<!fir.array<?xi32>> 157 /// ``` 158 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> { 159 public: 160 using OpRewritePattern::OpRewritePattern; 161 162 mlir::LogicalResult 163 matchAndRewrite(ReboxOp rebox, 164 mlir::PatternRewriter &rewriter) const override { 165 auto loc = rebox.getLoc(); 166 llvm::SmallVector<mlir::Value> shapeOpers; 167 llvm::SmallVector<mlir::Value> shiftOpers; 168 if (auto shapeVal = rebox.getShape()) { 169 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 170 populateShape(shapeOpers, shapeOp); 171 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 172 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 173 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 174 populateShift(shiftOpers, shiftOp); 175 else 176 return mlir::failure(); 177 } 178 llvm::SmallVector<mlir::Value> sliceOpers; 179 llvm::SmallVector<mlir::Value> subcompOpers; 180 llvm::SmallVector<mlir::Value> substrOpers; 181 if (auto s = rebox.getSlice()) 182 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 183 sliceOpers.append(sliceOp.getTriples().begin(), 184 sliceOp.getTriples().end()); 185 subcompOpers.append(sliceOp.getFields().begin(), 186 sliceOp.getFields().end()); 187 substrOpers.append(sliceOp.getSubstr().begin(), 188 sliceOp.getSubstr().end()); 189 } 190 191 auto xRebox = rewriter.create<cg::XReboxOp>( 192 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, 193 sliceOpers, subcompOpers, substrOpers); 194 LLVM_DEBUG(llvm::dbgs() 195 << "rewriting " << rebox << " to " << xRebox << '\n'); 196 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); 197 return mlir::success(); 198 } 199 }; 200 201 /// Convert all fir.array_coor to the extended form. 202 /// 203 /// For example, 204 /// ``` 205 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, 206 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> 207 /// ``` 208 /// converted to 209 /// ``` 210 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : 211 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> 212 /// !fir.ref<i32> 213 /// ``` 214 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> { 215 public: 216 using OpRewritePattern::OpRewritePattern; 217 218 mlir::LogicalResult 219 matchAndRewrite(ArrayCoorOp arrCoor, 220 mlir::PatternRewriter &rewriter) const override { 221 auto loc = arrCoor.getLoc(); 222 llvm::SmallVector<mlir::Value> shapeOpers; 223 llvm::SmallVector<mlir::Value> shiftOpers; 224 if (auto shapeVal = arrCoor.getShape()) { 225 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp())) 226 populateShape(shapeOpers, shapeOp); 227 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp())) 228 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); 229 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp())) 230 populateShift(shiftOpers, shiftOp); 231 else 232 return mlir::failure(); 233 } 234 llvm::SmallVector<mlir::Value> sliceOpers; 235 llvm::SmallVector<mlir::Value> subcompOpers; 236 if (auto s = arrCoor.getSlice()) 237 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) { 238 sliceOpers.append(sliceOp.getTriples().begin(), 239 sliceOp.getTriples().end()); 240 subcompOpers.append(sliceOp.getFields().begin(), 241 sliceOp.getFields().end()); 242 assert(sliceOp.getSubstr().empty() && 243 "Don't allow substring operations on array_coor. This " 244 "restriction may be lifted in the future."); 245 } 246 auto xArrCoor = rewriter.create<cg::XArrayCoorOp>( 247 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, 248 sliceOpers, subcompOpers, arrCoor.getIndices(), 249 arrCoor.getTypeparams()); 250 LLVM_DEBUG(llvm::dbgs() 251 << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); 252 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); 253 return mlir::success(); 254 } 255 }; 256 257 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> { 258 public: 259 void runOnOperation() override final { 260 auto op = getOperation(); 261 auto &context = getContext(); 262 mlir::OpBuilder rewriter(&context); 263 mlir::ConversionTarget target(context); 264 target.addLegalDialect<mlir::arith::ArithmeticDialect, FIROpsDialect, 265 FIRCodeGenDialect, mlir::func::FuncDialect>(); 266 target.addIllegalOp<ArrayCoorOp>(); 267 target.addIllegalOp<ReboxOp>(); 268 target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) { 269 return !(embox.getShape() || 270 embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>()); 271 }); 272 mlir::RewritePatternSet patterns(&context); 273 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>( 274 &context); 275 if (mlir::failed( 276 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 277 mlir::emitError(mlir::UnknownLoc::get(&context), 278 "error in running the pre-codegen conversions"); 279 signalPassFailure(); 280 } 281 } 282 }; 283 284 } // namespace 285 286 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { 287 return std::make_unique<CodeGenRewrite>(); 288 } 289