From 4fe4918eb9de60765d42f54d2dd8c1b18d99ba6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 18:01:53 +0200 Subject: [PATCH 1/5] Add __call__ to :members: in aggregator and weighting autoclasses --- docs/source/docs/aggregation/aligned_mtl.rst | 2 ++ docs/source/docs/aggregation/cagrad.rst | 2 ++ docs/source/docs/aggregation/config.rst | 1 + docs/source/docs/aggregation/constant.rst | 2 ++ docs/source/docs/aggregation/dualproj.rst | 2 ++ docs/source/docs/aggregation/flattening.rst | 1 + docs/source/docs/aggregation/graddrop.rst | 1 + docs/source/docs/aggregation/gradvac.rst | 4 ++-- docs/source/docs/aggregation/imtl_g.rst | 2 ++ docs/source/docs/aggregation/index.rst | 3 +++ docs/source/docs/aggregation/krum.rst | 2 ++ docs/source/docs/aggregation/mean.rst | 2 ++ docs/source/docs/aggregation/mgda.rst | 2 ++ docs/source/docs/aggregation/nash_mtl.rst | 2 +- docs/source/docs/aggregation/pcgrad.rst | 2 ++ docs/source/docs/aggregation/random.rst | 2 ++ docs/source/docs/aggregation/sum.rst | 2 ++ docs/source/docs/aggregation/trimmed_mean.rst | 1 + docs/source/docs/aggregation/upgrad.rst | 2 ++ 19 files changed, 34 insertions(+), 3 deletions(-) diff --git a/docs/source/docs/aggregation/aligned_mtl.rst b/docs/source/docs/aggregation/aligned_mtl.rst index 2b9d818c..527d1370 100644 --- a/docs/source/docs/aggregation/aligned_mtl.rst +++ b/docs/source/docs/aggregation/aligned_mtl.rst @@ -4,5 +4,7 @@ Aligned-MTL =========== .. autoclass:: torchjd.aggregation.AlignedMTL + :members: __call__ .. autoclass:: torchjd.aggregation.AlignedMTLWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/cagrad.rst b/docs/source/docs/aggregation/cagrad.rst index 8a8976b3..337c0661 100644 --- a/docs/source/docs/aggregation/cagrad.rst +++ b/docs/source/docs/aggregation/cagrad.rst @@ -4,5 +4,7 @@ CAGrad ====== .. autoclass:: torchjd.aggregation.CAGrad + :members: __call__ .. autoclass:: torchjd.aggregation.CAGradWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/config.rst b/docs/source/docs/aggregation/config.rst index fba4b384..1c327312 100644 --- a/docs/source/docs/aggregation/config.rst +++ b/docs/source/docs/aggregation/config.rst @@ -4,3 +4,4 @@ ConFIG ====== .. autoclass:: torchjd.aggregation.ConFIG + :members: __call__ diff --git a/docs/source/docs/aggregation/constant.rst b/docs/source/docs/aggregation/constant.rst index 1ad4fb29..2033dd0d 100644 --- a/docs/source/docs/aggregation/constant.rst +++ b/docs/source/docs/aggregation/constant.rst @@ -4,5 +4,7 @@ Constant ======== .. autoclass:: torchjd.aggregation.Constant + :members: __call__ .. autoclass:: torchjd.aggregation.ConstantWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/dualproj.rst b/docs/source/docs/aggregation/dualproj.rst index a326530b..7038de58 100644 --- a/docs/source/docs/aggregation/dualproj.rst +++ b/docs/source/docs/aggregation/dualproj.rst @@ -4,5 +4,7 @@ DualProj ======== .. autoclass:: torchjd.aggregation.DualProj + :members: __call__ .. autoclass:: torchjd.aggregation.DualProjWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/flattening.rst b/docs/source/docs/aggregation/flattening.rst index 0214f131..b6d7f698 100644 --- a/docs/source/docs/aggregation/flattening.rst +++ b/docs/source/docs/aggregation/flattening.rst @@ -4,3 +4,4 @@ Flattening ========== .. autoclass:: torchjd.aggregation.Flattening + :members: __call__ diff --git a/docs/source/docs/aggregation/graddrop.rst b/docs/source/docs/aggregation/graddrop.rst index 3dae9f04..e8d64605 100644 --- a/docs/source/docs/aggregation/graddrop.rst +++ b/docs/source/docs/aggregation/graddrop.rst @@ -4,3 +4,4 @@ GradDrop ======== .. autoclass:: torchjd.aggregation.GradDrop + :members: __call__ diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst index 471afc00..287ff066 100644 --- a/docs/source/docs/aggregation/gradvac.rst +++ b/docs/source/docs/aggregation/gradvac.rst @@ -4,7 +4,7 @@ GradVac ======= .. autoclass:: torchjd.aggregation.GradVac - :members: reset + :members: __call__, reset .. autoclass:: torchjd.aggregation.GradVacWeighting - :members: reset + :members: __call__, reset diff --git a/docs/source/docs/aggregation/imtl_g.rst b/docs/source/docs/aggregation/imtl_g.rst index 482dd2de..ad0da1a5 100644 --- a/docs/source/docs/aggregation/imtl_g.rst +++ b/docs/source/docs/aggregation/imtl_g.rst @@ -4,5 +4,7 @@ IMTL-G ====== .. autoclass:: torchjd.aggregation.IMTLG + :members: __call__ .. autoclass:: torchjd.aggregation.IMTLGWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 73442a93..4d62f820 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -8,10 +8,13 @@ Abstract base classes --------------------- .. autoclass:: torchjd.aggregation.Aggregator + :members: __call__ .. autoclass:: torchjd.aggregation.Weighting + :members: __call__ .. autoclass:: torchjd.aggregation.GeneralizedWeighting + :members: __call__ .. autoclass:: torchjd.aggregation.Stateful :members: reset diff --git a/docs/source/docs/aggregation/krum.rst b/docs/source/docs/aggregation/krum.rst index 6cf4bb07..7290e40b 100644 --- a/docs/source/docs/aggregation/krum.rst +++ b/docs/source/docs/aggregation/krum.rst @@ -4,5 +4,7 @@ Krum ==== .. autoclass:: torchjd.aggregation.Krum + :members: __call__ .. autoclass:: torchjd.aggregation.KrumWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/mean.rst b/docs/source/docs/aggregation/mean.rst index 848c820e..30771dde 100644 --- a/docs/source/docs/aggregation/mean.rst +++ b/docs/source/docs/aggregation/mean.rst @@ -4,5 +4,7 @@ Mean ==== .. autoclass:: torchjd.aggregation.Mean + :members: __call__ .. autoclass:: torchjd.aggregation.MeanWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/mgda.rst b/docs/source/docs/aggregation/mgda.rst index 06f67cce..f285ff64 100644 --- a/docs/source/docs/aggregation/mgda.rst +++ b/docs/source/docs/aggregation/mgda.rst @@ -4,5 +4,7 @@ MGDA ==== .. autoclass:: torchjd.aggregation.MGDA + :members: __call__ .. autoclass:: torchjd.aggregation.MGDAWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/nash_mtl.rst b/docs/source/docs/aggregation/nash_mtl.rst index a23f25bb..7aca6715 100644 --- a/docs/source/docs/aggregation/nash_mtl.rst +++ b/docs/source/docs/aggregation/nash_mtl.rst @@ -4,4 +4,4 @@ Nash-MTL ======== .. autoclass:: torchjd.aggregation.NashMTL - :members: reset + :members: __call__, reset diff --git a/docs/source/docs/aggregation/pcgrad.rst b/docs/source/docs/aggregation/pcgrad.rst index 5d59c885..8f70eadb 100644 --- a/docs/source/docs/aggregation/pcgrad.rst +++ b/docs/source/docs/aggregation/pcgrad.rst @@ -4,5 +4,7 @@ PCGrad ====== .. autoclass:: torchjd.aggregation.PCGrad + :members: __call__ .. autoclass:: torchjd.aggregation.PCGradWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/random.rst b/docs/source/docs/aggregation/random.rst index 1b1a0e28..54f0493f 100644 --- a/docs/source/docs/aggregation/random.rst +++ b/docs/source/docs/aggregation/random.rst @@ -4,5 +4,7 @@ Random ====== .. autoclass:: torchjd.aggregation.Random + :members: __call__ .. autoclass:: torchjd.aggregation.RandomWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/sum.rst b/docs/source/docs/aggregation/sum.rst index b2400322..510a719d 100644 --- a/docs/source/docs/aggregation/sum.rst +++ b/docs/source/docs/aggregation/sum.rst @@ -4,5 +4,7 @@ Sum === .. autoclass:: torchjd.aggregation.Sum + :members: __call__ .. autoclass:: torchjd.aggregation.SumWeighting + :members: __call__ diff --git a/docs/source/docs/aggregation/trimmed_mean.rst b/docs/source/docs/aggregation/trimmed_mean.rst index b332662b..07dc8299 100644 --- a/docs/source/docs/aggregation/trimmed_mean.rst +++ b/docs/source/docs/aggregation/trimmed_mean.rst @@ -4,3 +4,4 @@ Trimmed Mean ============ .. autoclass:: torchjd.aggregation.TrimmedMean + :members: __call__ diff --git a/docs/source/docs/aggregation/upgrad.rst b/docs/source/docs/aggregation/upgrad.rst index a44a46e1..97642058 100644 --- a/docs/source/docs/aggregation/upgrad.rst +++ b/docs/source/docs/aggregation/upgrad.rst @@ -4,5 +4,7 @@ UPGrad ====== .. autoclass:: torchjd.aggregation.UPGrad + :members: __call__ .. autoclass:: torchjd.aggregation.UPGradWeighting + :members: __call__ From 1945e6d63ee00fcfdf6daf623c5be19cb2026514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 01:20:28 +0200 Subject: [PATCH 2/5] Improve aggregator.__call__ docstring --- src/torchjd/aggregation/_aggregator_bases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index d4be05e9..2ed1505e 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -29,7 +29,11 @@ def forward(self, matrix: Matrix, /) -> Tensor: """Computes the aggregation from the input matrix.""" def __call__(self, matrix: Tensor, /) -> Tensor: - """Computes the aggregation from the input matrix and applies all registered hooks.""" + """ + Computes the aggregation from the input matrix and applies all registered hooks. + + :param matrix: The Jacobian to aggregate. + """ Aggregator._check_is_matrix(matrix) return super().__call__(matrix) From 764efd9cb0a39835143050b9c51083c4c1a60d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 01:21:42 +0200 Subject: [PATCH 3/5] Improve docstring of Weighting.__call__ --- src/torchjd/aggregation/_weighting_bases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index e321169c..7aeb39dd 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -28,7 +28,11 @@ def forward(self, stat: _T, /) -> Tensor: """Computes the vector of weights from the input stat.""" def __call__(self, stat: Tensor, /) -> Tensor: - """Computes the vector of weights from the input stat and applies all registered hooks.""" + """ + Computes the vector of weights from the input stat and applies all registered hooks. + + :param stat: The stat from which the weights must be extracted. + """ # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of # stat to be Tensor. From 630ffb95e23b924707c773a0a1a2884240f3f315 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 01:22:01 +0200 Subject: [PATCH 4/5] Improve docstring of GeneralizedWeighting.__call__ --- src/torchjd/aggregation/_weighting_bases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 7aeb39dd..d50dcf33 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -78,6 +78,8 @@ def __call__(self, generalized_gramian: Tensor, /) -> Tensor: """ Computes the tensor of weights from the input generalized Gramian and applies all registered hooks. + + :param generalized_gramian: The tensor from which the weights must be extracted. """ assert is_psd_tensor(generalized_gramian) From ba197a9d8ab060b58ef46bf29f6d1038fe36bb10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 01:51:39 +0200 Subject: [PATCH 5/5] Add MatrixWeighting and GramianWeighting --- src/torchjd/aggregation/_aligned_mtl.py | 4 +-- src/torchjd/aggregation/_cagrad.py | 4 +-- src/torchjd/aggregation/_constant.py | 6 ++--- src/torchjd/aggregation/_dualproj.py | 4 +-- src/torchjd/aggregation/_gradvac.py | 4 +-- src/torchjd/aggregation/_imtl_g.py | 4 +-- src/torchjd/aggregation/_krum.py | 4 +-- src/torchjd/aggregation/_mean.py | 6 ++--- src/torchjd/aggregation/_mgda.py | 4 +-- src/torchjd/aggregation/_nash_mtl.py | 5 ++-- src/torchjd/aggregation/_pcgrad.py | 4 +-- src/torchjd/aggregation/_random.py | 6 ++--- src/torchjd/aggregation/_sum.py | 6 ++--- src/torchjd/aggregation/_upgrad.py | 4 +-- src/torchjd/aggregation/_weighting_bases.py | 30 ++++++++++++++++++++- 15 files changed, 57 insertions(+), 38 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index ced7ae45..07574b1b 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -12,12 +12,12 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] -class AlignedMTLWeighting(Weighting[PSDMatrix]): +class AlignedMTLWeighting(GramianWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.AlignedMTL`. diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index b008fefb..88ca66e0 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -3,7 +3,7 @@ from torchjd._linalg import PSDMatrix from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting check_dependencies_are_installed(["cvxpy", "clarabel"]) @@ -18,7 +18,7 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGradWeighting(Weighting[PSDMatrix]): +class CAGradWeighting(GramianWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.CAGrad`. diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 0485e726..8b0f7307 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,13 +1,11 @@ from torch import Tensor -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import Weighting +from ._weighting_bases import MatrixWeighting -class ConstantWeighting(Weighting[Matrix]): +class ConstantWeighting(MatrixWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined weights. diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 372cd18d..087e8805 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -7,10 +7,10 @@ from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class DualProjWeighting(Weighting[PSDMatrix]): +class DualProjWeighting(GramianWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.DualProj`. diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index cc518fbb..5a1edc35 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -10,10 +10,10 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class GradVacWeighting(Weighting[PSDMatrix], Stateful): +class GradVacWeighting(GramianWeighting, Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 42062e93..672f4d51 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -5,10 +5,10 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class IMTLGWeighting(Weighting[PSDMatrix]): +class IMTLGWeighting(GramianWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.IMTLG`. diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 70e07202..d48f8918 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -5,10 +5,10 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class KrumWeighting(Weighting[PSDMatrix]): +class KrumWeighting(GramianWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.Krum`. diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 2ebe208d..13a72649 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,13 +1,11 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import MatrixWeighting -class MeanWeighting(Weighting[Matrix]): +class MeanWeighting(MatrixWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 510fa725..575f21a4 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -4,10 +4,10 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class MGDAWeighting(Weighting[PSDMatrix]): +class MGDAWeighting(GramianWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.MGDA`. diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index e48b32c8..9cd3e7bc 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,11 +1,10 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. -from torchjd._linalg import Matrix from torchjd.aggregation._mixins import Stateful from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Weighting +from ._weighting_bases import MatrixWeighting check_dependencies_are_installed(["cvxpy", "ecos"]) @@ -21,7 +20,7 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class _NashMTLWeighting(Weighting[Matrix], Stateful): +class _NashMTLWeighting(MatrixWeighting, Stateful): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 770ffe09..dd965ff7 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -7,10 +7,10 @@ from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class PCGradWeighting(Weighting[PSDMatrix]): +class PCGradWeighting(GramianWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.PCGrad`. diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 8345a15c..ca32d601 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,13 +2,11 @@ from torch import Tensor from torch.nn import functional as F -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import MatrixWeighting -class RandomWeighting(Weighting[Matrix]): +class RandomWeighting(MatrixWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights at each call. diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 0754f466..7e6beb55 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,13 +1,11 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix - from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Weighting +from ._weighting_bases import MatrixWeighting -class SumWeighting(Weighting[Matrix]): +class SumWeighting(MatrixWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 172e55a6..b09c0a59 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -8,10 +8,10 @@ from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting -class UPGradWeighting(Weighting[PSDMatrix]): +class UPGradWeighting(GramianWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.UPGrad`. diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index d50dcf33..2655aa28 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,7 +6,7 @@ from torch import Tensor, nn -from torchjd._linalg import PSDTensor, is_psd_tensor +from torchjd._linalg import Matrix, PSDMatrix, PSDTensor, is_psd_tensor _T = TypeVar("_T", contravariant=True, bound=Tensor) _FnInputT = TypeVar("_FnInputT", bound=Tensor) @@ -84,3 +84,31 @@ def __call__(self, generalized_gramian: Tensor, /) -> Tensor: assert is_psd_tensor(generalized_gramian) return super().__call__(generalized_gramian) + + +# Subclasses used only to redefine the __call__ method with more specific parameter names and +# docstrings. Note that MatrixWeighting <: Weighting[Matrix] <: Weighting[PSDMatrix], because +# PSDMatrix <: Matrix and Weighting[_T] is contravariant with _T. +# Also note that we don't have: MatrixWeighting <: GramianWeighting. GramianWeighting is not +# just an alias of Weighting[PSDMatrix], it's a subtype of it. So the type Weighting[PSDMatrix] +# should still be used when we expect a Weighting that works at least on PSD matrices. + + +class MatrixWeighting(Weighting[Matrix]): + def __call__(self, matrix: Tensor, /) -> Tensor: + """ + Computes the vector of weights from the input matrix and applies all registered hooks. + + :param matrix: The matrix from which the weights must be extracted. + """ + return super().__call__(matrix) + + +class GramianWeighting(Weighting[PSDMatrix]): + def __call__(self, gramian: Tensor, /) -> Tensor: + """ + Computes the vector of weights from the input gramian and applies all registered hooks. + + :param gramian: The gramian from which the weights must be extracted. + """ + return super().__call__(gramian)