Skip to content

Lagged Aggregation and Aggregation Strategies

Overview

This repository provides tools for performing lagged aggregation on large datasets using PySpark. The core functionality is organized into two main modules:

  1. lagged_aggregator.py: This module handles the creation of lagged features based on specified periods.
  2. aggregation_strategy.py: This module defines different strategies for aggregating data.

The goal of this guide is to help you understand how the modules are structured, how to use them effectively, and how to extend the functionality by adding new aggregation strategies.

Module Structure

lagged_aggregator.py

This module provides a single class, LaggedAggregation, which allows users to create lagged features for a given DataFrame. The key parameters for setting up the lagged aggregation include the type of lag (e.g., single or period), the lag range (week or month), the column representing the date, and a list of aggregation strategies to apply.

Week-Based Feature Behavior

When using week-based lag types (single_week or period_week), the system automatically excludes the current week (the week containing the calculation date) to ensure only complete historical weeks are included. This provides consistent weekly windows regardless of which day of the week the calculation runs. For example, if the calculation date is a Friday, the entire current week (Monday through Sunday) will be excluded from the lag calculations.

Core Components

  • LAG_TYPE_MAPPING Dictionary: This dictionary maps different lag types to their respective functions and prefixes. It serves as a central configuration that defines the behavior of the lagged aggregation based on the selected lag_type (single_month, period, single_week, period_week). It is used during initialization of the LaggedAggregation class to select the appropriate lag function and prefix.

  • LaggedAggregation Class: The primary class for creating lagged features and aggregating them. It accepts:

  • periods_list (List[int]): List of time periods to create lagged features for (months or weeks depending on lag_type).
  • time_col (str): Column name representing the time period in the DataFrame (month or week).
  • lag_type (str): Type of lag to apply (single_month, period, single_week, or period_week).
  • apply Method: This method applies the specified lagged aggregations to the DataFrame, combining the results with the original data.

Example usage

from lagged_aggregation import LaggedAggregation
from aggregation_strategy import SumAggregation, MeanAggregation

# Initialise LaggedAggregation with a list of months, month column, and lag type
lagged_agg = LaggedAggregation(
    periods_list=[1, 3, 6],
    time_col="month",
    lag_type="single_month"
)

# Apply lagged aggregations with the specified strategies
result_df = lagged_agg.apply(
    df,
    agg_col="sales",
    key_cols=["customer_id"],
    strategies=[SumAggregation(), MeanAggregation()]
)

Note

In this example, the lag_type is set to "single_month". The LAG_TYPE_MAPPING dictionary automatically selects the create_single_month_lag_df function and the "L" prefix for the generated lagged columns.

Weekly Lag Example

from lagged_aggregation import LaggedAggregation
from aggregation_strategy import SumAggregation

# Single-week lag features (e.g., last 1 and 2 weeks)
lagged_week_single = LaggedAggregation(
    periods_list=[1, 2],
    time_col="week",
    lag_type="single_week"
)
result_weekly_single = lagged_week_single.apply(
    df,
    agg_col="orders",
    key_cols=["customer_id"],
    strategies=[SumAggregation()]
)

# Period-week lag features (e.g., rolling sum over last 1 and 2 weeks)
lagged_week_period = LaggedAggregation(
    periods_list=[1, 2],
    time_col="week",
    lag_type="period_week"
)
result_weekly_period = lagged_week_period.apply(
    df,
    agg_col="orders",
    key_cols=["customer_id"],
    strategies=[SumAggregation()]
)

aggregation_strategy.py

This module defines the abstract base class AggregationStrategy and several concrete implementations. Each strategy specifies a different way to aggregate data.

Core Components

  • AggregationStrategy Class: The abstract base class that defines the interface for all aggregation strategies. All new strategies should inherit from this class and implement the aggregate method.
Concrete Strategies
  • SumAggregation: Sums the values in the specified column.
  • MeanAggregation: Computes the mean of the values in the specified column.
  • StddevAggregation: Computes the standard deviation of the values.
  • CountAggregation: Counts occurrences, with an option to include or exclude missing values.
  • CountIfOneAggregation: Counts the number of occurrences where the value equals one.

Example usage

from aggregation_strategy import SumAggregation, StddevAggregation

# Initialise and apply a sum aggregation strategy
sum_agg = SumAggregation()
result_df = sum_agg.aggregate(df, agg_col="sales", key_cols=["customer_id"])

# Initialise and apply a standard deviation aggregation strategy
stddev_agg = StddevAggregation()
result_df = stddev_agg.aggregate(df, agg_col="sales", key_cols=["customer_id"])

Adding a New Aggregation Strategy

To extend the functionality of the repository, you can create new aggregation strategies. Here’s a step-by-step guide on how to add a new strategy:

1. Create a New Class

Your new strategy should inherit from the AggregationStrategy class. Implement the aggregate method, which defines how the aggregation is performed. You can also override the include_col_name method if needed.

2. Implement the Aggregation Logic

In the aggregate method, define how the data should be aggregated based on the input DataFrame, the aggregation column, and the key columns.

3. Add Unit Tests

Test your new strategy to ensure it works as expected. You can create a separate test file or add to existing tests.

Example: Implementing MedianAggregation

from aggregation_strategy import AggregationStrategy
import pyspark.sql.functions as F

class MedianAggregation(AggregationStrategy):
    """Custom aggregation strategy for calculating the median."""

    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Aggregate data by calculating the median."""
        return df.groupby(*key_cols).agg(F.expr(f'percentile_approx({agg_col}, 0.5)'))

    def include_col_name(self) -> bool:
        """Include the column name in the output."""
        return True

4. Update Documentation

After adding your new strategy, ensure the documentation reflects this addition. You can do this by updating the docstrings and contributing documentation.

5. Test and Validate

Test the functionality with different datasets and ensure that the results are as expected.