Implementation:DataTalksClub Data engineering zoomcamp Spark UnionAll
| Page Metadata | |
|---|---|
| Knowledge Sources | repo: DataTalksClub/data-engineering-zoomcamp, Spark docs: PySpark API Reference |
| Domains | Data_Engineering, Batch_Processing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for selecting common columns, adding a service type discriminator, vertically concatenating green and yellow taxi DataFrames, and registering the result as a temporary SQL table in PySpark.
Description
This implementation combines three PySpark operations into a unified dataset preparation step:
- Column selection with
df.select(columns)-- Projects each DataFrame down to the 18 common columns shared between green and yellow taxi data. - Discriminator injection with
df.withColumn('service_type', F.lit(value))-- Adds a literal string column ('green'or'yellow') to each DataFrame to track data origin after merging. - Vertical concatenation with
df.unionAll(other)-- Appends all rows from both DataFrames into a single combined DataFrame. - Table registration with
df.registerTempTable('trips_data')-- Makes the unified DataFrame queryable via Spark SQL under the nametrips_data.
The 18 common columns are: VendorID, pickup_datetime, dropoff_datetime, store_and_fwd_flag, RatecodeID, PULocationID, DOLocationID, passenger_count, trip_distance, fare_amount, extra, mta_tax, tip_amount, tolls_amount, improvement_surcharge, total_amount, payment_type, and congestion_surcharge.
This is a Wrapper Doc implementation combining PySpark's DataFrame selection, literal column creation, union, and temp table registration APIs.
Usage
Use this implementation when:
- Combining green and yellow taxi data into a single queryable dataset
- You need a SQL-accessible table for downstream aggregation queries
- Both DataFrames have already been schema-normalized with matching column names
Code Reference
Source Location: 06-batch/code/06_spark_sql.py, lines 42-76
Signature:
df.select(columns).withColumn('service_type', F.lit('green')) -> DataFrame
df_green_sel.unionAll(df_yellow_sel) -> DataFrame
df_trips_data.registerTempTable('trips_data') -> None
Import:
from pyspark.sql import functions as F
I/O Contract
Inputs:
| Parameter | Type | Required | Description |
|---|---|---|---|
| df_green | DataFrame | Yes | Schema-normalized green taxi DataFrame |
| df_yellow | DataFrame | Yes | Schema-normalized yellow taxi DataFrame |
| common_colums | list[str] | Yes | List of 18 column names shared between both DataFrames |
Outputs:
| Output | Type | Description |
|---|---|---|
| df_green_sel | DataFrame | Green taxi data projected to 18 common columns plus service_type='green' (19 columns total)
|
| df_yellow_sel | DataFrame | Yellow taxi data projected to 18 common columns plus service_type='yellow' (19 columns total)
|
| df_trips_data | DataFrame | Unified DataFrame containing all rows from both sources (19 columns) |
| trips_data (temp table) | SQL Table | Registered temporary table name for Spark SQL queries |
Usage Examples
Selecting common columns and adding discriminator:
common_colums = [
'VendorID', 'pickup_datetime', 'dropoff_datetime',
'store_and_fwd_flag', 'RatecodeID', 'PULocationID',
'DOLocationID', 'passenger_count', 'trip_distance',
'fare_amount', 'extra', 'mta_tax', 'tip_amount',
'tolls_amount', 'improvement_surcharge', 'total_amount',
'payment_type', 'congestion_surcharge'
]
df_green_sel = df_green \
.select(common_colums) \
.withColumn('service_type', F.lit('green'))
df_yellow_sel = df_yellow \
.select(common_colums) \
.withColumn('service_type', F.lit('yellow'))
Combining and registering as a SQL table:
df_trips_data = df_green_sel.unionAll(df_yellow_sel)
df_trips_data.registerTempTable('trips_data')
# Verify the combined dataset
df_trips_data.groupBy('service_type').count().show()
# +------------+--------+
# |service_type| count|
# +------------+--------+
# | green| 2304517|
# | yellow| 3972891|
# +------------+--------+
Querying the registered table:
# Simple query against the temp table
spark.sql("SELECT service_type, COUNT(*) as cnt FROM trips_data GROUP BY service_type").show()