Distributed Data Processing with PySpark

From the big data models python curriculum · Updated May 31, 2026

Distributed Data Processing with PySpark

TL;DR

You'll learn how PySpark splits big datasets across multiple machines to process them in parallel. You'll understand RDDs and DataFrames as the two main data structures for distributed computing. You'll write real PySpark code that can scale from your laptop to a 100-node cluster.

1. The Mental Model

When your data gets too big for one machine, PySpark breaks it into chunks and sends those chunks to different computers. Each computer processes its chunk independently, then PySpark combines the results back together. That's the whole idea.

2. The Core Material

What Makes PySpark Different

PySpark is Python's interface to Apache Spark, a distributed computing framework. Unlike pandas which runs on one machine, PySpark automatically spreads your data and computations across multiple machines.

Here's the key difference:

import pandas as pd
# This runs on ONE machine
df_pandas = pd.read_csv("huge_file.csv")  # Might crash if file is too big

from pyspark.sql import SparkSession
# This can run across MANY machines
spark = SparkSession.builder.appName("MyApp").getOrCreate()
df_spark = spark.read.csv("huge_file.csv", header=True)  # Scales automatically

Core Data Structures

PySpark gives you two main ways to work with distributed data:

RDDs (Resilient Distributed Datasets) - The lower-level building blocks:

from pyspark import SparkContext
sc = SparkContext()

# Create an RDD from a list
numbers = sc.parallelize([1, 2, 3, 4, 5, 6, 7, 8])

# Operations are lazy - nothing happens yet
squared = numbers.map(lambda x: x ** 2)
filtered = squared.filter(lambda x: x > 10)

# Action triggers computation across cluster
result = filtered.collect()  # [16, 25, 36, 49, 64]

DataFrames - The high-level, SQL-like interface:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, count

spark = SparkSession.builder.appName("Analysis").getOrCreate()

# Read data (automatically distributed)
df = spark.read.option("header", True).csv("sales_data.csv")

# Operations look like pandas but run distributed
result = df.groupBy("region") \
          .agg(avg("sales").alias("avg_sales"), 
               count("*").alias("total_orders")) \
          .orderBy("avg_sales", ascending=False)

# Show results
result.show()

Lazy Evaluation and Actions

PySpark uses lazy evaluation - it builds up a plan of what to do but doesn't actually do it until you ask for results:

# These are TRANSFORMATIONS - no computation happens
df1 = spark.read.csv("data.csv")
df2 = df1.filter(col("age") > 18)
df3 = df2.select("name", "salary")

# These are ACTIONS - trigger actual computation
df3.show()           # Display results
df3.count()          # Count rows
df3.collect()        # Bring all data to driver
df3.write.csv("output/")  # Save to files
graph LR
    A["Raw Data"] --> B["Filter Transform"]
    B --> C["Select Transform"]
    C --> D["Group Transform"]
    D --> E["Action: show()"]
    E --> F["Results"]

    style A fill:#e1f5fe
    style F fill:#c8e6c9
    style E fill:#fff3e0

Key Functions You'll Use Daily

Data Loading:

# CSV files
df = spark.read.option("header", True).csv("path/to/file.csv")

# JSON files
df = spark.read.json("path/to/file.json")

# Parquet (most efficient)
df = spark.read.parquet("path/to/file.parquet")

# From databases
df = spark.read.jdbc(url="jdbc:postgresql://localhost/db", 
                     table="users", 
                     properties={"user": "admin", "password": "secret"})

Data Processing:

from pyspark.sql.functions import col, when, avg, sum, count

# Filtering
young_users = df.filter(col("age") < 30)

# Adding columns
df_with_category = df.withColumn("age_group", 
                                when(col("age") < 25, "young")
                                .when(col("age") < 50, "middle")
                                .otherwise("senior"))

# Aggregations
summary = df.groupBy("department") \
           .agg(avg("salary").alias("avg_salary"),
                count("*").alias("employee_count"))

3. Worked Example

Let's process a sales dataset to find the top-performing regions:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, avg, count, desc

# Initialize Spark
spark = SparkSession.builder \
    .appName("SalesAnalysis") \
    .config("spark.executor.memory", "2g") \
    .getOrCreate()

# Load the data
sales_df = spark.read.option("header", True) \
    .option("inferSchema", True) \
    .csv("sales_data.csv")

# Check the schema
sales_df.printSchema()
# root
#  |-- order_id: integer
#  |-- region: string
#  |-- sales_amount: double
#  |-- order_date: string
#  |-- customer_id: integer

# Clean and process the data
processed_df = sales_df.filter(col("sales_amount") > 0) \
    .filter(col("region").isNotNull())

# Calculate regional performance
regional_stats = processed_df.groupBy("region") \
    .agg(
        sum("sales_amount").alias("total_sales"),
        avg("sales_amount").alias("avg_order_value"),
        count("*").alias("total_orders")
    ) \
    .withColumn("revenue_per_order", col("total_sales") / col("total_orders")) \
    .orderBy(desc("total_sales"))

# Show top 5 regions
print("Top 5 Regions by Total Sales:")
regional_stats.show(5)

# Save results
regional_stats.write.mode("overwrite").csv("output/regional_analysis")

spark.stop()

Expected output:

Top 5 Regions by Total Sales:
+---------+-----------+---------------+------------+-----------------+
|   region|total_sales|avg_order_value|total_orders|revenue_per_order|
+---------+-----------+---------------+------------+-----------------+
|     West|  2,450,000|         245.50|      10,000|           245.50|
|     East|  2,100,000|         210.75|       9,967|           210.75|
|    North|  1,890,000|         189.45|       9,978|           189.45|
|    South|  1,675,000|         167.89|       9,981|           167.89|
|  Central|  1,234,000|         135.67|       9,095|           135.67|
+---------+-----------+---------------+------------+-----------------+

4. Production Pitfalls & Best Practices

4.1 Real-World Best Practices

  1. Cache frequently used DataFrames - If you'll use the same DataFrame multiple times, cache it in memory: df.cache()
  2. Use broadcast joins for small tables - When joining a large table with a small one (<200MB), use broadcast(small_df)
  3. Partition your data intelligently - Use .repartition(col("date")) to group related data together
  4. Choose the right file format - Parquet is almost always better than CSV for performance and storage
  5. Configure executor resources properly - Set spark.executor.memory and spark.executor.cores based on your cluster
  6. Use column pruning - Only select the columns you need: df.select("col1", "col2") instead of using the full DataFrame
  7. Persist intermediate results - Use .persist(StorageLevel.MEMORY_AND_DISK) for DataFrames used in multiple actions

4.2 Common Bugs & Anti-Patterns

BAD - Collecting huge datasets:

# This will crash your driver
huge_df = spark.read.parquet("10TB_dataset.parquet")
all_data = huge_df.collect()  # DON'T DO THIS

GOOD - Process in distributed fashion:

# Keep data distributed
huge_df = spark.read.parquet("10TB_dataset.parquet")
summary = huge_df.groupBy("category").agg(avg("value"))
summary.show()  # Only brings back the small summary

BAD - Using Python UDFs unnecessarily:

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# Slow - serializes to Python
categorize_udf = udf(lambda age: "young" if age < 30 else "old", StringType())
df.withColumn("category", categorize_udf(col("age")))

GOOD - Use built-in functions:

from pyspark.sql.functions import when

# Fast - stays in JVM
df.withColumn("category", 
              when(col("age") < 30, "young").otherwise("old"))

BAD - Creating too many small partitions:

# Creates thousands of tiny files
df.repartition(1000).write.csv("output/")

GOOD - Right-size your partitions:

# Aim for 128MB-1GB per partition
df.coalesce(10).write.csv("output/")

BAD - Not handling null values:

# Will produce wrong results if nulls exist
df.filter(col("age") > 18)  # Nulls are excluded silently

GOOD - Explicitly handle nulls:

# Be explicit about null handling
df.filter(col("age").isNotNull() & (col("age") > 18))

4.3 When to Use What (Decision Matrix)

Scenario Use This Why
Data < 1GB, complex logic pandas Faster for small data, more features
Data > 10GB, simple aggregations PySpark DataFrames Distributed processing, SQL-like syntax
Complex transformations, need fine control PySpark RDDs Full control over partitioning and operations
Real-time streaming data Structured Streaming Built-in support for continuous processing
Machine learning on big data PySpark MLlib Distributed ML algorithms
Need SQL queries Spark SQL Direct SQL interface with DataFrame benefits

5. Now Try It

Build a simple data pipeline that processes a sales dataset:

  1. Create a CSV file called sample_sales.csv with columns: date,product,region,sales_amount,customer_id
  2. Add at least 1000 rows of sample data (you can generate this programmatically)
  3. Write a PySpark script that:
    - Loads the data
    - Finds the top 3 products by total sales
    - Calculates monthly sales trends
    - Identifies regions with declining performance (month-over-month)
    - Saves the results to separate output files

What success looks like: You should have three output directories with your results, and your script should run without errors. The monthly trends should show actual month names, and you should be able to spot patterns in your data. Try running it with different partition sizes to see how it affects performance.


Get the full big data models python curriculum

Clone the complete plan to your dashboard for unlimited AI-generated notes, practice quizzes, and a personalised revision schedule.

Create Free Account