17a7eacc7SStephan Herhut //===- ParallelLoopMapper.cpp - Utilities for mapping parallel loops to GPU =//
27a7eacc7SStephan Herhut //
37a7eacc7SStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a7eacc7SStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
57a7eacc7SStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a7eacc7SStephan Herhut //
77a7eacc7SStephan Herhut //===----------------------------------------------------------------------===//
87a7eacc7SStephan Herhut //
97a7eacc7SStephan Herhut // This file implements utilities to generate mappings for parallel loops to
107a7eacc7SStephan Herhut // GPU devices.
117a7eacc7SStephan Herhut //
127a7eacc7SStephan Herhut //===----------------------------------------------------------------------===//
137a7eacc7SStephan Herhut 
14d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/ParallelLoopMapper.h"
157a7eacc7SStephan Herhut 
16bcf3d524SChristian Sigg #include "PassDetail.h"
17d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
19*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
207a7eacc7SStephan Herhut #include "mlir/IR/AffineMap.h"
21bcf3d524SChristian Sigg 
2246bb6613SMaheshRavishankar namespace mlir {
2346bb6613SMaheshRavishankar 
24bcf3d524SChristian Sigg using scf::ParallelOp;
2546bb6613SMaheshRavishankar 
getMappingAttrName()26bcf3d524SChristian Sigg StringRef gpu::getMappingAttrName() { return "mapping"; }
27bcf3d524SChristian Sigg 
287bdd3722SMogball LogicalResult
setMappingAttr(ParallelOp ploopOp,ArrayRef<ParallelLoopDimMappingAttr> mapping)297bdd3722SMogball gpu::setMappingAttr(ParallelOp ploopOp,
307bdd3722SMogball                     ArrayRef<ParallelLoopDimMappingAttr> mapping) {
3146bb6613SMaheshRavishankar   // Verify that each processor is mapped to only once.
3246bb6613SMaheshRavishankar   llvm::DenseSet<gpu::Processor> specifiedMappings;
3346bb6613SMaheshRavishankar   for (auto dimAttr : mapping) {
347bdd3722SMogball     gpu::Processor processor = dimAttr.getProcessor();
3546bb6613SMaheshRavishankar     if (processor != gpu::Processor::Sequential &&
3646bb6613SMaheshRavishankar         specifiedMappings.count(processor))
3746bb6613SMaheshRavishankar       return ploopOp.emitError(
3846bb6613SMaheshRavishankar           "invalid mapping multiple loops to same processor");
3946bb6613SMaheshRavishankar   }
4046bb6613SMaheshRavishankar   ArrayRef<Attribute> mappingAsAttrs(mapping.data(), mapping.size());
411ffc1aaaSChristian Sigg   ploopOp->setAttr(getMappingAttrName(),
42c2c83e97STres Popp                    ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs));
4346bb6613SMaheshRavishankar   return success();
4446bb6613SMaheshRavishankar }
4546bb6613SMaheshRavishankar 
46bcf3d524SChristian Sigg namespace gpu {
477a7eacc7SStephan Herhut namespace {
487a7eacc7SStephan Herhut enum MappingLevel { MapGrid = 0, MapBlock = 1, Sequential = 2 };
49bcf3d524SChristian Sigg } // namespace
507a7eacc7SStephan Herhut 
517a7eacc7SStephan Herhut static constexpr int kNumHardwareIds = 3;
527a7eacc7SStephan Herhut 
537a7eacc7SStephan Herhut /// Bounded increment on MappingLevel. Increments to the next
547a7eacc7SStephan Herhut /// level unless Sequential was already reached.
operator ++(MappingLevel & mappingLevel)55bcf3d524SChristian Sigg static MappingLevel &operator++(MappingLevel &mappingLevel) {
567a7eacc7SStephan Herhut   if (mappingLevel < Sequential) {
577a7eacc7SStephan Herhut     mappingLevel = static_cast<MappingLevel>(mappingLevel + 1);
587a7eacc7SStephan Herhut   }
597a7eacc7SStephan Herhut   return mappingLevel;
607a7eacc7SStephan Herhut }
617a7eacc7SStephan Herhut 
627a7eacc7SStephan Herhut /// Computed the hardware id to use for a given mapping level. Will
637a7eacc7SStephan Herhut /// assign x,y and z hardware ids for the first 3 dimensions and use
647a7eacc7SStephan Herhut /// sequential after.
659db53a18SRiver Riddle /// TODO: Make this use x for the inner-most loop that is
6646bb6613SMaheshRavishankar /// distributed to map to x, the next innermost to y and the next innermost to
6746bb6613SMaheshRavishankar /// z.
getHardwareIdForMapping(MappingLevel level,int dimension)68bcf3d524SChristian Sigg static Processor getHardwareIdForMapping(MappingLevel level, int dimension) {
6946bb6613SMaheshRavishankar 
707a7eacc7SStephan Herhut   if (dimension >= kNumHardwareIds || level == Sequential)
7146bb6613SMaheshRavishankar     return Processor::Sequential;
7246bb6613SMaheshRavishankar   switch (level) {
7346bb6613SMaheshRavishankar   case MapGrid:
7446bb6613SMaheshRavishankar     switch (dimension) {
7546bb6613SMaheshRavishankar     case 0:
7646bb6613SMaheshRavishankar       return Processor::BlockX;
7746bb6613SMaheshRavishankar     case 1:
7846bb6613SMaheshRavishankar       return Processor::BlockY;
7946bb6613SMaheshRavishankar     case 2:
8046bb6613SMaheshRavishankar       return Processor::BlockZ;
8146bb6613SMaheshRavishankar     default:
8246bb6613SMaheshRavishankar       return Processor::Sequential;
8346bb6613SMaheshRavishankar     }
8446bb6613SMaheshRavishankar     break;
8546bb6613SMaheshRavishankar   case MapBlock:
8646bb6613SMaheshRavishankar     switch (dimension) {
8746bb6613SMaheshRavishankar     case 0:
8846bb6613SMaheshRavishankar       return Processor::ThreadX;
8946bb6613SMaheshRavishankar     case 1:
9046bb6613SMaheshRavishankar       return Processor::ThreadY;
9146bb6613SMaheshRavishankar     case 2:
9246bb6613SMaheshRavishankar       return Processor::ThreadZ;
9346bb6613SMaheshRavishankar     default:
9446bb6613SMaheshRavishankar       return Processor::Sequential;
9546bb6613SMaheshRavishankar     }
9646bb6613SMaheshRavishankar   default:;
9746bb6613SMaheshRavishankar   }
9846bb6613SMaheshRavishankar   return Processor::Sequential;
997a7eacc7SStephan Herhut }
1007a7eacc7SStephan Herhut 
1017a7eacc7SStephan Herhut /// Add mapping information to the given parallel loop. Do not add
1027a7eacc7SStephan Herhut /// mapping information if the loop already has it. Also, don't
1037a7eacc7SStephan Herhut /// start a mapping at a nested loop.
mapParallelOp(ParallelOp parallelOp,MappingLevel mappingLevel=MapGrid)1047a7eacc7SStephan Herhut static void mapParallelOp(ParallelOp parallelOp,
1057a7eacc7SStephan Herhut                           MappingLevel mappingLevel = MapGrid) {
1067a7eacc7SStephan Herhut   // Do not try to add a mapping to already mapped loops or nested loops.
1071ffc1aaaSChristian Sigg   if (parallelOp->getAttr(getMappingAttrName()) ||
1080bf4a82aSChristian Sigg       ((mappingLevel == MapGrid) && parallelOp->getParentOfType<ParallelOp>()))
1097a7eacc7SStephan Herhut     return;
1107a7eacc7SStephan Herhut 
1117a7eacc7SStephan Herhut   MLIRContext *ctx = parallelOp.getContext();
1127a7eacc7SStephan Herhut   Builder b(ctx);
1137bdd3722SMogball   SmallVector<ParallelLoopDimMappingAttr, 4> attrs;
114c2d03e4eSAlexander Belyaev   attrs.reserve(parallelOp.getNumLoops());
115c2d03e4eSAlexander Belyaev   for (int i = 0, e = parallelOp.getNumLoops(); i < e; ++i) {
1167bdd3722SMogball     attrs.push_back(b.getAttr<ParallelLoopDimMappingAttr>(
11746bb6613SMaheshRavishankar         getHardwareIdForMapping(mappingLevel, i), b.getDimIdentityMap(),
11846bb6613SMaheshRavishankar         b.getDimIdentityMap()));
1197a7eacc7SStephan Herhut   }
120e21adfa3SRiver Riddle   (void)setMappingAttr(parallelOp, attrs);
1217a7eacc7SStephan Herhut   ++mappingLevel;
1227a7eacc7SStephan Herhut   // Parallel loop operations are immediately nested, so do not use
1237a7eacc7SStephan Herhut   // walk but just iterate over the operations.
1247a7eacc7SStephan Herhut   for (Operation &op : *parallelOp.getBody()) {
1257a7eacc7SStephan Herhut     if (ParallelOp nested = dyn_cast<ParallelOp>(op))
1267a7eacc7SStephan Herhut       mapParallelOp(nested, mappingLevel);
1277a7eacc7SStephan Herhut   }
1287a7eacc7SStephan Herhut }
1297a7eacc7SStephan Herhut 
130bcf3d524SChristian Sigg namespace {
131bcf3d524SChristian Sigg struct GpuMapParallelLoopsPass
132bcf3d524SChristian Sigg     : public GpuMapParallelLoopsPassBase<GpuMapParallelLoopsPass> {
runOnOperationmlir::gpu::__anond8c028e30211::GpuMapParallelLoopsPass133bcf3d524SChristian Sigg   void runOnOperation() override {
134bcf3d524SChristian Sigg     for (Region &region : getOperation()->getRegions()) {
1357a7eacc7SStephan Herhut       region.walk([](ParallelOp parallelOp) { mapParallelOp(parallelOp); });
1367a7eacc7SStephan Herhut     }
137bcf3d524SChristian Sigg   }
138bcf3d524SChristian Sigg };
139bcf3d524SChristian Sigg 
140bcf3d524SChristian Sigg } // namespace
141bcf3d524SChristian Sigg } // namespace gpu
142bcf3d524SChristian Sigg } // namespace mlir
143bcf3d524SChristian Sigg 
144bcf3d524SChristian Sigg std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createGpuMapParallelLoopsPass()145bcf3d524SChristian Sigg mlir::createGpuMapParallelLoopsPass() {
146bcf3d524SChristian Sigg   return std::make_unique<gpu::GpuMapParallelLoopsPass>();
147bcf3d524SChristian Sigg }
148