@@ -126,18 +126,16 @@ class TrackedRegisters {
126126
127127// The security property that is checked is:
128128// When a register is used as the address to jump to in a return instruction,
129- // that register must either:
130- // (a) never be changed within this function, i.e. have the same value as when
131- // the function started, or
129+ // that register must be safe-to-dereference. It must either
130+ // (a) be safe-to-dereference at function entry and never be changed within this
131+ // function, i.e. have the same value as when the function started, or
132132// (b) the last write to the register must be by an authentication instruction.
133133
134134// This property is checked by using dataflow analysis to keep track of which
135- // registers have been written (def-ed), since last authenticated. Those are
136- // exactly the registers containing values that should not be trusted (as they
137- // could have changed since the last time they were authenticated). For pac-ret,
138- // any return instruction using such a register is a gadget to be reported. For
139- // PAuthABI, probably at least any indirect control flow using such a register
140- // should be reported.
135+ // registers have been written (def-ed), since last authenticated. For pac-ret,
136+ // any return instruction using a register which is not safe-to-dereference is
137+ // a gadget to be reported. For PAuthABI, probably at least any indirect control
138+ // flow using such a register should be reported.
141139
142140// Furthermore, when producing a diagnostic for a found non-pac-ret protected
143141// return, the analysis also lists the last instructions that wrote to the
@@ -156,10 +154,29 @@ class TrackedRegisters {
156154// in the gadgets to be reported. This information is used in the second run
157155// to also track which instructions last wrote to those registers.
158156
157+ // / A state representing which registers are safe to use by an instruction
158+ // / at a given program point.
159+ // /
160+ // / To simplify reasoning, let's stick with the following approach:
161+ // / * when state is updated by the data-flow analysis, the sub-, super- and
162+ // / overlapping registers are marked as needed
163+ // / * when the particular instruction is checked if it represents a gadget,
164+ // / the specific bit of BitVector should be usable to answer this.
165+ // /
166+ // / For example, on AArch64:
167+ // / * An AUTIZA X0 instruction marks both X0 and W0 (as well as W0_HI) as
168+ // / safe-to-dereference. It does not change the state of X0_X1, for example,
169+ // / as super-registers partially retain their old, unsafe values.
170+ // / * LDR X1, [X0] marks as unsafe both X1 itself and anything it overlaps
171+ // / with: W1, W1_HI, X0_X1 and so on.
172+ // / * RET (which is implicitly RET X30) is a protected return if and only if
173+ // / X30 is safe-to-dereference - the state computed for sub- and
174+ // / super-registers is not inspected.
159175struct State {
160- // / A BitVector containing the registers that have been clobbered, and
161- // / not authenticated.
162- BitVector NonAutClobRegs;
176+ // / A BitVector containing the registers that are either safe at function
177+ // / entry and were not clobbered yet, or those not clobbered since being
178+ // / authenticated.
179+ BitVector SafeToDerefRegs;
163180 // / A vector of sets, only used in the second data flow run.
164181 // / Each element in the vector represents one of the registers for which we
165182 // / track the set of last instructions that wrote to this register. For
@@ -169,16 +186,26 @@ struct State {
169186 std::vector<SmallPtrSet<const MCInst *, 4 >> LastInstWritingReg;
170187 State () {}
171188 State (unsigned NumRegs, unsigned NumRegsToTrack)
172- : NonAutClobRegs(NumRegs), LastInstWritingReg(NumRegsToTrack) {}
173- State &operator |=(const State &StateIn) {
174- NonAutClobRegs |= StateIn.NonAutClobRegs ;
189+ : SafeToDerefRegs(NumRegs), LastInstWritingReg(NumRegsToTrack) {}
190+
191+ // / Returns S, so that S.merge(S1) == S1.merge(S) == S1.
192+ static State getMergeNeutralElement (unsigned NumRegs,
193+ unsigned NumRegsToTrack) {
194+ State S (NumRegs, NumRegsToTrack);
195+ S.SafeToDerefRegs .set ();
196+ return S;
197+ }
198+
199+ State &merge (const State &StateIn) {
200+ SafeToDerefRegs &= StateIn.SafeToDerefRegs ;
175201 for (unsigned I = 0 ; I < LastInstWritingReg.size (); ++I)
176202 for (const MCInst *J : StateIn.LastInstWritingReg [I])
177203 LastInstWritingReg[I].insert (J);
178204 return *this ;
179205 }
206+
180207 bool operator ==(const State &RHS) const {
181- return NonAutClobRegs == RHS.NonAutClobRegs &&
208+ return SafeToDerefRegs == RHS.SafeToDerefRegs &&
182209 LastInstWritingReg == RHS.LastInstWritingReg ;
183210 }
184211 bool operator !=(const State &RHS) const { return !((*this ) == RHS); }
@@ -199,7 +226,7 @@ static void printLastInsts(
199226
200227raw_ostream &operator <<(raw_ostream &OS, const State &S) {
201228 OS << " pacret-state<" ;
202- OS << " NonAutClobRegs : " << S.NonAutClobRegs << " , " ;
229+ OS << " SafeToDerefRegs : " << S.SafeToDerefRegs << " , " ;
203230 printLastInsts (OS, S.LastInstWritingReg );
204231 OS << " >" ;
205232 return OS;
@@ -217,8 +244,8 @@ class PacStatePrinter {
217244void PacStatePrinter::print (raw_ostream &OS, const State &S) const {
218245 RegStatePrinter RegStatePrinter (BC);
219246 OS << " pacret-state<" ;
220- OS << " NonAutClobRegs : " ;
221- RegStatePrinter.print (OS, S.NonAutClobRegs );
247+ OS << " SafeToDerefRegs : " ;
248+ RegStatePrinter.print (OS, S.SafeToDerefRegs );
222249 OS << " , " ;
223250 printLastInsts (OS, S.LastInstWritingReg );
224251 OS << " >" ;
@@ -257,12 +284,24 @@ class PacRetAnalysis
257284
258285 void preflight () {}
259286
287+ State createEntryState () {
288+ State S (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
289+ for (MCPhysReg Reg : BC.MIB ->getTrustedLiveInRegs ())
290+ S.SafeToDerefRegs |= BC.MIB ->getAliases (Reg, /* OnlySmaller=*/ true );
291+ return S;
292+ }
293+
260294 State getStartingStateAtBB (const BinaryBasicBlock &BB) {
261- return State (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
295+ if (BB.isEntryPoint ())
296+ return createEntryState ();
297+
298+ return State::getMergeNeutralElement (
299+ NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
262300 }
263301
264302 State getStartingStateAtPoint (const MCInst &Point) {
265- return State (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
303+ return State::getMergeNeutralElement (
304+ NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
266305 }
267306
268307 void doConfluence (State &StateOut, const State &StateIn) {
@@ -277,7 +316,7 @@ class PacRetAnalysis
277316 dbgs () << " )\n " ;
278317 });
279318
280- StateOut |= StateIn;
319+ StateOut. merge ( StateIn) ;
281320
282321 LLVM_DEBUG ({
283322 dbgs () << " merged state: " ;
@@ -298,7 +337,7 @@ class PacRetAnalysis
298337 });
299338
300339 State Next = Cur;
301- BitVector Written = BitVector (NumRegs, false );
340+ BitVector Clobbered (NumRegs, false );
302341 // Assume a call can clobber all registers, including callee-saved
303342 // registers. There's a good chance that callee-saved registers will be
304343 // saved on the stack at some point during execution of the callee.
@@ -307,36 +346,27 @@ class PacRetAnalysis
307346 // Also, not all functions may respect the AAPCS ABI rules about
308347 // caller/callee-saved registers.
309348 if (BC.MIB ->isCall (Point))
310- Written .set ();
349+ Clobbered .set ();
311350 else
312- // FIXME: `getWrittenRegs` only sets the register directly written in the
313- // instruction, and the smaller aliasing registers. It does not set the
314- // larger aliasing registers. To also set the larger aliasing registers,
315- // we'd have to call `getClobberedRegs`.
316- // It is unclear if there is any test case which shows a different
317- // behaviour between using `getWrittenRegs` vs `getClobberedRegs`. We'd
318- // first would like to see such a test case before making a decision
319- // on whether using `getClobberedRegs` below would be better.
320- // Also see the discussion on this at
321- // https://github.com/llvm/llvm-project/pull/122304#discussion_r1939511909
322- BC.MIB ->getWrittenRegs (Point, Written);
323- Next.NonAutClobRegs |= Written;
351+ BC.MIB ->getClobberedRegs (Point, Clobbered);
352+ Next.SafeToDerefRegs .reset (Clobbered);
324353 // Keep track of this instruction if it writes to any of the registers we
325354 // need to track that for:
326355 for (MCPhysReg Reg : RegsToTrackInstsFor.getRegisters ())
327- if (Written [Reg])
356+ if (Clobbered [Reg])
328357 lastWritingInsts (Next, Reg) = {&Point};
329358
330359 ErrorOr<MCPhysReg> AutReg = BC.MIB ->getAuthenticatedReg (Point);
331360 if (AutReg && *AutReg != BC.MIB ->getNoRegister ()) {
332- // FIXME: should we use `OnlySmaller=false` below? See similar
333- // FIXME about `getWrittenRegs` above and further discussion about this
334- // at
335- // https://github.com/llvm/llvm-project/pull/122304#discussion_r1939515516
336- Next.NonAutClobRegs .reset (
337- BC.MIB ->getAliases (*AutReg, /* OnlySmaller=*/ true ));
338- if (RegsToTrackInstsFor.isTracked (*AutReg))
339- lastWritingInsts (Next, *AutReg).clear ();
361+ // The sub-registers of *AutReg are also trusted now, but not its
362+ // super-registers (as they retain untrusted register units).
363+ BitVector AuthenticatedSubregs =
364+ BC.MIB ->getAliases (*AutReg, /* OnlySmaller=*/ true );
365+ for (MCPhysReg Reg : AuthenticatedSubregs.set_bits ()) {
366+ Next.SafeToDerefRegs .set (Reg);
367+ if (RegsToTrackInstsFor.isTracked (Reg))
368+ lastWritingInsts (Next, Reg).clear ();
369+ }
340370 }
341371
342372 LLVM_DEBUG ({
@@ -397,14 +427,11 @@ static std::shared_ptr<Report> tryCheckReturn(const BinaryContext &BC,
397427 });
398428 if (BC.MIB ->isAuthenticationOfReg (Inst, RetReg))
399429 return nullptr ;
400- BitVector UsedDirtyRegs = S.NonAutClobRegs ;
401- LLVM_DEBUG ({ traceRegMask (BC, " NonAutClobRegs at Ret" , UsedDirtyRegs); });
402- UsedDirtyRegs &= BC.MIB ->getAliases (RetReg, /* OnlySmaller=*/ true );
403- LLVM_DEBUG ({ traceRegMask (BC, " Intersection with RetReg" , UsedDirtyRegs); });
404- if (!UsedDirtyRegs.any ())
430+ LLVM_DEBUG ({ traceRegMask (BC, " SafeToDerefRegs" , S.SafeToDerefRegs ); });
431+ if (S.SafeToDerefRegs [RetReg])
405432 return nullptr ;
406433
407- return std::make_shared<GadgetReport>(RetKind, Inst, UsedDirtyRegs );
434+ return std::make_shared<GadgetReport>(RetKind, Inst, RetReg );
408435}
409436
410437FunctionAnalysisResult
0 commit comments