diff --git a/Factory/.jvm/src/test/scala/com/thoughtworks/feature/FactorySpec.scala b/Factory/.jvm/src/test/scala/com/thoughtworks/feature/FactorySpec.scala index 255ef46..abf04b6 100644 --- a/Factory/.jvm/src/test/scala/com/thoughtworks/feature/FactorySpec.scala +++ b/Factory/.jvm/src/test/scala/com/thoughtworks/feature/FactorySpec.scala @@ -3,6 +3,7 @@ package com.thoughtworks.feature import com.thoughtworks.feature.Factory.inject import org.scalatest.{FreeSpec, Matchers} import shapeless.Witness +import scala.language.higherKinds /** * @author 杨博 (Yang Bo) @@ -44,4 +45,25 @@ class FactorySpec extends FreeSpec with Matchers { Factory[A].newInstance().witness42.value should be(42) } + "Inner abstract types can have type parameter" in { + trait Outer1 { + protected trait InnerApi[A, +B, -C] + type Inner[A, +B, -C] <: InnerApi[A, B, C] + + @inject + def innerFactory[A]: Factory[Inner[A, String, Double]] + + } + + trait Outer2 { + protected trait InnerApi[A, +B, -C] + type Inner[A, +B, -C] <: InnerApi[A, B, C] + + } + + val outer = Factory[Outer1 with Outer2].newInstance() + val inner = outer.innerFactory[Int].newInstance() + "implicitly[inner.type <:< outer.Inner[Int, String, Double]]" should compile + } + } diff --git a/Factory/src/main/scala/com/thoughtworks/feature/Factory.scala b/Factory/src/main/scala/com/thoughtworks/feature/Factory.scala index 20576da..c434fca 100644 --- a/Factory/src/main/scala/com/thoughtworks/feature/Factory.scala +++ b/Factory/src/main/scala/com/thoughtworks/feature/Factory.scala @@ -132,7 +132,6 @@ import scala.collection.mutable.ListBuffer * val outer: Outer = Factory[Outer].newInstance(inner = Some("my value")) * outer.inner should be(Some("my value")) * }}} - * @author 杨博 (Yang Bo) <pop.atry@gmail.com> */ trait Factory[Output] extends Serializable { @@ -217,6 +216,10 @@ object Factory { } private val injectType = typeOf[inject] + private def isAbstractType[Output: WeakTypeTag](symbol: c.universe.TypeSymbol): Boolean = { + symbol.isAbstract && !symbol.isClass + } + def apply[Output: WeakTypeTag]: Tree = { val output = weakTypeOf[Output] @@ -238,16 +241,10 @@ object Factory { } } - def untype(tpe: Type): Tree = { - val untyper = new ThisUntyper - untyper.untype(tpe) - } + def untyper = new ThisUntyper - def dealiasUntype(tpe: Type): Tree = { - val untyper = new ThisUntyper { - override protected def preprocess(tpe: Type): Type = tpe.dealias - } - untyper.untype(tpe) + def dealiasUntyper = new ThisUntyper { + override protected def preprocess(tpe: Type): Type = tpe.dealias } val injectedNames = (for { @@ -267,7 +264,7 @@ object Factory { val methodName = injectedName.toTermName val memberSymbol = linearOutput.member(methodName).asTerm val methodType = memberSymbol.infoIn(linearThis) - val resultTypeTree: Tree = untype(methodType.finalResultType) + val resultTypeTree: Tree = untyper.untype(methodType.finalResultType) val modifiers = Modifiers( Flag.OVERRIDE | @@ -291,13 +288,16 @@ object Factory { } else { val argumentTrees = methodType.paramLists.map(_.map { argumentSymbol => if (argumentSymbol.asTerm.isImplicit) { - q"implicit val ${argumentSymbol.name.toTermName}: ${untype(argumentSymbol.info)}" + q"implicit val ${argumentSymbol.name.toTermName}: ${untyper.untype(argumentSymbol.info)}" } else { - q"val ${argumentSymbol.name.toTermName}: ${untype(argumentSymbol.info)}" + q"val ${argumentSymbol.name.toTermName}: ${untyper.untype(argumentSymbol.info)}" } }) + val typeParameterTrees = methodType.typeParams.map { typeParamSymbol => + untyper.typeDefinition(linearThis)(typeParamSymbol.asType) + } q""" - $modifiers def $methodName[..${methodType.typeArgs}](...$argumentTrees) = { + $modifiers def $methodName[..$typeParameterTrees](...$argumentTrees) = { val $methodName = () _root_.com.thoughtworks.feature.The.apply[$resultTypeTree].value } @@ -315,7 +315,7 @@ object Factory { val methodName = memberSymbol.name.toTermName val argumentName = c.freshName(methodName) val methodType = memberSymbol.infoIn(linearThis) - val resultTypeTree: Tree = dealiasUntype(methodType.finalResultType) + val resultTypeTree: Tree = dealiasUntyper.untype(methodType.finalResultType) if (memberSymbol.isVar || memberSymbol.setter != NoSymbol) { (q"override var $methodName = $argumentName", resultTypeTree, @@ -332,7 +332,7 @@ object Factory { argumentIdTrees: List[List[Ident]]) = methodType.paramLists.map { parameterList => parameterList.map { argumentSymbol => - val argumentTypeTree: Tree = dealiasUntype(argumentSymbol.info) + val argumentTypeTree: Tree = dealiasUntyper.untype(argumentSymbol.info) val argumentName = argumentSymbol.name.toTermName val argumentTree = if (argumentSymbol.asTerm.isImplicit) { q"implicit val $argumentName: $argumentTypeTree" @@ -349,7 +349,10 @@ object Factory { tq"..$arguments => $result" } } - (q"override def $methodName[..${methodType.typeArgs}](...$argumentTrees) = $argumentName", + val typeParameterTrees = methodType.typeParams.map { typeParamSymbol => + untyper.typeDefinition(linearThis)(typeParamSymbol.asType) + } + (q"override def $methodName[..$typeParameterTrees](...$argumentTrees) = $argumentName", functionTypeTree, q"val $argumentName: $functionTypeTree", q"val $methodName: $functionTypeTree") @@ -358,40 +361,31 @@ object Factory { val (proxies, parameterTypeTrees, parameterTrees, refinedTree) = zippedProxies.unzip4 val (defProxies, valProxies) = proxies.partition(_.isDef) - val overridenTypes = - (for { - componentType <- componentTypes - member <- componentType.members - if member.isType - } yield member).distinct - .groupBy(_.name.toString) - .withFilter { - _._2.forall { - _.info match { - case TypeBounds(_, _) => true - case _ => false - } - } - } - .map { - case (name, members) => - val lowerBounds: List[Tree] = members.collect(scala.Function.unlift[Symbol, Tree] { memberSymbol => - val TypeBounds(_, lowerBound) = memberSymbol.infoIn(linearThis) - if (lowerBound =:= definitions.AnyTpe) { - None - } else { - Some(untype(lowerBound)) - } - })(collection.breakOut(List.canBuildFrom)) - val typeTree = if (lowerBounds.isEmpty) { - TypeTree(definitions.AnyTpe) - } else { - CompoundTypeTree(Template(lowerBounds, noSelfType, Nil)) - } - val result = q"override type ${TypeName(name)} = $typeTree" - // c.info(c.enclosingPosition, show(result), true) - result - } + val typeMembers = for { + componentType <- componentTypes + member <- componentType.members + if member.isType + } yield member.asType + + val groupedTypeSymbols = typeMembers.groupBy(_.name.encodedName.toTypeName) + + def overrideType(name: TypeName, members: List[TypeSymbol]): Tree = { + val glbType = glb(members.map { memberSymbol => + memberSymbol.infoIn(linearThis) + }) + val typeParameterTrees = glbType.typeParams.map { typeParamSymbol => + untyper.typeDefinition(linearThis)(typeParamSymbol.asType) + } + val TypeBounds(_, lowerBound) = glbType.resultType + val result = q"override type $name[..$typeParameterTrees] = ${untyper.untype(lowerBound)}" + // c.info(c.enclosingPosition, show(result), true) + result + } + + val overridenTypes = for { + (name, members) <- groupedTypeSymbols + if members.forall(isAbstractType) + } yield overrideType(name, members) val makeNew = TermName(c.freshName("makeNew")) val constructorMethod = TermName(c.freshName("constructor")) diff --git a/SelfType/src/main/scala/com/thoughtworks/feature/SelfType.scala b/SelfType/src/main/scala/com/thoughtworks/feature/SelfType.scala index 40a61f0..622e49e 100644 --- a/SelfType/src/main/scala/com/thoughtworks/feature/SelfType.scala +++ b/SelfType/src/main/scala/com/thoughtworks/feature/SelfType.scala @@ -68,10 +68,6 @@ trait SelfType[A] { object SelfType { -// final class SelfTypeAux[A, Out0] extends SelfType[A] { -// type Out = Out0 -// } - type Aux[A, Out0] = SelfType[A] { type Out = Out0 } diff --git a/Untyper/src/main/scala/com/thoughtworks/feature/Untyper.scala b/Untyper/src/main/scala/com/thoughtworks/feature/Untyper.scala index 8806bc9..55e3f66 100644 --- a/Untyper/src/main/scala/com/thoughtworks/feature/Untyper.scala +++ b/Untyper/src/main/scala/com/thoughtworks/feature/Untyper.scala @@ -61,14 +61,29 @@ class Untyper[Universe <: Singleton with scala.reflect.api.Universe](val univers } private def typeDefinitionOption(symbol: TypeSymbol)(implicit tpe: Type): Option[TypeDef] = { + val flags = { + if (symbol.isCovariant) { + Flag.COVARIANT + } else if (symbol.isContravariant) { + Flag.CONTRAVARIANT + } else { + NoFlags + } + } | { + if (symbol.isParameter) { + Flag.PARAM + } else { + NoFlags + } + } symbol match { case typeDefinitionSymbol.extract(name, typeDefinition.extract.forall(params), TypeBounds(untypeOption.extract(upper), untypeOption.extract(lower))) => - Some(TypeDef(Modifiers(Flag.PARAM), name, params.toList, TypeBoundsTree(upper, lower))) + Some(q"$flags type $name[..$params] >: $upper <: $lower") case typeDefinitionSymbol .extract(name, typeDefinition.extract.forall(params: Seq[TypeDef]), untypeOption.extract(concreteType)) => - Some(q"type $name[..$params] = $concreteType") + Some(q"$flags type $name[..$params] = $concreteType") case _ => None } @@ -89,11 +104,12 @@ class Untyper[Universe <: Singleton with scala.reflect.api.Universe](val univers case varDefinitionSymbol.extract(name, untypeOption.extract(result)) => Some(q"var $name: $result") case valDefinitionSymbol.extract(name, untypeOption.extract(result)) => - Some(if (symbol.isImplicit) { - q"implicit val $name: $result" + val flags = if (symbol.isImplicit) { + Flag.IMPLICIT } else { - q"val $name: $result" - }) + NoFlags + } + Some(q"$flags val $name: $result") case defDefinitionSymbol.extract(name, typeDefinition.extract.forall(typeParams), termDefinition.extract.forall.forall(params),