Skip to content

LazyAdam

minnt.optimizers.LazyAdam

Bases: MultiOptimizer

A class implementing the LazyAdam optimizer.

The optimizer applies torch.optim.SparseAdam to parameters of all torch.nn.Embedding and torch.nn.EmbeddingBag layers with sparse gradients, and torch.optim.Adam to all other parameters. By default, all embedding layers in the module are set to have sparse gradients first.

Warning

The implementation of MultiOptimizer is quite hacky (it does not call the parent constructor of the torch.optim.Optimizer and provides only a subset of the functionality), but it seems to work well enough for LazyAdam to work.

Info

The current limitations of the MultiOptimizer and thus LazyAdam are:

  • it does not provide defaults and state properties;
  • it does not support passing a closure to the step() method;
  • it does not support hooks.
Source code in minnt/optimizers/lazy_adam.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class LazyAdam(MultiOptimizer):
    """A class implementing the LazyAdam optimizer.

    The optimizer applies [torch.optim.SparseAdam][] to parameters of all [torch.nn.Embedding][] and
    [torch.nn.EmbeddingBag][] layers with sparse gradients, and [torch.optim.Adam][] to all other parameters.
    By default, all embedding layers in the module are set to have sparse gradients first.

    Warning:
        The implementation of [MultiOptimizer][minnt.optimizers.MultiOptimizer] is quite hacky (it does not call
        the parent constructor of the [torch.optim.Optimizer][] and provides only a subset of the functionality),
        but it seems to work well enough for `LazyAdam` to work.

    Info:
        The current limitations of the [MultiOptimizer][minnt.optimizers.MultiOptimizer] and thus `LazyAdam` are:

        - it does not provide `defaults` and `state` properties;
        - it does not support passing a `closure` to the `step()` method;
        - it does not support hooks.
    """
    def __init__(
        self,
        module: torch.nn.Module,
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-7,
        *,
        make_embeddings_sparse: bool = True,
        adam_param_groups: Iterable | None = None,
    ) -> None:
        """Initializes the LazyAdam optimizer.

        Parameters:
          module: The module containing the embedding and non-embedding layers to be optimized.
          lr: The learning rate for both optimizers. Default is `0.001`.
          betas: The beta coefficients for both optimizers. Default is `(0.9, 0.999)`.
          eps: The epsilon value for both optimizers. Beware that the default value `1e-7` is
            different from `eps=1e-8` in [torch.optim.Adam][] and [torch.optim.SparseAdam][].
          make_embeddings_sparse: If `True` (default), sets the `sparse` attribute of all
            [torch.nn.Embedding][] and [torch.nn.EmbeddingBag][] layers to `True`; otherwise `LazyAdam` will
            consider only those layers that already have `sparse=True`.
          adam_param_groups: An optional iterable of parameters to optimize using [torch.optim.Adam][].
            If `None` (default), all `module.parameters()` that are not part of embedding layers with
            sparse gradients will be optimized using [torch.optim.Adam][].
        """
        sparse_params = []

        def collect_sparse_params(m: torch.nn.Module) -> None:
            if isinstance(m, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
                if make_embeddings_sparse:
                    m.sparse = True
                if m.sparse:
                    nonlocal sparse_params
                    sparse_params.extend(m.parameters())
        module.apply(collect_sparse_params)
        assert sparse_params, "No embedding layers with sparse gradients found in the module."

        if adam_param_groups is None:
            adam_param_groups = []
            sparse_param_set = set(sparse_params)
            for param in module.parameters():
                if param not in sparse_param_set:
                    adam_param_groups.append(param)

        self.adam = torch.optim.Adam(adam_param_groups, lr=lr, betas=betas, eps=eps)
        self.sparse_adam = torch.optim.SparseAdam(sparse_params, lr=lr, betas=betas, eps=eps)

        super().__init__([self.adam, self.sparse_adam])

    adam: torch.optim.Optimizer
    """The [torch.optim.Adam][] optimizer for non-embedding parameters."""

    sparse_adam: torch.optim.Optimizer
    """The [torch.optim.SparseAdam][] optimizer for embedding parameters."""

__init__

__init__(
    module: Module,
    lr: float = 0.001,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-07,
    *,
    make_embeddings_sparse: bool = True,
    adam_param_groups: Iterable | None = None
) -> None

Initializes the LazyAdam optimizer.

Parameters:

  • module (Module) –

    The module containing the embedding and non-embedding layers to be optimized.

  • lr (float, default: 0.001 ) –

    The learning rate for both optimizers. Default is 0.001.

  • betas (tuple[float, float], default: (0.9, 0.999) ) –

    The beta coefficients for both optimizers. Default is (0.9, 0.999).

  • eps (float, default: 1e-07 ) –

    The epsilon value for both optimizers. Beware that the default value 1e-7 is different from eps=1e-8 in torch.optim.Adam and torch.optim.SparseAdam.

  • make_embeddings_sparse (bool, default: True ) –

    If True (default), sets the sparse attribute of all torch.nn.Embedding and torch.nn.EmbeddingBag layers to True; otherwise LazyAdam will consider only those layers that already have sparse=True.

  • adam_param_groups (Iterable | None, default: None ) –

    An optional iterable of parameters to optimize using torch.optim.Adam. If None (default), all module.parameters() that are not part of embedding layers with sparse gradients will be optimized using torch.optim.Adam.

Source code in minnt/optimizers/lazy_adam.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    module: torch.nn.Module,
    lr: float = 0.001,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-7,
    *,
    make_embeddings_sparse: bool = True,
    adam_param_groups: Iterable | None = None,
) -> None:
    """Initializes the LazyAdam optimizer.

    Parameters:
      module: The module containing the embedding and non-embedding layers to be optimized.
      lr: The learning rate for both optimizers. Default is `0.001`.
      betas: The beta coefficients for both optimizers. Default is `(0.9, 0.999)`.
      eps: The epsilon value for both optimizers. Beware that the default value `1e-7` is
        different from `eps=1e-8` in [torch.optim.Adam][] and [torch.optim.SparseAdam][].
      make_embeddings_sparse: If `True` (default), sets the `sparse` attribute of all
        [torch.nn.Embedding][] and [torch.nn.EmbeddingBag][] layers to `True`; otherwise `LazyAdam` will
        consider only those layers that already have `sparse=True`.
      adam_param_groups: An optional iterable of parameters to optimize using [torch.optim.Adam][].
        If `None` (default), all `module.parameters()` that are not part of embedding layers with
        sparse gradients will be optimized using [torch.optim.Adam][].
    """
    sparse_params = []

    def collect_sparse_params(m: torch.nn.Module) -> None:
        if isinstance(m, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
            if make_embeddings_sparse:
                m.sparse = True
            if m.sparse:
                nonlocal sparse_params
                sparse_params.extend(m.parameters())
    module.apply(collect_sparse_params)
    assert sparse_params, "No embedding layers with sparse gradients found in the module."

    if adam_param_groups is None:
        adam_param_groups = []
        sparse_param_set = set(sparse_params)
        for param in module.parameters():
            if param not in sparse_param_set:
                adam_param_groups.append(param)

    self.adam = torch.optim.Adam(adam_param_groups, lr=lr, betas=betas, eps=eps)
    self.sparse_adam = torch.optim.SparseAdam(sparse_params, lr=lr, betas=betas, eps=eps)

    super().__init__([self.adam, self.sparse_adam])

adam instance-attribute

adam: Optimizer

The torch.optim.Adam optimizer for non-embedding parameters.

sparse_adam instance-attribute

sparse_adam: Optimizer

The torch.optim.SparseAdam optimizer for embedding parameters.