####################################
Model Parallel Scheduled Fine-Tuning
####################################
Overview
********
:class:`~finetuning_scheduler.fts.FinetuningScheduler` (FTS) now supports flexible, multi-phase, scheduled fine-tuning
with the :external+pl:class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy` strategy, enabling use
of PyTorch's composable distributed (e.g. ``fully_shard``, ``checkpoint``) and Tensor Parallelism (TP) APIs.
FTS augments Lightning's Model Parallel strategy by allowing users to apply the ``fully_shard`` API using module
name/pattern-based configuration instead of manually inspecting modules and applying the API in
``LightningModule.configure_model`` (see
:attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_plan`).
As the best way to learn how to use this FTS functionality may be by example, feel free to skip the discussion below
and move directly to :ref:`reviewing/running the examples` in this guide.
FTS 'Auto' FSDP2 Plan Configuration
***********************************
As with standard ``fully_shard`` (a.k.a. ``FSDP2``) usage, preparation of a
:external+pl:class:`~lightning.pytorch.core.module.LightningModule` for ``fully_shard`` training (a.k.a. ``FSDP2``, used
interchangeably in this tutorial) can be performed by providing manual FSDP2 sharding plan directives in the
``configure_model`` method of :external+pl:class:`~lightning.pytorch.core.module.LightningModule`.
Conveniently with FTS though, users can apply the ``fully_shard`` composable API using module name/pattern-based
configuration instead of manually inspecting modules and applying the API via ``LightningModule.configure_model`` method
customization.
The desired FSDP2 composition patterns are specified in an optional dictionary of module names or regex pattern keys
(:attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_plan`).
- The module name/pattern-based keys are associated with a dictionary of ``fully_shard`` API keyword arguments to apply
to matching modules.
- :attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_plan` directives can also be composed
with explicit ``fully_shard`` calls in ``LightningModule.configure_model``, as the ``fsdp_plan`` directives will only
invoke ``fully_shard`` on a specified module if it was not already applied to that module.
- All valid ``fully_shard`` API keyword arguments are supported.
- :attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_plan` directives are applied in the
order provided in the ``fsdp_plan`` dictionary.
Additionally, ``fsdp_plan`` supports ``act_ckpt`` and ``cpu_offload_policy`` keyword args described below.
.. note::
It should be noted the 'auto' FSDP2 plan configuration provided by FTS refers to the generation and application of
FSDP2 directives based upon the user's desired module name or regex patterns. FTS does not as of writing determine
which modules and ``fully_shard`` configurations to apply to a given model.
.. _model-parallel-fsdp2-auto-plan-aliases:
FSDP2 Auto Plan Convenience Aliases
***********************************
In addition to all valid ``fully_shard`` API keyword arguments, ``fsdp_plan`` (and ``fsdp_default_kwargs``) support
``act_ckpt`` and ``cpu_offload_policy`` keyword arguments.
**cpu_offload_policy**: This is a convenience alias that will apply ``CPUOffloadPolicy`` to the matching module(s) along
with any provided ``Dict`` of policy keyword args.
**act_ckpt**: For specified module/patterns (or ``fsdp_default_kwargs``), ``act_ckpt`` allows one to pass a string alias
specifying the use of the desired activation checkpointing (AC) API as well as an optional ``Dict`` of activation
checkpointing keyword arguments. The specified AC APIs will be applied to the matching module(s) before ``fully_shard``.
The currently supported AC APIs are listed below. (non-composable API :sup:`*`)
.. _model-parallel-supported-ac-apis:
- *composable*: ``torch.distributed._composable.checkpoint_activation.checkpoint``
- *wrapped* :sup:`*`: ``torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper``
- *wrapped_offload* :sup:`*`: ``torch.distributed.algorithms._checkpoint.checkpoint_wrapper.offload_wrapper``
.. note::
If using a non-composable AC API (NCAC API), a user's ``LightningModule`` will be dynamically composed with an
adapter that will allow FTS to use the NCAC API while in composition with composable APIs like ``fully_shard``.
This is similar to FSDP2's approach to `compositional enrichment `_
(via dynamic subclassing).
.. raw:: html
FSDP2 and FTS dynamic subclasses, NCAC adapted user module
.. warning::
When specific features of the NCAC APIs aren't required, using the composable AC API is recommended instead.
Dynamically adapting the NCAC APIs is experimental and not all NCAC API functionality may work as intended in that
context.
.. _model-parallel-fsdp-default-kwargs:
FSDP2 Default Keyword Arguments
*******************************
As applying a common set of defaults to all FSDP2 directives is often useful, flexible
defaults to be applied to all ``fully_shard`` directives can be provided in an optional dictionary (
:attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_default_kwargs`). Module
name/pattern-specific keyword arguments provided via ``fsdp_plan`` will take precedence over these default
directives. All keyword arguments valid for ``fsdp_plan`` are supported.
.. _model-parallel-fine-tuning-examples:
FTS Distributed Composable API Training Examples
************************************************
Distributed multi-phase scheduled finetuning is simpler and more powerful than before with FTS's enhanced support for
the ``fully_shard``/FSDP2 API. Using composable distributed APIs like ``fully_shard`` and ``checkpoint`` allows for the
composition of different forms of parallelism (e.g. FSDP2 and Tensor Parallel, other forms of parallelism coming soon
like Pipeline and Context Parallel).
The three examples in this tutorial assume basic familiarity with FSDP and Tensor Parallel training. For a good
introduction, please see the following PyTorch tutorials for
`FSDP `_ and
`TP `_ respectively.
.. note::
The examples below are not configured to execute a full training session but instead to generate the minimal
meaningful profiling statistics for analysis and exposition (e.g. using only 4 batches, a small configuration for
``torchtitan``'s latest Llama etc.)
Starting from this tutorial's base directory (``fts_examples/model_parallel``) demo schedule configurations are composed
with the same set of shared defaults , (``./config/defaults/*.yaml``) and can be executed as follows:
.. code-block:: bash
cd ./fts_examples/model_parallel
# Training with FSDP2 'Auto' Plan:
python mp_examples.py fit --config config/fts_fsdp_auto_plan.yaml
# TP Training:
python mp_examples.py fit --config config/fts_tp_plan.yaml
# FSDP2 `Auto` Plan thoroughly profiled with MemProfiler
python mp_examples.py fit --config config/fts_fsdp_profiling.yaml --config config/profiling/memprofiler_demo.yaml
All of these examples will use the same multi-phase schedule below (based on the
`latest torchtitan `_ Llama model):
.. code-block:: yaml
0:
params:
- model.output.weight
- model.norm.*
max_transition_epoch: 1
1:
params:
- model.layers.3.(feed_forward|ffn_norm|attention.w.*|attention_norm).*
max_transition_epoch: 2
2:
params:
- model.layers.[0-2].(feed_forward|ffn_norm|attention.w.*|attention_norm).*
- model.tok_embeddings.weight
.. _model-parallel-fsdp2-auto-plan:
FSDP2 'Auto' Plan Generation/Application
****************************************
FTS can leverage FSDP2 without any special accommodation by overriding LightningModule's ``configure_model`` method
and manually applying the ``fully_shard`` API to the desired modules as outlined in the
`Lightning FSDP2 guide `_.
The primary enhancement provided by FTS for this strategy is the ability to automatically apply the FSDP2 API to
modules based upon the user's desired module name or regex patterns without overriding
``LightningModule.configure_model``.
This is done by providing a dictionary of module name/pattern-based FSDP2 API directives via
:attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_plan`. The keys of
``fsdp_plan`` are either module names or regex patterns and the optional values are valid ``fully_shard`` keyword
arguments or any of the :ref:`FTS convenience aliases`.
As :ref:`discussed above`, ``fsdp_default_kwargs`` can be used to provide default
keyword arguments to compose with all ``fsdp_plan`` ``fully_shard`` directives.
For example, passing the below ``fsdp_plan`` to ``FinetuningScheduler`` via
:attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.strategy_adapter_cfg` will apply the
``fully_shard`` API to all ``TransformerBlock`` layers in the llama model as well as the final output module.
.. code-block:: python
from finetuning_scheduler import FinetuningScheduler
my_plan = {
"model.output": {"reshard_after_forward": True}, # any ``fully_shard`` API kwargs
"model.layers.\d*$": {}, # default ``fully_shard`` kwargs used here
}
fts_cfg = dict(
ft_schedule="config/defaults/llama_ft_schedule.yaml", max_depth=2, strategy_adapter_cfg={"fsdp_plan": my_plan}
)
fts_callback = FinetuningScheduler(**fts_cfg)
We can also use ``fsdp_default_kwargs`` to provide default keyword arguments to compose with all ``fsdp_plan``
``fully_shard`` directives. This example does so via the CLI and a yaml config and uses
:ref:`FTS convenience aliases` to enable cpu offloading and composable
activation checkpointing for all specified FSDP2 instances like so:
.. code-block:: yaml
:emphasize-lines: 3-5
strategy_adapter_cfg:
fsdp_default_kwargs:
reshard_after_forward: True # default value of a normal ``fully_shard`` kwarg
act_ckpt: ['composable'] # use composable AC with default kwargs
cpu_offload_policy: {} # apply default cpu offload policy
fsdp_plan: {'model.output': {}, 'model.layers.\d*$': {}}
That's it! We've configured composable/distributed/multi-phase/scheduled fine-tuning training and didn't even need to
override ``LightningModule.configure_model``!
.. list-table:: Resulting Composition
:widths: 50 50
:header-rows: 0
* -
.. figure:: ../_static/images/fts/pl_module_first_outer_tformer_noac.png
:alt: FSDP2 modules are composed with the provided modules as specified.
FSDP2 modules are composed with the provided modules as specified.
-
.. figure:: ../_static/images/fts/last_tblock_output_noac.png
:alt: Modules not specified as separate FSDP2 instances remain normal modules.
Modules not specified as separate FSDP2 instances remain normal modules (e.g. ``norm``, ``feed_forward`` etc.).
.. code-block:: bash
cd ./fts_examples/model_parallel
python mp_examples.py fit --config config/fts_fsdp_auto_plan.yaml
.. tip::
FTS will only apply ``fully_shard`` to a specified module if it was not already applied to that module, so using
``fsdp_plan`` (and ``fsdp_default_kwargs``) can be composed with existing ``fully_shard`` (or Tensor Parallel)
directives in ``LightningModule.configure_model``.
.. note::
As with manual application of the API,
:attr:`~finetuning_scheduler.strategy_adapters.ModelParallelStrategyAdapter.fsdp_plan` directives should be
applied bottom-up. For instance, one should compose ``self.model.layer`` before ``self.model``, e.g.
``fsdp_plan: {'model.layer': {}, 'model': {}}``
.. tip::
At time of writing, some optimizer operations do not support parameter groups with mixed DTensor/Non-DTensor
(usually ``torch.Tensor``) parameters.
.. raw:: html
FTS will inspect the provided fine-tuning schedule and FSDP plan for this condition and if it is detected provide
the user ``INFO``-level feedback like the above.
In the next section, we'll cover Tensor Parallel (TP) training with FTS.
.. _model-parallel-tp-plan:
FTS TP Plan
***********
FTS works with Tensor Parallel (TP) training without any special accommodation by overriding LightningModule's
``configure_model`` method and manually applying the relevant parallelism plan. Unlike the enhanced FSDP2 API, the
current version of FTS does not provide any auto-configuration enhancements for Tensor Parallel. For more on
constructing TP plans, see this
`Lightning TP guide `_.
As you can observe in (``./mp_examples.py``) our TP plan in this example is applied as usual by overriding
``LightningModule.configure_model`` like so:
.. code-block:: python
def configure_model(self):
if self.device_mesh["tensor_parallel"].size() > 1:
# User-defined function that applies a given TP plan if desired
apply_tp_plan(self.model, device_mesh=self.device_mesh, loss_parallel=self.hparams.exp_cfg.loss_parallel)
.. note::
FTS FSDP2 auto plan (and/or manual FSDP2 directives in ``LightningModule.configure_model``) can also be composed with
TP plan directives in ``LightningModule.configure_model`` for 2D parallelism similar
`to this example `_. Any specified
TP plan directives will be applied before subsequent FSDP2 directives.
.. code-block:: bash
cd ./fts_examples/model_parallel
python mp_examples.py fit --config config/fts_tp_plan.yaml