Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:DataTalksClub Data engineering zoomcamp Spark SQL Aggregation

From Leeroopedia


Page Metadata
Knowledge Sources repo: DataTalksClub/data-engineering-zoomcamp, Spark docs: PySpark API Reference
Domains Data_Engineering, Batch_Processing
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for executing a multi-dimensional SQL aggregation query in PySpark to compute monthly revenue statistics grouped by pickup location, month, and taxi service type.

Description

This implementation uses spark.sql() to execute a SQL query against the trips_data temporary table. The query performs a three-dimensional aggregation that computes 10 aggregate measures (8 SUM columns and 2 AVG columns) grouped by 3 dimensions: pickup location (as revenue_zone), truncated month (as revenue_month), and service type.

The SQL query uses:

  • date_trunc('month', pickup_datetime) to bucket timestamps into calendar months
  • PULocationID AS revenue_zone to alias the geographic dimension
  • GROUP BY 1, 2, 3 using ordinal position references for conciseness
  • Eight SUM() aggregations for revenue-related monetary amounts
  • Two AVG() aggregations for passenger count and trip distance

The result is a DataFrame with 13 total output columns (3 dimensions + 10 measures).

This is an API Doc implementation documenting the use of PySpark's spark.sql() API for analytical SQL queries.

Usage

Use this implementation when:

  • Computing monthly revenue reports from unified taxi trip data
  • You need multi-dimensional aggregation across zone, time, and service type
  • SQL syntax is preferred over the DataFrame API for expressing complex aggregations
  • The source data has already been unified and registered as a temporary table

Code Reference

Source Location: 06-batch/code/06_spark_sql.py, lines 79-103

Signature:

spark.sql("SELECT ... GROUP BY ...") -> DataFrame

Import:

from pyspark.sql import SparkSession

Full SQL Query:

df_result = spark.sql("""
SELECT
    -- Revenue grouping
    PULocationID AS revenue_zone,
    date_trunc('month', pickup_datetime) AS revenue_month,
    service_type,

    -- Revenue calculation
    SUM(fare_amount) AS revenue_monthly_fare,
    SUM(extra) AS revenue_monthly_extra,
    SUM(mta_tax) AS revenue_monthly_mta_tax,
    SUM(tip_amount) AS revenue_monthly_tip_amount,
    SUM(tolls_amount) AS revenue_monthly_tolls_amount,
    SUM(improvement_surcharge) AS revenue_monthly_improvement_surcharge,
    SUM(total_amount) AS revenue_monthly_total_amount,
    SUM(congestion_surcharge) AS revenue_monthly_congestion_surcharge,

    -- Additional calculations
    AVG(passenger_count) AS avg_montly_passenger_count,
    AVG(trip_distance) AS avg_montly_trip_distance
FROM
    trips_data
GROUP BY
    1, 2, 3
""")

I/O Contract

Inputs:

Input Type Description
trips_data (temp table) SQL Table Unified taxi trip data registered via registerTempTable, containing 19 columns (18 common columns + service_type)

GROUP BY Dimensions:

Position Column Output Alias Description
1 PULocationID revenue_zone Pickup location identifier (geographic zone)
2 date_trunc('month', pickup_datetime) revenue_month Pickup timestamp truncated to the first day of the month
3 service_type service_type Taxi service type discriminator ('green' or 'yellow')

Aggregate Measures:

Function Source Column Output Alias
SUM fare_amount revenue_monthly_fare
SUM extra revenue_monthly_extra
SUM mta_tax revenue_monthly_mta_tax
SUM tip_amount revenue_monthly_tip_amount
SUM tolls_amount revenue_monthly_tolls_amount
SUM improvement_surcharge revenue_monthly_improvement_surcharge
SUM total_amount revenue_monthly_total_amount
SUM congestion_surcharge revenue_monthly_congestion_surcharge
AVG passenger_count avg_montly_passenger_count
AVG trip_distance avg_montly_trip_distance

Outputs:

Output Type Description
df_result DataFrame Aggregated revenue DataFrame with 13 columns (3 dimensions + 10 measures), one row per unique (zone, month, service_type) combination

Usage Examples

Executing the revenue aggregation query:

df_result = spark.sql("""
SELECT
    PULocationID AS revenue_zone,
    date_trunc('month', pickup_datetime) AS revenue_month,
    service_type,
    SUM(fare_amount) AS revenue_monthly_fare,
    SUM(extra) AS revenue_monthly_extra,
    SUM(mta_tax) AS revenue_monthly_mta_tax,
    SUM(tip_amount) AS revenue_monthly_tip_amount,
    SUM(tolls_amount) AS revenue_monthly_tolls_amount,
    SUM(improvement_surcharge) AS revenue_monthly_improvement_surcharge,
    SUM(total_amount) AS revenue_monthly_total_amount,
    SUM(congestion_surcharge) AS revenue_monthly_congestion_surcharge,
    AVG(passenger_count) AS avg_montly_passenger_count,
    AVG(trip_distance) AS avg_montly_trip_distance
FROM
    trips_data
GROUP BY
    1, 2, 3
""")

Inspecting the results:

df_result.printSchema()
# root
#  |-- revenue_zone: integer (nullable = true)
#  |-- revenue_month: timestamp (nullable = true)
#  |-- service_type: string (nullable = false)
#  |-- revenue_monthly_fare: double (nullable = true)
#  |-- ...

df_result.show(5, truncate=False)
print(f"Total aggregated rows: {df_result.count()}")

Filtering for a specific zone:

df_zone_132 = spark.sql("""
SELECT *
FROM trips_data
WHERE PULocationID = 132
""")
df_zone_132.show(5)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment