1*7a7eacc7SStephan Herhut //===- ParallelLoopMapper.cpp - Utilities for mapping parallel loops to GPU =//
2*7a7eacc7SStephan Herhut //
3*7a7eacc7SStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7a7eacc7SStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
5*7a7eacc7SStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7a7eacc7SStephan Herhut //
7*7a7eacc7SStephan Herhut //===----------------------------------------------------------------------===//
8*7a7eacc7SStephan Herhut //
9*7a7eacc7SStephan Herhut // This file implements utilities to generate mappings for parallel loops to
10*7a7eacc7SStephan Herhut // GPU devices.
11*7a7eacc7SStephan Herhut //
12*7a7eacc7SStephan Herhut //===----------------------------------------------------------------------===//
13*7a7eacc7SStephan Herhut 
14*7a7eacc7SStephan Herhut #include "mlir/Dialect/GPU/ParallelLoopMapper.h"
15*7a7eacc7SStephan Herhut 
16*7a7eacc7SStephan Herhut #include "mlir/Dialect/GPU/GPUDialect.h"
17*7a7eacc7SStephan Herhut #include "mlir/Dialect/GPU/Passes.h"
18*7a7eacc7SStephan Herhut #include "mlir/Dialect/LoopOps/LoopOps.h"
19*7a7eacc7SStephan Herhut #include "mlir/IR/AffineMap.h"
20*7a7eacc7SStephan Herhut #include "mlir/Pass/Pass.h"
21*7a7eacc7SStephan Herhut 
22*7a7eacc7SStephan Herhut using namespace mlir;
23*7a7eacc7SStephan Herhut using namespace mlir::gpu;
24*7a7eacc7SStephan Herhut using namespace mlir::loop;
25*7a7eacc7SStephan Herhut 
26*7a7eacc7SStephan Herhut namespace {
27*7a7eacc7SStephan Herhut 
28*7a7eacc7SStephan Herhut enum MappingLevel { MapGrid = 0, MapBlock = 1, Sequential = 2 };
29*7a7eacc7SStephan Herhut 
30*7a7eacc7SStephan Herhut static constexpr int kNumHardwareIds = 3;
31*7a7eacc7SStephan Herhut 
32*7a7eacc7SStephan Herhut } // namespace
33*7a7eacc7SStephan Herhut 
34*7a7eacc7SStephan Herhut /// Bounded increment on MappingLevel. Increments to the next
35*7a7eacc7SStephan Herhut /// level unless Sequential was already reached.
36*7a7eacc7SStephan Herhut MappingLevel &operator++(MappingLevel &mappingLevel) {
37*7a7eacc7SStephan Herhut   if (mappingLevel < Sequential) {
38*7a7eacc7SStephan Herhut     mappingLevel = static_cast<MappingLevel>(mappingLevel + 1);
39*7a7eacc7SStephan Herhut   }
40*7a7eacc7SStephan Herhut   return mappingLevel;
41*7a7eacc7SStephan Herhut }
42*7a7eacc7SStephan Herhut 
43*7a7eacc7SStephan Herhut /// Computed the hardware id to use for a given mapping level. Will
44*7a7eacc7SStephan Herhut /// assign x,y and z hardware ids for the first 3 dimensions and use
45*7a7eacc7SStephan Herhut /// sequential after.
46*7a7eacc7SStephan Herhut static int64_t getHardwareIdForMapping(MappingLevel level, int dimension) {
47*7a7eacc7SStephan Herhut   if (dimension >= kNumHardwareIds || level == Sequential)
48*7a7eacc7SStephan Herhut     return Sequential * kNumHardwareIds;
49*7a7eacc7SStephan Herhut   return (level * kNumHardwareIds) + dimension;
50*7a7eacc7SStephan Herhut }
51*7a7eacc7SStephan Herhut 
52*7a7eacc7SStephan Herhut /// Add mapping information to the given parallel loop. Do not add
53*7a7eacc7SStephan Herhut /// mapping information if the loop already has it. Also, don't
54*7a7eacc7SStephan Herhut /// start a mapping at a nested loop.
55*7a7eacc7SStephan Herhut static void mapParallelOp(ParallelOp parallelOp,
56*7a7eacc7SStephan Herhut                           MappingLevel mappingLevel = MapGrid) {
57*7a7eacc7SStephan Herhut   // Do not try to add a mapping to already mapped loops or nested loops.
58*7a7eacc7SStephan Herhut   if (parallelOp.getAttr(gpu::kMappingAttributeName) ||
59*7a7eacc7SStephan Herhut       ((mappingLevel == MapGrid) && parallelOp.getParentOfType<ParallelOp>()))
60*7a7eacc7SStephan Herhut     return;
61*7a7eacc7SStephan Herhut 
62*7a7eacc7SStephan Herhut   MLIRContext *ctx = parallelOp.getContext();
63*7a7eacc7SStephan Herhut   Builder b(ctx);
64*7a7eacc7SStephan Herhut   SmallVector<Attribute, 4> attrs;
65*7a7eacc7SStephan Herhut   attrs.reserve(parallelOp.getNumInductionVars());
66*7a7eacc7SStephan Herhut   for (int i = 0, e = parallelOp.getNumInductionVars(); i < e; ++i) {
67*7a7eacc7SStephan Herhut     SmallVector<NamedAttribute, 3> entries;
68*7a7eacc7SStephan Herhut     entries.emplace_back(b.getNamedAttr(
69*7a7eacc7SStephan Herhut         kProcessorEntryName,
70*7a7eacc7SStephan Herhut         b.getI64IntegerAttr(getHardwareIdForMapping(mappingLevel, i))));
71*7a7eacc7SStephan Herhut     entries.emplace_back(b.getNamedAttr(
72*7a7eacc7SStephan Herhut         kIndexMapEntryName, AffineMapAttr::get(b.getDimIdentityMap())));
73*7a7eacc7SStephan Herhut     entries.emplace_back(b.getNamedAttr(
74*7a7eacc7SStephan Herhut         kBoundMapEntryName, AffineMapAttr::get(b.getDimIdentityMap())));
75*7a7eacc7SStephan Herhut     attrs.push_back(DictionaryAttr::get(entries, ctx));
76*7a7eacc7SStephan Herhut   }
77*7a7eacc7SStephan Herhut   parallelOp.setAttr(kMappingAttributeName, ArrayAttr::get(attrs, ctx));
78*7a7eacc7SStephan Herhut   ++mappingLevel;
79*7a7eacc7SStephan Herhut   // Parallel loop operations are immediately nested, so do not use
80*7a7eacc7SStephan Herhut   // walk but just iterate over the operations.
81*7a7eacc7SStephan Herhut   for (Operation &op : *parallelOp.getBody()) {
82*7a7eacc7SStephan Herhut     if (ParallelOp nested = dyn_cast<ParallelOp>(op))
83*7a7eacc7SStephan Herhut       mapParallelOp(nested, mappingLevel);
84*7a7eacc7SStephan Herhut   }
85*7a7eacc7SStephan Herhut }
86*7a7eacc7SStephan Herhut 
87*7a7eacc7SStephan Herhut void mlir::greedilyMapParallelLoopsToGPU(Region &region) {
88*7a7eacc7SStephan Herhut   region.walk([](ParallelOp parallelOp) { mapParallelOp(parallelOp); });
89*7a7eacc7SStephan Herhut }
90