Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ._generalized_gramian import flatten, movedim, reshape
from ._gramian import compute_gramian, normalize, regularize
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
from ._structure import Structure, extract_structure

__all__ = [
"extract_structure",
"Structure",
"compute_gramian",
"normalize",
"regularize",
Expand Down
16 changes: 16 additions & 0 deletions src/torchjd/_linalg/_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass

import torch

from torchjd._linalg import Matrix


@dataclass
class Structure:
m: int
device: torch.device
dtype: torch.dtype


def extract_structure(matrix: Matrix) -> Structure:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be not only for Matrices but any non-scalar?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's my plan for a future PR.

return Structure(m=matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
31 changes: 14 additions & 17 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from torch import Tensor

from torchjd._linalg import Matrix
from torchjd.aggregation._weighting_bases import FromNothingWeighting

from ._aggregator_bases import WeightedAggregator
from ._utils.str import vector_to_str
from ._weighting_bases import Weighting


class ConstantWeighting(Weighting[Matrix]):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
weights.

:param weights: The weights to return at each call.
"""

class _ConstantWeighting(Weighting[None]):
def __init__(self, weights: Tensor) -> None:
if weights.dim() != 1:
raise ValueError(
Expand All @@ -25,16 +18,20 @@ def __init__(self, weights: Tensor) -> None:
super().__init__()
self.weights = weights

def forward(self, matrix: Tensor, /) -> Tensor:
self._check_matrix_shape(matrix)
def forward(self, _: None, /) -> Tensor:
return self.weights

def _check_matrix_shape(self, matrix: Tensor) -> None:
if matrix.shape[0] != len(self.weights):
raise ValueError(
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
)

class ConstantWeighting(FromNothingWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
weights.

:param weights: The weights to return at each call.
"""

def __init__(self, weights: Tensor) -> None:
super().__init__(_ConstantWeighting(weights))


class Constant(WeightedAggregator):
Expand Down
22 changes: 14 additions & 8 deletions src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix
from torchjd._linalg import Structure
from torchjd.aggregation._weighting_bases import FromStructureWeighting

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting


class MeanWeighting(Weighting[Matrix]):
class _MeanWeighting(Weighting[Structure]):
def forward(self, structure: Structure, /) -> Tensor:
device = structure.device
dtype = structure.dtype
m = structure.m
weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
return weights


class MeanWeighting(FromStructureWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
\mathbb{R}^m`.
"""

def forward(self, matrix: Tensor, /) -> Tensor:
device = matrix.device
dtype = matrix.dtype
m = matrix.shape[0]
weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
return weights
def __init__(self) -> None:
super().__init__(_MeanWeighting())


class Mean(WeightedAggregator):
Expand Down
18 changes: 12 additions & 6 deletions src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._linalg import Matrix
from torchjd._linalg import Structure
from torchjd.aggregation._weighting_bases import FromStructureWeighting

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting


class RandomWeighting(Weighting[Matrix]):
class _RandomWeighting(Weighting[Structure]):
def forward(self, structure: Structure, /) -> Tensor:
random_vector = torch.randn(structure.m, device=structure.device, dtype=structure.dtype)
weights = F.softmax(random_vector, dim=-1)
return weights


class RandomWeighting(FromStructureWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
at each call.
"""

def forward(self, matrix: Tensor, /) -> Tensor:
random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
weights = F.softmax(random_vector, dim=-1)
return weights
def __init__(self) -> None:
super().__init__(_RandomWeighting())


class Random(WeightedAggregator):
Expand Down
18 changes: 11 additions & 7 deletions src/torchjd/aggregation/_sum.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix
from torchjd._linalg import Structure
from torchjd.aggregation._weighting_bases import FromStructureWeighting

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting


class SumWeighting(Weighting[Matrix]):
class _SumWeighting(Weighting[Structure]):
def forward(self, structure: Structure, /) -> Tensor:
weights = torch.ones(structure.m, device=structure.device, dtype=structure.dtype)
return weights


class SumWeighting(FromStructureWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
"""

def forward(self, matrix: Tensor, /) -> Tensor:
device = matrix.device
dtype = matrix.dtype
weights = torch.ones(matrix.shape[0], device=device, dtype=dtype)
return weights
def __init__(self) -> None:
super().__init__(_SumWeighting())


class Sum(WeightedAggregator):
Expand Down
38 changes: 31 additions & 7 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from torch import Tensor, nn

from torchjd._linalg import PSDTensor, is_psd_tensor
from torchjd._linalg import Matrix, PSDTensor, Structure, extract_structure, is_psd_tensor

_T = TypeVar("_T", contravariant=True, bound=Tensor)
_FnInputT = TypeVar("_FnInputT", bound=Tensor)
_FnOutputT = TypeVar("_FnOutputT", bound=Tensor)
_T = TypeVar("_T", contravariant=True)
_FnInputT = TypeVar("_FnInputT")
_FnOutputT = TypeVar("_FnOutputT")


class Weighting(nn.Module, ABC, Generic[_T]):
Expand All @@ -27,11 +27,9 @@ def __init__(self) -> None:
def forward(self, stat: _T, /) -> Tensor:
"""Computes the vector of weights from the input stat."""

def __call__(self, stat: Tensor, /) -> Tensor:
def __call__(self, stat: object, /) -> Tensor:
Copy link
Copy Markdown
Contributor Author

@ValerianRey ValerianRey Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. This type should be _T. Before, Tensor was the public equivalent of both Matrix and PSDMatrix, so it worked, but it's quite wrong.

If we do type this as _T, we need Matrix and PSDMatrix to become public, and I think we need users to cast to Matrix and PSDMatrix in order to use call.

Finding a workaround would be huge (this problem + the issue of having a gap between public and private type, solved at once).

"""Computes the vector of weights from the input stat and applies all registered hooks."""

# The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of
# stat to be Tensor.
return super().__call__(stat)

def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]:
Expand All @@ -55,6 +53,32 @@ def forward(self, stat: _T, /) -> Tensor:
return self.weighting(self.fn(stat))


class FromStructureWeighting(_Composition[Matrix]):
"""
Weighting that extracts the structure of the input matrix before applying a Weighting to it.

:param structure_weighting: The object responsible for extracting the vector of weights from the
structure.
"""

def __init__(self, structure_weighting: Weighting[Structure]) -> None:
super().__init__(structure_weighting, extract_structure)
self.structure_weighting = structure_weighting


class FromNothingWeighting(_Composition[Matrix]):
"""
Weighting that extracts nothing from the input matrix before applying a Weighting to it (i.e. to
None).

:param none_weighting: The object responsible for extracting the vector of weights from nothing.
"""

def __init__(self, none_weighting: Weighting[None]) -> None:
super().__init__(none_weighting, lambda _: None)
self.none_weighting = none_weighting


class GeneralizedWeighting(nn.Module, ABC):
r"""
Abstract base class for all weightings that operate on generalized Gramians. It has the role of
Expand Down
23 changes: 0 additions & 23 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,29 +63,6 @@ def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionCon
_ = Constant(weights=weights)


@mark.parametrize(
["weights_shape", "n_rows", "expectation"],
[
([0], 0, does_not_raise()),
([1], 1, does_not_raise()),
([5], 5, does_not_raise()),
([0], 1, raises(ValueError)),
([1], 0, raises(ValueError)),
([4], 5, raises(ValueError)),
([5], 4, raises(ValueError)),
],
)
def test_matrix_shape_check(
weights_shape: list[int], n_rows: int, expectation: ExceptionContext
) -> None:
matrix = ones_([n_rows, 5])
weights = ones_(weights_shape)
aggregator = Constant(weights)

with expectation:
_ = aggregator(matrix)


def test_representations() -> None:
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
assert repr(A) == "Constant(weights=tensor([1., 2.]))"
Expand Down
Loading