Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.thoughtworks.feature
import com.thoughtworks.feature.Factory.inject
import org.scalatest.{FreeSpec, Matchers}
import shapeless.Witness
import scala.language.higherKinds

/**
* @author 杨博 (Yang Bo)
Expand Down Expand Up @@ -44,4 +45,25 @@ class FactorySpec extends FreeSpec with Matchers {
Factory[A].newInstance().witness42.value should be(42)
}

"Inner abstract types can have type parameter" in {
trait Outer1 {
protected trait InnerApi[A, +B, -C]
type Inner[A, +B, -C] <: InnerApi[A, B, C]

@inject
def innerFactory[A]: Factory[Inner[A, String, Double]]

}

trait Outer2 {
protected trait InnerApi[A, +B, -C]
type Inner[A, +B, -C] <: InnerApi[A, B, C]

}

val outer = Factory[Outer1 with Outer2].newInstance()
val inner = outer.innerFactory[Int].newInstance()
"implicitly[inner.type <:< outer.Inner[Int, String, Double]]" should compile
}

}
96 changes: 45 additions & 51 deletions Factory/src/main/scala/com/thoughtworks/feature/Factory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ import scala.collection.mutable.ListBuffer
* val outer: Outer = Factory[Outer].newInstance(inner = Some("my value"))
* outer.inner should be(Some("my value"))
* }}}

* @author 杨博 (Yang Bo) &lt;[email protected]&gt;
*/
trait Factory[Output] extends Serializable {
Expand Down Expand Up @@ -217,6 +216,10 @@ object Factory {
}
private val injectType = typeOf[inject]

private def isAbstractType[Output: WeakTypeTag](symbol: c.universe.TypeSymbol): Boolean = {
symbol.isAbstract && !symbol.isClass
}

def apply[Output: WeakTypeTag]: Tree = {
val output = weakTypeOf[Output]

Expand All @@ -238,16 +241,10 @@ object Factory {
}
}

def untype(tpe: Type): Tree = {
val untyper = new ThisUntyper
untyper.untype(tpe)
}
def untyper = new ThisUntyper

def dealiasUntype(tpe: Type): Tree = {
val untyper = new ThisUntyper {
override protected def preprocess(tpe: Type): Type = tpe.dealias
}
untyper.untype(tpe)
def dealiasUntyper = new ThisUntyper {
override protected def preprocess(tpe: Type): Type = tpe.dealias
}

val injectedNames = (for {
Expand All @@ -267,7 +264,7 @@ object Factory {
val methodName = injectedName.toTermName
val memberSymbol = linearOutput.member(methodName).asTerm
val methodType = memberSymbol.infoIn(linearThis)
val resultTypeTree: Tree = untype(methodType.finalResultType)
val resultTypeTree: Tree = untyper.untype(methodType.finalResultType)

val modifiers = Modifiers(
Flag.OVERRIDE |
Expand All @@ -291,13 +288,16 @@ object Factory {
} else {
val argumentTrees = methodType.paramLists.map(_.map { argumentSymbol =>
if (argumentSymbol.asTerm.isImplicit) {
q"implicit val ${argumentSymbol.name.toTermName}: ${untype(argumentSymbol.info)}"
q"implicit val ${argumentSymbol.name.toTermName}: ${untyper.untype(argumentSymbol.info)}"
} else {
q"val ${argumentSymbol.name.toTermName}: ${untype(argumentSymbol.info)}"
q"val ${argumentSymbol.name.toTermName}: ${untyper.untype(argumentSymbol.info)}"
}
})
val typeParameterTrees = methodType.typeParams.map { typeParamSymbol =>
untyper.typeDefinition(linearThis)(typeParamSymbol.asType)
}
q"""
$modifiers def $methodName[..${methodType.typeArgs}](...$argumentTrees) = {
$modifiers def $methodName[..$typeParameterTrees](...$argumentTrees) = {
val $methodName = ()
_root_.com.thoughtworks.feature.The.apply[$resultTypeTree].value
}
Expand All @@ -315,7 +315,7 @@ object Factory {
val methodName = memberSymbol.name.toTermName
val argumentName = c.freshName(methodName)
val methodType = memberSymbol.infoIn(linearThis)
val resultTypeTree: Tree = dealiasUntype(methodType.finalResultType)
val resultTypeTree: Tree = dealiasUntyper.untype(methodType.finalResultType)
if (memberSymbol.isVar || memberSymbol.setter != NoSymbol) {
(q"override var $methodName = $argumentName",
resultTypeTree,
Expand All @@ -332,7 +332,7 @@ object Factory {
argumentIdTrees: List[List[Ident]]) =
methodType.paramLists.map { parameterList =>
parameterList.map { argumentSymbol =>
val argumentTypeTree: Tree = dealiasUntype(argumentSymbol.info)
val argumentTypeTree: Tree = dealiasUntyper.untype(argumentSymbol.info)
val argumentName = argumentSymbol.name.toTermName
val argumentTree = if (argumentSymbol.asTerm.isImplicit) {
q"implicit val $argumentName: $argumentTypeTree"
Expand All @@ -349,7 +349,10 @@ object Factory {
tq"..$arguments => $result"
}
}
(q"override def $methodName[..${methodType.typeArgs}](...$argumentTrees) = $argumentName",
val typeParameterTrees = methodType.typeParams.map { typeParamSymbol =>
untyper.typeDefinition(linearThis)(typeParamSymbol.asType)
}
(q"override def $methodName[..$typeParameterTrees](...$argumentTrees) = $argumentName",
functionTypeTree,
q"val $argumentName: $functionTypeTree",
q"val $methodName: $functionTypeTree")
Expand All @@ -358,40 +361,31 @@ object Factory {

val (proxies, parameterTypeTrees, parameterTrees, refinedTree) = zippedProxies.unzip4
val (defProxies, valProxies) = proxies.partition(_.isDef)
val overridenTypes =
(for {
componentType <- componentTypes
member <- componentType.members
if member.isType
} yield member).distinct
.groupBy(_.name.toString)
.withFilter {
_._2.forall {
_.info match {
case TypeBounds(_, _) => true
case _ => false
}
}
}
.map {
case (name, members) =>
val lowerBounds: List[Tree] = members.collect(scala.Function.unlift[Symbol, Tree] { memberSymbol =>
val TypeBounds(_, lowerBound) = memberSymbol.infoIn(linearThis)
if (lowerBound =:= definitions.AnyTpe) {
None
} else {
Some(untype(lowerBound))
}
})(collection.breakOut(List.canBuildFrom))
val typeTree = if (lowerBounds.isEmpty) {
TypeTree(definitions.AnyTpe)
} else {
CompoundTypeTree(Template(lowerBounds, noSelfType, Nil))
}
val result = q"override type ${TypeName(name)} = $typeTree"
// c.info(c.enclosingPosition, show(result), true)
result
}
val typeMembers = for {
componentType <- componentTypes
member <- componentType.members
if member.isType
} yield member.asType

val groupedTypeSymbols = typeMembers.groupBy(_.name.encodedName.toTypeName)

def overrideType(name: TypeName, members: List[TypeSymbol]): Tree = {
val glbType = glb(members.map { memberSymbol =>
memberSymbol.infoIn(linearThis)
})
val typeParameterTrees = glbType.typeParams.map { typeParamSymbol =>
untyper.typeDefinition(linearThis)(typeParamSymbol.asType)
}
val TypeBounds(_, lowerBound) = glbType.resultType
val result = q"override type $name[..$typeParameterTrees] = ${untyper.untype(lowerBound)}"
// c.info(c.enclosingPosition, show(result), true)
result
}

val overridenTypes = for {
(name, members) <- groupedTypeSymbols
if members.forall(isAbstractType)
} yield overrideType(name, members)

val makeNew = TermName(c.freshName("makeNew"))
val constructorMethod = TermName(c.freshName("constructor"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ trait SelfType[A] {

object SelfType {

// final class SelfTypeAux[A, Out0] extends SelfType[A] {
// type Out = Out0
// }

type Aux[A, Out0] = SelfType[A] {
type Out = Out0
}
Expand Down
28 changes: 22 additions & 6 deletions Untyper/src/main/scala/com/thoughtworks/feature/Untyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,29 @@ class Untyper[Universe <: Singleton with scala.reflect.api.Universe](val univers
}

private def typeDefinitionOption(symbol: TypeSymbol)(implicit tpe: Type): Option[TypeDef] = {
val flags = {
if (symbol.isCovariant) {
Flag.COVARIANT
} else if (symbol.isContravariant) {
Flag.CONTRAVARIANT
} else {
NoFlags
}
} | {
if (symbol.isParameter) {
Flag.PARAM
} else {
NoFlags
}
}
symbol match {
case typeDefinitionSymbol.extract(name,
typeDefinition.extract.forall(params),
TypeBounds(untypeOption.extract(upper), untypeOption.extract(lower))) =>
Some(TypeDef(Modifiers(Flag.PARAM), name, params.toList, TypeBoundsTree(upper, lower)))
Some(q"$flags type $name[..$params] >: $upper <: $lower")
case typeDefinitionSymbol
.extract(name, typeDefinition.extract.forall(params: Seq[TypeDef]), untypeOption.extract(concreteType)) =>
Some(q"type $name[..$params] = $concreteType")
Some(q"$flags type $name[..$params] = $concreteType")
case _ =>
None
}
Expand All @@ -89,11 +104,12 @@ class Untyper[Universe <: Singleton with scala.reflect.api.Universe](val univers
case varDefinitionSymbol.extract(name, untypeOption.extract(result)) =>
Some(q"var $name: $result")
case valDefinitionSymbol.extract(name, untypeOption.extract(result)) =>
Some(if (symbol.isImplicit) {
q"implicit val $name: $result"
val flags = if (symbol.isImplicit) {
Flag.IMPLICIT
} else {
q"val $name: $result"
})
NoFlags
}
Some(q"$flags val $name: $result")
case defDefinitionSymbol.extract(name,
typeDefinition.extract.forall(typeParams),
termDefinition.extract.forall.forall(params),
Expand Down