Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ exclude tests
recursive-exclude docs *
exclude docs
recursive-include docs/source/_images/logos/ *
recursive-include docs/source/_images/badges/ *
recursive-include docs/source/_images/general/ pl_overview* tf_* tutorial_* PTL101_*

# Include the Requirements
Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@
"""

import logging as python_logging
import os

_logger = python_logging.getLogger("lightning")
_logger.addHandler(python_logging.StreamHandler())
_logger.setLevel(python_logging.INFO)

PACKAGE_ROOT = os.path.dirname(__file__)
PROJECT_ROOT = os.path.dirname(PACKAGE_ROOT)

try:
# This variable is injected in the __builtins__ by the build
# process. It used to enable importing subpackages of skimage when
Expand Down Expand Up @@ -68,12 +72,5 @@
'metrics',
]

# necessary for regular bolts imports. Skip exception since bolts is not always installed
try:
from pytorch_lightning import bolts
except ImportError:
pass
# __call__ = __all__

# for compatibility with namespace packages
__import__('pkg_resources').declare_namespace(__name__)
166 changes: 166 additions & 0 deletions pytorch_lightning/setup_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/usr/bin/env python
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import warnings
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen

from pytorch_lightning import PROJECT_ROOT, __homepage__, __version__

_PATH_BADGES = os.path.join('.', 'docs', 'source', '_images', 'badges')
# badge to download
_DEFAULT_BADGES = [
'PyPI - Python Version',
'PyPI Status',
'PyPI Status',
'Conda',
'DockerHub',
'codecov',
'ReadTheDocs',
'Slack',
'Discourse status',
'license',
'Next Release'
]


def _load_requirements(path_dir, file_name='requirements.txt', comment_char='#'):
"""Load requirements from a file

>>> _load_requirements(PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['numpy...', 'torch...', ...]
"""
with open(os.path.join(path_dir, file_name), 'r') as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
# filer all comments
if comment_char in ln:
ln = ln[:ln.index(comment_char)].strip()
# skip directly installed dependencies
if ln.startswith('http'):
continue
if ln: # if requirement is not empty
reqs.append(ln)
return reqs


def _parse_for_badge(text: str, path_badges: str = _PATH_BADGES, badge_names: list = _DEFAULT_BADGES):
""" Returns the new parsed text with url change with local downloaded files

>>> _parse_for_badge('Some text here... ' # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
... '[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-lightning)]'
... '(https://pypi.org/project/pytorch-lightning/) and another text later')
'Some text here...
[![PyPI - Python Version](...docs...source..._images...badges...PyPI_Python_Version_badge.png)](https://pypi.org/project/pytorch-lightning/)
and another text later'
>>> import shutil
>>> shutil.rmtree(_PATH_BADGES)
"""
for line in text.split(os.linesep):
search_string = r'\[\!\[(.*)]\((.*)\)]'
match = re.search(search_string, line)
if match is None:
continue

badge_name, badge_url = match.groups()
# check if valid name
if badge_name not in badge_names:
continue

# download badge
saved_badge_name = _download_badge(badge_url, badge_name, path_badges)

# replace url with local file path
text = text.replace(f'[![{badge_name}]({badge_url})]', f'[![{badge_name}]({saved_badge_name})]')

return text


def _save_file(url_badge, save, extension, headers):
"""function for saving the badge either in `.png` or `.svg`"""

# because there are two badge with name `PyPI Status` the second one is download
if 'https://pepy.tech/badge/pytorch-lightning' in url_badge:
save += '_downloads'

try:
req = Request(url=url_badge, headers=headers)
resp = urlopen(req)
except URLError:
warnings.warn("Error while downloading the badge", UserWarning)
else:
save += extension
with open(save, 'wb') as download_file:
download_file.write(resp.read())


def _download_badge(url_badge, badge_name, target_dir):
"""Download badge from url

>>> path_img = _download_badge('https://img.shields.io/pypi/pyversions/pytorch-lightning',
... 'PyPI - Python Version', '.')
>>> os.path.isfile(path_img)
True
>>> path_img # doctest: +ELLIPSIS
'...PyPI_Python_Version_badge.png'
>>> os.remove(path_img)
"""
os.makedirs(target_dir, exist_ok=True)

headers = {
'User-Agent': 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:81.0) Gecko/20100101 Firefox/81.0',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/svg,*/*;q=0.8',
}

save_path = badge_name.replace(' - ', ' ')
save_path = os.path.join(target_dir, f"{save_path.replace(' ', '_')}_badge")

if "?" in url_badge and ".png" not in url_badge:
_save_file(url_badge, save_path, extension='.svg', headers=headers)
return save_path + '.svg'
else:
try:
# always try to download the png versions (some url have an already png version available)
_save_file(url_badge, save_path, extension='.png', headers=headers)
return save_path + '.png'
except HTTPError as err:
if err.code == 404:
# save the `.svg`
url_badge = url_badge.replace('.png', '.svg')
_save_file(url_badge, save_path, extension='.svg', headers=headers)
return save_path + '.svg'


def _load_long_description(path_dir):
"""Load readme as decribtion

>>> _load_long_description(PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'<div align="center">...'
>>> import shutil
>>> shutil.rmtree(_PATH_BADGES)
"""
# https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png
url = os.path.join(__homepage__, 'raw', __version__, 'docs')
path_readme = os.path.join(path_dir, 'README.md')
text = open(path_readme, encoding='utf-8').read()
# replace relative repository path to absolute link to the release
text = text.replace('](docs', f']({url}')
# SVG images are not readable on PyPI, so replace them with PNG
text = text.replace('.svg', '.png')
# download badge and replace url with local file
text = _parse_for_badge(text)
return text
44 changes: 8 additions & 36 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from io import open

# Always prefer setuptools over distutils
from setuptools import find_packages, setup
Expand All @@ -30,45 +29,18 @@
builtins.__LIGHTNING_SETUP__ = True

import pytorch_lightning # noqa: E402


def load_requirements(path_dir=PATH_ROOT, file_name='requirements.txt', comment_char='#'):
with open(os.path.join(path_dir, file_name), 'r') as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
# filer all comments
if comment_char in ln:
ln = ln[:ln.index(comment_char)].strip()
# skip directly installed dependencies
if ln.startswith('http'):
continue
if ln: # if requirement is not empty
reqs.append(ln)
return reqs


def load_long_description():
# https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png
url = os.path.join(pytorch_lightning.__homepage__, 'raw', pytorch_lightning.__version__, 'docs')
text = open('README.md', encoding='utf-8').read()
# replace relative repository path to absolute link to the release
text = text.replace('](docs', f']({url}')
# SVG images are not readable on PyPI, so replace them with PNG
text = text.replace('.svg', '.png')
return text

from pytorch_lightning.setup_tools import _load_long_description, _load_requirements # noqa: E402

# https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras
# Define package extras. These are only installed if you specify them.
# From remote, use like `pip install pytorch-lightning[dev, docs]`
# From local copy of repo, use like `pip install ".[dev, docs]"`
extras = {
# 'docs': load_requirements(file_name='docs.txt'),
'examples': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'),
'loggers': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'),
'extra': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'),
'test': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt')
'examples': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'),
'loggers': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'),
'extra': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'),
'test': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt')
}
extras['dev'] = extras['extra'] + extras['loggers'] + extras['test']
extras['all'] = extras['dev'] + extras['examples'] # + extras['docs']
Expand All @@ -89,7 +61,7 @@ def load_long_description():
# the goal of the project is simplicity for researchers, don't want to add too much
# engineer specific practices
setup(
name='pytorch-lightning',
name="pytorch-lightning",
version=pytorch_lightning.__version__,
description=pytorch_lightning.__docs__,
author=pytorch_lightning.__author__,
Expand All @@ -99,15 +71,15 @@ def load_long_description():
license=pytorch_lightning.__license__,
packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks']),

long_description=load_long_description(),
long_description=_load_long_description(PATH_ROOT),
long_description_content_type='text/markdown',
include_package_data=True,
zip_safe=False,

keywords=['deep learning', 'pytorch', 'AI'],
python_requires='>=3.6',
setup_requires=[],
install_requires=load_requirements(),
install_requires=_load_requirements(PATH_ROOT),
extras_require=extras,

project_urls={
Expand Down