import pandas as pd
df = pd.DataFrame({'name': ['Alice', 'Bob', 'Chris'],
'age': [24, 31, 31],
'state': ['LA', 'MA', 'CA']})
df
Ofte when working wiht a dataset, you'll want to filter it in different ways. In Pandas, you can easily filter based on a single column:
df[df.age >= 25]
And it isn't much harder to filter based on multiple columns:
df[(df.age >= 25) & (df.state == 'CA')]
Problem¶
However, I'll often find myself with a large DataFrame that I might want to filter in different ways. Since I'm doing this often, it makes sense to create a function.
def bad_filter(df, age_threshold, state_filter):
return df[(df.age >= age_threshold) & (df.state == state_filter)]
bad_filter(df, 25, 'CA')
However, the problem with this approach is that I might not want to always filter by the same set of columns ever time (age
and state
in the above example). It would be much nicer if we could simply pass in a dictionary that defines the filtering we want to do.
Flexible Filtering in Pandas¶
I'll present, the function, some examples, and then step through exactly what is happening. The key thing to note is the ability to pass in a dictionary that describes how you want to filter the DataFrame. This is much more flexible since you add / remove columns easily.
def flexible_filter(df, filters):
"""Filter DataFrame using dictionary of filters.
Args:
df (DataFrame)
filters (dict): Mapping of columns (keys) to filter (value)
Returns:
DataFrame
"""
return df[(df[list(filters)] == pd.Series(filters)).all(axis=1)]
# Filtering on a single column
flexible_filter(df, {'age': 31})
# Filtering on two columns
flexible_filter(df, {'age': 31, 'state': 'CA'})
The function isn't long (just one line!), but there is a lot going on, so let's step through to make sure we understand what is happening.
First, let's identify the columns that the user has decided to filter on:
filters = {'age': 31, 'state': 'CA'}
list(filters)
Then, we select those columns from the DataFrame:
df[list(filters)]
Cool Trick: Setting this equal to the Series that maps column names to value, we get the desired boolean values. Pandas is smart enough to know that when you do equality between a DataFrame and a Series, you are comparing along shared indexes (or columns in the case of the DataFrame).
df[list(filters)] == pd.Series(filters)
Last (but definitely not least!), we want only rows that consist of all True
values. We can easiliy do this using .all(axis=1)
.
(df[list(filters)] == pd.Series(filters)).all(axis=1)
Then, we have the trivial step of passing that boolean Series to the original DataFrame so that it filters the data just to the rows that satisfy all of our filter values.
Next Steps¶
There is more we could do with this function, including:
- Extend to allow OR by passing in list of values in the filtering dictionary
- Add ability to use thresholds for numeric columns