|
| 1 | +// Copyright 2024 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +// Package clonetest provides utility functions for testing Clone operations. |
| 6 | +// |
| 7 | +// The [NonZero] helper may be used to construct a type in which fields are |
| 8 | +// recursively set to a non-zero value. This value can then be cloned, and the |
| 9 | +// [ZeroOut] helper can set values stored in the clone to zero, recursively. |
| 10 | +// Doing so should not mutate the original. |
| 11 | +package clonetest |
| 12 | + |
| 13 | +import ( |
| 14 | + "fmt" |
| 15 | + "reflect" |
| 16 | +) |
| 17 | + |
| 18 | +// NonZero returns a T set to some appropriate nonzero value: |
| 19 | +// - Values of basic type are set to an arbitrary non-zero value. |
| 20 | +// - Struct fields are set to a non-zero value. |
| 21 | +// - Array indices are set to a non-zero value. |
| 22 | +// - Pointers point to a non-zero value. |
| 23 | +// - Maps and slices are given a non-zero element. |
| 24 | +// - Chan, Func, Interface, UnsafePointer are all unsupported. |
| 25 | +// |
| 26 | +// NonZero breaks cycles by returning a zero value for recursive types. |
| 27 | +func NonZero[T any]() T { |
| 28 | + var x T |
| 29 | + t := reflect.TypeOf(x) |
| 30 | + if t == nil { |
| 31 | + panic("untyped nil") |
| 32 | + } |
| 33 | + v := nonZeroValue(t, nil) |
| 34 | + return v.Interface().(T) |
| 35 | +} |
| 36 | + |
| 37 | +// nonZeroValue returns a non-zero, addressable value of the given type. |
| 38 | +func nonZeroValue(t reflect.Type, seen []reflect.Type) reflect.Value { |
| 39 | + for _, t2 := range seen { |
| 40 | + if t == t2 { |
| 41 | + // Cycle: return the zero value. |
| 42 | + return reflect.Zero(t) |
| 43 | + } |
| 44 | + } |
| 45 | + seen = append(seen, t) |
| 46 | + v := reflect.New(t).Elem() |
| 47 | + switch t.Kind() { |
| 48 | + case reflect.Bool: |
| 49 | + v.SetBool(true) |
| 50 | + |
| 51 | + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 52 | + v.SetInt(1) |
| 53 | + |
| 54 | + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
| 55 | + v.SetUint(1) |
| 56 | + |
| 57 | + case reflect.Float32, reflect.Float64: |
| 58 | + v.SetFloat(1) |
| 59 | + |
| 60 | + case reflect.Complex64, reflect.Complex128: |
| 61 | + v.SetComplex(1) |
| 62 | + |
| 63 | + case reflect.Array: |
| 64 | + for i := 0; i < v.Len(); i++ { |
| 65 | + v.Index(i).Set(nonZeroValue(t.Elem(), seen)) |
| 66 | + } |
| 67 | + |
| 68 | + case reflect.Map: |
| 69 | + v2 := reflect.MakeMap(t) |
| 70 | + v2.SetMapIndex(nonZeroValue(t.Key(), seen), nonZeroValue(t.Elem(), seen)) |
| 71 | + v.Set(v2) |
| 72 | + |
| 73 | + case reflect.Pointer: |
| 74 | + v2 := nonZeroValue(t.Elem(), seen) |
| 75 | + v.Set(v2.Addr()) |
| 76 | + |
| 77 | + case reflect.Slice: |
| 78 | + v2 := reflect.Append(v, nonZeroValue(t.Elem(), seen)) |
| 79 | + v.Set(v2) |
| 80 | + |
| 81 | + case reflect.String: |
| 82 | + v.SetString(".") |
| 83 | + |
| 84 | + case reflect.Struct: |
| 85 | + for i := 0; i < v.NumField(); i++ { |
| 86 | + v.Field(i).Set(nonZeroValue(t.Field(i).Type, seen)) |
| 87 | + } |
| 88 | + |
| 89 | + default: // Chan, Func, Interface, UnsafePointer |
| 90 | + panic(fmt.Sprintf("reflect kind %v not supported", t.Kind())) |
| 91 | + } |
| 92 | + return v |
| 93 | +} |
| 94 | + |
| 95 | +// ZeroOut recursively sets values contained in t to zero. |
| 96 | +// Values of king Chan, Func, Interface, UnsafePointer are all unsupported. |
| 97 | +// |
| 98 | +// No attempt is made to handle cyclic values. |
| 99 | +func ZeroOut[T any](t *T) { |
| 100 | + v := reflect.ValueOf(t).Elem() |
| 101 | + zeroOutValue(v) |
| 102 | +} |
| 103 | + |
| 104 | +func zeroOutValue(v reflect.Value) { |
| 105 | + if v.IsZero() { |
| 106 | + return // nothing to do; this also handles untyped nil values |
| 107 | + } |
| 108 | + |
| 109 | + switch v.Kind() { |
| 110 | + case reflect.Bool, |
| 111 | + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, |
| 112 | + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, |
| 113 | + reflect.Float32, reflect.Float64, |
| 114 | + reflect.Complex64, reflect.Complex128, |
| 115 | + reflect.String: |
| 116 | + |
| 117 | + v.Set(reflect.Zero(v.Type())) |
| 118 | + |
| 119 | + case reflect.Array: |
| 120 | + for i := 0; i < v.Len(); i++ { |
| 121 | + zeroOutValue(v.Index(i)) |
| 122 | + } |
| 123 | + |
| 124 | + case reflect.Map: |
| 125 | + iter := v.MapRange() |
| 126 | + for iter.Next() { |
| 127 | + mv := iter.Value() |
| 128 | + if mv.CanAddr() { |
| 129 | + zeroOutValue(mv) |
| 130 | + } else { |
| 131 | + mv = reflect.New(mv.Type()).Elem() |
| 132 | + } |
| 133 | + v.SetMapIndex(iter.Key(), mv) |
| 134 | + } |
| 135 | + |
| 136 | + case reflect.Pointer: |
| 137 | + zeroOutValue(v.Elem()) |
| 138 | + |
| 139 | + case reflect.Slice: |
| 140 | + for i := 0; i < v.Len(); i++ { |
| 141 | + zeroOutValue(v.Index(i)) |
| 142 | + } |
| 143 | + |
| 144 | + case reflect.Struct: |
| 145 | + for i := 0; i < v.NumField(); i++ { |
| 146 | + zeroOutValue(v.Field(i)) |
| 147 | + } |
| 148 | + |
| 149 | + default: |
| 150 | + panic(fmt.Sprintf("reflect kind %v not supported", v.Kind())) |
| 151 | + } |
| 152 | +} |
0 commit comments