Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ formats: all
python:
version: 3.7
install:
- requirements: docs/doc_requirements.txt
- requirements: docs/requirements.txt
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ cache: pip
install:
- pip install -e .
- pip install -r requirements.txt
- pip install -r tests/requirements.txt
- pip install -U numpy

# keep build from timing out
dist: xenial

# command to run tests
script:
- py.test # or py.test for Python versions 3.5 and below
- py.test -v # or py.test for Python versions 3.5 and below
1 change: 1 addition & 0 deletions docs/doc_requirements.txt → docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mkdocs-material==4.4.0
mkdocs==1.0.4
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,7 @@
torch.manual_seed(SEED)
np.random.seed(SEED)

# ---------------------
# DEFINE MODEL HERE
# ---------------------
from lightning_module_template import LightningTemplateModel
# ---------------------

"""
Allows training by using command line arguments
Run by:
# TYPE YOUR RUN COMMAND HERE
"""
from .lightning_module_template import LightningTemplateModel


def main_local(hparams):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
torch.manual_seed(SEED)
np.random.seed(SEED)

from lightning_module_template import LightningTemplateModel
from .lightning_module_template import LightningTemplateModel


def main(hparams):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
torch.manual_seed(SEED)
np.random.seed(SEED)

from lightning_module_template import LightningTemplateModel
from .lightning_module_template import LightningTemplateModel


def main(hparams):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
torch.manual_seed(SEED)
np.random.seed(SEED)

from lightning_module_template import LightningTemplateModel
from .lightning_module_template import LightningTemplateModel


def main(hparams):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
torch.manual_seed(SEED)
np.random.seed(SEED)

from lightning_module_template import LightningTemplateModel
from .lightning_module_template import LightningTemplateModel


def main(hparams):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from test_tube import HyperOptArgumentParser, Experiment
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
from pytorch_lightning.utilities.arg_parse import add_default_args
from pytorch_lightning.callbacks.pt_callbacks import EarlyStopping, ModelCheckpoint
from lightning_module_template import LightningTemplateModel

from .lightning_module_template import LightningTemplateModel


def main(hparams):
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import os
import shutil
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel

import numpy as np

from ..pt_overrides.override_data_parallel import LightningDistributedDataParallel


class Callback(object):
Expand Down Expand Up @@ -261,4 +263,3 @@ def on_epoch_end(self, epoch, logs=None):
print(loss)
if should_stop:
break

12 changes: 6 additions & 6 deletions pytorch_lightning/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import pdb
import re

import numpy as np
import tqdm
import torch
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
import tqdm

from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utils.debugging import MisconfigurationException
from ..root_module.memory import get_gpu_memory_map
from ..root_module.model_saving import TrainerIO
from ..pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
from ..utilities.debugging import MisconfigurationException

try:
from apex import amp
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/root_module/grads.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from torch import nn

"""
Module to describe gradients
"""

from torch import nn

class GradInformation(nn.Module):

Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/root_module/memory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
'''
Generates a summary of a model's layers and dimensionality
'''

import gc

import torch
import subprocess
import numpy as np
import pandas as pd


'''
Generates a summary of a model's layers and dimensionality
'''


class ModelSummary(object):

def __init__(self, model):
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/root_module/model_saving.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import os
import re
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel

import torch

from ..pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel


class ModelIO(object):
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/root_module/root_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from pytorch_lightning.root_module.memory import ModelSummary
from pytorch_lightning.root_module.grads import GradInformation
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.root_module.hooks import ModelHooks
from pytorch_lightning.root_module.decorators import data_loader

from .memory import ModelSummary
from .grads import GradInformation
from .model_saving import ModelIO, load_hparams_from_tags_csv
from .hooks import ModelHooks
from .decorators import data_loader


class LightningModule(GradInformation, ModelIO, ModelHooks):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
from collections import OrderedDict
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torch.nn.functional as F
from test_tube import HyperOptArgumentParser
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision import transforms
from test_tube import HyperOptArgumentParser

from pytorch_lightning.root_module.root_module import LightningModule
import pytorch_lightning as ptl
from ..root_module.root_module import LightningModule
from pytorch_lightning import data_loader


class LightningTestModel(LightningModule):
Expand Down Expand Up @@ -214,15 +215,15 @@ def __dataloader(self, train):

return loader

@ptl.data_loader
@data_loader
def tng_dataloader(self):
return self.__dataloader(train=True)

@ptl.data_loader
@data_loader
def val_dataloader(self):
return self.__dataloader(train=False)

@ptl.data_loader
@data_loader
def test_dataloader(self):
return self.__dataloader(train=False)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pdb
"""
List of default args which mught be useful for all the available flags
Might need to update with the new flags
"""

def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None):

Expand Down Expand Up @@ -73,4 +76,4 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
parser.add_argument('--local', dest='local', action='store_true', help='enables local tng')

# optimizer
parser.add_argument('--lr_scheduler_milestones', default=None, type=str)
parser.add_argument('--lr_scheduler_milestones', default=None, type=str)
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
import pdb
import sys

class MisconfigurationException(Exception):
pass
pass
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
coverage==4.5.3
mkdocs==1.0.4
pytest==5.0.1
scikit-learn==0.20.2
tqdm==4.32.1
twine==1.13.0
Expand Down
7 changes: 1 addition & 6 deletions tests/debug.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel
from examples import LightningTemplateModel
from argparse import Namespace
from test_tube import Experiment
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
import warnings
import torch
import os
import shutil
import pdb

import pytorch_lightning as ptl
import torch
Expand Down
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
coverage==4.5.3
pytest==5.0.1
20 changes: 9 additions & 11 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import os
import shutil
import warnings

import pytest
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel
from pytorch_lightning.testing_models.lm_test_module import LightningTestModel
from examples import LightningTemplateModel
from pytorch_lightning.testing.lm_test_module import LightningTestModel
from argparse import Namespace
from test_tube import Experiment, SlurmCluster
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utils.debugging import MisconfigurationException
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.root_module import memory
from pytorch_lightning.models.trainer import reduce_distributed_output
from pytorch_lightning.root_module import model_saving
import numpy as np
import warnings
import torch
import os
import shutil
import pdb

SEED = 2334
torch.manual_seed(SEED)
Expand Down Expand Up @@ -232,8 +232,6 @@ def test_model_saving_loading():
clear_save_dir()




def test_model_freeze_unfreeze():
hparams = get_hparams()
model = LightningTestModel(hparams)
Expand Down