Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ before_test:

# to run your custom scripts instead of automatic tests
test_script:
- py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
#- python setup.py sdist
#- twine check dist/*

Expand Down
30 changes: 16 additions & 14 deletions pytorch_lightning/logging/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC
from functools import wraps


def rank_zero_only(fn):
"""Decorate a logger method to run it only on the process with rank 0
"""Decorate a logger method to run it only on the process with rank 0.

:param fn: Function to decorate
"""
Expand All @@ -15,12 +16,16 @@ def wrapped_fn(self, *args, **kwargs):
return wrapped_fn


class LightningLoggerBase(object):
"""Base class for experiment loggers"""
class LightningLoggerBase(ABC):
"""Base class for experiment loggers."""

def __init__(self):
self._rank = 0

@property
def experiment(self):
raise NotImplementedError()

def log_metrics(self, metrics, step):
"""Record metrics.

Expand All @@ -30,46 +35,43 @@ def log_metrics(self, metrics, step):
raise NotImplementedError()

def log_hyperparams(self, params):
"""Record hyperparameters
"""Record hyperparameters.

:param params: argparse.Namespace containing the hyperparameters
"""
raise NotImplementedError()

def save(self):
"""Save log data"""
"""Save log data."""
pass

def finalize(self, status):
"""Do any processing that is necessary to finalize an experiment
"""Do any processing that is necessary to finalize an experiment.

:param status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass

def close(self):
"""Do any cleanup that is necessary to close an experiment"""
"""Do any cleanup that is necessary to close an experiment."""
pass

@property
def rank(self):
"""
Process rank. In general, metrics should only be logged by the process
with rank 0
"""
"""Process rank. In general, metrics should only be logged by the process with rank 0."""
return self._rank

@rank.setter
def rank(self, value):
"""Set the process rank"""
"""Set the process rank."""
self._rank = value

@property
def name(self):
"""Return the experiment name"""
"""Return the experiment name."""
raise NotImplementedError("Sub-classes must provide a name property")

@property
def version(self):
"""Return the experiment version"""
"""Return the experiment version."""
raise NotImplementedError("Sub-classes must provide a version property")
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def any_lightning_module_function_or_hook(...):
class CometLogger(LightningLoggerBase):
def __init__(self, api_key=None, save_dir=None, workspace=None,
rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
"""
Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode)
"""Initialize a Comet.ml logger.
Requires either an API Key (online mode) or a local directory path (offline mode)

:param str api_key: Required in online mode. API key, found on Comet.ml
:param str save_dir: Required in offline mode. The path for the directory to save local comet logs
Expand Down
22 changes: 13 additions & 9 deletions pytorch_lightning/logging/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,29 @@ def any_lightning_module_function_or_hook(...):
class MLFlowLogger(LightningLoggerBase):
def __init__(self, experiment_name, tracking_uri=None, tags=None):
super().__init__()
self.experiment = mlflow.tracking.MlflowClient(tracking_uri)
self._mlflow_client = mlflow.tracking.MlflowClient(tracking_uri)
self.experiment_name = experiment_name
self._run_id = None
self.tags = tags

@property
def experiment(self):
return self._mlflow_client

@property
def run_id(self):
if self._run_id is not None:
return self._run_id

experiment = self.experiment.get_experiment_by_name(self.experiment_name)
if experiment is None:
logger.warning(
f"Experiment with name f{self.experiment_name} not found. Creating it."
)
self.experiment.create_experiment(self.experiment_name)
experiment = self.experiment.get_experiment_by_name(self.experiment_name)
expt = self._mlflow_client.get_experiment_by_name(self.experiment_name)

if expt:
self._expt_id = expt.experiment_id
else:
logger.warning(f"Experiment with name f{self.experiment_name} not found. Creating it.")
self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name)

run = self.experiment.create_run(experiment.experiment_id, tags=self.tags)
run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags)
self._run_id = run.info.run_id
return self._run_id

Expand Down
12 changes: 4 additions & 8 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,13 @@ def test_mlflow_logger(tmpdir):
model = LightningTestModel(hparams)

mlflow_dir = os.path.join(tmpdir, "mlruns")

logger = MLFlowLogger("test", f"file://{mlflow_dir}")
logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}")

trainer_options = dict(
max_epochs=1,
train_percent_check=0.01,
logger=logger
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

Expand All @@ -88,13 +86,11 @@ def test_mlflow_pickle(tmpdir):
except ModuleNotFoundError:
return

hparams = tutils.get_hparams()
model = LightningTestModel(hparams)
# hparams = tutils.get_hparams()
# model = LightningTestModel(hparams)

mlflow_dir = os.path.join(tmpdir, "mlruns")

logger = MLFlowLogger("test", f"file://{mlflow_dir}")

logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}")
trainer_options = dict(
max_epochs=1,
logger=logger
Expand Down