Source code for ghonn_models_pytorch.core.ghonn

"""Defines the GHONN (Gated Higher Order Neural Network) model architecture."""

from __future__ import annotations

from typing import Any, Callable

import torch
from torch import Tensor, nn

from .ghonu import GHONU
from .utils import normalize_list_to_size

__version__ = "0.0.1"


[docs] class GHONN(nn.Module): """GHONN (Gated Higher Order Neural Network) model. This class implements a neural network composed of multiple GHONUs (Gated Higher Order Neural Units) in a single layer. Each unit applies polynomial transformations to the input data and uses a gating mechanism to modulate the output. The model supports flexible configurations, including the number of units, polynomial orders for both the predictor and gate, and the activation function for the gate and different output types. """ def __init__( # noqa: PLR0913 self, in_features: int, out_features: int, layer_size: int, predictor_orders: list[int], gate_orders: list[int], *, predictor_activations: list[str] | tuple[str] | str = "identity", gate_activations: list[str] | tuple[str] | str = "sigmoid", output_type: str = "linear", **kwargs: dict[str, Any], ) -> None: """Initialize the Gater Higher Order Neural Network model. Args: in_features (int): Number of input features. out_features (int): Number of output features. layer_size (int): Number of layers in the model. predictor_orders (list[int]): List of predictor orders for each layer. gate_orders (list[int]): List of gate orders for each layer. predictor_activations (list[str] | tuple[str], optional): List or tuple of activation functions for the predictor. Defaults to ("identity"). gate_activations (list[str] | tuple[str], optional): List or tuple of activation functions for the gates. Defaults to ("sigmoid"). output_type (str, optional): Type of output layer. Defaults to "linear". **kwargs: Additional keyword arguments passed to the GHONU layers. Attributes: in_features (int): Number of input features. out_features (int): Number of output features. layer_size (int): Number of layers in the model. predictor_orders (list[int]): Normalized list of predictor orders for each layer. gate_orders (list[int]): Normalized list of gate orders for each layer. gate_activations (list[str]): Normalized list of activation functions for the gates. output_type (str): Type of output layer. ghonus (nn.ModuleList): List of GHONU layers. head (nn.Module): Output head module. """ super().__init__() # Main model parameters self.in_features = in_features self.out_features = out_features self.layer_size = layer_size self.predictor_orders = normalize_list_to_size( self.layer_size, predictor_orders, description="predictor" ) self.gate_orders = normalize_list_to_size(self.layer_size, gate_orders, description="gate") # Ensure that even single str value is passed as a list for activation functions predictor_activations = ( (predictor_activations,) if isinstance(predictor_activations, str) else predictor_activations ) self.predictor_activations = normalize_list_to_size( self.layer_size, predictor_activations, description="predictor activations" ) gate_activations = ( (gate_activations,) if isinstance(gate_activations, str) else gate_activations ) self.gate_activations = normalize_list_to_size( self.layer_size, gate_activations, description="gate activations" ) # Optional params self.output_type = output_type self._validate_setup() # Initialize the GHONUs self.ghonus = nn.ModuleList( [ GHONU( in_features, p, g, predictor_activation=pa, gate_activation=ga, **kwargs, ) for p, g, pa, ga in zip( self.predictor_orders, self.gate_orders, self.predictor_activations, self.gate_activations, ) ] ) self.head = self._get_head() def __repr__(self) -> str: """Return a string representation of the GHONN model.""" cls = self.__class__.__name__ # Describe head if self.output_type == "sum": head_desc = "SummedHonuOutputs" elif self.output_type == "linear": head_desc = repr(self.head) elif self.output_type == "raw": head_desc = "RawHonuOutputs" else: head_desc = "UnknownHead" # Describe the model gate_activations = tuple(dict.fromkeys(self.gate_activations)) predictor_activations = tuple(dict.fromkeys(self.predictor_activations)) lines = [ f"{cls}(", f" in_features={self.in_features},", f" out_features={self.out_features},", f" layer_size={self.layer_size},", f" output_type={self.output_type},", f" predictor_activation_functions={predictor_activations},", f" gate_activation_functions={gate_activations},", f" head={head_desc},", f" ghonus={self.ghonus}", ] # Describe the model return "\n".join(lines) + "\n" + ")" @property def predictors(self) -> nn.ModuleList: """Get the ModuleList of predictor HONUs in the GHONN model. Returns: nn.ModuleList: ModuleList of predictor HONUs. """ return nn.ModuleList([ghonu.predictor for ghonu in self.ghonus]) @property def gates(self) -> nn.ModuleList: """Get the ModuleList of gate HONUs in the GHONN model. Returns: nn.ModuleList: ModuleList of gate HONUs. """ return nn.ModuleList([ghonu.gate for ghonu in self.ghonus]) def _validate_setup(self) -> None: """Validates the configuration of the model to ensure all parameters are correctly set. This method checks the following conditions: - The `output_type` must be one of the supported types: "sum", "linear", or "raw". - The `layer_size` must be greater than 0. - The `out_features` must be greater than 0. - If `output_type` is "sum", `out_features` must be exactly 1. - If `output_type` is "raw", `out_features` must match the value of `layer_size`. Raises: ValueError: If any of the above conditions are not met. """ supported_output_types = ["sum", "linear", "raw"] if self.output_type not in supported_output_types: msg = ( f"Invalid output type: {self.output_type}. Must be one of {supported_output_types}." ) raise ValueError(msg) if self.layer_size <= 0: msg = f"Invalid layer_size: {self.layer_size}. Must be > 0." raise ValueError(msg) if self.out_features <= 0: msg = f"Invalid out_features: {self.out_features}. Must be > 0." raise ValueError(msg) if self.output_type == "sum" and self.out_features != 1: msg = f"Invalid out_features: {self.out_features}. Must be 1 when output_type is 'sum'." raise ValueError(msg) if self.output_type == "raw" and self.out_features != self.layer_size: msg = ( f"Invalid out_features: {self.out_features}. Must be {self.layer_size} when " "output_type is 'raw'." ) raise ValueError(msg) def _get_head(self) -> Callable: """Constructs and returns the output head function based on the specified output type. Supported `output_type` values: - "sum": Returns a lambda function that computes the sum of the input tensor along the last dimension. - "linear": Returns a fully connected layer (`nn.Linear`) with `layer_size` input features and `out_features` output features. - "raw": Returns a lambda function that outputs the input tensor unchanged. Returns: callable: A function or module that processes the output of the model. Raises: ValueError: If the specified `output_type` is not one of the supported types. """ if self.output_type == "sum": return lambda x: x.sum(dim=-1) if self.output_type == "linear": return nn.Linear(self.layer_size, self.out_features) if self.output_type == "raw": return lambda x: x msg = f"Invalid output type: {self.output_type}" raise ValueError(msg) def forward( self, x: Tensor, *, return_elements: bool = False ) -> Tensor | tuple[Tensor, tuple[Tensor, Tensor]]: """Perform the forward pass of the GHONN model. Args: x (Tensor): Input tensor. return_elements (bool, optional): If True, return the individual predictor and gate outputs along with the network output. Defaults to False. Returns: Tensor: Model output if return_elements is False. tuple: (output, (predictor_outputs, gate_outputs)) if return_elements is True. """ if return_elements: outs: list[Tensor] = [ghonu(x, return_elements=True) for ghonu in self.ghonus] # outs: Tuple of (output, predictor_output, gate_output) outputs, predictor_outputs, gate_outputs = (torch.stack(t, dim=-1) for t in zip(*outs)) # Reshape outputs for 'linear' and 'raw' output types if self.output_type in ["linear", "raw"]: outputs = outputs.view(x.size(0), -1) predictor_outputs = predictor_outputs.view(x.size(0), -1) gate_outputs = gate_outputs.view(x.size(0), -1) final = self.head(outputs) return final, (predictor_outputs, gate_outputs) outputs = torch.stack([ghonu(x) for ghonu in self.ghonus], dim=-1) # Reshape outputs for 'linear' and 'raw' output types if self.output_type in ["linear", "raw"]: outputs = outputs.view(x.size(0), -1) return self.head(outputs)
if __name__ == "__main__": from pathlib import Path filename = Path(__file__).name MSG = f"The {filename} is not meant to be run as a script." raise OSError(MSG)