-
Notifications
You must be signed in to change notification settings - Fork 15
typing(aggregation): Add structure and none based weightings #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
|
|
||
| from torchjd._linalg import Matrix | ||
|
|
||
|
|
||
| @dataclass | ||
| class Structure: | ||
| m: int | ||
| device: torch.device | ||
| dtype: torch.dtype | ||
|
|
||
|
|
||
| def extract_structure(matrix: Matrix) -> Structure: | ||
| return Structure(m=matrix.shape[0], device=matrix.device, dtype=matrix.dtype) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,11 +6,11 @@ | |
|
|
||
| from torch import Tensor, nn | ||
|
|
||
| from torchjd._linalg import PSDTensor, is_psd_tensor | ||
| from torchjd._linalg import Matrix, PSDTensor, Structure, extract_structure, is_psd_tensor | ||
|
|
||
| _T = TypeVar("_T", contravariant=True, bound=Tensor) | ||
| _FnInputT = TypeVar("_FnInputT", bound=Tensor) | ||
| _FnOutputT = TypeVar("_FnOutputT", bound=Tensor) | ||
| _T = TypeVar("_T", contravariant=True) | ||
| _FnInputT = TypeVar("_FnInputT") | ||
| _FnOutputT = TypeVar("_FnOutputT") | ||
|
|
||
|
|
||
| class Weighting(nn.Module, ABC, Generic[_T]): | ||
|
|
@@ -27,11 +27,9 @@ def __init__(self) -> None: | |
| def forward(self, stat: _T, /) -> Tensor: | ||
| """Computes the vector of weights from the input stat.""" | ||
|
|
||
| def __call__(self, stat: Tensor, /) -> Tensor: | ||
| def __call__(self, stat: object, /) -> Tensor: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is wrong. This type should be _T. Before, Tensor was the public equivalent of both Matrix and PSDMatrix, so it worked, but it's quite wrong. If we do type this as _T, we need Matrix and PSDMatrix to become public, and I think we need users to cast to Matrix and PSDMatrix in order to use call. Finding a workaround would be huge (this problem + the issue of having a gap between public and private type, solved at once). |
||
| """Computes the vector of weights from the input stat and applies all registered hooks.""" | ||
|
|
||
| # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of | ||
| # stat to be Tensor. | ||
| return super().__call__(stat) | ||
|
|
||
| def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: | ||
|
|
@@ -55,6 +53,32 @@ def forward(self, stat: _T, /) -> Tensor: | |
| return self.weighting(self.fn(stat)) | ||
|
|
||
|
|
||
| class FromStructureWeighting(_Composition[Matrix]): | ||
| """ | ||
| Weighting that extracts the structure of the input matrix before applying a Weighting to it. | ||
|
|
||
| :param structure_weighting: The object responsible for extracting the vector of weights from the | ||
| structure. | ||
| """ | ||
|
|
||
| def __init__(self, structure_weighting: Weighting[Structure]) -> None: | ||
| super().__init__(structure_weighting, extract_structure) | ||
| self.structure_weighting = structure_weighting | ||
|
|
||
|
|
||
| class FromNothingWeighting(_Composition[Matrix]): | ||
| """ | ||
| Weighting that extracts nothing from the input matrix before applying a Weighting to it (i.e. to | ||
| None). | ||
|
|
||
| :param none_weighting: The object responsible for extracting the vector of weights from nothing. | ||
| """ | ||
|
|
||
| def __init__(self, none_weighting: Weighting[None]) -> None: | ||
| super().__init__(none_weighting, lambda _: None) | ||
| self.none_weighting = none_weighting | ||
|
|
||
|
|
||
| class GeneralizedWeighting(nn.Module, ABC): | ||
| r""" | ||
| Abstract base class for all weightings that operate on generalized Gramians. It has the role of | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be not only for Matrices but any non-scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. That's my plan for a future PR.