@@ -27,8 +27,8 @@ import org.json4s.JsonDSL._
2727import org .json4s .jackson .JsonMethods ._
2828
2929import org .apache .spark .internal .Logging
30- import org .apache .spark .ml .{Estimator , Model }
31- import org .apache .spark .ml .param .Param
30+ import org .apache .spark .ml .{Estimator , Model , PipelineStage }
31+ import org .apache .spark .ml .param .{ Param , Params }
3232import org .apache .spark .rdd .RDD
3333import org .apache .spark .sql .Dataset
3434import org .apache .spark .util .Utils
@@ -41,41 +41,40 @@ private[spark] class Instrumentation extends Logging {
4141
4242 private val id = UUID .randomUUID()
4343 private val shortId = id.toString.take(8 )
44- private var prefix = s " $shortId: "
44+ private val prefix = s " [ $shortId] "
4545
4646 // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor
47- var estimator : Estimator [_] = _
47+ var stage : Params = _
4848 private def this (estimator : Estimator [_], dataset : RDD [_]) = {
4949 this ()
50- logContext(estimator, dataset)
50+ logPipelineStage(estimator)
51+ logDataset(dataset)
5152 }
5253
5354 /**
54- * Log info about the estimator and dataset being fit.
55- *
56- * @param estimator the estimator that is being fit
57- * @param dataset the training dataset
55+ * Log some info about the pipeline stage being fit.
5856 */
59- def logContext (estimator : Estimator [_], dataset : RDD [_]): Unit = {
60- this .estimator = estimator
61- prefix = {
62- // estimator.getClass.getSimpleName can cause Malformed class name error,
63- // call safer `Utils.getSimpleName` instead
64- val className = Utils .getSimpleName(estimator.getClass)
65- s " $shortId- $className- ${estimator.uid}- ${dataset.hashCode()}: "
66- }
67-
68- log(s " training: numPartitions= ${dataset.partitions.length}" +
69- s " storageLevel= ${dataset.getStorageLevel}" )
57+ def logPipelineStage (stage : PipelineStage ): Unit = {
58+ this .stage = stage
59+ // estimator.getClass.getSimpleName can cause Malformed class name error,
60+ // call safer `Utils.getSimpleName` instead
61+ val className = Utils .getSimpleName(stage.getClass)
62+ logInfo(s " Stage class: $className" )
63+ logInfo(s " Stage uid: ${stage.uid}" )
7064 }
7165
7266 /**
73- * Log info about the estimator and dataset being fit.
74- *
75- * @param e the estimator that is being fit
76- * @param dataset the training dataset
67+ * Log some data about the dataset being fit.
7768 */
78- def logContext (e : Estimator [_], dataset : Dataset [_]): Unit = logContext(e, dataset.rdd)
69+ def logDataset (dataset : Dataset [_]): Unit = logDataset(dataset.rdd)
70+
71+ /**
72+ * Log some data about the dataset being fit.
73+ */
74+ def logDataset (dataset : RDD [_]): Unit = {
75+ logInfo(s " training: numPartitions= ${dataset.partitions.length}" +
76+ s " storageLevel= ${dataset.getStorageLevel}" )
77+ }
7978
8079 /**
8180 * Logs a debug message with a prefix that uniquely identifies the training session.
@@ -105,29 +104,25 @@ private[spark] class Instrumentation extends Logging {
105104 super .logInfo(prefix + msg)
106105 }
107106
108- /**
109- * Alias for logInfo, see above.
110- */
111- def log (msg : String ): Unit = logInfo(msg)
112-
113107 /**
114108 * Logs the value of the given parameters for the estimator being used in this session.
115109 */
116- def logParams (estimator : Estimator [_] , params : Param [_]* ): Unit = {
110+ def logParams (hasParams : Params , params : Param [_]* ): Unit = {
117111 val pairs : Seq [(String , JValue )] = for {
118112 p <- params
119- value <- estimator .get(p)
113+ value <- hasParams .get(p)
120114 } yield {
121115 val cast = p.asInstanceOf [Param [Any ]]
122116 p.name -> parse(cast.jsonEncode(value))
123117 }
124- log (compact(render(map2jvalue(pairs.toMap))))
118+ logInfo (compact(render(map2jvalue(pairs.toMap))))
125119 }
126120
127121 // TODO: remove this
128122 def logParams (params : Param [_]* ): Unit = {
129- require(estimator != null , " `logContext` must be called before `logParams`." )
130- logParams(estimator, params : _* )
123+ require(stage != null , " `logStageParams` must be called before `logParams` (or an instance of" +
124+ " Params must be provided explicitly)." )
125+ logParams(stage, params : _* )
131126 }
132127
133128 def logNumFeatures (num : Long ): Unit = {
@@ -146,27 +141,27 @@ private[spark] class Instrumentation extends Logging {
146141 * Logs the value with customized name field.
147142 */
148143 def logNamedValue (name : String , value : String ): Unit = {
149- log (compact(render(name -> value)))
144+ logInfo (compact(render(name -> value)))
150145 }
151146
152147 def logNamedValue (name : String , value : Long ): Unit = {
153- log (compact(render(name -> value)))
148+ logInfo (compact(render(name -> value)))
154149 }
155150
156151 def logNamedValue (name : String , value : Double ): Unit = {
157- log (compact(render(name -> value)))
152+ logInfo (compact(render(name -> value)))
158153 }
159154
160155 def logNamedValue (name : String , value : Array [String ]): Unit = {
161- log (compact(render(name -> compact(render(value.toSeq)))))
156+ logInfo (compact(render(name -> compact(render(value.toSeq)))))
162157 }
163158
164159 def logNamedValue (name : String , value : Array [Long ]): Unit = {
165- log (compact(render(name -> compact(render(value.toSeq)))))
160+ logInfo (compact(render(name -> compact(render(value.toSeq)))))
166161 }
167162
168163 def logNamedValue (name : String , value : Array [Double ]): Unit = {
169- log (compact(render(name -> compact(render(value.toSeq)))))
164+ logInfo (compact(render(name -> compact(render(value.toSeq)))))
170165 }
171166
172167
@@ -175,19 +170,19 @@ private[spark] class Instrumentation extends Logging {
175170 * Logs the successful completion of the training session.
176171 */
177172 def logSuccess (model : Model [_]): Unit = {
178- log (s " training finished " )
173+ logInfo (s " training finished " )
179174 }
180175
181176 def logSuccess (): Unit = {
182- log (" training finished" )
177+ logInfo (" training finished" )
183178 }
184179
185180 /**
186181 * Logs an exception raised during a training session.
187182 */
188183 def logFailure (e : Throwable ): Unit = {
189184 val msg = e.getStackTrace.mkString(" \n " )
190- super .logInfo (msg)
185+ super .logError (msg)
191186 }
192187}
193188
@@ -222,13 +217,13 @@ private[spark] object Instrumentation {
222217
223218 def instrumented [T ](body : (Instrumentation => T )): T = {
224219 val instr = new Instrumentation ()
225- Try (body(new Instrumentation () )) match {
220+ Try (body(instr )) match {
226221 case Failure (NonFatal (e)) =>
227222 instr.logFailure(e)
228223 throw e
229- case Success (model ) =>
224+ case Success (result ) =>
230225 instr.logSuccess()
231- model
226+ result
232227 }
233228 }
234229}
@@ -273,7 +268,7 @@ private[spark] object OptionalInstrumentation {
273268 */
274269 def create (instr : Instrumentation ): OptionalInstrumentation = {
275270 new OptionalInstrumentation (Some (instr),
276- instr.estimator .getClass.getName.stripSuffix(" $" ))
271+ instr.stage .getClass.getName.stripSuffix(" $" ))
277272 }
278273
279274 /**
0 commit comments