Skip to content

Commit f53ecd9

Browse files
authored
Merge pull request #555 from patrickvonplaten/add_tf_hub
Proposal to integrate into 🤗 Hub
2 parents ba46b47 + f4efa38 commit f53ecd9

File tree

13 files changed

+117
-3
lines changed

13 files changed

+117
-3
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"tensorflow-gpu==2.3.1",
2626
"tensorflow-addons>=0.10.0",
2727
"setuptools>=38.5.1",
28+
"huggingface_hub==0.0.8",
2829
"librosa>=0.7.0",
2930
"soundfile>=0.10.2",
3031
"matplotlib>=3.1.0",

tensorflow_tts/inference/auto_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import yaml
19+
import os
1920
from collections import OrderedDict
2021

2122
from tensorflow_tts.configs import (
@@ -28,6 +29,10 @@
2829
ParallelWaveGANGeneratorConfig,
2930
)
3031

32+
from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME
33+
from tensorflow_tts import __version__ as VERSION
34+
from huggingface_hub import hf_hub_url, cached_download
35+
3136
CONFIG_MAPPING = OrderedDict(
3237
[
3338
("fastspeech", FastSpeechConfig),
@@ -50,6 +55,20 @@ def __init__(self):
5055

5156
@classmethod
5257
def from_pretrained(cls, pretrained_path, **kwargs):
58+
# load weights from hf hub
59+
if not os.path.isfile(pretrained_path):
60+
# retrieve correct hub url
61+
download_url = hf_hub_url(repo_id=pretrained_path, filename=CONFIG_FILE_NAME)
62+
63+
pretrained_path = str(
64+
cached_download(
65+
url=download_url,
66+
library_name=LIBRARY_NAME,
67+
library_version=VERSION,
68+
cache_dir=CACHE_DIRECTORY,
69+
)
70+
)
71+
5372
with open(pretrained_path) as f:
5473
config = yaml.load(f, Loader=yaml.SafeLoader)
5574

tensorflow_tts/inference/auto_model.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import logging
1818
import warnings
19+
import os
20+
1921
from collections import OrderedDict
2022

2123
from tensorflow_tts.configs import (
@@ -40,6 +42,9 @@
4042
SavableTFFastSpeech2,
4143
SavableTFTacotron2
4244
)
45+
from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME
46+
from tensorflow_tts import __version__ as VERSION
47+
from huggingface_hub import hf_hub_url, cached_download
4348

4449

4550
TF_MODEL_MAPPING = OrderedDict(
@@ -62,8 +67,35 @@ def __init__(self):
6267
raise EnvironmentError("Cannot be instantiated using `__init__()`")
6368

6469
@classmethod
65-
def from_pretrained(cls, config, pretrained_path=None, **kwargs):
70+
def from_pretrained(cls, config=None, pretrained_path=None, **kwargs):
6671
is_build = kwargs.pop("is_build", True)
72+
73+
# load weights from hf hub
74+
if pretrained_path is not None:
75+
if not os.path.isfile(pretrained_path):
76+
# retrieve correct hub url
77+
download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME)
78+
79+
downloaded_file = str(
80+
cached_download(
81+
url=download_url,
82+
library_name=LIBRARY_NAME,
83+
library_version=VERSION,
84+
cache_dir=CACHE_DIRECTORY,
85+
)
86+
)
87+
88+
# load config from repo as well
89+
if config is None:
90+
from tensorflow_tts.inference import AutoConfig
91+
92+
config = AutoConfig.from_pretrained(pretrained_path)
93+
94+
pretraine_path = downloaded_file
95+
96+
97+
assert config is not None, "Please make sure to pass a config along to load a model from a local file"
98+
6799
for config_class, model_class in TF_MODEL_MAPPING.items():
68100
if isinstance(config, config_class) and str(config_class.__name__) in str(
69101
config
@@ -79,6 +111,7 @@ def from_pretrained(cls, config, pretrained_path=None, **kwargs):
79111
pretrained_path, by_name=True, skip_mismatch=True
80112
)
81113
return model
114+
82115
raise ValueError(
83116
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
84117
"Model type should be one of {}.".format(

tensorflow_tts/inference/auto_processor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import json
19+
import os
1920
from collections import OrderedDict
2021

2122
from tensorflow_tts.processor import (
@@ -26,6 +27,10 @@
2627
ThorstenProcessor,
2728
)
2829

30+
from tensorflow_tts.utils import CACHE_DIRECTORY, PROCESSOR_FILE_NAME, LIBRARY_NAME
31+
from tensorflow_tts import __version__ as VERSION
32+
from huggingface_hub import hf_hub_url, cached_download
33+
2934
CONFIG_MAPPING = OrderedDict(
3035
[
3136
("LJSpeechProcessor", LJSpeechProcessor),
@@ -46,6 +51,19 @@ def __init__(self):
4651

4752
@classmethod
4853
def from_pretrained(cls, pretrained_path, **kwargs):
54+
# load weights from hf hub
55+
if not os.path.isfile(pretrained_path):
56+
# retrieve correct hub url
57+
download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME)
58+
59+
pretrained_path = str(
60+
cached_download(
61+
url=download_url,
62+
library_name=LIBRARY_NAME,
63+
library_version=VERSION,
64+
cache_dir=CACHE_DIRECTORY,
65+
)
66+
)
4967
with open(pretrained_path, "r") as f:
5068
config = json.load(f)
5169

tensorflow_tts/processor/baker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pypinyin.converter import DefaultConverter
2828
from pypinyin.core import Pinyin
2929
from tensorflow_tts.processor import BaseProcessor
30+
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
3031

3132
_pad = ["pad"]
3233
_eos = ["eos"]
@@ -552,6 +553,13 @@ def __post_init__(self):
552553
def setup_eos_token(self):
553554
return _eos[0]
554555

556+
def save_pretrained(self, saved_path):
557+
os.makedirs(saved_path, exist_ok=True)
558+
self._save_mapper(
559+
os.path.join(saved_path, PROCESSOR_FILE_NAME),
560+
{"pinyin_dict": self.pinyin_dict},
561+
)
562+
555563
def create_items(self):
556564
items = []
557565
if self.data_dir:

tensorflow_tts/processor/base_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,8 @@ def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None)
224224
if extra_attrs_to_save:
225225
full_mapper = {**full_mapper, **extra_attrs_to_save}
226226
json.dump(full_mapper, f)
227+
228+
@abc.abstractmethod
229+
def save_pretrained(self, saved_path):
230+
"""Save mappers to file"""
231+
pass

tensorflow_tts/processor/kss.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow_tts.processor import BaseProcessor
2424
from tensorflow_tts.utils import cleaners
2525
from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS
26+
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
2627

2728
# Regular expression matching text enclosed in curly braces:
2829
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
@@ -57,6 +58,10 @@ def split_line(self, data_dir, line, split):
5758
def setup_eos_token(self):
5859
return "eos"
5960

61+
def save_pretrained(self, saved_path):
62+
os.makedirs(saved_path, exist_ok=True)
63+
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
64+
6065
def get_one_sample(self, item):
6166
text, wav_path, speaker_name = item
6267

tensorflow_tts/processor/libritts.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from g2p_en import g2p as grapheme_to_phonem
2525

2626
from tensorflow_tts.processor.base_processor import BaseProcessor
27+
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
2728

2829
g2p = grapheme_to_phonem.G2p()
2930

@@ -84,7 +85,11 @@ def get_one_sample(self, item):
8485
return sample
8586

8687
def setup_eos_token(self):
87-
return None # because we do not use this
88+
return None # because we do not use this
89+
90+
def save_pretrained(self, saved_path):
91+
os.makedirs(saved_path, exist_ok=True)
92+
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
8893

8994
def text_to_sequence(self, text):
9095
if (

tensorflow_tts/processor/ljspeech.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dataclasses import dataclass
2323
from tensorflow_tts.processor import BaseProcessor
2424
from tensorflow_tts.utils import cleaners
25+
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
2526

2627
valid_symbols = [
2728
"AA",
@@ -158,6 +159,10 @@ def split_line(self, data_dir, line, split):
158159
def setup_eos_token(self):
159160
return _eos
160161

162+
def save_pretrained(self, saved_path):
163+
os.makedirs(saved_path, exist_ok=True)
164+
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
165+
161166
def get_one_sample(self, item):
162167
text, wav_path, speaker_name = item
163168

tensorflow_tts/processor/thorsten.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dataclasses import dataclass
2323
from tensorflow_tts.processor import BaseProcessor
2424
from tensorflow_tts.utils import cleaners
25+
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
2526

2627
_pad = "pad"
2728
_eos = "eos"
@@ -67,6 +68,10 @@ def split_line(self, data_dir, line, split):
6768
def setup_eos_token(self):
6869
return _eos
6970

71+
def save_pretrained(self, saved_path):
72+
os.makedirs(saved_path, exist_ok=True)
73+
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
74+
7075
def get_one_sample(self, item):
7176
text, wav_path, speaker_name = item
7277

0 commit comments

Comments
 (0)