Skip to content

Rust: Fix type inference for explicit dereference with * to the Deref trait #19820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module Impl {
*/
abstract class Call extends ExprImpl::Expr {
/** Holds if the receiver of this call is implicitly borrowed. */
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition(), _) }

/** Gets the trait targeted by this call, if any. */
abstract Trait getTrait();
Expand All @@ -47,7 +47,7 @@ module Impl {
abstract Expr getArgument(ArgumentPosition pos);

/** Holds if the argument at `pos` might be implicitly borrowed. */
abstract predicate implicitBorrowAt(ArgumentPosition pos);
abstract predicate implicitBorrowAt(ArgumentPosition pos, boolean certain);

/** Gets the number of arguments _excluding_ any `self` argument. */
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }
Expand Down Expand Up @@ -85,7 +85,7 @@ module Impl {

override Trait getTrait() { none() }

override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }

override Expr getArgument(ArgumentPosition pos) {
result = super.getArgList().getArg(pos.asPosition())
Expand All @@ -109,7 +109,7 @@ module Impl {
qualifier.toString() != "Self"
}

override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }

override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = super.getArgList().getArg(0)
Expand All @@ -123,7 +123,9 @@ module Impl {

override Trait getTrait() { none() }

override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
pos.isSelf() and certain = false
}

override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
Expand All @@ -143,10 +145,13 @@ module Impl {

override Trait getTrait() { result = trait }

override predicate implicitBorrowAt(ArgumentPosition pos) {
pos.isSelf() and borrows >= 1
or
pos.asPosition() = 0 and borrows = 2
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
(
pos.isSelf() and borrows >= 1
or
pos.asPosition() = 0 and borrows = 2
) and
certain = true
}

override Expr getArgument(ArgumentPosition pos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private predicate isOverloaded(string op, int arity, string path, string method,
op = "!" and path = "core::ops::bit::Not" and method = "not" and borrows = 0
or
// Dereference
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 0
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 1
)
or
arity = 2 and
Expand Down
187 changes: 105 additions & 82 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix1.isEmpty() and
prefix2 = TypePath::singleton(TRefTypeParameter())
or
n1 = n2.(DerefExpr).getExpr() and
prefix1 = TypePath::singleton(TRefTypeParameter()) and
prefix2.isEmpty()
or
exists(BlockExpr be |
n1 = be and
n2 = be.getStmtList().getTailExpr() and
Expand Down Expand Up @@ -640,20 +636,20 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}

private newtype TAccessPosition =
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed) or
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed, Boolean certain) or
TReturnAccessPosition()

class AccessPosition extends TAccessPosition {
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _) }
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _, _) }

predicate isBorrowed() { this = TArgumentAccessPosition(_, true) }
predicate isBorrowed(boolean certain) { this = TArgumentAccessPosition(_, true, certain) }

predicate isReturn() { this = TReturnAccessPosition() }

string toString() {
exists(ArgumentPosition pos, boolean borrowed |
this = TArgumentAccessPosition(pos, borrowed) and
result = pos + ":" + borrowed
exists(ArgumentPosition pos, boolean borrowed, boolean certain |
this = TArgumentAccessPosition(pos, borrowed, certain) and
result = pos + ":" + borrowed + ":" + certain
)
or
this.isReturn() and
Expand All @@ -674,10 +670,15 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}

AstNode getNodeAt(AccessPosition apos) {
exists(ArgumentPosition pos, boolean borrowed |
apos = TArgumentAccessPosition(pos, borrowed) and
result = this.getArgument(pos) and
if this.implicitBorrowAt(pos) then borrowed = true else borrowed = false
exists(ArgumentPosition pos, boolean borrowed, boolean certain |
apos = TArgumentAccessPosition(pos, borrowed, certain) and
result = this.getArgument(pos)
|
if this.implicitBorrowAt(pos, _)
then borrowed = true and this.implicitBorrowAt(pos, certain)
else (
borrowed = false and certain = true
)
)
or
result = this and apos.isReturn()
Expand Down Expand Up @@ -705,51 +706,54 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
predicate adjustAccessType(
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
) {
if apos.isBorrowed()
then
exists(Type selfParamType |
selfParamType =
target
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
TypePath::nil())
|
if selfParamType = TRefType()
apos.isBorrowed(true) and
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
tAdj = t
or
apos.isBorrowed(false) and
exists(Type selfParamType |
selfParamType =
target
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
TypePath::nil())
|
if selfParamType = TRefType()
then
if t != TRefType() and path.isEmpty()
then
if t != TRefType() and path.isEmpty()
// adjust for implicit borrow
pathAdj.isEmpty() and
tAdj = TRefType()
or
// adjust for implicit borrow
pathAdj = TypePath::singleton(TRefTypeParameter()) and
tAdj = t
else
if path.isCons(TRefTypeParameter(), _)
then
pathAdj = path and
tAdj = t
else (
// adjust for implicit borrow
pathAdj.isEmpty() and
tAdj = TRefType()
or
// adjust for implicit borrow
pathAdj = TypePath::singleton(TRefTypeParameter()) and
not (t = TRefType() and path.isEmpty()) and
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
tAdj = t
else
if path.isCons(TRefTypeParameter(), _)
then
pathAdj = path and
tAdj = t
else (
// adjust for implicit borrow
not (t = TRefType() and path.isEmpty()) and
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
tAdj = t
)
else (
// adjust for implicit deref
path.isCons(TRefTypeParameter(), pathAdj) and
tAdj = t
or
not path.isCons(TRefTypeParameter(), _) and
not (t = TRefType() and path.isEmpty()) and
pathAdj = path and
tAdj = t
)
)
else (
// adjust for implicit deref
path.isCons(TRefTypeParameter(), pathAdj) and
tAdj = t
or
not path.isCons(TRefTypeParameter(), _) and
not (t = TRefType() and path.isEmpty()) and
pathAdj = path and
tAdj = t
)
else (
pathAdj = path and
tAdj = t
)
or
not apos.isBorrowed(_) and
pathAdj = path and
tAdj = t
}
}

Expand All @@ -766,35 +770,47 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
TypePath path0
|
n = a.getNodeAt(apos) and
result = CallExprBaseMatching::inferAccessType(a, apos, path0) and
if apos.isBorrowed()
then
exists(Type argType | argType = inferType(n) |
if argType = TRefType()
then
path = path0 and
path0.isCons(TRefTypeParameter(), _)
or
// adjust for implicit deref
result = CallExprBaseMatching::inferAccessType(a, apos, path0)
|
(
apos.isBorrowed(true)
or
// The desugaring of the unary `*e` is `*Deref::deref(&e)`. To handle the
// deref expression after the call we must strip a `&` from the type at
// the return position.
apos.isReturn() and a instanceof DerefExpr
) and
path0.isCons(TRefTypeParameter(), path)
or
apos.isBorrowed(false) and
exists(Type argType | argType = inferType(n) |
if argType = TRefType()
then
path = path0 and
path0.isCons(TRefTypeParameter(), _)
or
// adjust for implicit deref
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = TypePath::cons(TRefTypeParameter(), path0)
else (
not (
argType.(StructType).asItemNode() instanceof StringStruct and
result.(StructType).asItemNode() instanceof Builtins::Str
) and
(
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = TypePath::cons(TRefTypeParameter(), path0)
else (
not (
argType.(StructType).asItemNode() instanceof StringStruct and
result.(StructType).asItemNode() instanceof Builtins::Str
) and
(
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = path0
or
// adjust for implicit borrow
path0.isCons(TRefTypeParameter(), path)
)
path = path0
or
// adjust for implicit borrow
path0.isCons(TRefTypeParameter(), path)
)
)
else path = path0
)
or
not apos.isBorrowed(_) and
path = path0
)
}

Expand Down Expand Up @@ -1141,8 +1157,15 @@ final class MethodCall extends Call {
(
path0.isCons(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
(
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType())
or
// Ideally we should find all methods on reference types, but as
// that currently causes a blowup we limit this to the `deref`
// method in order to make dereferencing work.
this.getMethodName() = "deref"
) and
path = path0
)
|
Expand Down Expand Up @@ -1389,7 +1412,7 @@ private module Cached {
predicate receiverHasImplicitDeref(AstNode receiver) {
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
apos.getArgumentPosition().isSelf() and
apos.isBorrowed() and
apos.isBorrowed(_) and
receiver = a.getNodeAt(apos) and
inferType(receiver) = TRefType() and
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) != TRefType()
Expand All @@ -1401,7 +1424,7 @@ private module Cached {
predicate receiverHasImplicitBorrow(AstNode receiver) {
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
apos.getArgumentPosition().isSelf() and
apos.isBorrowed() and
apos.isBorrowed(_) and
receiver = a.getNodeAt(apos) and
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) = TRefType() and
inferType(receiver) != TRefType()
Expand Down
2 changes: 1 addition & 1 deletion rust/ql/test/library-tests/dataflow/global/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ fn test_operator_overloading() {

let a = MyInt { value: source(28) };
let c = *a;
sink(c); // $ MISSING: hasValueFlow=28
sink(c); // $ hasTaintFlow=28 MISSING: hasValueFlow=28
}

trait MyTrait {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
| main.rs:165:13:165:34 | ...::new(...) | main.rs:158:5:161:5 | fn new |
| main.rs:165:24:165:33 | source(...) | main.rs:1:1:3:1 | fn source |
| main.rs:167:5:167:11 | sink(...) | main.rs:5:1:7:1 | fn sink |
| main.rs:181:10:181:14 | * ... | main.rs:188:5:190:5 | fn deref |
| main.rs:189:11:189:15 | * ... | main.rs:188:5:190:5 | fn deref |
| main.rs:195:28:195:36 | source(...) | main.rs:1:1:3:1 | fn source |
| main.rs:197:13:197:17 | ... + ... | main.rs:173:5:176:5 | fn add |
| main.rs:198:5:198:17 | sink(...) | main.rs:5:1:7:1 | fn sink |
Expand Down
Loading