Skip to content

Commit 57fdb54

Browse files
authored
fix: skip unnecessary null checks in equals() methods for generated structures (#1346)
1 parent 74772d9 commit 57fdb54

File tree

1 file changed

+27
-13
lines changed
  • codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering

1 file changed

+27
-13
lines changed

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/StructureGenerator.kt

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ package software.amazon.smithy.kotlin.codegen.rendering
66

77
import software.amazon.smithy.codegen.core.CodegenException
88
import software.amazon.smithy.codegen.core.Symbol
9-
import software.amazon.smithy.kotlin.codegen.core.*
9+
import software.amazon.smithy.kotlin.codegen.core.RenderingContext
10+
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
11+
import software.amazon.smithy.kotlin.codegen.core.defaultName
12+
import software.amazon.smithy.kotlin.codegen.core.withBlock
1013
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1114
import software.amazon.smithy.kotlin.codegen.model.*
1215
import software.amazon.smithy.kotlin.codegen.rendering.serde.ClientErrorCorrection
@@ -192,20 +195,31 @@ class StructureGenerator(
192195

193196
for (memberShape in sortedMembers) {
194197
val target = model.expectShape(memberShape.target)
195-
val memberName = memberNameSymbolIndex[memberShape]!!.first
198+
val (memberName, memberSymbol) = memberNameSymbolIndex.getValue(memberShape)
196199
if (target is BlobShape && !target.hasTrait<StreamingTrait>()) {
197-
openBlock("if (#1L != null) {", memberName)
198-
.write("if (other.#1L == null) return false", memberName)
199-
.write("if (!#1L.contentEquals(other.#1L)) return false", memberName)
200-
.closeBlock("} else if (other.#1L != null) return false", memberName)
200+
if (memberSymbol.isNullable) {
201+
openBlock("if (#1L != null) {", memberName)
202+
.write("if (other.#1L == null) return false", memberName)
203+
.write("if (!#1L.contentEquals(other.#1L)) return false", memberName)
204+
.closeBlock("} else if (other.#1L != null) return false", memberName)
205+
} else {
206+
write("if (!#1L.contentEquals(other.#1L)) return false", memberName)
207+
}
201208
} else if (target is ListShape && target.member.targetOrSelf(model).isBlobShape) {
202-
openBlock("if (#L != null) {", memberName)
203-
.write("if (other.#L == null) return false", memberName)
204-
.write("if (#1L.size != other.#1L.size) return false", memberName)
205-
.openBlock("for (i in #L.indices) {", memberName)
206-
.write("if (!#1L[i].contentEquals(other.#1L[i])) return false", memberName)
207-
.closeBlock("}")
208-
.closeBlock("} else if (other.#1L != null) return false", memberName)
209+
if (memberSymbol.isNullable) {
210+
openBlock("if (#L != null) {", memberName)
211+
.write("if (other.#L == null) return false", memberName)
212+
.write("if (#1L.size != other.#1L.size) return false", memberName)
213+
.openBlock("for (i in #L.indices) {", memberName)
214+
.write("if (!#1L[i].contentEquals(other.#1L[i])) return false", memberName)
215+
.closeBlock("}")
216+
.closeBlock("} else if (other.#1L != null) return false", memberName)
217+
} else {
218+
write("if (#1L.size != other.#1L.size) return false", memberName)
219+
.openBlock("for (i in #L.indices) {", memberName)
220+
.write("if (!#1L[i].contentEquals(other.#1L[i])) return false", memberName)
221+
.closeBlock("}")
222+
}
209223
} else if (target is DoubleShape || target is FloatShape) {
210224
// NaNs must be compared using .equals()
211225
write("if (!(#1L?.equals(other.#1L) ?: (other.#1L == null))) return false", memberName)

0 commit comments

Comments
 (0)