Skip to content

Aggregation Strategies

Module for aggregation strategies.

Author: Daniel Wertheimer

AggregationStrategy

Bases: ABC

Abstract base class for aggregation strategies.

Subclasses should implement the aggregate method to perform a specific type of aggregation.

Methods:

Name Description
aggregate

Abstract method to be overridden by subclasses to perform aggregation.

include_col_name

Indicates whether to include the column name in the output.

Raises:

Type Description
NotImplementedError

If the method is not overridden by a subclass.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
class AggregationStrategy(ABC):
    """
    Abstract base class for aggregation strategies.

    Subclasses should implement the aggregate method to perform a specific type of aggregation.

    Methods
    -------
    aggregate(df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame
        Abstract method to be overridden by subclasses to perform aggregation.

    include_col_name() -> bool
        Indicates whether to include the column name in the output.

    Raises
    ------
    NotImplementedError
        If the method is not overridden by a subclass.
    """

    @abstractmethod
    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Perform aggregation on the specified column and group by the key columns.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be aggregated.
        agg_col : str
            The column in the DataFrame to be aggregated.
        key_cols : List[str]
            The columns in the DataFrame to group by.

        Returns
        -------
        DataFrame
            A DataFrame containing the aggregated data.
        """
        raise NotImplementedError("This method should be overridden by subclasses.")

    def include_col_name(self) -> bool:
        """
        Indicate whether to include the column name in the output.

        Returns
        -------
        bool
            True if the column name should be included, False otherwise.
        """
        return True

aggregate(df, agg_col, key_cols) abstractmethod

Perform aggregation on the specified column and group by the key columns.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data to be aggregated.

required
agg_col str

The column in the DataFrame to be aggregated.

required
key_cols List[str]

The columns in the DataFrame to group by.

required

Returns:

Type Description
DataFrame

A DataFrame containing the aggregated data.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
@abstractmethod
def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
    """Perform aggregation on the specified column and group by the key columns.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data to be aggregated.
    agg_col : str
        The column in the DataFrame to be aggregated.
    key_cols : List[str]
        The columns in the DataFrame to group by.

    Returns
    -------
    DataFrame
        A DataFrame containing the aggregated data.
    """
    raise NotImplementedError("This method should be overridden by subclasses.")

include_col_name()

Indicate whether to include the column name in the output.

Returns:

Type Description
bool

True if the column name should be included, False otherwise.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
def include_col_name(self) -> bool:
    """
    Indicate whether to include the column name in the output.

    Returns
    -------
    bool
        True if the column name should be included, False otherwise.
    """
    return True

CountAggregation

Bases: AggregationStrategy

CountAggregation class for counting occurrences of values in a specified column.

Attributes:

Name Type Description
include_missing bool

If True, missing values (nulls and NaNs) are included in the count aggregation. If False, missing values are excluded from the count.

Methods:

Name Description
aggregate

Aggregate data by counting occurrences.

_check_contains_missing

Checks if the specified column contains any missing values (nulls or NaNs).

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
class CountAggregation(AggregationStrategy):
    """CountAggregation class for counting occurrences of values in a specified column.

    Attributes
    ----------
    include_missing : bool
        If True, missing values (nulls and NaNs) are included in the count aggregation.
        If False, missing values are excluded from the count.

    Methods
    -------
    aggregate(df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame
        Aggregate data by counting occurrences.
    _check_contains_missing(df: DataFrame, agg_col: str) -> bool
        Checks if the specified column contains any missing values (nulls or NaNs).
    """

    def __init__(self, include_missing: Optional[bool] = False) -> None:
        """
        Initialise the CountAggregation class with the option to include or exclude missing values.

        Parameters
        ----------
        include_missing : bool, optional
            If True, include missing values in the count (default is False).

        Returns
        -------
        None
        """
        super().__init__()
        self.include_missing = include_missing

    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Aggregate data by counting occurrences.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be aggregated.
        agg_col : str
            The column in the DataFrame to be aggregated.
        key_cols : List[str]
            The columns in the DataFrame to group by.

        Returns
        -------
        DataFrame
            A DataFrame containing the aggregated data.
        """
        if self._check_contains_missing(df, agg_col):
            logger.warning(
                f"Column {agg_col} contains nulls. "
                f"Count aggregation may include null values. "
                f"If this is intentional, ignore this warning."
            )
        if not self.include_missing:
            output_df = df.dropna(subset=[agg_col]).groupby(*key_cols).agg(F.count(F.col(agg_col)))
        else:
            output_df = df.groupby(*key_cols).agg(F.count("*"))
        return output_df

    def _check_contains_missing(self, df: DataFrame, agg_col: str) -> bool:
        """
        Check if the specified column contains any missing values (nulls or NaNs).

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be checked.
        agg_col : str
            The column in the DataFrame to check for missing values.

        Returns
        -------
        bool
            True if the column contains missing values (nulls or NaNs), False otherwise.
        """
        null_count = df.filter(F.col(agg_col).isNull()).count()
        nan_count = df.filter(F.isnan(F.col(agg_col))).count()

        return null_count != 0 or nan_count != 0

aggregate(df, agg_col, key_cols)

Aggregate data by counting occurrences.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data to be aggregated.

required
agg_col str

The column in the DataFrame to be aggregated.

required
key_cols List[str]

The columns in the DataFrame to group by.

required

Returns:

Type Description
DataFrame

A DataFrame containing the aggregated data.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
    """Aggregate data by counting occurrences.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data to be aggregated.
    agg_col : str
        The column in the DataFrame to be aggregated.
    key_cols : List[str]
        The columns in the DataFrame to group by.

    Returns
    -------
    DataFrame
        A DataFrame containing the aggregated data.
    """
    if self._check_contains_missing(df, agg_col):
        logger.warning(
            f"Column {agg_col} contains nulls. "
            f"Count aggregation may include null values. "
            f"If this is intentional, ignore this warning."
        )
    if not self.include_missing:
        output_df = df.dropna(subset=[agg_col]).groupby(*key_cols).agg(F.count(F.col(agg_col)))
    else:
        output_df = df.groupby(*key_cols).agg(F.count("*"))
    return output_df

CountIfOneAggregation

Bases: AggregationStrategy

CountIfOneAggregation class for counting the number of values that are equal to one.

Methods:

Name Description
aggregate

Aggregate data by counting occurrences of the number 1.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
class CountIfOneAggregation(AggregationStrategy):
    """CountIfOneAggregation class for counting the number of values that are equal to one.

    Methods
    -------
    aggregate(df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame
        Aggregate data by counting occurrences of the number 1.
    """

    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Aggregate data by counting occurrences of the number 1.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be aggregated.
        agg_col : str
            The column in the DataFrame to be aggregated.
        key_cols : List[str]
            The columns in the DataFrame to group by.

        Returns
        -------
        DataFrame
            A DataFrame containing the aggregated data.
        """
        return df.filter(F.col(agg_col) == 1).groupby(*key_cols).agg(F.count(F.col(agg_col)))

aggregate(df, agg_col, key_cols)

Aggregate data by counting occurrences of the number 1.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data to be aggregated.

required
agg_col str

The column in the DataFrame to be aggregated.

required
key_cols List[str]

The columns in the DataFrame to group by.

required

Returns:

Type Description
DataFrame

A DataFrame containing the aggregated data.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
    """Aggregate data by counting occurrences of the number 1.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data to be aggregated.
    agg_col : str
        The column in the DataFrame to be aggregated.
    key_cols : List[str]
        The columns in the DataFrame to group by.

    Returns
    -------
    DataFrame
        A DataFrame containing the aggregated data.
    """
    return df.filter(F.col(agg_col) == 1).groupby(*key_cols).agg(F.count(F.col(agg_col)))

MeanAggregation

Bases: AggregationStrategy

MeanAggregation class for calculating the mean of values in a specified column.

Methods:

Name Description
aggregate

Aggregate data by calculating the mean.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
class MeanAggregation(AggregationStrategy):
    """MeanAggregation class for calculating the mean of values in a specified column.

    Methods
    -------
    aggregate(df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame
        Aggregate data by calculating the mean.
    """

    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Aggregate data by calculating the mean.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be aggregated.
        agg_col : str
            The column in the DataFrame to be aggregated.
        key_cols : List[str]
            The columns in the DataFrame to group by.

        Returns
        -------
        DataFrame
            A DataFrame containing the aggregated data.
        """
        return df.groupby(*key_cols).agg(F.mean(F.col(agg_col)))

aggregate(df, agg_col, key_cols)

Aggregate data by calculating the mean.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data to be aggregated.

required
agg_col str

The column in the DataFrame to be aggregated.

required
key_cols List[str]

The columns in the DataFrame to group by.

required

Returns:

Type Description
DataFrame

A DataFrame containing the aggregated data.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
    """Aggregate data by calculating the mean.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data to be aggregated.
    agg_col : str
        The column in the DataFrame to be aggregated.
    key_cols : List[str]
        The columns in the DataFrame to group by.

    Returns
    -------
    DataFrame
        A DataFrame containing the aggregated data.
    """
    return df.groupby(*key_cols).agg(F.mean(F.col(agg_col)))

StddevAggregation

Bases: AggregationStrategy

StddevAggregation class for calculating the standard deviation of values in a specified column.

Methods:

Name Description
aggregate

Aggregate data by calculating the standard deviation.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
class StddevAggregation(AggregationStrategy):
    """StddevAggregation class for calculating the standard deviation of values in a specified column.

    Methods
    -------
    aggregate(df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame
        Aggregate data by calculating the standard deviation.
    """

    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Aggregate data by calculating the standard deviation.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be aggregated.
        agg_col : str
            The column in the DataFrame to be aggregated.
        key_cols : List[str]
            The columns in the DataFrame to group by.

        Returns
        -------
        DataFrame
            A DataFrame containing the aggregated data.
        """
        return df.groupby(*key_cols).agg(F.stddev(F.col(agg_col)))

aggregate(df, agg_col, key_cols)

Aggregate data by calculating the standard deviation.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data to be aggregated.

required
agg_col str

The column in the DataFrame to be aggregated.

required
key_cols List[str]

The columns in the DataFrame to group by.

required

Returns:

Type Description
DataFrame

A DataFrame containing the aggregated data.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
    """Aggregate data by calculating the standard deviation.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data to be aggregated.
    agg_col : str
        The column in the DataFrame to be aggregated.
    key_cols : List[str]
        The columns in the DataFrame to group by.

    Returns
    -------
    DataFrame
        A DataFrame containing the aggregated data.
    """
    return df.groupby(*key_cols).agg(F.stddev(F.col(agg_col)))

SumAggregation

Bases: AggregationStrategy

SumAggregation class for summing values in a specified column.

Methods:

Name Description
aggregate

Aggregate data by summing the values.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
class SumAggregation(AggregationStrategy):
    """SumAggregation class for summing values in a specified column.

    Methods
    -------
    aggregate(df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame
        Aggregate data by summing the values.
    """

    def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
        """Aggregate data by summing the values.

        Parameters
        ----------
        df : DataFrame
            The input DataFrame containing the data to be aggregated.
        agg_col : str
            The column in the DataFrame to be aggregated.
        key_cols : List[str]
            The columns in the DataFrame to group by.

        Returns
        -------
        DataFrame
            A DataFrame containing the aggregated data.
        """
        return df.groupby(*key_cols).agg(F.sum(F.col(agg_col)))

aggregate(df, agg_col, key_cols)

Aggregate data by summing the values.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame containing the data to be aggregated.

required
agg_col str

The column in the DataFrame to be aggregated.

required
key_cols List[str]

The columns in the DataFrame to group by.

required

Returns:

Type Description
DataFrame

A DataFrame containing the aggregated data.

Source code in amee_utils/feature_generator/feature_set/aggregation_strategy.py
def aggregate(self, df: DataFrame, agg_col: str, key_cols: List[str]) -> DataFrame:
    """Aggregate data by summing the values.

    Parameters
    ----------
    df : DataFrame
        The input DataFrame containing the data to be aggregated.
    agg_col : str
        The column in the DataFrame to be aggregated.
    key_cols : List[str]
        The columns in the DataFrame to group by.

    Returns
    -------
    DataFrame
        A DataFrame containing the aggregated data.
    """
    return df.groupby(*key_cols).agg(F.sum(F.col(agg_col)))