11package org .jetbrains .plugins .scala .codeInspection .typeChecking
22
3- import com .intellij .codeInspection .{LocalInspectionTool , ProblemHighlightType , ProblemsHolder }
3+ import com .intellij .codeInspection .{LocalInspectionTool , ProblemsHolder }
44import com .intellij .psi .PsiMethod
55import com .siyeh .ig .psiutils .MethodUtils
66import org .jetbrains .annotations .Nls
77import org .jetbrains .plugins .scala .codeInspection .collections .MethodRepr
88import org .jetbrains .plugins .scala .codeInspection .typeChecking .ComparingUnrelatedTypesInspection ._
99import org .jetbrains .plugins .scala .codeInspection .{PsiElementVisitorSimple , ScalaInspectionBundle }
1010import org .jetbrains .plugins .scala .extensions ._
11+ import org .jetbrains .plugins .scala .lang .psi .api .base .types .ScParameterizedTypeElement
1112import org .jetbrains .plugins .scala .lang .psi .api .expr .{ScExpression , ScReferenceExpression }
1213import org .jetbrains .plugins .scala .lang .psi .api .statements .ScFunction
13- import org .jetbrains .plugins .scala .lang .psi .api .toplevel .typedef .ScClass
14+ import org .jetbrains .plugins .scala .lang .psi .api .toplevel .typedef .{ ScClass , ScGiven }
1415import org .jetbrains .plugins .scala .lang .psi .impl .toplevel .synthetic .ScSyntheticFunction
1516import org .jetbrains .plugins .scala .lang .psi .types ._
1617import org .jetbrains .plugins .scala .lang .psi .types .api ._
@@ -127,12 +128,37 @@ object ComparingUnrelatedTypesInspection {
127128 }
128129 }
129130 }
131+
132+ private def hasCanEqual (expr : ScExpression , source : ScType , target : ScType ): Boolean = {
133+ lazy val expressionTypes : Seq [ScType ] = List (source, target)
134+ lazy val canEqualExists : Boolean = expr
135+ .contexts
136+ .flatMap(_.children)
137+ .filterByType[ScGiven ]
138+ .filter(_.`type`().map(_.canonicalText.matches(" _root_\\ .scala\\ .CanEqual\\ [.+?, .+?]" )).getOrElse(false ))
139+ .flatMap(_.children.filterByType[ScParameterizedTypeElement ])
140+ .map(_.typeArgList.typeArgs.flatMap(_.`type`().map(_.tryExtractDesignatorSingleton).toSeq))
141+ .exists(_
142+ .zip(expressionTypes)
143+ .forall {
144+ case (givenType, compType) =>
145+ ! checkComparability(givenType, compType, isBuiltinOperation = true ).shouldNotBeCompared
146+ }
147+ )
148+
149+
150+ val wideSource : ScType = source.widenIfLiteral
151+ // Even though CanEqual[Primitive | String, _] can be defined and will satisfy compiler in strictEquals mode,
152+ // it is not possible to override equals method on the primitives or Strings
153+ ! wideSource.isPrimitive &&
154+ ! wideSource.canonicalText.matches(" _root_\\ .java\\ .lang\\ .String" ) &&
155+ (expr.isCompilerStrictEqualityMode || canEqualExists)
156+ }
130157}
131158
132159class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
133160
134161 override def buildVisitor (holder : ProblemsHolder , isOnTheFly : Boolean ): PsiElementVisitorSimple = {
135- case e if e.isInScala3File => () // TODO Handle Scala 3 code (`CanEqual` instances, etc.), SCL-19722
136162 case MethodRepr (expr, Some (left), Some (oper), Seq (right)) if isComparingFunctions(oper.refName) =>
137163 // "blub" == 3
138164 val needHighlighting = oper.resolve() match {
@@ -145,7 +171,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
145171 case Seq (Right (leftType), Right (rightType)) =>
146172 val isBuiltinOperation = isIdentityFunction(oper.refName) || ! hasNonDefaultEquals(leftType)
147173 val comparability = checkComparability(leftType, rightType, isBuiltinOperation)
148- if (comparability.shouldNotBeCompared) {
174+ if ((! expr.isInScala3File && comparability.shouldNotBeCompared) ||
175+ (expr.isInScala3File && comparability.shouldNotBeCompared && ! hasCanEqual(expr, leftType, rightType))) {
149176 val message = generateComparingUnrelatedTypesMsg(leftType, rightType)(expr)
150177 holder.registerProblem(expr, message)
151178 }
@@ -158,7 +185,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
158185 ParameterizedType (_, Seq (elemType)) <- receiverType(baseExpr, ref).map(_.tryExtractDesignatorSingleton)
159186 argType <- arg.`type`().toOption
160187 comparability = checkComparability(elemType, argType, isBuiltinOperation = ! hasNonDefaultEquals(elemType))
161- if comparability.shouldNotBeCompared
188+ if (! baseExpr.isInScala3File && comparability.shouldNotBeCompared) ||
189+ (baseExpr.isInScala3File && comparability.shouldNotBeCompared && ! hasCanEqual(baseExpr, elemType, argType))
162190 } {
163191 val message = generateComparingUnrelatedTypesMsg(elemType, argType)(arg)
164192 holder.registerProblem(arg, message)
0 commit comments