Implementation:DataTalksClub Data engineering zoomcamp Spark SQL Aggregation
| Page Metadata | |
|---|---|
| Knowledge Sources | repo: DataTalksClub/data-engineering-zoomcamp, Spark docs: PySpark API Reference |
| Domains | Data_Engineering, Batch_Processing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for executing a multi-dimensional SQL aggregation query in PySpark to compute monthly revenue statistics grouped by pickup location, month, and taxi service type.
Description
This implementation uses spark.sql() to execute a SQL query against the trips_data temporary table. The query performs a three-dimensional aggregation that computes 10 aggregate measures (8 SUM columns and 2 AVG columns) grouped by 3 dimensions: pickup location (as revenue_zone), truncated month (as revenue_month), and service type.
The SQL query uses:
date_trunc('month', pickup_datetime)to bucket timestamps into calendar monthsPULocationID AS revenue_zoneto alias the geographic dimensionGROUP BY 1, 2, 3using ordinal position references for conciseness- Eight
SUM()aggregations for revenue-related monetary amounts - Two
AVG()aggregations for passenger count and trip distance
The result is a DataFrame with 13 total output columns (3 dimensions + 10 measures).
This is an API Doc implementation documenting the use of PySpark's spark.sql() API for analytical SQL queries.
Usage
Use this implementation when:
- Computing monthly revenue reports from unified taxi trip data
- You need multi-dimensional aggregation across zone, time, and service type
- SQL syntax is preferred over the DataFrame API for expressing complex aggregations
- The source data has already been unified and registered as a temporary table
Code Reference
Source Location: 06-batch/code/06_spark_sql.py, lines 79-103
Signature:
spark.sql("SELECT ... GROUP BY ...") -> DataFrame
Import:
from pyspark.sql import SparkSession
Full SQL Query:
df_result = spark.sql("""
SELECT
-- Revenue grouping
PULocationID AS revenue_zone,
date_trunc('month', pickup_datetime) AS revenue_month,
service_type,
-- Revenue calculation
SUM(fare_amount) AS revenue_monthly_fare,
SUM(extra) AS revenue_monthly_extra,
SUM(mta_tax) AS revenue_monthly_mta_tax,
SUM(tip_amount) AS revenue_monthly_tip_amount,
SUM(tolls_amount) AS revenue_monthly_tolls_amount,
SUM(improvement_surcharge) AS revenue_monthly_improvement_surcharge,
SUM(total_amount) AS revenue_monthly_total_amount,
SUM(congestion_surcharge) AS revenue_monthly_congestion_surcharge,
-- Additional calculations
AVG(passenger_count) AS avg_montly_passenger_count,
AVG(trip_distance) AS avg_montly_trip_distance
FROM
trips_data
GROUP BY
1, 2, 3
""")
I/O Contract
Inputs:
| Input | Type | Description |
|---|---|---|
| trips_data (temp table) | SQL Table | Unified taxi trip data registered via registerTempTable, containing 19 columns (18 common columns + service_type)
|
GROUP BY Dimensions:
| Position | Column | Output Alias | Description |
|---|---|---|---|
| 1 | PULocationID | revenue_zone | Pickup location identifier (geographic zone) |
| 2 | date_trunc('month', pickup_datetime) | revenue_month | Pickup timestamp truncated to the first day of the month |
| 3 | service_type | service_type | Taxi service type discriminator ('green' or 'yellow')
|
Aggregate Measures:
| Function | Source Column | Output Alias |
|---|---|---|
| SUM | fare_amount | revenue_monthly_fare |
| SUM | extra | revenue_monthly_extra |
| SUM | mta_tax | revenue_monthly_mta_tax |
| SUM | tip_amount | revenue_monthly_tip_amount |
| SUM | tolls_amount | revenue_monthly_tolls_amount |
| SUM | improvement_surcharge | revenue_monthly_improvement_surcharge |
| SUM | total_amount | revenue_monthly_total_amount |
| SUM | congestion_surcharge | revenue_monthly_congestion_surcharge |
| AVG | passenger_count | avg_montly_passenger_count |
| AVG | trip_distance | avg_montly_trip_distance |
Outputs:
| Output | Type | Description |
|---|---|---|
| df_result | DataFrame | Aggregated revenue DataFrame with 13 columns (3 dimensions + 10 measures), one row per unique (zone, month, service_type) combination |
Usage Examples
Executing the revenue aggregation query:
df_result = spark.sql("""
SELECT
PULocationID AS revenue_zone,
date_trunc('month', pickup_datetime) AS revenue_month,
service_type,
SUM(fare_amount) AS revenue_monthly_fare,
SUM(extra) AS revenue_monthly_extra,
SUM(mta_tax) AS revenue_monthly_mta_tax,
SUM(tip_amount) AS revenue_monthly_tip_amount,
SUM(tolls_amount) AS revenue_monthly_tolls_amount,
SUM(improvement_surcharge) AS revenue_monthly_improvement_surcharge,
SUM(total_amount) AS revenue_monthly_total_amount,
SUM(congestion_surcharge) AS revenue_monthly_congestion_surcharge,
AVG(passenger_count) AS avg_montly_passenger_count,
AVG(trip_distance) AS avg_montly_trip_distance
FROM
trips_data
GROUP BY
1, 2, 3
""")
Inspecting the results:
df_result.printSchema()
# root
# |-- revenue_zone: integer (nullable = true)
# |-- revenue_month: timestamp (nullable = true)
# |-- service_type: string (nullable = false)
# |-- revenue_monthly_fare: double (nullable = true)
# |-- ...
df_result.show(5, truncate=False)
print(f"Total aggregated rows: {df_result.count()}")
Filtering for a specific zone:
df_zone_132 = spark.sql("""
SELECT *
FROM trips_data
WHERE PULocationID = 132
""")
df_zone_132.show(5)