Skip to content

Commit ade882a

Browse files
committed
basic equal operator logic
Signed-off-by: czechbol <[email protected]>
1 parent 8b91515 commit ade882a

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

analyzers/conversion_overflow.go

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"math"
2121
"regexp"
2222
"strconv"
23+
"strings"
2324

2425
"golang.org/x/exp/constraints"
2526
"golang.org/x/tools/go/analysis"
@@ -37,16 +38,20 @@ type integer struct {
3738
}
3839

3940
type rangeResult struct {
40-
MinValue int
41-
MaxValue uint
42-
IsRangeCheck bool
43-
ConvertFound bool
41+
MinValue int
42+
MaxValue uint
43+
ExplixitPositiveVals []uint
44+
ExplicitNegativeVals []int
45+
IsRangeCheck bool
46+
ConvertFound bool
4447
}
4548

46-
type branchBounds struct {
47-
MinValue *int
48-
MaxValue *uint
49-
ConvertFound bool
49+
type branchResults struct {
50+
MinValue *int
51+
MaxValue *uint
52+
ExplixitPositiveVals []uint
53+
ExplicitNegativeVals []int
54+
ConvertFound bool
5055
}
5156

5257
func newConversionOverflowAnalyzer(id string, description string) *analysis.Analyzer {
@@ -151,11 +156,6 @@ func parseIntType(intType string) (integer, error) {
151156
min = -1 << (intSize - 1)
152157

153158
} else {
154-
// Perform a bounds check
155-
if intSize < 0 {
156-
return integer{}, fmt.Errorf("invalid bit size: %d", intSize)
157-
}
158-
159159
max = (1 << uint(intSize)) - 1
160160
min = 0
161161
}
@@ -255,6 +255,8 @@ func hasExplicitRangeCheck(instr *ssa.Convert, dstType string) bool {
255255

256256
minValue := srcInt.min
257257
maxValue := srcInt.max
258+
explicitPositiveVals := []uint{}
259+
explicitNegativeVals := []int{}
258260

259261
if minValue > dstInt.min && maxValue < dstInt.max {
260262
return true
@@ -269,6 +271,8 @@ func hasExplicitRangeCheck(instr *ssa.Convert, dstType string) bool {
269271
if result.IsRangeCheck {
270272
minValue = max(minValue, &result.MinValue)
271273
maxValue = min(maxValue, &result.MaxValue)
274+
explicitPositiveVals = append(explicitPositiveVals, result.ExplixitPositiveVals...)
275+
explicitNegativeVals = append(explicitNegativeVals, result.ExplicitNegativeVals...)
272276
}
273277
case *ssa.Call:
274278
// these function return an int of a guaranteed size
@@ -287,7 +291,9 @@ func hasExplicitRangeCheck(instr *ssa.Convert, dstType string) bool {
287291
}
288292
}
289293

290-
if minValue >= dstInt.min && maxValue <= dstInt.max {
294+
if explicitValsInRange(explicitPositiveVals, explicitNegativeVals, dstInt) {
295+
return true
296+
} else if minValue >= dstInt.min && maxValue <= dstInt.max {
291297
return true
292298
}
293299
}
@@ -350,6 +356,29 @@ func updateResultFromBinOp(result *rangeResult, binOp *ssa.BinOp, instr *ssa.Con
350356
updateMinMaxForLessOrEqual(result, constVal, binOp.Op, operandsFlipped, successPathConvert)
351357
case token.GEQ, token.GTR:
352358
updateMinMaxForGreaterOrEqual(result, constVal, binOp.Op, operandsFlipped, successPathConvert)
359+
case token.EQL:
360+
if !successPathConvert {
361+
break
362+
}
363+
364+
// determine if the constant value is positive or negative
365+
if strings.Contains(constVal.String(), "-") {
366+
result.ExplicitNegativeVals = append(result.ExplicitNegativeVals, int(constVal.Int64()))
367+
} else {
368+
result.ExplixitPositiveVals = append(result.ExplixitPositiveVals, uint(constVal.Uint64()))
369+
}
370+
371+
case token.NEQ:
372+
if successPathConvert {
373+
break
374+
}
375+
376+
// determine if the constant value is positive or negative
377+
if strings.Contains(constVal.String(), "-") {
378+
result.ExplicitNegativeVals = append(result.ExplicitNegativeVals, int(constVal.Int64()))
379+
} else {
380+
result.ExplixitPositiveVals = append(result.ExplixitPositiveVals, uint(constVal.Uint64()))
381+
}
353382
}
354383
}
355384

@@ -384,8 +413,8 @@ func updateMinMaxForGreaterOrEqual(result *rangeResult, constVal *ssa.Const, op
384413
}
385414

386415
// walkBranchForConvert walks the branch of the if statement to find the range of the variable and where the conversion is
387-
func walkBranchForConvert(block *ssa.BasicBlock, instr *ssa.Convert, visitedIfs map[*ssa.If]bool) branchBounds {
388-
bounds := branchBounds{}
416+
func walkBranchForConvert(block *ssa.BasicBlock, instr *ssa.Convert, visitedIfs map[*ssa.If]bool) branchResults {
417+
bounds := branchResults{}
389418

390419
for _, blockInstr := range block.Instrs {
391420
switch v := blockInstr.(type) {
@@ -417,12 +446,35 @@ func walkBranchForConvert(block *ssa.BasicBlock, instr *ssa.Convert, visitedIfs
417446
func isRangeCheck(v ssa.Value, x ssa.Value) bool {
418447
switch op := v.(type) {
419448
case *ssa.BinOp:
420-
return (op.X == x || op.Y == x) &&
421-
(op.Op == token.LSS || op.Op == token.LEQ || op.Op == token.GTR || op.Op == token.GEQ)
449+
switch op.Op {
450+
case token.LSS, token.LEQ, token.GTR, token.GEQ,
451+
token.EQL, token.NEQ:
452+
return op.X == x || op.Y == x
453+
}
422454
}
423455
return false
424456
}
425457

458+
func explicitValsInRange(explicitPosVals []uint, explicitNegVals []int, dstInt integer) bool {
459+
if len(explicitPosVals) == 0 && len(explicitNegVals) == 0 {
460+
return false
461+
}
462+
463+
for _, val := range explicitPosVals {
464+
if val > dstInt.max {
465+
return false
466+
}
467+
}
468+
469+
for _, val := range explicitNegVals {
470+
if val < dstInt.min {
471+
return false
472+
}
473+
}
474+
475+
return true
476+
}
477+
426478
func min[T constraints.Integer](a T, b *T) T {
427479
if b == nil {
428480
return a

testutils/g115_samples.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,44 @@ func main() {
643643
fmt.Printf("%d\n", b)
644644
}
645645
panic("out of range")
646+
}
647+
`,
648+
}, 0, gosec.NewConfig()},
649+
{[]string{
650+
`
651+
package main
652+
653+
import (
654+
"fmt"
655+
"math/rand"
656+
)
657+
658+
func main() {
659+
a := rand.Int63()
660+
if a == 3 || a == 4 {
661+
b := int32(a)
662+
fmt.Printf("%d\n", b)
663+
}
664+
panic("out of range")
665+
}
666+
`,
667+
}, 0, gosec.NewConfig()},
668+
{[]string{
669+
`
670+
package main
671+
672+
import (
673+
"fmt"
674+
"math/rand"
675+
)
676+
677+
func main() {
678+
a := rand.Int63()
679+
if a != 3 || a != 4 {
680+
panic("out of range")
681+
}
682+
b := int32(a)
683+
fmt.Printf("%d\n", b)
646684
}
647685
`,
648686
}, 0, gosec.NewConfig()},

0 commit comments

Comments
 (0)