block by daniarleagk 6df2f695f2397fa38a9cf70f9c829d0e

Group median spark sql

Group Median in Spark SQL

To compute exact median for a group of rows we can use the build-in MEDIAN() function with a window function. However, not every database provides this function. In this case, we can compute the median using row_number() and count() in conjunction with a window function.

E.g. for a given table with two column “key” and “val” (I use spark-shell (spark version 2.1.2) local mode for demonstration):

val data1 = Array( ("a", 1), ("a", 2), ("a", 2), ("a", 5), ("a", 6), ("a", 7), ("a", 8), ("a", 8), ("a", 9) )
val data2 = Array( ("b", 1), ("b", 2), ("b", 2), ("b", 2), ("b", 5), ("b", 7), ("b", 8), ("b", 8))
val union = data1.union(data2)
val df = sc.parallelize(union).toDF("key", "val")
df.cache.createOrReplaceTempView("kvTable")
spark.sql("SET spark.sql.shuffle.partitions=2")

The goal is to compute a median for each group (by key).

+---+---+
|key|val|
+---+---+
|  a|  1|
|  a|  2|
|  a|  2|
|  a|  5|
|  a|  6|
|  a|  7|
|  a|  8|
|  a|  8|
|  a|  9|
|  b|  1|
|  b|  2|
|  b|  2|
|  b|  2|
|  b|  5|
|  b|  7|
|  b|  8|
|  b|  8|
+---+---+

Since there is no build-in median function in conjunction with an analytical function in spark 2.0 and 2.1., we can implement this function using the following SQL query.

var ds = spark.sql("""
SELECT *
FROM kvTable k NATURAL JOIN ( 
  SELECT key, avg(val) as median
  FROM ( 
    SELECT key, val, rN, (CASE WHEN cN % 2 = 0 then (cN DIV 2) ELSE (cN DIV 2) + 1 end) as m1, (cN DIV 2) + 1 as m2 
    FROM ( 
      SELECT key, val, row_number() OVER (PARTITION BY key ORDER BY val ) as rN, count(val) OVER (PARTITION BY key ) as cN
      FROM kvTable
         ) s
    ) r
  WHERE rN BETWEEN m1 and m2
  GROUP BY key 
) t
""")
ds.show

Query result:

+---+---+------+                                                                
|key|val|median|
+---+---+------+
|  a|  1|   6.0|
|  a|  2|   6.0|
|  a|  2|   6.0|
|  a|  5|   6.0|
|  a|  6|   6.0|
|  a|  7|   6.0|
|  a|  8|   6.0|
|  a|  8|   6.0|
|  a|  9|   6.0|
|  b|  1|   3.5|
|  b|  2|   3.5|
|  b|  2|   3.5|
|  b|  2|   3.5|
|  b|  5|   3.5|
|  b|  7|   3.5|
|  b|  8|   3.5|
|  b|  8|   3.5|
+---+---+------+

Sub query (see below) computes median per group and finally we join the result with a source table. Spark executes and efficiently optimizes the query, since it reuses data partitioning.

-- sub query
SELECT key, avg(val) as median
FROM ( SELECT key, val, rN, (CASE WHEN cN % 2 = 0 then (cN DIV 2) ELSE (cN DIV 2) + 1 end) as m1, (cN DIV 2) + 1 as m2 
        FROM ( 
            SELECT key, val, row_number() OVER (PARTITION BY key ORDER BY val ) as rN, count(val) OVER (PARTITION BY key ) as cN
            FROM kvTable
         ) s
    ) r
WHERE rN BETWEEN m1 and m2
GROUP BY key 

Sub query result:

+---+------+
|key|median|
+---+------+
|  a|   6.0|
|  b|   3.5|
+---+------+

To see the query plan and code produced by spark run these commands:

ds.explain(true)
ds.queryExecution.debug.codegen