Skip to content

Commit 0fa5787

Browse files
committed
[SPARK-48883][ML][R] Replace RDD read / write API invocation with Dataframe read / write API
### What changes were proposed in this pull request? Replace RDD read / write API invocation with Dataframe read / write API ### Why are the changes needed? In databricks runtime, RDD read / write API has some issue for certain storage types that requires the account key, but Dataframe read / write API works. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47328 from WeichenXu123/ml-df-writer-save-2. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent e20db13 commit 0fa5787

24 files changed

+114
-45
lines changed

mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
129129
val rMetadata = ("class" -> instance.getClass.getName) ~
130130
("features" -> instance.features.toImmutableArraySeq)
131131
val rMetadataJson: String = compact(render(rMetadata))
132-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
132+
sparkSession.createDataFrame(
133+
Seq(Tuple1(rMetadataJson))
134+
).repartition(1).write.text(rMetadataPath)
133135

134136
instance.pipeline.save(pipelinePath)
135137
}
@@ -142,7 +144,8 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
142144
val rMetadataPath = new Path(path, "rMetadata").toString
143145
val pipelinePath = new Path(path, "pipeline").toString
144146

145-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
147+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
148+
.first().getString(0)
146149
val rMetadata = parse(rMetadataStr)
147150
val features = (rMetadata \ "features").extract[Array[String]]
148151

mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
9494
val rMetadata = ("class" -> instance.getClass.getName) ~
9595
("ratingCol" -> instance.ratingCol)
9696
val rMetadataJson: String = compact(render(rMetadata))
97-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
97+
sparkSession.createDataFrame(
98+
Seq(Tuple1(rMetadataJson))
99+
).repartition(1).write.text(rMetadataPath)
98100

99101
instance.alsModel.save(modelPath)
100102
}
@@ -107,7 +109,8 @@ private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
107109
val rMetadataPath = new Path(path, "rMetadata").toString
108110
val modelPath = new Path(path, "model").toString
109111

110-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
112+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
113+
.first().getString(0)
111114
val rMetadata = parse(rMetadataStr)
112115
val ratingCol = (rMetadata \ "ratingCol").extract[String]
113116
val alsModel = ALSModel.load(modelPath)

mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ private[r] object BisectingKMeansWrapper extends MLReadable[BisectingKMeansWrapp
120120
("size" -> instance.size.toImmutableArraySeq)
121121
val rMetadataJson: String = compact(render(rMetadata))
122122

123-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
123+
sparkSession.createDataFrame(
124+
Seq(Tuple1(rMetadataJson))
125+
).repartition(1).write.text(rMetadataPath)
124126
instance.pipeline.save(pipelinePath)
125127
}
126128
}
@@ -133,7 +135,8 @@ private[r] object BisectingKMeansWrapper extends MLReadable[BisectingKMeansWrapp
133135
val pipelinePath = new Path(path, "pipeline").toString
134136
val pipeline = PipelineModel.load(pipelinePath)
135137

136-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
138+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
139+
.first().getString(0)
137140
val rMetadata = parse(rMetadataStr)
138141
val features = (rMetadata \ "features").extract[Array[String]]
139142
val size = (rMetadata \ "size").extract[Array[Long]]

mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
131131
("features" -> instance.features.toImmutableArraySeq)
132132
val rMetadataJson: String = compact(render(rMetadata))
133133

134-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
134+
sparkSession.createDataFrame(
135+
Seq(Tuple1(rMetadataJson))
136+
).repartition(1).write.text(rMetadataPath)
135137
instance.pipeline.save(pipelinePath)
136138
}
137139
}
@@ -144,7 +146,8 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
144146
val pipelinePath = new Path(path, "pipeline").toString
145147
val pipeline = PipelineModel.load(pipelinePath)
146148

147-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
149+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
150+
.first().getString(0)
148151
val rMetadata = parse(rMetadataStr)
149152
val formula = (rMetadata \ "formula").extract[String]
150153
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressorWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRe
114114
("features" -> instance.features.toImmutableArraySeq)
115115
val rMetadataJson: String = compact(render(rMetadata))
116116

117-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
117+
sparkSession.createDataFrame(
118+
Seq(Tuple1(rMetadataJson))
119+
).repartition(1).write.text(rMetadataPath)
118120
instance.pipeline.save(pipelinePath)
119121
}
120122
}
@@ -127,7 +129,8 @@ private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRe
127129
val pipelinePath = new Path(path, "pipeline").toString
128130
val pipeline = PipelineModel.load(pipelinePath)
129131

130-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
132+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
133+
.first().getString(0)
131134
val rMetadata = parse(rMetadataStr)
132135
val formula = (rMetadata \ "formula").extract[String]
133136
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ private[r] object FMClassifierWrapper
151151
("features" -> instance.features.toImmutableArraySeq) ~
152152
("labels" -> instance.labels.toImmutableArraySeq)
153153
val rMetadataJson: String = compact(render(rMetadata))
154-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
154+
sparkSession.createDataFrame(
155+
Seq(Tuple1(rMetadataJson))
156+
).repartition(1).write.text(rMetadataPath)
155157

156158
instance.pipeline.save(pipelinePath)
157159
}
@@ -164,7 +166,8 @@ private[r] object FMClassifierWrapper
164166
val rMetadataPath = new Path(path, "rMetadata").toString
165167
val pipelinePath = new Path(path, "pipeline").toString
166168

167-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
169+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
170+
.first().getString(0)
168171
val rMetadata = parse(rMetadataStr)
169172
val features = (rMetadata \ "features").extract[Array[String]]
170173
val labels = (rMetadata \ "labels").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/FMRegressorWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ private[r] object FMRegressorWrapper
132132
val rMetadata = ("class" -> instance.getClass.getName) ~
133133
("features" -> instance.features.toImmutableArraySeq)
134134
val rMetadataJson: String = compact(render(rMetadata))
135-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
135+
sparkSession.createDataFrame(
136+
Seq(Tuple1(rMetadataJson))
137+
).repartition(1).write.text(rMetadataPath)
136138

137139
instance.pipeline.save(pipelinePath)
138140
}
@@ -145,7 +147,8 @@ private[r] object FMRegressorWrapper
145147
val rMetadataPath = new Path(path, "rMetadata").toString
146148
val pipelinePath = new Path(path, "pipeline").toString
147149

148-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
150+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
151+
.first().getString(0)
149152
val rMetadata = parse(rMetadataStr)
150153
val features = (rMetadata \ "features").extract[Array[String]]
151154

mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] {
7878
"class" -> instance.getClass.getName
7979
))
8080

81-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
81+
sparkSession.createDataFrame(
82+
Seq(Tuple1(rMetadataJson))
83+
).repartition(1).write.text(rMetadataPath)
8284

8385
instance.fpGrowthModel.save(modelPath)
8486
}

mllib/src/main/scala/org/apache/spark/ml/r/GBTClassifierWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
138138
("features" -> instance.features.toImmutableArraySeq)
139139
val rMetadataJson: String = compact(render(rMetadata))
140140

141-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
141+
sparkSession.createDataFrame(
142+
Seq(Tuple1(rMetadataJson))
143+
).repartition(1).write.text(rMetadataPath)
142144
instance.pipeline.save(pipelinePath)
143145
}
144146
}
@@ -151,7 +153,8 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
151153
val pipelinePath = new Path(path, "pipeline").toString
152154
val pipeline = PipelineModel.load(pipelinePath)
153155

154-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
156+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
157+
.first().getString(0)
155158
val rMetadata = parse(rMetadataStr)
156159
val formula = (rMetadata \ "formula").extract[String]
157160
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressorWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] {
122122
("features" -> instance.features.toImmutableArraySeq)
123123
val rMetadataJson: String = compact(render(rMetadata))
124124

125-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
125+
sparkSession.createDataFrame(
126+
Seq(Tuple1(rMetadataJson))
127+
).repartition(1).write.text(rMetadataPath)
126128
instance.pipeline.save(pipelinePath)
127129
}
128130
}
@@ -135,7 +137,8 @@ private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] {
135137
val pipelinePath = new Path(path, "pipeline").toString
136138
val pipeline = PipelineModel.load(pipelinePath)
137139

138-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
140+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
141+
.first().getString(0)
139142
val rMetadata = parse(rMetadataStr)
140143
val formula = (rMetadata \ "formula").extract[String]
141144
val features = (rMetadata \ "features").extract[Array[String]]

0 commit comments

Comments
 (0)