diff --git a/caller/shared/src/main/scala/com/thoughtworks/Caller.scala b/caller/shared/src/main/scala/com/thoughtworks/Caller.scala index 7032750..627e2a1 100644 --- a/caller/shared/src/main/scala/com/thoughtworks/Caller.scala +++ b/caller/shared/src/main/scala/com/thoughtworks/Caller.scala @@ -6,10 +6,10 @@ import scala.reflect.macros.Context final case class Caller[+A](value: A) object Caller { - implicit def generate: Caller[Any] = macro thisCaller + implicit def generate[A]: Caller[A] = macro thisCaller[A] - def thisCaller(c: Context) = { + def thisCaller[A](c: Context): c.Expr[Caller[A]] = { import c.universe._ - c.Expr[Caller[Any]](q"new _root_.com.thoughtworks.Caller[this.type](this)") + c.Expr[Caller[A]](q"new _root_.com.thoughtworks.Caller[this.type](this)") } } diff --git a/caller/shared/src/test/scala/com/thoughtworks/CallerSpec.scala b/caller/shared/src/test/scala/com/thoughtworks/CallerSpec.scala index 2612fd8..03142ec 100644 --- a/caller/shared/src/test/scala/com/thoughtworks/CallerSpec.scala +++ b/caller/shared/src/test/scala/com/thoughtworks/CallerSpec.scala @@ -1,12 +1,26 @@ package com.thoughtworks import org.scalatest.{FreeSpec, Matchers} + object CallerSpec { object Foo { def call(implicit caller: Caller[_]): String = { caller.value.getClass.getName } } + + class IKnowWhatImDoing + + object Foo2{ + def runDangerous()(implicit caller: Caller[IKnowWhatImDoing]) = { + println(caller.value) + } + } + + object Bar2 extends IKnowWhatImDoing { + Foo2.runDangerous() // ok, prints Bar2 + } + } final class CallerSpec extends FreeSpec with Matchers { @@ -16,4 +30,12 @@ final class CallerSpec extends FreeSpec with Matchers { className should be(this.getClass.getName) } + "restricted" in { + """ + object Bar { + Foo2.runDangerous() + } + """ shouldNot typeCheck + } + }