diff --git a/pyproject.toml b/pyproject.toml index 6d973aa0dde51..3ba8409b6ebf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,6 @@ module = [ "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.loggers.neptune", - "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.ddp", diff --git a/src/pytorch_lightning/profilers/base.py b/src/pytorch_lightning/profilers/base.py index b91f628013a33..1d3b670207f70 100644 --- a/src/pytorch_lightning/profilers/base.py +++ b/src/pytorch_lightning/profilers/base.py @@ -13,7 +13,8 @@ # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" from abc import ABC, abstractmethod -from typing import Any +from pathlib import Path +from typing import Any, Optional, Union from pytorch_lightning.profilers.profiler import Profiler from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -57,11 +58,11 @@ class BaseProfiler(Profiler): Please use `Profiler` instead. """ - def __init__(self, *args, **kwargs): + def __init__(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]): rank_zero_deprecation( "`BaseProfiler` was deprecated in v1.6 and will be removed in v1.8. Please use `Profiler` instead." ) - super().__init__(*args, **kwargs) + super().__init__(dirpath=dirpath, filename=filename) class PassThroughProfiler(Profiler):