Skip to content

Commit b98d772

Browse files
committed
PR feedback.
1 parent 3a6537d commit b98d772

File tree

2 files changed

+47
-50
lines changed

2 files changed

+47
-50
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,10 @@ class LogisticRegression @Since("1.2.0") (
488488
train(dataset, handlePersistence)
489489
}
490490

491+
import Instrumentation.instrumented
491492
protected[spark] def train(
492493
dataset: Dataset[_],
493-
handlePersistence: Boolean): LogisticRegressionModel = Instrumentation.instrumented { instr =>
494+
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
494495
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
495496
val instances: RDD[Instance] =
496497
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
@@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") (
500501

501502
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
502503

503-
instr.logContext(this, dataset)
504+
instr.logPipelineStage(this)
505+
instr.logDataset(dataset)
504506
instr.logParams(regParam, elasticNetParam, standardization, threshold,
505507
maxIter, tol, fitIntercept)
506508

mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import org.json4s.JsonDSL._
2727
import org.json4s.jackson.JsonMethods._
2828

2929
import org.apache.spark.internal.Logging
30-
import org.apache.spark.ml.{Estimator, Model}
31-
import org.apache.spark.ml.param.Param
30+
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
31+
import org.apache.spark.ml.param.{Param, Params}
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.sql.Dataset
3434
import org.apache.spark.util.Utils
@@ -41,41 +41,40 @@ private[spark] class Instrumentation extends Logging {
4141

4242
private val id = UUID.randomUUID()
4343
private val shortId = id.toString.take(8)
44-
private var prefix = s"$shortId:"
44+
private val prefix = s"[$shortId] "
4545

4646
// TODO: update spark.ml to use new Instrumentation APIs and remove this constructor
47-
var estimator: Estimator[_] = _
47+
var stage: Params = _
4848
private def this(estimator: Estimator[_], dataset: RDD[_]) = {
4949
this()
50-
logContext(estimator, dataset)
50+
logPipelineStage(estimator)
51+
logDataset(dataset)
5152
}
5253

5354
/**
54-
* Log info about the estimator and dataset being fit.
55-
*
56-
* @param estimator the estimator that is being fit
57-
* @param dataset the training dataset
55+
* Log some info about the pipeline stage being fit.
5856
*/
59-
def logContext(estimator: Estimator[_], dataset: RDD[_]): Unit = {
60-
this.estimator = estimator
61-
prefix = {
62-
// estimator.getClass.getSimpleName can cause Malformed class name error,
63-
// call safer `Utils.getSimpleName` instead
64-
val className = Utils.getSimpleName(estimator.getClass)
65-
s"$shortId-$className-${estimator.uid}-${dataset.hashCode()}:"
66-
}
67-
68-
log(s"training: numPartitions=${dataset.partitions.length}" +
69-
s" storageLevel=${dataset.getStorageLevel}")
57+
def logPipelineStage(stage: PipelineStage): Unit = {
58+
this.stage = stage
59+
// estimator.getClass.getSimpleName can cause Malformed class name error,
60+
// call safer `Utils.getSimpleName` instead
61+
val className = Utils.getSimpleName(stage.getClass)
62+
logInfo(s"Stage class: $className")
63+
logInfo(s"Stage uid: ${stage.uid}")
7064
}
7165

7266
/**
73-
* Log info about the estimator and dataset being fit.
74-
*
75-
* @param e the estimator that is being fit
76-
* @param dataset the training dataset
67+
* Log some data about the dataset being fit.
7768
*/
78-
def logContext(e: Estimator[_], dataset: Dataset[_]): Unit = logContext(e, dataset.rdd)
69+
def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd)
70+
71+
/**
72+
* Log some data about the dataset being fit.
73+
*/
74+
def logDataset(dataset: RDD[_]): Unit = {
75+
logInfo(s"training: numPartitions=${dataset.partitions.length}" +
76+
s" storageLevel=${dataset.getStorageLevel}")
77+
}
7978

8079
/**
8180
* Logs a debug message with a prefix that uniquely identifies the training session.
@@ -105,29 +104,25 @@ private[spark] class Instrumentation extends Logging {
105104
super.logInfo(prefix + msg)
106105
}
107106

108-
/**
109-
* Alias for logInfo, see above.
110-
*/
111-
def log(msg: String): Unit = logInfo(msg)
112-
113107
/**
114108
* Logs the value of the given parameters for the estimator being used in this session.
115109
*/
116-
def logParams(estimator: Estimator[_], params: Param[_]*): Unit = {
110+
def logParams(hasParams: Params, params: Param[_]*): Unit = {
117111
val pairs: Seq[(String, JValue)] = for {
118112
p <- params
119-
value <- estimator.get(p)
113+
value <- hasParams.get(p)
120114
} yield {
121115
val cast = p.asInstanceOf[Param[Any]]
122116
p.name -> parse(cast.jsonEncode(value))
123117
}
124-
log(compact(render(map2jvalue(pairs.toMap))))
118+
logInfo(compact(render(map2jvalue(pairs.toMap))))
125119
}
126120

127121
// TODO: remove this
128122
def logParams(params: Param[_]*): Unit = {
129-
require(estimator != null, "`logContext` must be called before `logParams`.")
130-
logParams(estimator, params: _*)
123+
require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" +
124+
" Params must be provided explicitly).")
125+
logParams(stage, params: _*)
131126
}
132127

133128
def logNumFeatures(num: Long): Unit = {
@@ -146,27 +141,27 @@ private[spark] class Instrumentation extends Logging {
146141
* Logs the value with customized name field.
147142
*/
148143
def logNamedValue(name: String, value: String): Unit = {
149-
log(compact(render(name -> value)))
144+
logInfo(compact(render(name -> value)))
150145
}
151146

152147
def logNamedValue(name: String, value: Long): Unit = {
153-
log(compact(render(name -> value)))
148+
logInfo(compact(render(name -> value)))
154149
}
155150

156151
def logNamedValue(name: String, value: Double): Unit = {
157-
log(compact(render(name -> value)))
152+
logInfo(compact(render(name -> value)))
158153
}
159154

160155
def logNamedValue(name: String, value: Array[String]): Unit = {
161-
log(compact(render(name -> compact(render(value.toSeq)))))
156+
logInfo(compact(render(name -> compact(render(value.toSeq)))))
162157
}
163158

164159
def logNamedValue(name: String, value: Array[Long]): Unit = {
165-
log(compact(render(name -> compact(render(value.toSeq)))))
160+
logInfo(compact(render(name -> compact(render(value.toSeq)))))
166161
}
167162

168163
def logNamedValue(name: String, value: Array[Double]): Unit = {
169-
log(compact(render(name -> compact(render(value.toSeq)))))
164+
logInfo(compact(render(name -> compact(render(value.toSeq)))))
170165
}
171166

172167

@@ -175,19 +170,19 @@ private[spark] class Instrumentation extends Logging {
175170
* Logs the successful completion of the training session.
176171
*/
177172
def logSuccess(model: Model[_]): Unit = {
178-
log(s"training finished")
173+
logInfo(s"training finished")
179174
}
180175

181176
def logSuccess(): Unit = {
182-
log("training finished")
177+
logInfo("training finished")
183178
}
184179

185180
/**
186181
* Logs an exception raised during a training session.
187182
*/
188183
def logFailure(e: Throwable): Unit = {
189184
val msg = e.getStackTrace.mkString("\n")
190-
super.logInfo(msg)
185+
super.logError(msg)
191186
}
192187
}
193188

@@ -222,13 +217,13 @@ private[spark] object Instrumentation {
222217

223218
def instrumented[T](body: (Instrumentation => T)): T = {
224219
val instr = new Instrumentation()
225-
Try(body(new Instrumentation())) match {
220+
Try(body(instr)) match {
226221
case Failure(NonFatal(e)) =>
227222
instr.logFailure(e)
228223
throw e
229-
case Success(model) =>
224+
case Success(result) =>
230225
instr.logSuccess()
231-
model
226+
result
232227
}
233228
}
234229
}
@@ -273,7 +268,7 @@ private[spark] object OptionalInstrumentation {
273268
*/
274269
def create(instr: Instrumentation): OptionalInstrumentation = {
275270
new OptionalInstrumentation(Some(instr),
276-
instr.estimator.getClass.getName.stripSuffix("$"))
271+
instr.stage.getClass.getName.stripSuffix("$"))
277272
}
278273

279274
/**

0 commit comments

Comments
 (0)