@@ -300,83 +300,20 @@ bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
300300 return replaceDynamicExtractElementInst (EEI);
301301}
302302
303- static void buildConstExprGEPChain (GetElementPtrInst &GEPI, Value *CurrentPtr,
304- SmallVector<ConstantExpr *, 4 > &GEPChain,
305- IRBuilder<> &Builder) {
306- // Process the rest of the chain in reverse order (skipping the innermost)
307- for (int I = GEPChain.size () - 2 ; I >= 0 ; I--) {
308- ConstantExpr *CE = GEPChain[I];
309- GetElementPtrInst *GEPInst =
310- cast<GetElementPtrInst>(CE->getAsInstruction ());
311- GEPInst->insertBefore (GEPI.getIterator ());
312- SmallVector<Value *, MaxVecSize> CurrIndices (GEPInst->indices ());
313-
314- // Create a new GEP instruction
315- Type *SourceTy = GEPInst->getSourceElementType ();
316- CurrentPtr =
317- Builder.CreateGEP (SourceTy, CurrentPtr, CurrIndices, GEPInst->getName (),
318- GEPInst->getNoWrapFlags ());
319-
320- // If this is the outermost GEP, update the main GEPI
321- if (I == 0 ) {
322- GEPI.setOperand (GEPI.getPointerOperandIndex (), CurrentPtr);
323- }
324-
325- // Clean up the temporary instruction
326- GEPInst->eraseFromParent ();
327- }
328- }
329-
330303bool DataScalarizerVisitor::visitGetElementPtrInst (GetElementPtrInst &GEPI) {
331- Value *PtrOperand = GEPI. getPointerOperand ( );
332- Type *OrigGEPType = GEPI. getSourceElementType ();
333- Type *NewGEPType = OrigGEPType ;
304+ GEPOperator *GOp = cast<GEPOperator>(&GEPI );
305+ Value *PtrOperand = GOp-> getPointerOperand ();
306+ Type *NewGEPType = GOp-> getSourceElementType () ;
334307 bool NeedsTransform = false ;
335- // Check if the pointer operand is a ConstantExpr GEP
336- if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
337- PtrOpGEPCE && PtrOpGEPCE->getOpcode () == Instruction::GetElementPtr) {
338-
339- // Collect all nested GEPs in the chain
340- SmallVector<ConstantExpr *, 4 > GEPChain;
341- Value *BasePointer = PtrOpGEPCE->getOperand (0 );
342- GEPChain.push_back (PtrOpGEPCE);
343-
344- // Walk up the chain to find all nested GEPs and the base pointer
345- while (auto *NextGEP = dyn_cast<ConstantExpr>(BasePointer)) {
346- if (NextGEP->getOpcode () != Instruction::GetElementPtr)
347- break ;
348-
349- GEPChain.push_back (NextGEP);
350- BasePointer = NextGEP->getOperand (0 );
351- }
352308
353- // Check if the base pointer is a global that needs replacement
354- if (GlobalVariable *NewGlobal = lookupReplacementGlobal (BasePointer)) {
355- IRBuilder<> Builder (&GEPI);
356-
357- // Create a new GEP for the innermost GEP (last in the chain)
358- ConstantExpr *InnerGEPCE = GEPChain.back ();
359- GetElementPtrInst *InnerGEP =
360- cast<GetElementPtrInst>(InnerGEPCE->getAsInstruction ());
361- InnerGEP->insertBefore (GEPI.getIterator ());
362-
363- SmallVector<Value *, MaxVecSize> Indices (InnerGEP->indices ());
364- Type *NewGEPType = NewGlobal->getValueType ();
365- Value *NewInnerGEP =
366- Builder.CreateGEP (NewGEPType, NewGlobal, Indices, InnerGEP->getName (),
367- InnerGEP->getNoWrapFlags ());
368-
369- // If there's only one GEP in the chain, update the main GEPI directly
370- if (GEPChain.size () == 1 )
371- GEPI.setOperand (GEPI.getPointerOperandIndex (), NewInnerGEP);
372- else
373- // For multiple GEPs, we need to create a chain of GEPs
374- buildConstExprGEPChain (GEPI, NewInnerGEP, GEPChain, Builder);
375-
376- // Clean up the innermost GEP
377- InnerGEP->eraseFromParent ();
378- return true ;
379- }
309+ // Unwrap GEP ConstantExprs to find the base operand and element type
310+ while (auto *CE = dyn_cast<ConstantExpr>(PtrOperand)) {
311+ if (auto *GEPCE = dyn_cast<GEPOperator>(CE)) {
312+ GOp = GEPCE;
313+ PtrOperand = GEPCE->getPointerOperand ();
314+ NewGEPType = GEPCE->getSourceElementType ();
315+ } else
316+ break ;
380317 }
381318
382319 if (GlobalVariable *NewGlobal = lookupReplacementGlobal (PtrOperand)) {
@@ -385,30 +322,32 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
385322 NeedsTransform = true ;
386323 } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
387324 Type *AllocatedType = Alloca->getAllocatedType ();
388- // Only transform if the allocated type is an array
389- if ( AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType )) {
325+ if (isa<ArrayType>(AllocatedType) &&
326+ AllocatedType != GOp-> getResultElementType ( )) {
390327 NewGEPType = AllocatedType;
391328 NeedsTransform = true ;
392329 }
393330 }
394331
395- // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will
396- // convert these scalar geps into flattened array geps
397- if (!isa<ArrayType>(OrigGEPType))
398- NewGEPType = OrigGEPType;
399-
400- // Note: We bail if this isn't a gep touched via alloca or global
401- // transformations
402332 if (!NeedsTransform)
403333 return false ;
404334
335+ // Keep scalar GEPs scalar; dxil-flatten-arrays will do flattening later
336+ if (!isa<ArrayType>(GOp->getSourceElementType ()))
337+ NewGEPType = GOp->getSourceElementType ();
338+
405339 IRBuilder<> Builder (&GEPI);
406340 SmallVector<Value *, MaxVecSize> Indices (GEPI.indices ());
407-
408341 Value *NewGEP = Builder.CreateGEP (NewGEPType, PtrOperand, Indices,
409- GEPI.getName (), GEPI.getNoWrapFlags ());
410- GEPI.replaceAllUsesWith (NewGEP);
411- GEPI.eraseFromParent ();
342+ GOp->getName (), GOp->getNoWrapFlags ());
343+
344+ GOp->replaceAllUsesWith (NewGEP);
345+
346+ if (auto *CE = dyn_cast<ConstantExpr>(GOp))
347+ CE->destroyConstant ();
348+ else if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
349+ OldGEPI->eraseFromParent (); // This will always be true in visit* context
350+
412351 return true ;
413352}
414353
0 commit comments