Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/aligned_mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ Aligned-MTL
===========

.. autoclass:: torchjd.aggregation.AlignedMTL
:members: __call__

.. autoclass:: torchjd.aggregation.AlignedMTLWeighting
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/cagrad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ CAGrad
======

.. autoclass:: torchjd.aggregation.CAGrad
:members: __call__

.. autoclass:: torchjd.aggregation.CAGradWeighting
:members: __call__
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ ConFIG
======

.. autoclass:: torchjd.aggregation.ConFIG
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/constant.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ Constant
========

.. autoclass:: torchjd.aggregation.Constant
:members: __call__

.. autoclass:: torchjd.aggregation.ConstantWeighting
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/dualproj.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ DualProj
========

.. autoclass:: torchjd.aggregation.DualProj
:members: __call__

.. autoclass:: torchjd.aggregation.DualProjWeighting
:members: __call__
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/flattening.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Flattening
==========

.. autoclass:: torchjd.aggregation.Flattening
:members: __call__
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/graddrop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ GradDrop
========

.. autoclass:: torchjd.aggregation.GradDrop
:members: __call__
4 changes: 2 additions & 2 deletions docs/source/docs/aggregation/gradvac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ GradVac
=======

.. autoclass:: torchjd.aggregation.GradVac
:members: reset
:members: __call__, reset

.. autoclass:: torchjd.aggregation.GradVacWeighting
:members: reset
:members: __call__, reset
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/imtl_g.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ IMTL-G
======

.. autoclass:: torchjd.aggregation.IMTLG
:members: __call__

.. autoclass:: torchjd.aggregation.IMTLGWeighting
:members: __call__
3 changes: 3 additions & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ Abstract base classes
---------------------

.. autoclass:: torchjd.aggregation.Aggregator
:members: __call__

.. autoclass:: torchjd.aggregation.Weighting
:members: __call__

.. autoclass:: torchjd.aggregation.GeneralizedWeighting
:members: __call__

.. autoclass:: torchjd.aggregation.Stateful
:members: reset
Expand Down
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/krum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ Krum
====

.. autoclass:: torchjd.aggregation.Krum
:members: __call__

.. autoclass:: torchjd.aggregation.KrumWeighting
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/mean.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ Mean
====

.. autoclass:: torchjd.aggregation.Mean
:members: __call__

.. autoclass:: torchjd.aggregation.MeanWeighting
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/mgda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ MGDA
====

.. autoclass:: torchjd.aggregation.MGDA
:members: __call__

.. autoclass:: torchjd.aggregation.MGDAWeighting
:members: __call__
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/nash_mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Nash-MTL
========

.. autoclass:: torchjd.aggregation.NashMTL
:members: reset
:members: __call__, reset
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/pcgrad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ PCGrad
======

.. autoclass:: torchjd.aggregation.PCGrad
:members: __call__

.. autoclass:: torchjd.aggregation.PCGradWeighting
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ Random
======

.. autoclass:: torchjd.aggregation.Random
:members: __call__

.. autoclass:: torchjd.aggregation.RandomWeighting
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/sum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ Sum
===

.. autoclass:: torchjd.aggregation.Sum
:members: __call__

.. autoclass:: torchjd.aggregation.SumWeighting
:members: __call__
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/trimmed_mean.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Trimmed Mean
============

.. autoclass:: torchjd.aggregation.TrimmedMean
:members: __call__
2 changes: 2 additions & 0 deletions docs/source/docs/aggregation/upgrad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ UPGrad
======

.. autoclass:: torchjd.aggregation.UPGrad
:members: __call__

.. autoclass:: torchjd.aggregation.UPGradWeighting
:members: __call__
6 changes: 5 additions & 1 deletion src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def forward(self, matrix: Matrix, /) -> Tensor:
"""Computes the aggregation from the input matrix."""

def __call__(self, matrix: Tensor, /) -> Tensor:
"""Computes the aggregation from the input matrix and applies all registered hooks."""
"""
Computes the aggregation from the input matrix and applies all registered hooks.

:param matrix: The Jacobian to aggregate.
"""
Aggregator._check_is_matrix(matrix)
return super().__call__(matrix)

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting

SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]


class AlignedMTLWeighting(Weighting[PSDMatrix]):
class AlignedMTLWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.AlignedMTL`.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchjd._linalg import PSDMatrix

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting

check_dependencies_are_installed(["cvxpy", "clarabel"])

Expand All @@ -18,7 +18,7 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class CAGradWeighting(Weighting[PSDMatrix]):
class CAGradWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.CAGrad`.
Expand Down
6 changes: 2 additions & 4 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._utils.str import vector_to_str
from ._weighting_bases import Weighting
from ._weighting_bases import MatrixWeighting


class ConstantWeighting(Weighting[Matrix]):
class ConstantWeighting(MatrixWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
weights.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class DualProjWeighting(Weighting[PSDMatrix]):
class DualProjWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.DualProj`.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class GradVacWeighting(Weighting[PSDMatrix], Stateful):
class GradVacWeighting(GramianWeighting, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class IMTLGWeighting(Weighting[PSDMatrix]):
class IMTLGWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.IMTLG`.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class KrumWeighting(Weighting[PSDMatrix]):
class KrumWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.Krum`.
Expand Down
6 changes: 2 additions & 4 deletions src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting
from ._weighting_bases import MatrixWeighting


class MeanWeighting(Weighting[Matrix]):
class MeanWeighting(MatrixWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class MGDAWeighting(Weighting[PSDMatrix]):
class MGDAWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.MGDA`.
Expand Down
5 changes: 2 additions & 3 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.

from torchjd._linalg import Matrix
from torchjd.aggregation._mixins import Stateful

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Weighting
from ._weighting_bases import MatrixWeighting

check_dependencies_are_installed(["cvxpy", "ecos"])

Expand All @@ -21,7 +20,7 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class _NashMTLWeighting(Weighting[Matrix], Stateful):
class _NashMTLWeighting(MatrixWeighting, Stateful):
"""
:class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that
extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class PCGradWeighting(Weighting[PSDMatrix]):
class PCGradWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.PCGrad`.
Expand Down
6 changes: 2 additions & 4 deletions src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting
from ._weighting_bases import MatrixWeighting


class RandomWeighting(Weighting[Matrix]):
class RandomWeighting(MatrixWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
at each call.
Expand Down
6 changes: 2 additions & 4 deletions src/torchjd/aggregation/_sum.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting
from ._weighting_bases import MatrixWeighting


class SumWeighting(Weighting[Matrix]):
class SumWeighting(MatrixWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting


class UPGradWeighting(Weighting[PSDMatrix]):
class UPGradWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.UPGrad`.
Expand Down
Loading
Loading