Implementation:Spotify Luigi CopyToTable Run
| Knowledge Sources | Spotify Luigi Repository |
|---|---|
| Domains | Pipeline_Orchestration, Database, ETL |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Concrete tool for loading processed pipeline data into database tables provided by Luigi.
Description
Luigis CopyToTable classes implement the full data loading lifecycle through their run(), rows(), copy(), init_copy(), and post_copy() methods. Each database backend provides a specialized implementation tuned to its bulk loading capabilities.
PostgreSQL (postgres.CopyToTable.run()): The rows() method yields tuples by reading lines from the upstream tasks output, splitting on tab characters. The run() method iterates over these rows, applies map_column() to escape special characters (backslash, tab, newline, carriage return, vertical tab, backspace, form feed) and convert null values to \\N, then writes each row as a tab-separated line to a tempfile.TemporaryFile. After writing, it seeks back to position 0 and calls copy(cursor, file), which issues a COPY table (columns) FROM STDIN SQL command using cursor.copy_expert() (psycopg2) or cursor.execute() with a stream parameter (pg8000).
MySQL (mysqldb.CopyToTable.run()): The copy(cursor) method constructs a parameterized INSERT INTO table (columns) VALUES (%s, ...) query and uses cursor.executemany() to insert rows in batches of bulk_size (default 10,000). No temporary file is used.
SQLAlchemy (sqla.CopyToTable.run()): The run() method creates the table via create_table(engine), then iterates over rows() in chunks of chunk_size (default 5,000). Each chunk is converted to a list of dictionaries keyed by column names (prefixed with underscore for SQLAlchemy bindparam compatibility) and inserted via conn.execute(table.insert().values(bound_cols), rows).
All backends call init_copy() before and post_copy() after the data transfer, and all mark completion via output().touch() within the same transaction.
Usage
Use these run() / rows() / copy() methods when:
- You have a Luigi task that produces tabular data and needs to load it into a database table.
- You want to customize row generation by overriding
rows()to yield data from a non-standard source. - You need pre-copy or post-copy processing by overriding
init_copy()orpost_copy(). - You want to tune bulk loading performance by adjusting
bulk_size(MySQL) orchunk_size(SQLAlchemy).
Code Reference
Source Location
- Postgres rows, copy, run:
luigi/contrib/postgres.py, lines 314-430 - Postgres map_column:
luigi/contrib/postgres.py, lines 322-331 - Abstract init_copy, post_copy:
luigi/contrib/rdbms.py, lines 237-265 - MySQL copy, run:
luigi/contrib/mysqldb.py, lines 195-253 - SQLAlchemy run, copy:
luigi/contrib/sqla.py, lines 388-418
Signature
postgres.CopyToTable.rows:
def rows(self):
"""Return/yield tuples or lists corresponding to each row to be inserted."""
with self.input().open('r') as fobj:
for line in fobj:
yield line.strip('\n').split('\t')
postgres.CopyToTable.map_column:
def map_column(self, value):
"""Applied to each column of every row returned by rows()."""
if value in self.null_values:
return r'\\N'
else:
return default_escape(str(value))
postgres.CopyToTable.copy:
def copy(self, cursor, file):
if isinstance(self.columns[0], str):
column_names = self.columns
elif len(self.columns[0]) == 2:
column_names = [c[0] for c in self.columns]
else:
raise Exception(
'columns must consist of column strings or '
'(column string, type string) tuples (was %r ...)' % (self.columns[0],)
)
copy_sql = (
"COPY {table} ({column_list}) FROM STDIN "
"WITH (FORMAT text, NULL '{null_string}', DELIMITER '{delimiter}')"
).format(
table=self.table,
delimiter=self.column_separator,
null_string=r'\\N',
column_list=", ".join(column_names),
)
if hasattr(cursor, 'copy_expert'):
cursor.copy_expert(copy_sql, file)
else:
cursor.execute(copy_sql, stream=file)
postgres.CopyToTable.run (core logic):
def run(self):
if not (self.table and self.columns):
raise Exception("table and columns need to be specified")
connection = self.output().connect()
tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(dir=tmp_dir)
n = 0
for row in self.rows():
n += 1
if n % 100000 == 0:
logger.info("Wrote %d lines", n)
rowstr = self.column_separator.join(self.map_column(val) for val in row)
rowstr += "\n"
tmp_file.write(rowstr.encode('utf-8'))
logger.info("Done writing, importing at %s", datetime.datetime.now())
tmp_file.seek(0)
for attempt in range(2):
try:
cursor = connection.cursor()
self.init_copy(connection)
self.copy(cursor, tmp_file)
self.post_copy(connection)
if self.enable_metadata_columns:
self.post_copy_metacolumns(cursor)
except dbapi.DatabaseError as e:
if db_error_code(e) == ERROR_UNDEFINED_TABLE and attempt == 0:
logger.info("Creating table %s", self.table)
if hasattr(connection, 'reset'):
connection.reset()
else:
_pg8000_connection_reset(connection)
self.create_table(connection)
else:
raise
else:
break
self.output().touch(connection)
connection.commit()
connection.close()
tmp_file.close()
rdbms.CopyToTable.init_copy:
def init_copy(self, connection):
"""Override to perform custom queries prior to copying data."""
if hasattr(self, "clear_table"):
raise Exception("The clear_table attribute has been removed. Override init_copy instead!")
if self.enable_metadata_columns:
self._add_metadata_columns(connection)
rdbms.CopyToTable.post_copy:
def post_copy(self, connection):
"""Override to perform custom queries after copying data."""
pass
Import
import luigi.contrib.postgres
import luigi.contrib.mysqldb
from luigi.contrib import sqla
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
| self.input() | luigi.Target |
Output of the upstream dependency task; opened with .open('r') to read rows
|
| columns | list |
Column definitions used to construct the COPY or INSERT SQL command
|
| table | str |
Target table name |
| column_separator | str |
Delimiter between column values (default: "\t")
|
| null_values | tuple |
Container of values to be treated as SQL NULL (default: (None,))
|
| bulk_size | int |
(MySQL) Number of rows per executemany() batch (default: 10,000)
|
| chunk_size | int |
(SQLAlchemy) Number of rows per insert chunk (default: 5,000) |
Outputs
| Name | Type | Description |
|---|---|---|
| Database rows | Table rows | All rows from rows() are inserted into the target table
|
| Marker table entry | Marker row | A completion record is written to the marker table via output().touch()
|
Usage Examples
Basic PostgreSQL Data Loading
import luigi
import luigi.contrib.postgres
class LoadTopArtists(luigi.contrib.postgres.CopyToTable):
date_interval = luigi.DateIntervalParameter()
use_spark = luigi.BoolParameter()
host = "localhost"
database = "toplists"
user = "luigi"
password = "abc123"
table = "top10"
columns = [
("date_from", "DATE"),
("date_to", "DATE"),
("artist", "TEXT"),
("streams", "INT"),
]
def requires(self):
return Top10Artists(self.date_interval, self.use_spark)
# rows() is inherited: reads tab-separated lines from input
Custom Row Generation
import luigi
import luigi.contrib.postgres
class LoadTransformedData(luigi.contrib.postgres.CopyToTable):
host = "db.example.com"
database = "warehouse"
user = "etl"
password = "secret"
table = "transformed_events"
columns = [
("event_id", "BIGINT"),
("event_type", "TEXT"),
("payload", "JSONB"),
]
def requires(self):
return ExtractEvents()
def rows(self):
"""Override rows() to apply custom transformation logic."""
with self.input().open('r') as fobj:
for line in fobj:
parts = line.strip().split(',')
event_id = int(parts[0])
event_type = parts[1].upper()
payload = '{"raw": "%s"}' % parts[2]
yield (event_id, event_type, payload)
Using init_copy for Rolling Window
import luigi
import luigi.contrib.postgres
class LoadDailyMetrics(luigi.contrib.postgres.CopyToTable):
date = luigi.DateParameter()
host = "localhost"
database = "metrics"
user = "luigi"
password = "pass"
table = "daily_metrics"
columns = [
("metric_date", "DATE"),
("metric_name", "TEXT"),
("value", "NUMERIC"),
]
def init_copy(self, connection):
"""Remove data for this date before re-loading (rolling window)."""
cursor = connection.cursor()
cursor.execute(
"DELETE FROM %s WHERE metric_date = %%s" % self.table,
(str(self.date),),
)
def requires(self):
return ComputeDailyMetrics(date=self.date)