@@ -20,6 +20,7 @@ import (
2020 "math"
2121 "regexp"
2222 "strconv"
23+ "strings"
2324
2425 "golang.org/x/exp/constraints"
2526 "golang.org/x/tools/go/analysis"
@@ -37,16 +38,20 @@ type integer struct {
3738}
3839
3940type rangeResult struct {
40- MinValue int
41- MaxValue uint
42- IsRangeCheck bool
43- ConvertFound bool
41+ MinValue int
42+ MaxValue uint
43+ ExplixitPositiveVals []uint
44+ ExplicitNegativeVals []int
45+ IsRangeCheck bool
46+ ConvertFound bool
4447}
4548
46- type branchBounds struct {
47- MinValue * int
48- MaxValue * uint
49- ConvertFound bool
49+ type branchResults struct {
50+ MinValue * int
51+ MaxValue * uint
52+ ExplixitPositiveVals []uint
53+ ExplicitNegativeVals []int
54+ ConvertFound bool
5055}
5156
5257func newConversionOverflowAnalyzer (id string , description string ) * analysis.Analyzer {
@@ -151,11 +156,6 @@ func parseIntType(intType string) (integer, error) {
151156 min = - 1 << (intSize - 1 )
152157
153158 } else {
154- // Perform a bounds check
155- if intSize < 0 {
156- return integer {}, fmt .Errorf ("invalid bit size: %d" , intSize )
157- }
158-
159159 max = (1 << uint (intSize )) - 1
160160 min = 0
161161 }
@@ -255,6 +255,8 @@ func hasExplicitRangeCheck(instr *ssa.Convert, dstType string) bool {
255255
256256 minValue := srcInt .min
257257 maxValue := srcInt .max
258+ explicitPositiveVals := []uint {}
259+ explicitNegativeVals := []int {}
258260
259261 if minValue > dstInt .min && maxValue < dstInt .max {
260262 return true
@@ -269,6 +271,8 @@ func hasExplicitRangeCheck(instr *ssa.Convert, dstType string) bool {
269271 if result .IsRangeCheck {
270272 minValue = max (minValue , & result .MinValue )
271273 maxValue = min (maxValue , & result .MaxValue )
274+ explicitPositiveVals = append (explicitPositiveVals , result .ExplixitPositiveVals ... )
275+ explicitNegativeVals = append (explicitNegativeVals , result .ExplicitNegativeVals ... )
272276 }
273277 case * ssa.Call :
274278 // these function return an int of a guaranteed size
@@ -287,7 +291,9 @@ func hasExplicitRangeCheck(instr *ssa.Convert, dstType string) bool {
287291 }
288292 }
289293
290- if minValue >= dstInt .min && maxValue <= dstInt .max {
294+ if explicitValsInRange (explicitPositiveVals , explicitNegativeVals , dstInt ) {
295+ return true
296+ } else if minValue >= dstInt .min && maxValue <= dstInt .max {
291297 return true
292298 }
293299 }
@@ -350,6 +356,29 @@ func updateResultFromBinOp(result *rangeResult, binOp *ssa.BinOp, instr *ssa.Con
350356 updateMinMaxForLessOrEqual (result , constVal , binOp .Op , operandsFlipped , successPathConvert )
351357 case token .GEQ , token .GTR :
352358 updateMinMaxForGreaterOrEqual (result , constVal , binOp .Op , operandsFlipped , successPathConvert )
359+ case token .EQL :
360+ if ! successPathConvert {
361+ break
362+ }
363+
364+ // determine if the constant value is positive or negative
365+ if strings .Contains (constVal .String (), "-" ) {
366+ result .ExplicitNegativeVals = append (result .ExplicitNegativeVals , int (constVal .Int64 ()))
367+ } else {
368+ result .ExplixitPositiveVals = append (result .ExplixitPositiveVals , uint (constVal .Uint64 ()))
369+ }
370+
371+ case token .NEQ :
372+ if successPathConvert {
373+ break
374+ }
375+
376+ // determine if the constant value is positive or negative
377+ if strings .Contains (constVal .String (), "-" ) {
378+ result .ExplicitNegativeVals = append (result .ExplicitNegativeVals , int (constVal .Int64 ()))
379+ } else {
380+ result .ExplixitPositiveVals = append (result .ExplixitPositiveVals , uint (constVal .Uint64 ()))
381+ }
353382 }
354383}
355384
@@ -384,8 +413,8 @@ func updateMinMaxForGreaterOrEqual(result *rangeResult, constVal *ssa.Const, op
384413}
385414
386415// walkBranchForConvert walks the branch of the if statement to find the range of the variable and where the conversion is
387- func walkBranchForConvert (block * ssa.BasicBlock , instr * ssa.Convert , visitedIfs map [* ssa.If ]bool ) branchBounds {
388- bounds := branchBounds {}
416+ func walkBranchForConvert (block * ssa.BasicBlock , instr * ssa.Convert , visitedIfs map [* ssa.If ]bool ) branchResults {
417+ bounds := branchResults {}
389418
390419 for _ , blockInstr := range block .Instrs {
391420 switch v := blockInstr .(type ) {
@@ -417,12 +446,35 @@ func walkBranchForConvert(block *ssa.BasicBlock, instr *ssa.Convert, visitedIfs
417446func isRangeCheck (v ssa.Value , x ssa.Value ) bool {
418447 switch op := v .(type ) {
419448 case * ssa.BinOp :
420- return (op .X == x || op .Y == x ) &&
421- (op .Op == token .LSS || op .Op == token .LEQ || op .Op == token .GTR || op .Op == token .GEQ )
449+ switch op .Op {
450+ case token .LSS , token .LEQ , token .GTR , token .GEQ ,
451+ token .EQL , token .NEQ :
452+ return op .X == x || op .Y == x
453+ }
422454 }
423455 return false
424456}
425457
458+ func explicitValsInRange (explicitPosVals []uint , explicitNegVals []int , dstInt integer ) bool {
459+ if len (explicitPosVals ) == 0 && len (explicitNegVals ) == 0 {
460+ return false
461+ }
462+
463+ for _ , val := range explicitPosVals {
464+ if val > dstInt .max {
465+ return false
466+ }
467+ }
468+
469+ for _ , val := range explicitNegVals {
470+ if val < dstInt .min {
471+ return false
472+ }
473+ }
474+
475+ return true
476+ }
477+
426478func min [T constraints.Integer ](a T , b * T ) T {
427479 if b == nil {
428480 return a
0 commit comments