Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pl_examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,40 @@

TORCHVISION_AVAILABLE = _module_available("torchvision")
DALI_AVAILABLE = _module_available("nvidia.dali")


LIGHTNING_LOGO = """
####
###########
####################
############################
#####################################
##############################################
######################### ###################
####################### ###################
#################### ####################
################## #####################
################ ######################
##################### #################
###################### ###################
##################### #####################
#################### #######################
################### #########################
##############################################
#####################################
############################
####################
##########
####
"""


def nice_print(msg, last=False):
print()
print("\033[0;35m" + msg + "\033[0m")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try to find Pytorch Lightning color for terminal

if last:
print()


def cli_lightning_logo():
nice_print(LIGHTNING_LOGO)
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data import random_split

import pytorch_lightning as pl
from pl_examples import TORCHVISION_AVAILABLE
from pl_examples import TORCHVISION_AVAILABLE, cli_lightning_logo

if TORCHVISION_AVAILABLE:
from torchvision.datasets.mnist import MNIST
Expand Down Expand Up @@ -105,4 +105,5 @@ def cli_main():


if __name__ == '__main__':
cli_lightning_logo()
cli_main()
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pl_examples import DATASETS_PATH, TORCHVISION_AVAILABLE
from pl_examples import DATASETS_PATH, TORCHVISION_AVAILABLE, cli_lightning_logo

if TORCHVISION_AVAILABLE:
from torchvision.datasets.mnist import MNIST
Expand Down Expand Up @@ -125,4 +125,5 @@ def cli_main():


if __name__ == '__main__':
cli_lightning_logo()
cli_main()
2 changes: 2 additions & 0 deletions pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import torchvision

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning import Trainer
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
Expand Down Expand Up @@ -190,6 +191,7 @@ def instantiate_datamodule(args):


if __name__ == "__main__":
cli_lightning_logo()
parser = ArgumentParser(description="Pipe Example")
parser.add_argument("--use_ddp_sequential", action="store_true")
parser = Trainer.add_argparse_args(parser)
Expand Down
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import random_split

import pytorch_lightning as pl
from pl_examples import TORCHVISION_AVAILABLE, DALI_AVAILABLE
from pl_examples import TORCHVISION_AVAILABLE, DALI_AVAILABLE, cli_lightning_logo

if TORCHVISION_AVAILABLE:
from torchvision.datasets.mnist import MNIST
Expand Down Expand Up @@ -204,4 +204,5 @@ def cli_main():


if __name__ == "__main__":
cli_lightning_logo()
cli_main()
2 changes: 2 additions & 0 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


Expand Down Expand Up @@ -103,4 +104,5 @@ def cli_main():


if __name__ == '__main__':
cli_lightning_logo()
cli_main()
3 changes: 3 additions & 0 deletions pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import os
import torch
from torch.utils.data import Dataset

from pl_examples import cli_lightning_logo
from pytorch_lightning import Trainer, LightningModule


Expand Down Expand Up @@ -137,4 +139,5 @@ def on_train_epoch_start(self) -> None:


if __name__ == '__main__':
cli_lightning_logo()
run_test()
15 changes: 15 additions & 0 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computer vision example on Transfer Learning.

This computer vision example illustrates how one could fine-tune a pre-trained
Expand Down Expand Up @@ -40,6 +53,7 @@
from torchvision.datasets.utils import download_and_extract_archive

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning import _logger as log

BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)
Expand Down Expand Up @@ -451,4 +465,5 @@ def get_args() -> argparse.Namespace:


if __name__ == '__main__':
cli_lightning_logo()
main(get_args())
15 changes: 15 additions & 0 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this template just do:
python generative_adversarial_net.py
Expand All @@ -18,6 +31,7 @@
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pl_examples import cli_lightning_logo
from pytorch_lightning.core import LightningModule, LightningDataModule
from pytorch_lightning.trainer import Trainer

Expand Down Expand Up @@ -211,6 +225,7 @@ def main(args: Namespace) -> None:


if __name__ == '__main__':
cli_lightning_logo()
parser = ArgumentParser()

# Add program level args, if any.
Expand Down
15 changes: 15 additions & 0 deletions pl_examples/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py

Expand Down Expand Up @@ -32,6 +45,7 @@
import torchvision.transforms as transforms

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning.core import LightningModule


Expand Down Expand Up @@ -246,4 +260,5 @@ def run_cli():


if __name__ == '__main__':
cli_lightning_logo()
run_cli()
15 changes: 15 additions & 0 deletions pl_examples/domain_templates/reinforce_learn_Qnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deep Reinforcement Learning: Deep Q-network (DQN)

Expand Down Expand Up @@ -33,6 +46,7 @@
from torch.utils.data.dataset import IterableDataset

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo


class DQN(nn.Module):
Expand Down Expand Up @@ -349,6 +363,7 @@ def main(args) -> None:


if __name__ == '__main__':
cli_lightning_logo()
torch.manual_seed(0)
np.random.seed(0)

Expand Down
16 changes: 16 additions & 0 deletions pl_examples/domain_templates/semantic_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
from argparse import ArgumentParser, Namespace
Expand All @@ -10,6 +24,7 @@
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pl_examples.domain_templates.unet import UNet
from pytorch_lightning.loggers import WandbLogger

Expand Down Expand Up @@ -225,6 +240,7 @@ def main(hparams: Namespace):


if __name__ == '__main__':
cli_lightning_logo()
parser = ArgumentParser()
parser.add_argument("--data_path", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
Expand Down
14 changes: 14 additions & 0 deletions pl_examples/domain_templates/unet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
13 changes: 13 additions & 0 deletions pl_examples/pytorch_ecosystem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
38 changes: 0 additions & 38 deletions pl_examples/pytorch_ecosystem/pytorch_geometric/README.md

This file was deleted.

Empty file.
Loading