@@ -141,6 +141,20 @@ struct CufDeallocateOpConversion
141141 }
142142};
143143
144+ static bool inDeviceContext (mlir::Operation *op) {
145+ if (op->getParentOfType <cuf::KernelOp>())
146+ return true ;
147+ if (auto funcOp = op->getParentOfType <mlir::func::FuncOp>()) {
148+ if (auto cudaProcAttr =
149+ funcOp.getOperation ()->getAttrOfType <cuf::ProcAttributeAttr>(
150+ cuf::getProcAttrName ())) {
151+ return cudaProcAttr.getValue () != cuf::ProcAttribute::Host &&
152+ cudaProcAttr.getValue () != cuf::ProcAttribute::HostDevice;
153+ }
154+ }
155+ return false ;
156+ }
157+
144158struct CufAllocOpConversion : public mlir ::OpRewritePattern<cuf::AllocOp> {
145159 using OpRewritePattern::OpRewritePattern;
146160
@@ -157,6 +171,16 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
157171 if (!boxTy)
158172 return failure ();
159173
174+ if (inDeviceContext (op.getOperation ())) {
175+ // In device context just replace the cuf.alloc operation with a fir.alloc
176+ // the cuf.free will be removed.
177+ rewriter.replaceOpWithNewOp <fir::AllocaOp>(
178+ op, op.getInType (), op.getUniqName () ? *op.getUniqName () : " " ,
179+ op.getBindcName () ? *op.getBindcName () : " " , op.getTypeparams (),
180+ op.getShape ());
181+ return mlir::success ();
182+ }
183+
160184 auto mod = op->getParentOfType <mlir::ModuleOp>();
161185 fir::FirOpBuilder builder (rewriter, mod);
162186 mlir::Location loc = op.getLoc ();
@@ -200,6 +224,11 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
200224 if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ()))
201225 return failure ();
202226
227+ if (inDeviceContext (op.getOperation ())) {
228+ rewriter.eraseOp (op);
229+ return mlir::success ();
230+ }
231+
203232 auto mod = op->getParentOfType <mlir::ModuleOp>();
204233 fir::FirOpBuilder builder (rewriter, mod);
205234 mlir::Location loc = op.getLoc ();
@@ -248,6 +277,7 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
248277 [](::cuf::AllocateOp op) { return isBoxGlobal (op); });
249278 target.addDynamicallyLegalOp <cuf::DeallocateOp>(
250279 [](::cuf::DeallocateOp op) { return isBoxGlobal (op); });
280+ target.addLegalDialect <fir::FIROpsDialect>();
251281 patterns.insert <CufAllocOpConversion>(ctx, &*dl, &typeConverter);
252282 patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
253283 CufFreeOpConversion>(ctx);
0 commit comments