Skip to content

Commit

Permalink
fix: don't throw on invalid columns in DropColumns (#1695)
Browse files Browse the repository at this point in the history
* remove unused imports

* batch prompts

* don't throw for invalid columns in DropColumn

* remove now unused verify method from DropColumns
  • Loading branch information
niehaus59 authored Oct 26, 2022
1 parent c531bbb commit f4af33f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,14 @@ class DropColumns(val uid: String) extends Transformer with Wrappable with Defau
*/
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
verifySchema(dataset.schema)
dataset.toDF().drop(getCols: _*)
})
}

def transformSchema(schema: StructType): StructType = {
verifySchema(schema)
val droppedCols = getCols.toSet
StructType(schema.fields.filter(f => !droppedCols(f.name)))
}

def copy(extra: ParamMap): DropColumns = defaultCopy(extra)

private def verifySchema(schema: StructType): Unit = {
val providedCols = schema.fields.map(_.name).toSet
val invalidCols = getCols.filter(!providedCols(_))

if (invalidCols.length > 0) {
throw new NoSuchElementException(
s"DataFrame does not contain specified columns: ${invalidCols.reduce(_ + "," + _)}")
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ class DropColumnsSuite extends TestBase with TransformerFuzzing[DropColumns] {

test("Invalid column specified") {
try {
new DropColumns().setCol("four").transform(makeBasicDF())
fail()
val df = makeBasicDF()
new DropColumns().setCol("four").transform(df)
val result = new DropColumns().setCol("four").transform(df)
assert(df.schema == result.schema)
} catch {
case _: NoSuchElementException =>
case _: Exception =>
fail("DropColumns should not throw when for invalid column input")
}
}

Expand Down

0 comments on commit f4af33f

Please sign in to comment.