PySpark:
- PySpark is one of the Python libraries mainly an API (Application Programming Interface) for Apache Spark.
- It is used to get help from Spark’s big data processing framework.
Core Components:
SparkSession
– The entry point for using Spark in Python.DataFrame API
– For working with structured data, similar to Pandas.RDD API
– For working with low-level resilient distributed datasets.MLlib
– Built-in library for machine learning.Spark SQL
– Query data using SQL syntax.
Apache Spark:
- Apache Spark is an open-source distributed computing framework.
- It is designed for big data processing and implicit parallel computation.
- It is much faster than Hadoop’s MapReduce because it performs in-memory computations.
- Shorthand for Apache Spark is Spark
- It gives a powerful engine for processing large-scale data and has built-in libraries for SQL, machine learning, streaming, and graph processing.
PySpark interacts with the Spark API via Java Virtual Machines to perform parallel analysis
Built-in Libraries:
- Spark SQL – Query structured data using SQL.
- MLlib – Machine Learning library.
- Spark Streaming – Real-time data processing.
- GraphX – Graph computation.
Resilient Distributed Datasets (RDDs):
- These are fundamental data structures in Apache Spark, designed for efficient, fault-tolerant distributed computing.
- A large dataset is divided into smaller logical parts and distributed across multiple computers in a cluster allowing parallel evaluation.
Key Features of RDDs:
Resilient (Fault Tolerant)
If a machine fails, Spark can automatically recover lost data using the RDD’s lineage (a history of transformations).
Distributed
- Data is split into smaller chunks and stored across multiple nodes in a cluster, allowing parallel processing.
Immutable
- Once created, an RDD cannot be modified. Instead, transformations create a new RDD.
Lazy Evaluation
- Transformations (like
map
orfilter
) are not executed immediately. They are only computed when an action (likecollect
orcount
) is triggered.
- Transformations (like
In-Memory Processing
- RDDs store data in memory (RAM), making computations much faster than traditional disk-based processing.
Operations on RDDs:
RDDs support two types of operations:
Transformations (Lazy Operations)
- Examples:
map()
,filter()
,flatMap()
,reduceByKey()
- These return a new RDD without modifying the original.
- Examples:
Actions (Trigger Execution)
- Examples:
collect()
,count()
,show()
,saveAsTextFile()
- These return a value or save the result.
- Examples:
Transformation in PySpark:
A transformation is an operation that applies a function to a DataFrame or RDD and produces a new DataFrame or RDD without modifying the original data. Transformations are lazy meaning they are not executed immediately. Instead, Spark builds a logical execution plan and waits until an action (like .show(), .count()) is triggered.
Transformations are divided into two types:
Narrow Transformations
Wide Transformations
01. Narrow Transformations:
- In narrow transformations, each input partition contributes to only one output partition.
- No data shuffling is required across partitions.
- Examples: map(), filter(), select(), union(), sample().
02. Wide Transformations:
- In wide transformations, each input partition may contribute to multiple output partitions.
- Data shuffling is required across partitions, which can be expensive in performance.
- Examples: groupBy(), agg(), join(), distinct(), orderBy().
Action in Pyspark:
- Actions are methods that trigger computation
- Example: collect(), count(), show()
Staging Queries:
Staging Queries refer to the intermediate steps or transformations in the logical execution plan before the final result is computed.
When we apply transformations like: filter(), map(), groupBy(), etc. to a DataFrame or RDD, PySpark builds a logical plan of these operations. These intermediate steps are called staging queries because they are staged for execution and executed when an action(e.g., collect(), count(), show()) is triggered.
This is especially useful when dealing with large datasets, expensive transformations, and shuffle operations.
How Staging Works in PySpark?
- Spark automatically creates stages based on transformations applied to DataFrames or RDDs.
- Each stage consists of narrow transformations (e.g., map(), filter(), select()).
- When a wide transformation (like groupBy() or join()) occurs, a new stage is created due to data shuffling.
Example of Staging Queries in PySpark SQL:
from pyspark.sql import SparkSession
spark = SparkSession.builder\
.appName("StagingQueriesExample")\
.getOrCreate()
# Load Data
df = spark.read.csv("sales_data.csv", header=True, inferSchema=True)
# Stage 1: Filtering and Selecting required columns
df_filtered = df.filter(df["amount"] > 1000).select("customer_id", "amount")
# Stage 2: Aggregation (this causes a shuffle and creates a new stage)
df_grouped = df_filtered.groupBy("customer_id").sum("amount")
# Stage 3: Further filtering after aggregation
df_final = df_grouped.filter(df_grouped["sum(amount)"] > 5000)
df_final.show()
How Staging Queries Help in Optimization?
- First Stage: Reads the file, filters data, and selects columns (narrow transformations).
- Second Stage: Groups the data (wide transformation causing shuffle).
- Third Stage: Filters aggregated results (narrow transformation).
- By persisting intermediate DataFrames (df_filtered, df_grouped), Spark avoids recomputation in case of multiple actions.
Optimized Version Using Caching:
df_filtered = df.filter(df["amount"] > 1000).select("customer_id", "amount").cache()
df_grouped = df_filtered.groupBy("customer_id").sum("amount").cache()
df_final = df_grouped.filter(df_grouped["sum(amount)"] > 5000)
df_final.show()
By caching df_filtered and df_grouped, we skip redundant computations and improve performance.
Initialize Spark in PySpark:
There are two entry points for interacting as below:
- SparkContext
- SparkSession
Here, we can use SQL queries and other functions to analyze and transform data
SparkContext:
- SparkContext is the older entry point for working with Spark and was used in Spark’s RDD-based API.
- It is the low-level API used to initialize Spark, access Spark clusters, and perform parallel processing of data through RDDs (Resilient Distributed Datasets).
- SparkContext was commonly used before the introduction of SparkSession in Spark 2.0.
Python Implementation of SparkContext:
from pyspark import SparkContext
# Initialize SparkContext
sc = SparkContext("url", "ApplicationName")
# Example: Create an RDD
rdd = sc.parallelize([1, 2, 3, 4, 5])
# Perform an operation on the RDD
result = rdd.map(lambda x: x * 2).collect() # Result will give square of 1,2,3,4,5
# Show the result
print(result)
# Stop SparkContext
sc.stop()
OutPut: [2, 4, 6, 8, 10]
SparkSession:
- SparkSession is the new entry point that was introduced in Spark 2.0 to simplify the Spark API.
- It replaces SparkContext and provides a higher-level interface for working with Spark SQL, DataFrames, Datasets, and MLlib (Machine Learning).
- SparkSession wraps SparkContext and provides a unified API for working with structured data (DataFrames), which is easier and more intuitive than working directly with RDDs.
Python Implementation of SparkSession:
from pyspark.sql import SparkSession
# Initialize SparkSession
spark = SparkSession.builder \
.master("local")\
.appName("Name") \
.config('spark.ui.port', '4050')\
.getOrCreate()
# Example: Create a DataFrame
data = [("Rana", 25), ("Musfiq", 24), ("Sayem", 23)]
df = spark.createDataFrame(data, ["Name", "Age"])
# Perform an operation on the DataFrame
df.show()
# Stop SparkSession
spark.stop()
Output: +------+---+ | Name|Age| +------+---+ | Rana| 25| |Musfiq| 24| | Sayem| 23| +------+---+
Summary of Master URLs:
Master URL | Description | Example |
---|---|---|
local | Runs Spark on a single machine with 1 core | SparkContext("local", "App") |
local[N] | Runs Spark on a single machine with N cores | SparkContext("local[4]", "App") |
local[*] | Runs Spark on a single machine with all cores | SparkContext("local[*]", "App") |
spark://<master-ip>:<port> | Runs Spark on a standalone cluster | SparkContext("spark://192.168.1.1:7077", "App") |
yarn | Runs Spark on a YARN-managed cluster | SparkContext("yarn", "App") |
mesos://<master-ip>:<port> | Runs Spark on a Mesos cluster | SparkContext("mesos://192.168.1.1:5050", "App") |
k8s://<master-ip>:<port> | Runs Spark on a Kubernetes cluster | SparkContext("k8s://192.168.1.1:6443", "App") |
Differences Between SparkContext and SparkSession:
Feature | SparkContext | SparkSession |
---|---|---|
Introduced | Spark 1.x | Spark 2.x and later |
Main Purpose | Low-level API to work with RDDs | High-level API to work with DataFrames, SQL, MLlib |
Access to RDDs | Yes | Yes (via spark.sparkContext ) |
Access to DataFrames | No (RDD-based API) | Yes (DataFrame-based API) |
Access to SQL | No | Yes (via spark.sql() ) |
Unified API | No | Yes |
Loading data set(CSV file):
from pyspark.sql import SparkSession
Spark = SparkSession.builder\
.master('local')\
.appName('Housing')\
.getOrCreate()
# Load in the .csv file to a DataFrame
filepath = "/content/drive/MyDrive/Applied Data Science 2/Week 02/housing.csv"
#way 01 for CSV file
data01 = Spark.read.csv(filepath)
data01.show(2)
data01.printSchema()
#way 02 for CSV file
data02 = Spark.read.option('header', True).option('inferSchema', True).csv(filepath)
data02.show(2)
data02.printSchema()
#way 03 for CSV file
data03 = Spark.read.csv(filepath, header=True, inferSchema=True)
data03.show(2)
#Schema will be same as data02, so skipping
Output: +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ | _c0| _c1| _c2| _c3| _c4| _c5| _c6| _c7| _c8| _c9| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ |longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity| | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0| 8.3252| 452600.0| NEAR BAY| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ only showing top 2 rows root |-- _c0: string (nullable = true) |-- _c1: string (nullable = true) |-- _c2: string (nullable = true) |-- _c3: string (nullable = true) |-- _c4: string (nullable = true) |-- _c5: string (nullable = true) |-- _c6: string (nullable = true) |-- _c7: string (nullable = true) |-- _c8: string (nullable = true) |-- _c9: string (nullable = true) +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ |longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0| 8.3252| 452600.0| NEAR BAY| | -122.22| 37.86| 21.0| 7099.0| 1106.0| 2401.0| 1138.0| 8.3014| 358500.0| NEAR BAY| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ only showing top 2 rows root |-- longitude: double (nullable = true) |-- latitude: double (nullable = true) |-- housing_median_age: double (nullable = true) |-- total_rooms: double (nullable = true) |-- total_bedrooms: double (nullable = true) |-- population: double (nullable = true) |-- households: double (nullable = true) |-- median_income: double (nullable = true) |-- median_house_value: double (nullable = true) |-- ocean_proximity: string (nullable = true) +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ |longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0| 8.3252| 452600.0| NEAR BAY| | -122.22| 37.86| 21.0| 7099.0| 1106.0| 2401.0| 1138.0| 8.3014| 358500.0| NEAR BAY| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ only showing top 2 rows
Predefining Schema in PySpark:
- If we do not specify inferSchema=True, Spark treats all columns as strings (StringType) by default when reading a CSV file.
- When working with large datasets, inferring the schema using .option(“inferSchema”, “true”) can be computationally expensive because PySpark has to scan the data.
- To optimize performance, we can manually define the schema using StructType and StructField before loading the DataFrame.
### Load the DataFrame using a schema define with StructType and StructField
from pyspark.sql.types import DoubleType, StringType, StructType, StructField
# Complete the schema
userDefinedSchema = StructType([
StructField("longitude", DoubleType(), True),
StructField("latitude", DoubleType(), True),
StructField("housing_median_age", DoubleType(), True),
StructField("total_rooms", DoubleType(), True),
StructField("total_bedrooms", DoubleType(), True),
StructField("population", DoubleType(), True),
StructField("households", DoubleType(), True),
StructField("median_income", DoubleType(), True),
StructField("median_house_value", DoubleType(), True),
StructField("ocean_proximity", StringType(), True),
])
#way 04 for CSV file
data04 = Spark.read.option('header', True).schema(userDefinedSchema).csv(filepath)
data04.show(2)
data04.printSchema()
Output: +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ |longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0| 8.3252| 452600.0| NEAR BAY| | -122.22| 37.86| 21.0| 7099.0| 1106.0| 2401.0| 1138.0| 8.3014| 358500.0| NEAR BAY| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ only showing top 2 rows root |-- longitude: double (nullable = true) |-- latitude: double (nullable = true) |-- housing_median_age: double (nullable = true) |-- total_rooms: double (nullable = true) |-- total_bedrooms: double (nullable = true) |-- population: double (nullable = true) |-- households: double (nullable = true) |-- median_income: double (nullable = true) |-- median_house_value: double (nullable = true) |-- ocean_proximity: string (nullable = true)
Saving File:
# Save the DataFrame as a .csv, .parquet and .JSON
data04.write.option('header', 'True').option('delimeter', ',').csv('data04.csv')
data04.write.option('compression', 'snappy').parquet('data04.parquet')
data04.write.json('data04.JSON')
# Alternate way to save DataFrame as a .csv, .parquet and .JSON
data04.write.format('csv').option('header', 'True').option('delimeter', ',').mode('overwrite').save('data04.csv')
data04.write.format('parquet').option('compression', 'snappy').mode('overwrite').save('data04.parquet')
data04.write.format('json').mode('overwrite').save('data04.JSON')
Selecting specific column:
data02.select('median_house_value', 'longitude', 'housing_median_age')\
.show(3)
data02.select('median_house_value', 'longitude', 'housing_median_age')\
.filter(data02['housing_median_age']<21)\
.show(3)
data02.select('median_house_value', 'longitude', 'housing_median_age')\
.orderBy('housing_median_age')\
.show(3)
OutPut: +------------------+---------+------------------+ |median_house_value|longitude|housing_median_age| +------------------+---------+------------------+ | 452600.0| -122.23| 41.0| | 358500.0| -122.22| 21.0| | 352100.0| -122.24| 52.0| +------------------+---------+------------------+ only showing top 3 rows +------------------+---------+------------------+ |median_house_value|longitude|housing_median_age| +------------------+---------+------------------+ | 60000.0| -122.29| 2.0| | 137500.0| -122.29| 20.0| | 177500.0| -122.28| 17.0| +------------------+---------+------------------+ only showing top 3 rows +------------------+---------+------------------+ |median_house_value|longitude|housing_median_age| +------------------+---------+------------------+ | 141700.0| -117.95| 1.0| | 189200.0| -120.93| 1.0| | 55000.0| -116.95| 1.0| +------------------+---------+------------------+ only showing top 3 rows
Filtering:
#Use of filter()
data01.filter('ocean_proximity == "NEAR BAY"').show(2)
data01.filter('housing_median_age < 20').show(2)
#Alternate Way of using filter()
data01.filter(data01['ocean_proximity']=='NEAR BAY').show(2)
data01.filter(data01['housing_median_age']<20).show(2)
# Use of where() instead of filter()
data01.where('ocean_proximity == "NEAR BAY"').show(2)
data01.where('housing_median_age < 20').show(2)
#Alternate Way of using where()
data01.where(data01['ocean_proximity']=='NEAR BAY').show(2)
data01.where(data01['housing_median_age']<20).show(2)
Output: +------------------+---------+------------------+ |median_house_value|longitude|housing_median_age| +------------------+---------+------------------+ | 60000.0| -122.29| 2.0| | 137500.0| -122.29| 20.0| | 177500.0| -122.28| 17.0| +------------------+---------+------------------+ only showing top 3 rows
Filtering not null values only:
import pyspark.sql.functions as F
data01.where(F.col('total_bedrooms').isNotNull()).show()
Output: +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ |longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|ocean_proximity| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0| 8.3252| 452600.0| NEAR BAY| | -122.22| 37.86| 21.0| 7099.0| 1106.0| 2401.0| 1138.0| 8.3014| 358500.0| NEAR BAY| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+---------------+
Checking unique values of one variable:
data01.select('ocean_proximity').distinct().show()
OutPut: +---------------+ |ocean_proximity| +---------------+ | ISLAND| | NEAR OCEAN| | NEAR BAY| | <1H OCEAN| | INLAND| +---------------+
Drop Any Column:
data04.drop(F.col('median_house_value')).show(2)
Output: +---------+--------+------------------+-----------+--------------+----------+----------+-------------+---------------+ |longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|ocean_proximity| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+---------------+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0| 8.3252| NEAR BAY| | -122.22| 37.86| 21.0| 7099.0| 1106.0| 2401.0| 1138.0| 8.3014| NEAR BAY| +---------+--------+------------------+-----------+--------------+----------+----------+-------------+---------------+ only showing top 2 rows
Aggregations:
- Grouping by any variable or column & find out the sum, mean, max, min, etc…
#Loading a new data set to perform aggregation task
from pyspark.sql import SparkSession
Spark = SparkSession.builder\
.master('local')\
.appName('Titanic')\
.getOrCreate()
data_titanic = Spark.read.csv('/content/drive/MyDrive/Applied Data Science 2/Week 02/Titanic Data Set.csv',header=True,inferSchema=True)
data_titanic.show(3)
Output: +-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+ |PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked| +-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+ | 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| NULL| S| | 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C| | 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| NULL| S| +-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+ only showing top 3 rows
Basic Aggregations:
Function | Description | Example |
---|---|---|
count() | Counts the number of rows | df.groupBy("col").count() |
sum() | Sums up values | df.groupBy("col").agg(sum("value")) |
avg() / mean() | Computes the average | df.groupBy("col").agg(mean("value")) |
min() | Finds the minimum value | df.groupBy("col").agg(min("value")) |
max() | Finds the maximum value | df.groupBy("col").agg(max("value")) |
Example Code:
data_titanic.groupby('Sex').count().show()
data_titanic.groupby('Sex').mean('Age').show()
data_titanic.groupby('Sex','Pclass').max('Age').show()
data_titanic.groupby('Sex').min('Age').show()
data_titanic.groupby('Sex').sum('Fare').show()
'''
Alternate way below for same output
Note: groupBy() creates a GroupedData object. The .sum(), .avg(), ,max(),.min(), and .count() is natively available for grouped data and does not require an explicit function import.
If we use agg() instead (which is more flexible), we need to import .sum(), .avg(), ,max(),.min(), and .count() explicitly from pyspark.sql.functions, and since agg() does not automatically recognize aggregations.
'''
from pyspark.sql.functions import count,mean,max,min,sum
data_titanic.groupby('Sex').agg(count("*")).show()
data_titanic.groupby('Sex').agg(mean('Age')).show()
data_titanic.groupby('Sex', 'Pclass').agg(max('Age')).show()
data_titanic.groupby('Sex').agg(min('Age')).show()
data_titanic.groupby('Sex').agg(sum('Fare')).show()
Output: +------+-----+ | Sex|count| +------+-----+ |female| 314| | male| 577| +------+-----+ +------+------------------+ | Sex| mean(Age)| +------+------------------+ |female|27.915708812260537| | male| 30.72664459161148| +------+------------------+ +------+--------+ | Sex|max(Age)| +------+--------+ |female| 63.0| | male| 80.0| +------+--------+ +------+--------+ | Sex|min(Age)| +------+--------+ |female| 0.75| | male| 0.42| +------+--------+ +------+-----------------+ | Sex| sum(Fare)| +------+-----------------+ |female|13966.66279999999| | male|14727.28649999999| +------+-----------------+
Multiple aggregation operations & giving a name to new column:
- Putting all aggregations functions inside of agg()
- Using .alias() to name each new column
# Efficient Way
data_titanic.groupby("Sex").agg(
count("*").alias("Total_Count"), # Count all rows per Sex
mean("Age").alias("Avg_Age"), # Mean of Age
min("Age").alias("Min_Age"), # Minimum Age
max("Age").alias("Max_Age"), # Maximum Age
sum("Fare").alias("Total_Fare") # Sum of Fare
).show()
Output: +------+-----------+------------------+-------+-------+-----------------+ | Sex|Total_Count| Avg_Age|Min_Age|Max_Age| Total_Fare| +------+-----------+------------------+-------+-------+-----------------+ |female| 314|27.915708812260537| 0.75| 63.0|13966.66279999999| | male| 577| 30.72664459161148| 0.42| 80.0|14727.28649999999| +------+-----------+------------------+-------+-------+-----------------+
Statistical Aggregations:
Function | Description | Example |
---|---|---|
variance() / var_samp() | Sample variance | df.groupBy("col").agg(var_samp("value")) |
stddev() / stddev_samp() | Sample standard deviation | df.groupBy("col").agg(stddev("value")) |
approx_count_distinct() | Approximate distinct count | df.groupBy("col").agg(approx_count_distinct("value")) |
corr() | Correlation | df.agg(corr("col1", "col2")) |
Example code:
from pyspark.sql.functions import variance,stddev,approx_count_distinct
data_titanic.groupBy("Sex").agg(variance("Fare")).show()
data_titanic.groupBy("Sex").agg(stddev("Fare")).show()
data_titanic.groupBy("Sex").agg(approx_count_distinct("Fare")).show()
data_titanic.corr('Age','Fare')
Output: +------+-----------------+ | Sex| var_samp(Fare)| +------+-----------------+ |female|3363.732929578914| | male|1860.909702161692| +------+-----------------+ +------+-----------------+ | Sex| stddev(Fare)| +------+-----------------+ |female|57.99769762308599| | male|43.13826262335668| +------+-----------------+ +------+---------------------------+ | Sex|approx_count_distinct(Fare)| +------+---------------------------+ |female| 155| | male| 183| +------+---------------------------+ 0.135515853527051
Conditional Aggregations:
Function | Description | Example |
---|---|---|
sum_distinct() | Sum of unique values | df.agg(sum_distinct("value")) |
first() | First value in group | df.groupBy("col").agg(first("value")) |
last() | Last value in group | df.groupBy("col").agg(last("value")) |
Example Code:
from pyspark.sql.functions import sum_distinct,first,last
data_titanic.agg(sum_distinct('Pclass')).show()
data_titanic.groupBy("Sex").agg(first("age")).show()
data_titanic.groupBy("Sex").agg(last("age")).show()
Output: +--------------------+ |sum(DISTINCT Pclass)| +--------------------+ | 6| +--------------------+ +------+----------+ | Sex|first(age)| +------+----------+ |female| 38.0| | male| 22.0| +------+----------+ +------+---------+ | Sex|last(age)| +------+---------+ |female| NULL| | male| 32.0| +------+---------+
Pivot Tables (Dynamic Aggregation):
- Create a group by one column & take another categorical column unique values as column
data_titanic.groupBy("Sex").pivot('Pclass').sum('Survived').show()
Output: +------+---+---+---+ | Sex| 1| 2| 3| +------+---+---+---+ |female| 91| 70| 72| | male| 45| 17| 47| +------+---+---+---+
Grouping by multiple columns:
data_titanic.groupBy("Pclass",'Sex').sum('Survived').show()
Output: +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 2|female| 70| | 3| male| 47| | 1| male| 45| | 3|female| 72| | 1|female| 91| | 2| male| 17| +------+------+-------------+
Sorting:
#Default order is ascending.
data_titanic.groupBy("Pclass",'Sex').sum('Survived').sort("Pclass").show()
data_titanic.groupBy("Pclass",'Sex').sum('Survived').orderBy("Pclass").show()
# Sorting in descending order
from pyspark.sql.functions import desc
data_titanic.groupBy("Pclass",'Sex').sum('Survived').sort(desc("Pclass")).show()
data_titanic.groupBy("Pclass",'Sex').sum('Survived').orderBy(desc("Pclass")).show()
data_titanic.groupBy("Pclass",'Sex').sum('Survived').sort("Pclass",ascending = False).show()
data_titanic.groupBy("Pclass",'Sex').sum('Survived').orderBy("Pclass",ascending = False).show()
data_titanic.groupBy("Pclass",'Sex').sum('Survived').sort("Pclass",'Sex').show()
data_titanic.groupBy("Pclass",'Sex').sum('Survived').orderBy("Pclass",'Sex').show()
Output: +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 1| male| 45| | 1|female| 91| | 2|female| 70| | 2| male| 17| | 3| male| 47| | 3|female| 72| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 1| male| 45| | 1|female| 91| | 2|female| 70| | 2| male| 17| | 3| male| 47| | 3|female| 72| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 3| male| 47| | 3|female| 72| | 2|female| 70| | 2| male| 17| | 1| male| 45| | 1|female| 91| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 3| male| 47| | 3|female| 72| | 2|female| 70| | 2| male| 17| | 1| male| 45| | 1|female| 91| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 3| male| 47| | 3|female| 72| | 2|female| 70| | 2| male| 17| | 1| male| 45| | 1|female| 91| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 3| male| 47| | 3|female| 72| | 2|female| 70| | 2| male| 17| | 1| male| 45| | 1|female| 91| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 1|female| 91| | 1| male| 45| | 2|female| 70| | 2| male| 17| | 3|female| 72| | 3| male| 47| +------+------+-------------+ +------+------+-------------+ |Pclass| Sex|sum(Survived)| +------+------+-------------+ | 1|female| 91| | 1| male| 45| | 2|female| 70| | 2| male| 17| | 3|female| 72| | 3| male| 47| +------+------+-------------+
Checking the number of unique values of a column:
from pyspark.sql.functions import count_distinct
data_titanic.select(count_distinct('Pclass').alias('Number of class')).show()
data_titanic.select(
count_distinct('Pclass').alias('Number of class'),
count_distinct('Sex').alias('Number of Sex')).show()
Output: +---------------+ |Number of class| +---------------+ | 3| +---------------+ +---------------+-------------+ |Number of class|Number of Sex| +---------------+-------------+ | 3| 2| +---------------+-------------+
Use of withColumn() :
- Used to create a new column or modify an existing column.
- Requires a column expression (like F.when(), F.col(), etc.).
import pyspark.sql.functions as F
data_titanic.withColumn('Death_Status',F.when(F.col('Survived')==1,'Alive').otherwise('Dead')).show(2)
Output: --------------+------+----+-----+-----+---------+-------+-----+--------+------------+ |PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|Death_Status| +-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+------------+ | 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0|A/5 21171| 7.25| NULL| S| Dead| | 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C| Alive| +-----------+--------+------+--------------------+------+----+-----+-----+---------+-------+-----+--------+------------+ only showing top 2 rows
Use of limit():
- we can use limit(n) to create a new data frame containing n rows
data_titanic.select('Name').limit(10).show()
Output: +--------------------+ | Name| +--------------------+ |Braund, Mr. Owen ...| |Cumings, Mrs. Joh...| |Heikkinen, Miss. ...| |Futrelle, Mrs. Ja...| |Allen, Mr. Willia...| | Moran, Mr. James| |McCarthy, Mr. Tim...| |Palsson, Master. ...| |Johnson, Mrs. Osc...| |Nasser, Mrs. Nich...| +--------------------+
Use of collect_set():
- Collects unique (distinct) values from a column and returns them as an array.
- Removes duplicates automatically.
Use of sort_array():
- Sorts an array column in ascending or descending order.
- asc=True → Sorts in ascending order (default) and asc=False → Sorts in descending order.
- While sort() function Sorts the entire DataFrame based on one or more columns, sort_array() sorts elements inside an array column, NOT the entire DataFrame.
- Does NOT change row order, just modifies array values inside a column.
Df_PcClass = data_titanic.select(
F.sort_array(
F.collect_set('Pclass'))
.alias('PcClass'))
Df_PcClass.show()
Output: +---------+ | PcClass| +---------+ |[1, 2, 3]| +---------+