Source code for fts_examples.profiling.config
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
MemProfiler Configuration Dataclasses
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This module defines the configuration dataclasses for the MemProfiler.
"""
from typing import Optional, Callable, List, Set, Union
from dataclasses import dataclass, field, fields
from pathlib import Path
import torch
from lightning.fabric.utilities import rank_zero_warn
[docs]@dataclass
class MemProfilerHooks:
pre_forward_hooks: List[Union[str, Callable]] = \
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_pre_forward'])
post_forward_hooks: List[Union[str, Callable]] = \
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_post_forward'])
# the provided reset_state_hooks will be called with the model and the `save_hook_attrs` list
reset_state_hooks: List[Union[str, Callable]] = \
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._reset_memory_hooks_state'])
[docs]@dataclass
class MemProfilerFuncs: # can specify arbitrary list of `memprofilable` decorated function names
# funcs that will be added to all memory collection types
default: Set[str] = field(default_factory=lambda: {'training_step'})
cpu: Set[str] = field(default_factory=set)
cuda: Set[str] = field(default_factory=set)
cuda_allocator_history: Set[str] = field(default_factory=set)
fsdp: Set[str] = field(default_factory=set)
[docs]@dataclass
class MemProfilerSchedule:
# keeping schedule simple as possibile for now, may expand to accommodate more flexible schedules in the future
warmup_iters: int = 1
max_iter: Optional[int] = None
[docs]@dataclass
class MemProfilerCfg:
"""Configuration dataclass for the MemProfiler.
:param enabled: Whether to enable memory profiling.
:param collect_funcs: A MemProfilerFuncs instance specifying the functions to collect per memory collection type.
:param cuda_allocator_history: Whether to collect CUDA memory allocator history.
:param track_fsdp_mem: Whether to collect FSDP memory statistics.
:param fsdp_mem_track_module_depth: The depth of FSDP modules to track.
:param fsdp_mem_tracker_tabulate: Whether to print FSDP memory statistics in a tabular format.
:param fsdp_mem_tracker_units: The units to use for FSDP memory statistics.
:param fsdp_mem_tracker_root_module: The root module to use for FSDP memory statistics.
:param dump_memorystats_pickle: Whether to dump memory statistics to a pickle file.
:param dump_memorystats_yaml: Whether to dump memory statistics to a yaml file.
:param schedule: A MemProfilerSchedule instance specifying the schedule for memory collection.
:param save_dir: The directory to save the memory statistics.
:param enable_memory_hooks: Whether to enable memory hooks.
:param enable_saved_tensors_hooks: Whether to enable saved tensors hooks.
:param memory_hooks: A MemProfilerHooks instance specifying the memory hooks.
:param saved_tensors_funcs: A list of saved tensors functions.
:param save_hook_attrs: A list of module state attributes to save.
:param retain_hooks_for_funcs: A set of functions to retain memory hooks for.
"""
enabled: bool = False
# specify funcs to collect per memory collection type, a default list to apply to all types or both composed
collect_funcs: MemProfilerFuncs = field(default_factory=MemProfilerFuncs)
cuda_allocator_history: bool = False
track_fsdp_mem: bool = False
fsdp_mem_track_module_depth: int = 2
fsdp_mem_tracker_tabulate: bool = False
fsdp_mem_tracker_units: str = "MiB"
fsdp_mem_tracker_root_module: str = ""
dump_memorystats_pickle: bool = False
dump_memorystats_yaml: bool = True
schedule: MemProfilerSchedule = field(default_factory=MemProfilerSchedule)
save_dir: Optional[Union[str, Path]] = None
enable_memory_hooks: bool = True
enable_saved_tensors_hooks: bool = True
memory_hooks: MemProfilerHooks = field(default_factory=MemProfilerHooks)
# because it's frequently used for unpacking and to ensure this dataclass remains serializable, we allow
# specification of 'identity_lambda' which will resolve to `lambda x: x`
saved_tensors_funcs: List = field(default_factory=lambda: list(('fts_examples.profiling.npp_hooks._npp_hook',
'identity_lambda')))
# if you add custom hooks, make sure to add the desired module state attributes to save to `save_hook_attrs`
save_hook_attrs: List = field(default_factory=lambda: ["rss_pre_forward", "rss_post_forward", "rss_diff",
"npp_pre_forward", "npp_post_forward", "npp_diff"])
# since we cannot reliably ascertain when all MemProfilerFuncs will be executed, memory hooks will
# only be removed once the funcs in this set have reached `max_iter`
retain_hooks_for_funcs: Set[str] = field(default_factory=lambda: {'training_step'})
def __post_init__(self) -> None:
if not self.enabled:
return
if not torch.cuda.is_available() and any((self.collect_funcs.cuda_allocator_history, self.collect_funcs.cuda,
self.cuda_allocator_history)):
rank_zero_warn("Disabling CUDA memory profiling functionality since no CUDA device detected.")
self.collect_funcs.cuda, self.collect_funcs.cuda_allocator_history = set(), set()
self.cuda_allocator_history = False
has_hooks = any(getattr(self.memory_hooks, ht.name) for ht in fields(self.memory_hooks))
if not has_hooks:
rank_zero_warn("MemProfilerCfg is configured to enable memory hooks but MemProfilerHooks does not have"
" any specified.")
if self.schedule.max_iter is None:
self.schedule.max_iter = self.schedule.warmup_iters + 1
# compose all non-default func sets with the default set
default_funcs = self.collect_funcs.default
for k in self.collect_funcs.__dataclass_fields__.keys():
if k != 'default':
getattr(self.collect_funcs, k).update(default_funcs)