To compute exact median for a group of rows we can use the build-in MEDIAN() function with a window function. However, not every database provides this function. In this case, we can compute the median using row_number() and count() in conjunction with a window function.
E.g. for a given table with two column “key” and “val” (I use spark-shell (spark version 2.1.2) local mode for demonstration):
val data1 = Array( ("a", 1), ("a", 2), ("a", 2), ("a", 5), ("a", 6), ("a", 7), ("a", 8), ("a", 8), ("a", 9) )
val data2 = Array( ("b", 1), ("b", 2), ("b", 2), ("b", 2), ("b", 5), ("b", 7), ("b", 8), ("b", 8))
val union = data1.union(data2)
val df = sc.parallelize(union).toDF("key", "val")
df.cache.createOrReplaceTempView("kvTable")
spark.sql("SET spark.sql.shuffle.partitions=2")
The goal is to compute a median for each group (by key).
+---+---+
|key|val|
+---+---+
| a| 1|
| a| 2|
| a| 2|
| a| 5|
| a| 6|
| a| 7|
| a| 8|
| a| 8|
| a| 9|
| b| 1|
| b| 2|
| b| 2|
| b| 2|
| b| 5|
| b| 7|
| b| 8|
| b| 8|
+---+---+
Since there is no build-in median function in conjunction with an analytical function in spark 2.0 and 2.1., we can implement this function using the following SQL query.
var ds = spark.sql("""
SELECT *
FROM kvTable k NATURAL JOIN (
SELECT key, avg(val) as median
FROM (
SELECT key, val, rN, (CASE WHEN cN % 2 = 0 then (cN DIV 2) ELSE (cN DIV 2) + 1 end) as m1, (cN DIV 2) + 1 as m2
FROM (
SELECT key, val, row_number() OVER (PARTITION BY key ORDER BY val ) as rN, count(val) OVER (PARTITION BY key ) as cN
FROM kvTable
) s
) r
WHERE rN BETWEEN m1 and m2
GROUP BY key
) t
""")
ds.show
Query result:
+---+---+------+
|key|val|median|
+---+---+------+
| a| 1| 6.0|
| a| 2| 6.0|
| a| 2| 6.0|
| a| 5| 6.0|
| a| 6| 6.0|
| a| 7| 6.0|
| a| 8| 6.0|
| a| 8| 6.0|
| a| 9| 6.0|
| b| 1| 3.5|
| b| 2| 3.5|
| b| 2| 3.5|
| b| 2| 3.5|
| b| 5| 3.5|
| b| 7| 3.5|
| b| 8| 3.5|
| b| 8| 3.5|
+---+---+------+
Sub query (see below) computes median per group and finally we join the result with a source table. Spark executes and efficiently optimizes the query, since it reuses data partitioning.
-- sub query
SELECT key, avg(val) as median
FROM ( SELECT key, val, rN, (CASE WHEN cN % 2 = 0 then (cN DIV 2) ELSE (cN DIV 2) + 1 end) as m1, (cN DIV 2) + 1 as m2
FROM (
SELECT key, val, row_number() OVER (PARTITION BY key ORDER BY val ) as rN, count(val) OVER (PARTITION BY key ) as cN
FROM kvTable
) s
) r
WHERE rN BETWEEN m1 and m2
GROUP BY key
Sub query result:
+---+------+
|key|median|
+---+------+
| a| 6.0|
| b| 3.5|
+---+------+
To see the query plan and code produced by spark run these commands:
ds.explain(true)
ds.queryExecution.debug.codegen