Source code for ghonn_models_pytorch.core.ghonu

"""Defines the GHONU (Gated Higher-Order Neural Unit) model."""

from __future__ import annotations

from typing import Any

from torch import Tensor, nn

from .honu import HONU

__version__ = "0.0.1"


[docs] class GHONU(nn.Module): """GHONU (Gated Higher-Order Neural Unit) model. This model combines two Higher-Order Neural Units (HONUs): a predictor HONU and a gate HONU. The gate HONU modulates the output of the predictor HONU using a specified activation function. Methods: __repr__: Returns a string representation of the GHONU model. forward: Performs the forward pass of the GHONU model. """ def __init__( self, in_features: int, predictor_order: int, gate_order: int, *, predictor_activation: str = "identity", gate_activation: str = "sigmoid", **kwargs: dict[str, Any], ) -> None: """Initialize the GHONU (Gated Higher-Order Neural Unit) model. Args: in_features (int): The number of input features for the model. predictor_order (int): The order of the predictor HONU. gate_order (int): The order of the gate HONU. predictor_activation (str, optional): The activation function to use for the predictor. Defaults to "identity". Must be a valid function in `torch.nn.functional`. gate_activation (str, optional): The activation function to use for the gate. Defaults to "sigmoid". Must be a valid activation function in `torch.nn.functional`. **kwargs: Additional keyword arguments passed to the HONU modules (e.g., weight_divisor, bias). Attributes: in_features (int): The number of input features for the model. predictor_order (int): The order of the predictor HONU. gate_order (int): The order of the gate HONU. _gate_activation (str): The activation function used for the gate. _predictor_activation (str): The activation function used for the predictor. predictor (HONU): The predictor HONU instance. gate (HONU): The gate HONU instance. """ super().__init__() # Main model parameters self.in_features = in_features self.predictor_order = predictor_order self.gate_order = gate_order # Optional params self._gate_activation = gate_activation self._predictor_activation = predictor_activation # Initialize predictor and gate HONUs self.predictor = HONU( in_features, predictor_order, activation=self._predictor_activation, **kwargs ) self.gate = HONU(in_features, gate_order, activation=self._gate_activation, **kwargs)
[docs] def __repr__(self) -> str: """Return a string representation of the GHONU model.""" lines = [ f"{self.__class__.__name__}(", f" in_features={self.in_features}, ", f" predictor={self.predictor!r},", f" gate={self.gate!r}", ] return "\n".join(lines) + "\n" + ")"
[docs] def forward( self, x: Tensor, *, return_elements: bool = False ) -> Tensor | tuple[Tensor, Tensor, Tensor]: """Perform the forward pass of the GHONU model. Args: x (Tensor): Input tensor. return_elements (bool, optional): If True, return the individual outputs of the predictor and gate along with the final output. Defaults to False. Returns: Tensor: The final output of the model. If `return_elements` is True, returns a tuple containing the final output, predictor output, and gate output. """ # Get the outputs of the predictor and gate HONUs predictor_output = self.predictor(x) gate_output = self.gate(x) # Apply the gate to the predictor output output = predictor_output * gate_output if return_elements: return output, predictor_output, gate_output return output
if __name__ == "__main__": from pathlib import Path filename = Path(__file__).name MSG = f"The {filename} is not meant to be run as a script." raise OSError(MSG)