FSDP Scheduled Fine-Tuning¶
Important
This guide is for PyTorch’s original, wrapper-based FSDP API sometimes referred to as FSDP1. Using the newer
composable distributed API fully_shard a.k.a. FSDP2 is likely preferable for most use cases. See the
FTS guide for using it here.
Overview¶
FinetuningScheduler (FTS) now supports flexible, multi-phase, scheduled fine-tuning
with the Fully Sharded Data Parallel (FSDP) strategy (
FSDPStrategy). This tutorial
assumes a basic understanding of FSDP training, please see
this PyTorch tutorial for a good introduction to
FSDP training.
As with standard FSDP usage, FSDP wrapping of a LightningModule
can be performed either by providing an auto_wrap_policy or (for maximal control) by overriding the
configure_model method of LightningModule and
manually wrapping the module.
This tutorial walks through the configuration of an example multi-phase, scheduled FSDP fine-tuning training session and largely uses the same code as the basic scheduled fine-tuning for SuperGLUE examples.
Example: Multi-Phase Scheduled Fine-Tuning with FSDP¶
Demonstration FTS FSDP training/profiling configurations and a DDP baseline for comparison are available under
./fts_examples/config/advanced/fsdp.
Most of these FTS FSDP training examples have the same dependencies as the basic scheduled fine-tuning for SuperGLUE examples. Running the basic example.
Note
The examples below are not configured to execute a full training session but instead to generate the minimal meaningful profiling statistics for analysis and exposition (e.g. using only 2 batches, very limited epochs, etc.)
The demo schedule configurations are composed with the basic FTS example’s shared defaults
(./config/fts_defaults.yaml) and can be executed as follows:
cd ./fts_examples
# there is an open issue regarding superfluous profiler messages (still as of 2023.04.15)
# setting the environmental variable below is a workaround to keep the example output clean:
export TORCH_CPP_LOG_LEVEL=ERROR
# Profiled demo of basic scheduled fine-tuning with FSDP
python fts_superglue.py fit --config config/advanced/fsdp/fts_fsdp_basic_profile.yaml
# Profiled demo of FSDP scheduled fine-tuning using the ``awp_overrides`` option:
python fts_superglue.py fit --config config/advanced/fsdp/fts_fsdp_awp_overrides_profile.yaml
# Profiled demo of comparable DDP scheduled fine-tuning baseline:
python fts_superglue.py fit --config config/advanced/fsdp/fts_ddp_fsdp_baseline_profile.yaml
# Profiled demo of FSDP scheduled fine-tuning with CPU Offloading but full precision
# (for reference, not reviewed in this tutorial)
python fts_superglue.py fit --config config/advanced/fsdp/fts_fsdp_awp_overrides_offload_profile.yaml
Basic Scheduled Fine-Tuning with FSDP¶
As you’ll see below, scheduled fine-tuning with FSDP is pretty straightforward! All one need do:
Pass
use_orig_paramsto the FSDP strategy configuration.Provide a simple
auto_wrap_policyconfiguration (not technically required but almost always desired).
For a given fine-tuning schedule:
10:
2 params:
3 - model.classifier.*
4 max_transition_epoch: 1
51:
6 params:
7 - model.pooler.dense.*
8 - model.deberta.encoder.layer.11.(output|attention|intermediate).*
9 max_transition_epoch: 2
102:
11 params:
12 - model.deberta.encoder.layer.([0-9]|10).(output|attention|intermediate).*
13 - model.deberta.encoder.LayerNorm.bias
14 - model.deberta.encoder.LayerNorm.weight
15 - model.deberta.encoder.rel_embeddings.weight
We can just define an auto_wrap_policy for our DeBERTa-v3 module, directing FTS/FSDP to wrap the specified Transformer layers in separate FSDP modules:
1strategy:
2 class_path: lightning.pytorch.strategies.FSDPStrategy
3 init_args:
4 # other FSDP args as desired ...
5 use_orig_params: True
6 auto_wrap_policy:
7 class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
8 init_args:
9 module_classes: !!set
10 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
That’s it! Note that we set use_orig_params to True in line 5 as it allows for more flexible fine-tuning schedules.
In the next section, we’ll cover some of the more advanced configuration options available for customizing scheduled fine-tuning with FSDP.
Advanced FSDP Wrapping For Scheduled Fine-Tuning¶
There are a number of usage contexts that might motivate moving beyond the simple configuration above. For instance:
Potential Use case |
Relevant Features & Info |
|---|---|
Optimize resource utilization (whether memory, compute or network) |
|
More granular control over module wrapping policy w/o manually writing a “configure_model” method |
|
A desire to use FSDP in the default “use_orig_params=False” mode |
As with standard FSDP module wrapping, one can use an auto_wrap_policy to wrap a model for FSDP scheduled
fine-tuning. In the current FTS release, there is only one FTS-specific FSDP configuration enhancement to consider:
the awp_overrides list.
awp_overrides is an optional list of module names
that should be wrapped in separate FSDP instances, complementing the modules that would be individually wrapped by
auto_wrap_policy provided in the
FSDPStrategy strategy
configuration.
Starting with a defined auto_wrap_policy and providing module name-based complements/overrides as needed using
awp_overrides is often the most expedient approach
to auto-wrapping models in alignment with a fine-tuning schedule.
We again start by defining a simple fine-tuning schedule that we would like to ensure our module wrapping supports:
10:
2 params:
3 - model.classifier.*
4 max_transition_epoch: 1
51:
6 params:
7 - model.pooler.dense.*
8 - model.deberta.encoder.layer.11.(output|attention|intermediate).*
9 max_transition_epoch: 2
102:
11 params:
12 - model.deberta.encoder.layer.([0-9]|10).(output|attention|intermediate).*
13 - model.deberta.encoder.LayerNorm.bias
14 - model.deberta.encoder.LayerNorm.weight
15 - model.deberta.encoder.rel_embeddings.weight
16 # excluding these parameters from the schedule to enhance the debugging demonstration
17 #- model.deberta.embeddings.LayerNorm.bias
18 #- model.deberta.embeddings.LayerNorm.weight
19 #- model.deberta.embeddings.word_embeddings.weight
We define the auto_wrap_policy for our DeBERTa-v3 module as follows:
1strategy:
2 class_path: lightning.pytorch.strategies.FSDPStrategy
3 init_args:
4 # other FSDP args as desired ...
5 auto_wrap_policy:
6 class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
7 init_args:
8 module_classes: !!set
9 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
10 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Embeddings
11 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Encoder
We’ll inspect the rationale for this policy below, but first, notice we have not referenced our classifier and
pooler layers. Because we would like to thaw our classifier and pooler layers in separate phases from some
other layers, we need to separately wrap these layers as well. If we specified separate wrapping of all Linear
layers however in our auto_wrap_policy, we would end up unnecessarily (and in many cases problematically) separately
wrapping the many Linear layers within our currently FSDP wrapped modules (DebertaV2Layer etc.).
To facilitate module wrapping in alignment with fine-tuning schedule phases, FTS provides the
awp_overrides feature which allows users to provide
module name-based complements to a given auto_wrap_policy.
In this case, simply listing the names of (or regex patterns matching) modules we would like to separately wrap allows
us to achieve FSDP wrapping that aligns with our fine-tuning schedule. FTS support for FSDP training is provided via a
StrategyAdapter
(FSDPStrategyAdapter). Configuration for FTS-extensions of strategies
like FSDP is passed to FTS via the
strategy_adapter_cfg configuration dictionary.
So in our example, we can pass the awp_overrides
configuration option to FTS like so:
1# in ./fts_examples/config/advanced/fsdp/fts_fsdp_awp_overrides_profile.yaml
2...
3 - class_path: finetuning_scheduler.FinetuningScheduler
4 init_args:
5 ft_schedule: ./config/RteBoolqModule_ft_schedule_deberta_base_fsdp.yaml
6 max_depth: 2
7 strategy_adapter_cfg:
8 awp_overrides: ["model.pooler.dense", "model.classifier"]
9...
Finally, we configure the FSDP training strategy as desired per usual, for instance, specifying
activation_checkpointing_policy and cpu_offload configurations in addition the auto_wrap_policy we defined above:
1# in ./fts_examples/config/advanced/fsdp/fts_fsdp_awp_overrides_profile.yaml
2 ...
3 strategy:
4 class_path: lightning.pytorch.strategies.FSDPStrategy
5 init_args:
6 cpu_offload: false
7 activation_checkpointing_policy: !!set
8 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
9 auto_wrap_policy:
10 class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
11 init_args:
12 module_classes: !!set
13 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
14 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Embeddings
15 ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Encoder
That’s all there is to it! We’ve successfully defined our fine-tuning schedule and FSDP wrapped our model in a manner that supports FSDP multi-phase scheduled fine-tuning.
Additional FSDP Wrapping and Debugging Guidance¶
In order to support multi-phase scheduled fine-tuning with FSDP in use_orig_params=False mode, FTS’s key precondition
is that the defined fine-tuning schedule phases have disjoint sets of FSDP-flattened parameters (a FlatParameter is created when wrapping a set of
modules in a FSDP instance/unit). This constraint is derived from the fact that (if in use_orig_params=False mode) the requires_grad attribute
must be the same for all parameters flattened into the same FlatParameter. [1]
FTS will attempt to validate that the module is wrapped in a manner that aligns with the defined fine-tuning schedule phases prior to the start of training and provide detailed feedback for the user if a misalignment is discovered.
For example, note that because we wanted to thaw some DebertaV2Layer s separately from others, we directed FSDP to
wrap DebertaV2Layer s in their own FSDP instances rather than just the entire DebertaV2Encoder.
What happens if we just direct FSDP to wrap DebertaV2Layer s and not DebertaV2Encoder s and
DebertaV2Embeddings as well?
FTS stops before beginning training and provides extensive context via this error message:
"Fine-tuning schedule phases do not have disjoint FSDP-flattened parameter sets. Because the `requires_grad` attribute of FSDP-flattened parameters currently must be the same for all flattened parameters (if in ``use_orig_params=False`` mode), fine-tuning schedules must avoid thawing parameters in the same FSDP-flattened parameter in different phases. Please ensure parameters associated with each phase are wrapped in separate phase-aligned FSDP instances.
In this particular case, there are parameters not included in your fine-tuning schedule that span more than one fine-tuning phase. HINT: parameters associated with unwrapped modules will be included in the top-level (aka 'root') FSDP instance so ensuring all modules associated with fine-tuning scheduled parameters are wrapped separately from the top-level FSDP instance may avoid triggering this exception.
The following logical parameters are associated with an FSDP-flattened parameter that spans more than one fine-tuning phase. The mapping of each logical parameter with the module name wrapped by its associated FSDP instance is provided below:
{'model.deberta.embeddings.LayerNorm.bias': 'DebertaV2ForSequenceClassification',
'model.deberta.embeddings.LayerNorm.weight': 'DebertaV2ForSequenceClassification',
'model.deberta.embeddings.word_embeddings.weight': 'DebertaV2ForSequenceClassification',
'model.deberta.encoder.LayerNorm.bias': 'DebertaV2ForSequenceClassification',
'model.deberta.encoder.LayerNorm.weight': 'DebertaV2ForSequenceClassification',
'model.deberta.encoder.rel_embeddings.weight': 'DebertaV2ForSequenceClassification'}"
This helps us understand that we have parameters that all belong to the same top-level FSDP instance (the instance
that wraps DebertaV2ForSequenceClassification). By failing to specify separate wrapping of DebertaV2Encoder s,
parameters associated with that module fell to the top-level/root FSDP instance to be managed. While
DebertaV2Embeddings parameters were not included in our schedule, they still must be wrapped by FSDP and so also are
included with DebertaV2Encoder parameters in the same top-level FlatParameter. If training had been permitted
to proceed in this case, DebertaV2Embeddings parameters would have been thawed along with the DebertaV2Encoder
parameters in phase 2, violating of our specified fine-tuning schedule.
To avoid violating the phase-wise disjointness constraint, we add DebertaV2Encoder to our auto_wrap_policy.
While not technically required, we add DebertaV2Embeddings separately as well for future experimental flexibility.
As always, if needed, one can alternatively override configure_model and manually wrap a given
LightningModule to align with a desired fine-tuning schedule.
Warning
FSDPStrategyAdapter is in BETA and subject to change. The
interface can bring breaking changes and new features with the next release of PyTorch.
Note
The no_decay attribute that FTS supports on
LightningModule with the base
StrategyAdapter is not currently supported in the context of
FSDP fine-tuning.
Note
Resuming across heterogeneous use_orig_params contexts with FTS is not currently supported (e.g.
use_orig_params=True checkpoints need to be resumed with use_orig_params=True set)
Tip
If FSDP training with use_orig_params=True, DEBUG level logging will provide parameter shard allocation
diagnostic info where relevant.
Tip
If you want to extend FTS to use a custom, currently unsupported strategy or override current FTS behavior with a
given training strategy, subclassing StrategyAdapter is a way to do
so.