Flattening Lists in Python: Reducing Dimensionality Without Prior Knowledge of Data

Data wrangling with multi-dimensional data can be tricky in the best of times. Having a few handy ways to flatten lists in Python can help give a quick glimpse into datasets without requiring tedious parsing.
Flattening Lists in Python

One of Python’s strongest features are the list-based structures it offers developers. These structures can represent multi-dimensional data of varying types with ease. Flattening a list in Python is a process by which developers can simplify structures like binary search trees and matrices.

Flattening a list in Python requires some careful consideration, however. Nested lists, strings, byte arrays, and non-iterable sub-items can all produce edge cases that can disrupt a program’s logic. In this article, we’ll walk through many such cases while developing an evolving function to handle flattening lists in Python—no matter their contents.

Highlights

  • Python lists can contain single objects, iterable objects, or even primitives.
  • List structures can be categorized as either regular or irregular.
  • Regular lists are easy to flatten and can be done in one line.
  • Irregular lists are tricker but can be done with standard Python functions.
  • Recursion and generators can flatten irregular lists while maintaining performance and syntactic simplicity.

TL;DR – One can use the following function to flatten any list in Python while retaining strings and bytes objects as atomic elements:

from collections.abc import Iterable

def flatten(any_list):
    for element in any_list:
        if hasattr(element, "__iter__") and not isinstance(element, (str, bytes)):
            yield from flatten(element)
        else:
            yield element

Introduction: Regular vs. Irregular Lists

Python lists are incredibly flexible—they are dynamically sized, can reference elements of any type, and are mutable. In the case of flattening lists, this incredible flexibility can become a pain.  Specifically, Python lists can contain a mixture of iterable and non-iterable elements. When considering how to flatten a list, we need to consider the following two cases:

  1. regular – a list containing either non-iterable elements or elements of the same iterable depth;
  2. irregular – a list containing elements that might be both iterable or non-iterable, where iterable elements might be of varying depth.

To get a better picture of these two cases, let’s create two such lists in Python:

# regular list of sublists
regular = [
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
]

# irregular list containing sublists,
# elements, and sublists of sublists
irregular = [
    [1, 2, 3],
    [[4, 5, 6], 7],
    [8, [9, 10, [11, 12, 13]], 14],
    15,
    max
]

Note that the irregular list contains both iterable elements, non-iterable elements, and even a callable function. This presents several cases we’ll have to find a way to handle. First, let’s start slowly by considering how to flatten a regular list.

Flattening Regular Lists

Regular lists provide a degree of predictability that allows certain assumptions to be made. Namely, we can assume that all elements are iterable but without varying depths of order.

That is—if one element is a single sublist with non-iterable items all the others will be as well. We can approach flattening a regular list in several ways. First, let’s consider a method using a simple for loop:

Method 1: Looping Over Elements

In this approach, we will create a new list to hold each element as we discover them. We’ll loop over each element in the list as well as each element in the sublist:

# Iterate over list
flat_list = []
for item in regular:
    for subitem in item:
        flat_list.append(subitem)

# View result
>>> flat_list
[1, 2, 3, 4, 5, 6, 7, 8, 9]

This showcases the base logic we need to flatten a list: extract each element from any iterable element and copy it to a new list.

Method 2: List Comprehension

List comprehension is a powerful tool in Python to avoid the cumbersome syntax of writing lots of loops.  We can use list comprehension to flatten a regular list as such:

# Create a flattened list using list comprehension
flat_list = [subitem for item in regular for subitem in item]

# View result
>>> flat_list
[1, 2, 3, 4, 5, 6, 7, 8, 9]

Method 3: The sum Function

For cases of 2-dimensional regular lists we can use a very simplified approach that makes use of Python’s sum() function. This one is a little odd, but can be done as such:

sum(regular, [])

>>> [1, 2, 3, 4, 5, 6, 7, 8, 9]

In this approach, we are summing all elements found in our list into a new list, passed as a second argument.

Edge Case 1: 2-Dimensional Data

Using the sum function seems pretty slick but is only applicable to cases of 2D regular lists. Let’s consider what happens when we start adding extra dimensions to our data. Let’s take a look at the following regular list:

# Create a regular list with iterable sub-items
nested_regular = [
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
]

# Attempting to flatten with list comprehension
flat = [subitem for item in nested_regular for subitem in item]

>>> flat
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

# Attempting to flatten with sum function
summed = sum(nested_regular, [])

>>> summed
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

Here we’ve only succeeded in extracting the sublists from our list rather than extracting all single elements from those sublists. To handle this goal, we’ll need to revert to the use of a for loop:

flat_list = []
for item in regular:
    for subitem in item:
        for single_item in subitem:
            flat_list.append(single_item)

>>> flat_list
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

This works well enough, but what happens when we have an unknown number of nested levels? In other words—what happens when we don’t know how many dimensions our data is represented as beforehand?

Edge Case 2: Multi-Dimensional Data

It’s not very practical to restrict our approach only to cases where we know the structure of our data beforehand. To tackle this, we would need to take a recursive approach by which we define a new function. Consider the following:

def flatten_levels(items, levels):

    output = items
    for i in range(levels):
        tmp = []
        for item in output:
            for subitem in item:
                tmp.append(subitem)
        output = tmp
    return output

# Call function with req'd levels argument
flat = flatten_levels(regular, 2)

>>> flat
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

Here we see a function that will iterate into sub-lists when the proper argument is given for the levels parameter. While this solves the dimensionality issue for regular lists, it requires that we know how many dimensions our data has beforehand—not a flexible solution. We can re-work this approach by adding a check for whether a subitem contains an iterable object or not.

def flatten_with_check(items):

    # Assume items are iterable
    is_iterable = True
    output = items

    # Get each item in items
    while is_iterable is True:
        tmp = []
        for item in output:
            tmp.extend(item)  # <--- Will break on irregularity
        output = tmp

        # Check if another level of nesting
        try:
            iter(output[0])
            is_iterable = True
        except:
            is_iterable = False

    return output

flat = flatten_with_check(regular)
>>> flat
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

This works and will handle lists with as much nesting as we care to throw its way. It’s not the most elegant solution (also a syntactic nightmare here) and only handles lists where our data is regular in nature. Let’s consider now how we might build upon this solution to handle irregular lists as well.

Flattening Irregular Lists

Our function to flatten a regular list ended up pretty unsightly. It will also throw a TypeError if we pass an irregular list in as an argument. To address this, we can use a recursive function to both simplify our syntax and handle irregular lists. Let’s start off with a function that uses a try-catch block for determining whether an object is iterable or not.

def flatten(iterable, flattened):
    try:
        iter(iterable)
        for item in iterable:
            flatten(item, flattened)
    except:
        flattened.append(iterable)
    return flattened

# Flatten our regular list
flat = flatten(regular, [])
>>> flat
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]


# Flatten our Irregular list
flat = flatten(irregular, [])
>>> flat
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, <built-in function max>]

This approach involves less syntax and seems more elegant. Recursive functions run the risk of obfuscating logic but, in this case, the benefits outweigh the downside.

Note: this would crash our program with lists that had more than 1000 levels of nesting (Python’s default recursion depth limit).

While not likely to be an issue in most cases, we’ll take a look at an approach that will also handle that issue in a moment. First—let’s deal with strings and other iterables that we might want to extract as single elements.

Edge Case 1: Strings & Bytes

Strings and bytes often represent conceptually atomic objects. For example, the string "python" isn’t likely to be considered equivalent to the list ['p', 'y', 't', 'h', 'o', 'n']. In most cases, this object needs to be retained as a single word rather than a series of characters.

Our approaches thus far treat strings, bytes, and lists all as the iterable objects they are. To ensure our strings and bytes stay atomic, we can alter our previous recursive function to check for these data types. Also, we’ll create a new irregular list to demonstrate the issue:

# Create a new irregular list
irregular = [
    [1, 2, 3],
    [4],
    [
        [4, 5, 6], [5, 6, [11, 12, ['alpha']], max(1, 2)],
        [3, 'beta', [['delta', 12, 4 + 5]]]
    ],
    'gamma'
]

# Run our previous function
flat = flatten(irregular, [])
>>> flat
[1, 2, 3, 4, 4, 5, 6, 5, 6, 11, 12, 'a', 'l', 'p', 'h', 'a', 2, 3, 'b', 'e', 't', 'a', 'd', 'e', 'l', 't', 'a', 12, 9, 'g', 'a', 'm', 'm', 'a']

Here we note that the strings contained in our irregular list are treated as iterable objects, resulting in single characters being added to our final output. To avoid this issue, we’ll add a single conditional statement to check for strings and bytes:

def atomic_flatten(iterable, flattened):

    try:
        iter(iterable)
        if type(iterable) not in [str, bytes]:
            for item in iterable:
                atomic_flatten(item, flattened)
        else:
            flattened.append(iterable)
    except:
        flattened.append(iterable)
    return flattened

flat = atomic_flatten(irregular, [])
>>> flat
[1, 2, 3, 4, 4, 5, 6, 5, 6, 11, 12, 'alpha', 2, 3, 'beta', 'delta', 12, 9, 'gamma']

At this point, we’ve achieved our goal of flattening lists. This function can handle regular lists, irregular lists, avoid breaking strings and bytes objects into individual characters, and requires no prior knowledge of our lists’ dimensionality. All in all, this gets the job done. Rather than leave things here, let’s consider how we might make this function a bit more performant, avoid recursion depth errors, tidy up our syntax, and keep things flexible.

Optimized Approach

We can approach optimizing our function by using generators for performance, isinstance checks rather than try:except blocks, and the Python collectionsIterable class to rule out single items. This approach is done as follows:

# import the Iterable class from Python's ABC module
from collections.abc import Iterable

# Define our irregularly-shaped list
irregular = [
    [1, 2, 3],
    [4],
    [
        [4, 5, 6], [5, 6, [11, 12, ['alpha']], max(1, 2)],
        [3, 'beta', [['delta', 12, 5+4]]]
    ],
    'gamma'
]

# Check every item in our 
def flatten(any_list):

    # Iterate over every possible element
    for element in any_list:

        # Check if the object is iterable & not a string or bytes object
        if isinstance(element, Iterable) and not isinstance(element, (str, bytes)):

            # recurse on iterable elements
            yield from flatten(element)
        else:

            # keep single elements
            yield element

# Use function to flatten list
flat = flatten(irregular)

# View results
>>> flat
<generator object flatten at 0x000001DCA4DDAF20>

There are a few new things going on here worth noting:

  1. We get a generator object returned rather than a list;
  2. We use the Iterable class from the Python standard library;
  3. We use the yield keyword rather than return;
  4. We use the yield from syntax when recursing;
  5. We are still treating str and bytes elements as atomic.

To view each item in our newly-flattened list, can cast the generator object to a list as such:

# Print each item in our flattened list
>>> list(flat)
[1, 2, 3, 4, 4, 5, 6, 5, 6, 11, 12, 'alpha', 2, 3, 'beta', 'delta', 12, 9, 'gamma']

Note: This function is available from the Github repo here.

The only caveat with this approach is our use of the Iterable class from the Python collections library. A dependency on Python’s standard library is, in my opinion, a fair trade-off. However, for anyone looking to avoid such a dependency, the isinstance(element, Iterable) condition can be changed to hasattr(element, "__iter__").

Final Thoughts

The concept of flattening lists can be applied when dealing with matrices, binary trees, or any number of other data structures. Particularly, cases where a hierarchy is present in data often result in irregularly-shaped lists. Flattening isn’t always needed for meaningful data analysis—though when it is, the approaches we’ve discussed here can be useful in considering both the feasibility, practicality, and performance of such an undertaking.

Zαck West
Full-Stack Software Engineer with 10+ years of experience. Expertise in developing distributed systems, implementing object-oriented models with a focus on semantic clarity, driving development with TDD, enhancing interfaces through thoughtful visual design, and developing deep learning agents.