Skip to content

MultiOptimizer

minnt.optimizers.MultiOptimizer

Bases: Optimizer

A class managing multiple PyTorch optimizers, deferring all function calls to them.

Warning

The implementation is quite hacky (MultiOptimizer does not call the parent constructor of the torch.optim.Optimizer and provides only a subset of the functionality), but it seems to work well enough for minnt.optimizers.LazyAdam to work.

Info

The current limitations of the MultiOptimizer are:

  • it does not provide defaults and state properties;
  • it does not support passing a closure to the step() method;
  • it does not support hooks.
Source code in minnt/optimizers/multi_optimizer.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class MultiOptimizer(torch.optim.Optimizer):
    """A class managing multiple PyTorch optimizers, deferring all function calls to them.

    Warning:
        The implementation is quite hacky (MultiOptimizer does not call the parent constructor
        of the [torch.optim.Optimizer][] and provides only a subset of the functionality),
        but it seems to work well enough for [minnt.optimizers.LazyAdam][] to work.

    Info:
        The current limitations of the MultiOptimizer are:

        - it does not provide `defaults` and `state` properties;
        - it does not support passing a `closure` to the `step()` method;
        - it does not support hooks.
    """
    def __init__(self, optimizers: Iterable[torch.optim.Optimizer]):
        """Initializes the MultiOptimizer with a list of optimizers.

        Raises:
          ValueError: If the provided list of optimizers is empty,
            or if any parameter appears in more than one optimizer.
        """
        if not optimizers:
            raise ValueError("At least one optimizer must be provided")
        self.optimizers = list(optimizers)

        for i in range(len(self.optimizers)):
            for param_group_i in self.optimizers[i].param_groups:
                for j in range(i + 1, len(self.optimizers)):
                    for param_group_j in self.optimizers[j].param_groups:
                        overlap = set(param_group_i["params"]) & set(param_group_j["params"])
                        if overlap:
                            raise ValueError(f"Parameters appear in more than one optimizer: {overlap}")

    optimizers: list[torch.optim.Optimizer]
    """A list of optimizers managed by this MultiOptimizer instance."""

    @property
    def param_groups(self) -> list[dict]:
        """Returns a list of all parameter groups from all managed optimizers."""
        param_groups = []
        for optimizer in self.optimizers:
            param_groups.extend(optimizer.param_groups)
        return param_groups

    @property
    def defaults(self) -> None:
        raise RuntimeError("MultiOptimizer does not have defaults")

    @property
    def state(self) -> None:
        raise RuntimeError("MultiOptimizer does not have state")

    def step(self, closure: None = None) -> None:
        """Performs a single optimization step for all managed optimizers.

        While the `closure` argument is accepted to follow [torch.optim.Optimizer.step][],
        it must be `None` because the MultiOptimizer cannot pass it to multiple optimizers.

        Parameters:
          closure: Must be `None`, while it might be `Callable[[], float]` in a
            [torch.optim.Optimizer.step][] method.
        """
        assert closure is None, "MultiOptimizer.step does not support a closure"
        for optimizer in self.optimizers:
            optimizer.step()

    def zero_grad(self, set_to_none: bool = True) -> None:
        """Sets the gradients of all optimized parameters to zero for all managed optimizers.

        Parameters:
          set_to_none: See [torch.optim.Optimizer.zero_grad][].
        """
        for optimizer in self.optimizers:
            optimizer.zero_grad(set_to_none)

    def state_dict(self) -> dict[str, Any]:
        """Returns the state dictionary containing the state of all managed optimizers."""
        state = {}
        for i, optimizer in enumerate(self.optimizers):
            for key, value in optimizer.state_dict().items():
                state[f"optimizer_{i}_{key}"] = value
        return state

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        """Loads the state dictionary in all managed optimizers."""
        for i, optimizer in enumerate(self.optimizers):
            optimizer_state_dict = {}
            for key in state_dict:
                if key.startswith(f"optimizer_{i}_"):
                    optimizer_state_dict[key.removeprefix(f"optimizer_{i}_")] = state_dict[key]
            optimizer.load_state_dict(optimizer_state_dict)

__init__

__init__(optimizers: Iterable[Optimizer])

Initializes the MultiOptimizer with a list of optimizers.

Raises:

  • ValueError

    If the provided list of optimizers is empty, or if any parameter appears in more than one optimizer.

Source code in minnt/optimizers/multi_optimizer.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(self, optimizers: Iterable[torch.optim.Optimizer]):
    """Initializes the MultiOptimizer with a list of optimizers.

    Raises:
      ValueError: If the provided list of optimizers is empty,
        or if any parameter appears in more than one optimizer.
    """
    if not optimizers:
        raise ValueError("At least one optimizer must be provided")
    self.optimizers = list(optimizers)

    for i in range(len(self.optimizers)):
        for param_group_i in self.optimizers[i].param_groups:
            for j in range(i + 1, len(self.optimizers)):
                for param_group_j in self.optimizers[j].param_groups:
                    overlap = set(param_group_i["params"]) & set(param_group_j["params"])
                    if overlap:
                        raise ValueError(f"Parameters appear in more than one optimizer: {overlap}")

optimizers instance-attribute

optimizers: list[Optimizer]

A list of optimizers managed by this MultiOptimizer instance.

param_groups property

param_groups: list[dict]

Returns a list of all parameter groups from all managed optimizers.

step

step(closure: None = None) -> None

Performs a single optimization step for all managed optimizers.

While the closure argument is accepted to follow torch.optim.Optimizer.step, it must be None because the MultiOptimizer cannot pass it to multiple optimizers.

Parameters:

Source code in minnt/optimizers/multi_optimizer.py
65
66
67
68
69
70
71
72
73
74
75
76
77
def step(self, closure: None = None) -> None:
    """Performs a single optimization step for all managed optimizers.

    While the `closure` argument is accepted to follow [torch.optim.Optimizer.step][],
    it must be `None` because the MultiOptimizer cannot pass it to multiple optimizers.

    Parameters:
      closure: Must be `None`, while it might be `Callable[[], float]` in a
        [torch.optim.Optimizer.step][] method.
    """
    assert closure is None, "MultiOptimizer.step does not support a closure"
    for optimizer in self.optimizers:
        optimizer.step()

zero_grad

zero_grad(set_to_none: bool = True) -> None

Sets the gradients of all optimized parameters to zero for all managed optimizers.

Parameters:

Source code in minnt/optimizers/multi_optimizer.py
79
80
81
82
83
84
85
86
def zero_grad(self, set_to_none: bool = True) -> None:
    """Sets the gradients of all optimized parameters to zero for all managed optimizers.

    Parameters:
      set_to_none: See [torch.optim.Optimizer.zero_grad][].
    """
    for optimizer in self.optimizers:
        optimizer.zero_grad(set_to_none)

state_dict

state_dict() -> dict[str, Any]

Returns the state dictionary containing the state of all managed optimizers.

Source code in minnt/optimizers/multi_optimizer.py
88
89
90
91
92
93
94
def state_dict(self) -> dict[str, Any]:
    """Returns the state dictionary containing the state of all managed optimizers."""
    state = {}
    for i, optimizer in enumerate(self.optimizers):
        for key, value in optimizer.state_dict().items():
            state[f"optimizer_{i}_{key}"] = value
    return state

load_state_dict

load_state_dict(state_dict: dict[str, Any]) -> None

Loads the state dictionary in all managed optimizers.

Source code in minnt/optimizers/multi_optimizer.py
 96
 97
 98
 99
100
101
102
103
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Loads the state dictionary in all managed optimizers."""
    for i, optimizer in enumerate(self.optimizers):
        optimizer_state_dict = {}
        for key in state_dict:
            if key.startswith(f"optimizer_{i}_"):
                optimizer_state_dict[key.removeprefix(f"optimizer_{i}_")] = state_dict[key]
        optimizer.load_state_dict(optimizer_state_dict)