In PySpark, ArrayType
and MapType
are used to define complex data structures within a DataFrame schema.
☰ ArrayType
column, and functions,
ArrayType
allows you to store and work with arrays, which can hold multiple values of the same data type.
sample dataframe:
id, numbers|
1, [1, 2, 3]
2, [4, 5, 6]
3, [7, 8, 9]
explode ()
“explode” a given array into individual new rows using the explode
function, Offen use it to flatten JSON.
from pyspark.sql.functions import explode
# Explode the 'numbers' array into separate rows
exploded_df = df.withColumn("number", explode(df.numbers))
display(explode_df)
==output==
id numbers number
1 [1,2,3] 1
1 [1,2,3] 2
1 [1,2,3] 3
2 [4,5,6] 4
2 [4,5,6] 5
2 [4,5,6] 6
3 [7,8,9] 7
3 [7,8,9] 8
3 [7,8,9] 9
split ()
Split strings based on a specified delimiter, return a array type.
from pyspark.sql.functions import split
df.withColumn(“Name_Split”, split(df[“Name”], “,”))
sample dataframe
+————–+
| Name |
+————–+
| John,Doe |
| Jane,Smith |
| Alice,Cooper |
+————–+
from pyspark.sql.functions import split
# Split the 'Name' column by comma
df_split = df.withColumn("Name_Split", split(df["Name"], ","))
==output==
+-------------+----------------+
| Name | Name_Split |
+-------------+----------------+
| John,Doe | [John, Doe] |
| Jane,Smith | [Jane, Smith] |
| Alice,Cooper| [Alice, Cooper]|
+-------------+----------------+
array ()
Creates an array column.
from pyspark.sql.functions import array, col
data=[(1,2,3),(4,5,6)]
schema=['num1','num2','num3']
df1=spark.createDataFrame(data,schema)
df1.show()
# create a new column - numbers, array type. elements use num1,num2,num3
df1.withColumn("numbers",array(col("num1"),col("num2"),col("num3"))).show()
==output==
+----+----+----+
|num1|num2|num3|
+----+----+----+
| 1| 2| 3|
| 4| 5| 6|
+----+----+----+
#new array column "numbers" created
+----+----+----+-----------+
|num1|num2|num3| numbers |
+----+----+----+-----------+
| 1| 2| 3| [1, 2, 3] |
| 4| 5| 6| [4, 5, 6] |
+----+----+----+-----------+
array_contains ()
Checks if an array contains a specific element.
from pyspark.sql.functions import array_contains
array_contains(array, value)
sample dataframe
+—+———————–+
|id |fruits |
+—+———————–+
|1 |[apple, banana, cherry]|
|2 |[orange, apple, grape] |
|3 |[pear, peach, plum] |
+—+———————–+
from pyspark.sql.functions import array_contains
# Using array_contains to check if the array contains 'apple'
df.select("id", array_contains("fruits", "apple").alias("has_apple")).show()
==output==
+---+----------+
| id|has_apple |
+---+----------+
| 1| true|
| 2| true|
| 3| false|
+---+----------+
getItem()
Access individual elements of an array by their index using the getItem()
method
# Select the second element (index start from 0) of the 'numbers' array
df1 = df.withColumn("item_1_value", df.numbers.getItem(1))
display(df1)
==output==
id numbers item_1_value
1 [1,2,3] 2
2 [4,5,6] 5
3 [7,8,9] 8
size ()
Returns the size of the array.
from pyspark.sql.functions import size
# Get the size of the 'numbers' array
df.select(size(df.numbers)).show()
==output==
+-------------+
|size(numbers)|
+-------------+
| 3|
| 3|
| 3|
+-------------+
sort_array()
Sorts the array elements.
sort_array(col: ‘ColumnOrName’, asc: bool = True)
If `asc` is True (default) then ascending and if False then descending. if asc=True, can be omitted.
from pyspark.sql.functions import sort_array
df.withColumn("numbers", sort_array("numbers")).show()
==output==
ascending
+---+---------+
| id| numbers|
+---+---------+
| 1|[1, 2, 3]|
| 2|[4, 5, 6]|
| 3|[7, 8, 9]|
+---+---------+
df.select(sort_array("numbers", asc=False).alias("sorted_desc")).show()
==output==
descending
+-----------+
|sorted_desc|
+-----------+
| [3, 2, 1]|
| [6, 5, 4]|
| [9, 8, 7]|
+-----------+
concat ()
concat()
is used to concatenate arrays (or strings) into a single array (or string). When dealing with ArrayType
, concat()
is typically used to combine two or more arrays into one.
from pyspark.sql.functions import concat
concat(*cols)
sample DataFrames
+—+——+——+
|id |array1|array2|
+—+——+——+
|1 | [a, b] | [x, y]|
|2 | [c] | [z] |
|3 | [d, e] | null |
+—+——-+——+
from pyspark.sql.functions import concat
# Concatenating array columns
df_concat = df.withColumn("concatenated_array", concat(col("array1"), col("array2")))
df_concat.show(truncate=False)
==output==
+---+------+------+------------------+
|id |array1|array2|concatenated_array|
+---+------+------+------------------+
|1 |[a, b]|[x, y]|[a, b, x, y] |
|2 |[c] |[z] |[c, z] |
|3 |[d, e]|null |null |
+---+------+------+------------------+
Handling null
Values
If any of the input columns are null
, the entire result can become null
. This is why you’re seeing null
instead of just the non-null array.
To handle this, you can use coalesce()
to substitute null
with an empty array before performing the concat()
. coalesce()
returns the first non-null argument. Here’s how you can modify your code:
from pyspark.sql.functions import concat, coalesce, lit
# Define an empty array for the same type
empty_array = array()
# Concatenate with null handling using coalesce
df_concat = df.withColumn(
"concatenated_array",
concat(coalesce(col("array1"), empty_array), coalesce(col("array2"), empty_array))
)
df_concat.show(truncate=False)
==output==
+---+------+------+------------------+
|id |array1|array2|concatenated_array|
+---+------+------+------------------+
|1 |[a, b]|[x, y]|[a, b, x, y] |
|2 |[c] |[z] |[c, z] |
|3 |[d, e]|null |[d, e] |
+---+------+------+------------------+
array_zip ()
Combines arrays into a single array of structs.
☰ MapType column, and functions
MapType is used to represent map key-value pair similar to python Dictionary (Dic)
from pyspark.sql.types import MapType, StringType, IntegerType
# Define a MapType
my_map = MapType(StringType(), IntegerType(), valueContainsNull=True)
Parameters:
keyType
: Data type of the keys in the map. You can use PySpark data types likeStringType()
,IntegerType()
,DoubleType()
, etc.valueType
: Data type of the values in the map. It can be any valid PySpark data typevalueContainsNull
: Boolean flag (optional). It indicates whether null values are allowed in the map. Default isTrue
.
sample dataset
# Sample dataset (Product ID and prices in various currencies)
data = [
(1, {“USD”: 100, “EUR”: 85, “GBP”: 75}),
(2, {“USD”: 150, “EUR”: 130, “GBP”: 110}),
(3, {“USD”: 200, “EUR”: 170, “GBP”: 150}),
]
sample dataframe
+———-+————————————+
|product_id|prices |
+———-+————————————+
|1 |{EUR -> 85, GBP -> 75, USD -> 100} |
|2 |{EUR -> 130, GBP -> 110, USD -> 150}|
|3 |{EUR -> 170, GBP -> 150, USD -> 200}|
+———-+————————————+
Accessing map_keys (), map_values ()
Extract keys (currency codes) and values (prices):
from pyspark.sql.functions import col, map_keys, map_values
# Extract map keys and values
df.select(
col("product_id"),
map_keys(col("prices")).alias("currencies"),
map_values(col("prices")).alias("prices_in_currencies")
).show(truncate=False)
==output==
+----------+---------------+--------------------+
|product_id|currencies |prices_in_currencies|
+----------+---------------+--------------------+
|1 |[EUR, GBP, USD]|[85, 75, 100] |
|2 |[EUR, GBP, USD]|[130, 110, 150] |
|3 |[EUR, GBP, USD]|[170, 150, 200] |
+----------+---------------+--------------------+
exploder ()
Use explode ()
to flatten the map into multiple rows, where each key-value pair from the map becomes a separate row.
from pyspark.sql.functions import explode
# Use explode to flatten the map
df_exploded = df.select("product_id", explode("prices").alias("currency", "price")).show()
==output==
+----------+--------+-----+
|product_id|currency|price|
+----------+--------+-----+
| 1| EUR| 85|
| 1| GBP| 75|
| 1| USD| 100|
| 2| EUR| 130|
| 2| GBP| 110|
| 2| USD| 150|
| 3| EUR| 170|
| 3| GBP| 150|
| 3| USD| 200|
+----------+--------+-----+
Accessing specific elements in the map
To get the price for a specific currency (e.g., USD) for each product:
from pyspark.sql.functions import col, map_keys, map_values
# Access the value for a specific key in the map
df.select(
col("product_id"),
col("prices").getItem("USD").alias("price_in_usd")
).show(truncate=False)
==output==
+----------+------------+
|product_id|price_in_usd|
+----------+------------+
|1 |100 |
|2 |150 |
|3 |200 |
+----------+------------+
filtering
filter the rows based on conditions involving the map values
from pyspark.sql.functions import col, map_keys, map_values
# Filter rows where price in USD is greater than 150
df.filter(col("prices").getItem("USD") > 150).show(truncate=False)
==output==
+----------+------------------------------------+
|product_id|prices |
+----------+------------------------------------+
|3 |{EUR -> 170, GBP -> 150, USD -> 200}|
+----------+------------------------------------+
map_concat ()
Combines two or more map columns by merging their key-value pairs.
from pyspark.sql.functions import map_concat, create_map, lit
# Define the additional currency as a new map using create_map()
additional_currency = create_map(lit("CAD"), lit(120))
# Add a new currency (e.g., CAD) with a fixed price to all rows
df.withColumn(
"updated_prices",
map_concat(col("prices"), additional_currency)
).show(truncate=False)
==output==
+----------+------------------------------------+
|product_id|prices |
+----------+------------------------------------+
|3 |{EUR -> 170, GBP -> 150, USD -> 200}|
+----------+------------------------------------+