Skip to content

SKU Feature Set

Module for the SKUFeatures and SKUFeatureConfig classes.

Author: Jessica Matthysen, Garett Sidwell

SKUFeatureConfig

Bases: FeatureConfig

Configuration class for SKU features.

Attributes:

Name Type Description
primary_sort_field str

The primary field used for sorting SKUs before aggregation.

lag_months list[int]

List of lag months to consider for SKU calculations (default is [3]).

sku_id_column str

Name of the column containing SKU IDs (default is an empty string).

number_of_skus int

Number of top and bottom SKUs to calculate (default is 3).

secondary_sort_field str

The secondary field used to resolve ties in SKU sorting (default is an empty string).

resolve_tie_break bool

Whether to resolve ties using the secondary sorting field (default is False).

Source code in amee_utils/feature_generator/feature_set/sku.py
@attrs.define
class SKUFeatureConfig(FeatureConfig):
    """Configuration class for SKU features.

    Attributes
    ----------
    primary_sort_field : str
        The primary field used for sorting SKUs before aggregation.
    lag_months : list[int]
        List of lag months to consider for SKU calculations (default is [3]).
    sku_id_column : str
        Name of the column containing SKU IDs (default is an empty string).
    number_of_skus : int
        Number of top and bottom SKUs to calculate (default is 3).
    secondary_sort_field : str
        The secondary field used to resolve ties in SKU sorting (default is an empty string).
    resolve_tie_break : bool
        Whether to resolve ties using the secondary sorting field (default is False).
    """

    primary_sort_field: str
    lag_months: list[int] = [3]
    sku_id_column: str = ""
    number_of_skus: int = 3
    secondary_sort_field: str = ""
    resolve_tie_break: bool = False

    def __attrs_post_init__(self):
        """
        Post-initialization method for SKUFeatureConfig.

        Raises
        ------
        ValueError
            If resolve_tie_break is True but no secondary_sort_field is provided.
        """
        if self.resolve_tie_break and not self.secondary_sort_field:
            raise ValueError(
                "resolve_tie_break is True, but no secondary_sort_field was provided in" " SKUFeatureConfig."
            )

SKUFeatureSet

Bases: FeatureSet[SKUFeatureConfig]

A feature set for generating SKU related features.

Source code in amee_utils/feature_generator/feature_set/sku.py
class SKUFeatureSet(FeatureSet[SKUFeatureConfig]):
    """A feature set for generating SKU related features."""

    def calculate(
        self,
        df: DataFrame,
        dataset_config: DatasetConfig,
        feature_config: SKUFeatureConfig,
        calculation_date: datetime,
    ) -> DataFrame:
        """
        Calculate SKU features for the given DataFrame, DatasetConfig, and SKUFeatureConfig.

        If tie-breaking is enabled (i.e., resolve_tie_break=True), the method will perform
        an additional aggregation based on the secondary_sort_field and join it with the
        primary aggregation results.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data.
        dataset_config : DatasetConfig
            Configuration details for the dataset.
        feature_config : SKUFeatureConfig
            Configuration details specific to SKU features.
        calculation_date : datetime
            The date for which the calculations are performed.

        Returns
        -------
        DataFrame
            The DataFrame with the calculated SKU features.

        Example
        -------
        Input DataFrame:

        | customer_id | sku_id | quantity | date       |
        |-------------|--------|----------|------------|
        | C1          | SKU1   | 10       | 2024-03-01 |
        | C1          | SKU2   | 5        | 2024-04-01 |
        | C1          | SKU3   | 3        | 2024-05-01 |
        | C1          | SKU4   | 8        | 2024-06-01 |
        | C1          | SKU5   | 1        | 2024-07-01 |
        | C2          | SKU1   | 2        | 2024-03-01 |
        | C2          | SKU2   | 7        | 2024-04-01 |
        | C2          | SKU3   | 6        | 2024-05-01 |
        | C2          | SKU4   | 4        | 2024-06-01 |
        | C2          | SKU5   | 9        | 2024-07-01 |

        Output DataFrame (with lag_months=[3, 6]):

        | customer_id | TOP_SKU_P3_1 | TOP_SKU_P3_2 | TOP_SKU_P3_3 | BOTTOM_SKU_P3_1 |
        |-------------|--------------|--------------|--------------|-----------------|
        | C1          | SKU4         | SKU2         | SKU3         | SKU3            |
        | C2          | SKU5         | SKU3         | SKU4         | SKU1            |

        ...more columns for P3 and P6
        """
        self.feature_config = feature_config
        key_cols = dataset_config.key_cols
        self.df = df
        lag_months = feature_config.lag_months
        base = df.select(key_cols).distinct()

        logger.info("Filtering data for the specified number of months")

        # Filter the data for each lag month separately
        lag_dfs = []
        for lag in lag_months:
            lag_int = int(lag)
            df_filtered = filter_data_by_months(
                df=df,
                number_of_months=lag_int,
                calculation_date=calculation_date,
                date_column=dataset_config.date_col,
            )
            lag_dfs.append((lag, df_filtered))

        logger.info("Applying lagged aggregation")

        result_dfs = []
        for lag, df_filtered in lag_dfs:
            lagged_aggregation = LaggedAggregation(
                periods_list=[lag],
                time_col="month",
                lag_type="period",
            )
            strategy = SumAggregation()

            df_aggregated = lagged_aggregation.apply(
                df_filtered,
                feature_config.primary_sort_field,
                key_cols + [feature_config.sku_id_column],
                [strategy],
            )

            if feature_config.resolve_tie_break and feature_config.secondary_sort_field:
                df_secondary_aggregated = lagged_aggregation.apply(
                    df_filtered,
                    feature_config.secondary_sort_field,
                    key_cols + [feature_config.sku_id_column],
                    [strategy],
                )

                df_aggregated = df_aggregated.join(
                    df_secondary_aggregated,
                    key_cols + [feature_config.sku_id_column],
                    how="left",
                )

            logger.info(f"Calculating top and bottom SKUs for lag {lag}")

            top_skus_df = self.top_skus(df_aggregated, key_cols, lag)
            bottom_skus_df = self.bottom_skus(df_aggregated, key_cols, lag)

            lag_result = join_multiple_to_base(base, [top_skus_df, bottom_skus_df], key_cols)
            result_dfs.append(lag_result)

        result_df = join_multiple_to_base(base, result_dfs, key_cols)

        return result_df

    def top_skus(self, df: DataFrame, key_cols: list, lag: int) -> DataFrame:
        """
        Get the top N SKUs for a specific lag.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data.
        key_cols : list
            List of key columns for grouping.
        lag : int
            The lag period for which to calculate the top SKUs.

        Returns
        -------
        DataFrame
            The DataFrame with the top N SKUs for the specified lag period.

        Example
        -------
        Input DataFrame:

        | customer_id | sku_id | quantity |
        |-------------|--------|----------|
        | C1          | SKU2   | 5        |
        | C1          | SKU3   | 3        |
        | C1          | SKU4   | 8        |
        | C2          | SKU2   | 7        |
        | C2          | SKU3   | 6        |
        | C2          | SKU4   | 4        |

        Output DataFrame (for lag=3):

        | customer_id | TOP_SKU_P3_1 | TOP_SKU_P3_2 | TOP_SKU_P3_3 |
        |-------------|--------------|--------------|--------------|
        | C1          | SKU4         | SKU2         | SKU3         |
        | C2          | SKU2         | SKU3         | SKU4         |

        """
        return self._get_ranked_skus(df, key_cols, "desc", f"TOP_SKU_P{lag}_")

    def bottom_skus(self, df: DataFrame, key_cols: list, lag: int) -> DataFrame:
        """
        Get the bottom N SKUs for a specific lag.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data.
        key_cols : list
            List of key columns for grouping.
        lag : int
            The lag period for which to calculate the bottom SKUs.

        Returns
        -------
        DataFrame
            The DataFrame with the bottom N SKUs for the specified lag period.

        Example
        -------
        Input DataFrame:

        | customer_id | sku_id | quantity |
        |-------------|--------|----------|
        | C1          | SKU2   | 5        |
        | C1          | SKU3   | 3        |
        | C1          | SKU4   | 8        |
        | C2          | SKU2   | 7        |
        | C2          | SKU3   | 6        |
        | C2          | SKU4   | 4        |

        Output DataFrame (for lag=3):

        | customer_id | BOTTOM_SKU_P3_1 | BOTTOM_SKU_P3_2 | BOTTOM_SKU_P3_3 |
        |-------------|-----------------|-----------------|-----------------|
        | C1          | SKU3            | SKU2            | SKU4            |
        | C2          | SKU4            | SKU3            | SKU2            |

        """
        return self._get_ranked_skus(df, key_cols, "asc", f"BOTTOM_SKU_P{lag}_")

    def _get_ranked_skus(
        self,
        df: DataFrame,
        key_cols: list,
        order: str,
        prefix: str,
    ) -> DataFrame:
        """
        Get the ranked SKUs.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data.
        key_cols : list
            List of key columns for grouping.
        order : str
            Order by which to rank SKUs ("asc" for ascending, "desc" for descending).
        prefix : str
            Prefix for the column names in the output DataFrame.

        Returns
        -------
        DataFrame
            The DataFrame with the ranked SKUs by primary and optional secondary field for tie breaks.
        """
        feature_config = self.feature_config
        order_func = F.asc if order == "asc" else F.desc

        primary_sort_field = feature_config.primary_sort_field
        secondary_sort_field = feature_config.secondary_sort_field
        resolve_tie_break = feature_config.resolve_tie_break

        func_name = "SUM"
        lag = int(prefix.split("P")[-1].split("_")[0])  # Extract lag from prefix
        primary_order_func_col = f"{func_name}_{primary_sort_field.upper()}_P{lag}"

        secondary_order_func_col = None
        if resolve_tie_break and secondary_sort_field:
            secondary_order_func_col = f"{func_name}_{secondary_sort_field.upper()}_P{lag}"

        # Create a single orderBy clause
        order_cols = [order_func(primary_order_func_col)]
        if secondary_order_func_col:
            order_cols.append(order_func(secondary_order_func_col))

        window_spec = Window.partitionBy(*key_cols).orderBy(*order_cols)

        columns_to_select = [
            *key_cols,
            feature_config.sku_id_column,
            F.col(primary_order_func_col).alias(f"{prefix}primary_{primary_order_func_col}"),
            "rank",
        ]

        if secondary_order_func_col:
            columns_to_select.append(
                F.col(secondary_order_func_col).alias(f"{prefix}secondary_{secondary_order_func_col}")
            )

        ranked_skus = (
            df.withColumn("rank", F.row_number().over(window_spec))
            .filter(F.col("rank") <= feature_config.number_of_skus)
            .select(*columns_to_select)
        )

        ranked_skus_pivot = ranked_skus.groupBy(*key_cols).pivot("rank").agg(F.first(feature_config.sku_id_column))
        for i in range(1, feature_config.number_of_skus + 1):
            ranked_skus_pivot = ranked_skus_pivot.withColumnRenamed(str(i), f"{prefix}{i}")

        return ranked_skus_pivot

    def _construct_alias(self, field_name: str, func_name: str, lag_months: int) -> str:
        """Dynamically construct the alias for the field based on the aggregation.

        Parameters
        ----------
        field_name : str
            The name of the field (column) to generate the alias for.
        func_name : str
            The name of the aggregation function (e.g., 'sum').
        lag_months : int
            The number of lag months used in the aggregation.

        Returns
        -------
        str
            The dynamically constructed alias for the field.
        """
        return f"{func_name}_{field_name.upper()}_P{lag_months}"

bottom_skus(df, key_cols, lag)

Get the bottom N SKUs for a specific lag.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data.

required
key_cols list

List of key columns for grouping.

required
lag int

The lag period for which to calculate the bottom SKUs.

required

Returns:

Type Description
DataFrame

The DataFrame with the bottom N SKUs for the specified lag period.

Example

Input DataFrame:

customer_id sku_id quantity
C1 SKU2 5
C1 SKU3 3
C1 SKU4 8
C2 SKU2 7
C2 SKU3 6
C2 SKU4 4

Output DataFrame (for lag=3):

customer_id BOTTOM_SKU_P3_1 BOTTOM_SKU_P3_2 BOTTOM_SKU_P3_3
C1 SKU3 SKU2 SKU4
C2 SKU4 SKU3 SKU2
Source code in amee_utils/feature_generator/feature_set/sku.py
def bottom_skus(self, df: DataFrame, key_cols: list, lag: int) -> DataFrame:
    """
    Get the bottom N SKUs for a specific lag.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data.
    key_cols : list
        List of key columns for grouping.
    lag : int
        The lag period for which to calculate the bottom SKUs.

    Returns
    -------
    DataFrame
        The DataFrame with the bottom N SKUs for the specified lag period.

    Example
    -------
    Input DataFrame:

    | customer_id | sku_id | quantity |
    |-------------|--------|----------|
    | C1          | SKU2   | 5        |
    | C1          | SKU3   | 3        |
    | C1          | SKU4   | 8        |
    | C2          | SKU2   | 7        |
    | C2          | SKU3   | 6        |
    | C2          | SKU4   | 4        |

    Output DataFrame (for lag=3):

    | customer_id | BOTTOM_SKU_P3_1 | BOTTOM_SKU_P3_2 | BOTTOM_SKU_P3_3 |
    |-------------|-----------------|-----------------|-----------------|
    | C1          | SKU3            | SKU2            | SKU4            |
    | C2          | SKU4            | SKU3            | SKU2            |

    """
    return self._get_ranked_skus(df, key_cols, "asc", f"BOTTOM_SKU_P{lag}_")

calculate(df, dataset_config, feature_config, calculation_date)

Calculate SKU features for the given DataFrame, DatasetConfig, and SKUFeatureConfig.

If tie-breaking is enabled (i.e., resolve_tie_break=True), the method will perform an additional aggregation based on the secondary_sort_field and join it with the primary aggregation results.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data.

required
dataset_config DatasetConfig

Configuration details for the dataset.

required
feature_config SKUFeatureConfig

Configuration details specific to SKU features.

required
calculation_date datetime

The date for which the calculations are performed.

required

Returns:

Type Description
DataFrame

The DataFrame with the calculated SKU features.

Example

Input DataFrame:

customer_id sku_id quantity date
C1 SKU1 10 2024-03-01
C1 SKU2 5 2024-04-01
C1 SKU3 3 2024-05-01
C1 SKU4 8 2024-06-01
C1 SKU5 1 2024-07-01
C2 SKU1 2 2024-03-01
C2 SKU2 7 2024-04-01
C2 SKU3 6 2024-05-01
C2 SKU4 4 2024-06-01
C2 SKU5 9 2024-07-01

Output DataFrame (with lag_months=[3, 6]):

customer_id TOP_SKU_P3_1 TOP_SKU_P3_2 TOP_SKU_P3_3 BOTTOM_SKU_P3_1
C1 SKU4 SKU2 SKU3 SKU3
C2 SKU5 SKU3 SKU4 SKU1

...more columns for P3 and P6

Source code in amee_utils/feature_generator/feature_set/sku.py
def calculate(
    self,
    df: DataFrame,
    dataset_config: DatasetConfig,
    feature_config: SKUFeatureConfig,
    calculation_date: datetime,
) -> DataFrame:
    """
    Calculate SKU features for the given DataFrame, DatasetConfig, and SKUFeatureConfig.

    If tie-breaking is enabled (i.e., resolve_tie_break=True), the method will perform
    an additional aggregation based on the secondary_sort_field and join it with the
    primary aggregation results.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data.
    dataset_config : DatasetConfig
        Configuration details for the dataset.
    feature_config : SKUFeatureConfig
        Configuration details specific to SKU features.
    calculation_date : datetime
        The date for which the calculations are performed.

    Returns
    -------
    DataFrame
        The DataFrame with the calculated SKU features.

    Example
    -------
    Input DataFrame:

    | customer_id | sku_id | quantity | date       |
    |-------------|--------|----------|------------|
    | C1          | SKU1   | 10       | 2024-03-01 |
    | C1          | SKU2   | 5        | 2024-04-01 |
    | C1          | SKU3   | 3        | 2024-05-01 |
    | C1          | SKU4   | 8        | 2024-06-01 |
    | C1          | SKU5   | 1        | 2024-07-01 |
    | C2          | SKU1   | 2        | 2024-03-01 |
    | C2          | SKU2   | 7        | 2024-04-01 |
    | C2          | SKU3   | 6        | 2024-05-01 |
    | C2          | SKU4   | 4        | 2024-06-01 |
    | C2          | SKU5   | 9        | 2024-07-01 |

    Output DataFrame (with lag_months=[3, 6]):

    | customer_id | TOP_SKU_P3_1 | TOP_SKU_P3_2 | TOP_SKU_P3_3 | BOTTOM_SKU_P3_1 |
    |-------------|--------------|--------------|--------------|-----------------|
    | C1          | SKU4         | SKU2         | SKU3         | SKU3            |
    | C2          | SKU5         | SKU3         | SKU4         | SKU1            |

    ...more columns for P3 and P6
    """
    self.feature_config = feature_config
    key_cols = dataset_config.key_cols
    self.df = df
    lag_months = feature_config.lag_months
    base = df.select(key_cols).distinct()

    logger.info("Filtering data for the specified number of months")

    # Filter the data for each lag month separately
    lag_dfs = []
    for lag in lag_months:
        lag_int = int(lag)
        df_filtered = filter_data_by_months(
            df=df,
            number_of_months=lag_int,
            calculation_date=calculation_date,
            date_column=dataset_config.date_col,
        )
        lag_dfs.append((lag, df_filtered))

    logger.info("Applying lagged aggregation")

    result_dfs = []
    for lag, df_filtered in lag_dfs:
        lagged_aggregation = LaggedAggregation(
            periods_list=[lag],
            time_col="month",
            lag_type="period",
        )
        strategy = SumAggregation()

        df_aggregated = lagged_aggregation.apply(
            df_filtered,
            feature_config.primary_sort_field,
            key_cols + [feature_config.sku_id_column],
            [strategy],
        )

        if feature_config.resolve_tie_break and feature_config.secondary_sort_field:
            df_secondary_aggregated = lagged_aggregation.apply(
                df_filtered,
                feature_config.secondary_sort_field,
                key_cols + [feature_config.sku_id_column],
                [strategy],
            )

            df_aggregated = df_aggregated.join(
                df_secondary_aggregated,
                key_cols + [feature_config.sku_id_column],
                how="left",
            )

        logger.info(f"Calculating top and bottom SKUs for lag {lag}")

        top_skus_df = self.top_skus(df_aggregated, key_cols, lag)
        bottom_skus_df = self.bottom_skus(df_aggregated, key_cols, lag)

        lag_result = join_multiple_to_base(base, [top_skus_df, bottom_skus_df], key_cols)
        result_dfs.append(lag_result)

    result_df = join_multiple_to_base(base, result_dfs, key_cols)

    return result_df

top_skus(df, key_cols, lag)

Get the top N SKUs for a specific lag.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data.

required
key_cols list

List of key columns for grouping.

required
lag int

The lag period for which to calculate the top SKUs.

required

Returns:

Type Description
DataFrame

The DataFrame with the top N SKUs for the specified lag period.

Example

Input DataFrame:

customer_id sku_id quantity
C1 SKU2 5
C1 SKU3 3
C1 SKU4 8
C2 SKU2 7
C2 SKU3 6
C2 SKU4 4

Output DataFrame (for lag=3):

customer_id TOP_SKU_P3_1 TOP_SKU_P3_2 TOP_SKU_P3_3
C1 SKU4 SKU2 SKU3
C2 SKU2 SKU3 SKU4
Source code in amee_utils/feature_generator/feature_set/sku.py
def top_skus(self, df: DataFrame, key_cols: list, lag: int) -> DataFrame:
    """
    Get the top N SKUs for a specific lag.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data.
    key_cols : list
        List of key columns for grouping.
    lag : int
        The lag period for which to calculate the top SKUs.

    Returns
    -------
    DataFrame
        The DataFrame with the top N SKUs for the specified lag period.

    Example
    -------
    Input DataFrame:

    | customer_id | sku_id | quantity |
    |-------------|--------|----------|
    | C1          | SKU2   | 5        |
    | C1          | SKU3   | 3        |
    | C1          | SKU4   | 8        |
    | C2          | SKU2   | 7        |
    | C2          | SKU3   | 6        |
    | C2          | SKU4   | 4        |

    Output DataFrame (for lag=3):

    | customer_id | TOP_SKU_P3_1 | TOP_SKU_P3_2 | TOP_SKU_P3_3 |
    |-------------|--------------|--------------|--------------|
    | C1          | SKU4         | SKU2         | SKU3         |
    | C2          | SKU2         | SKU3         | SKU4         |

    """
    return self._get_ranked_skus(df, key_cols, "desc", f"TOP_SKU_P{lag}_")