diff --git a/README.md b/README.md index 95486aa..f097d5a 100644 --- a/README.md +++ b/README.md @@ -27,24 +27,25 @@ Please consider citing the [paper](https://arxiv.org/abs/1708.00524) of DeepMoji ## Installation -We assume that you're using [Python 2.7-3.5](https://www.python.org/downloads/) with [pip](https://pip.pypa.io/en/stable/installing/) installed. - -First you need to install [pyTorch (version 0.2+)](http://pytorch.org/), currently by: -```bash -conda install pytorch -c pytorch -``` -At the present stage the model can't make efficient use of CUDA. See details in the [Hugging Face blog post](https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983). - -When pyTorch is installed, run the following in the root directory to install the remaining dependencies: +Assuming you have [Conda](https://conda.io) installed, run: ```bash +conda create -n torchMoji -f environment.yml +conda activate torchMoji pip install -e . ``` + This will install the following dependencies: + +* [PyTorch](https://pytorch.org) * [scikit-learn](https://github.com/scikit-learn/scikit-learn) * [text-unidecode](https://github.com/kmike/text-unidecode) * [emoji](https://github.com/carpedm20/emoji) +If you do not want to use Conda, please install `torch==1.3.1` from PIP and then run `pip install -e .` from the root directory (don't forget to set up a virtual environment). + +At the present stage the model can't make efficient use of CUDA. See details in the [Hugging Face blog post](https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983). + Then, run the download script to downloads the pretrained torchMoji weights (~85MB) from [here](https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0) and put them in the model/ directory: ```bash diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..74e8ef4 --- /dev/null +++ b/environment.yml @@ -0,0 +1,41 @@ +name: torchMoji +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1 + - blas=1.0 + - ca-certificates=2019.11.27 + - certifi=2019.11.28 + - cffi=1.13.2 + - cudatoolkit=10.1.243 + - intel-openmp=2019.4 + - libedit=3.1.20181209 + - libffi=3.2.1 + - libgcc-ng=9.1.0 + - libgfortran-ng=7.3.0 + - libstdcxx-ng=9.1.0 + - mkl=2018.0.3 + - ncurses=6.1 + - ninja=1.9.0 + - nose=1.3.7 + - numpy=1.13.1 + - openssl=1.1.1d + - pip=19.3.1 + - pycparser=2.19 + - python=3.6.9 + - pytorch=1.3.1 + - readline=7.0 + - scikit-learn=0.19.0 + - scipy=0.19.1 + - setuptools=42.0.2 + - sqlite=3.30.1 + - text-unidecode=1.0 + - tk=8.6.8 + - wheel=0.33.6 + - xz=5.2.4 + - zlib=1.2.11 + - pip: + - emoji==0.4.5 +prefix: /home/cbowdon/miniconda3/envs/torchMoji + diff --git a/torchmoji/finetuning.py b/torchmoji/finetuning.py index 513dd14..f8617b3 100644 --- a/torchmoji/finetuning.py +++ b/torchmoji/finetuning.py @@ -18,7 +18,7 @@ from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader from torch.utils.data.sampler import BatchSampler, SequentialSampler -from torch.nn.utils import clip_grad_norm +from torch.nn.utils import clip_grad_norm_ from sklearn.metrics import f1_score @@ -521,7 +521,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs, torch.save(model.state_dict(), checkpoint_path) model.eval() - best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen]) + best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy() for xv, yv in val_gen]) print("original val loss", best_loss) epoch_without_impr = 0 @@ -535,17 +535,17 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs, output = model(X_train) loss = calc_loss(loss_op, output, y_train) loss.backward() - clip_grad_norm(model.parameters(), 1) + clip_grad_norm_(model.parameters(), 1) optim_op.step() acc = evaluate_using_acc(model, [(X_train.data, y_train.data)]) - print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc) + print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy(), "train acc", acc) model.eval() acc = evaluate_using_acc(model, val_gen) print("val acc", acc) - val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen]) + val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy() for xv, yv in val_gen]) print("val loss", val_loss) if best_loss is not None and val_loss >= best_loss: epoch_without_impr += 1 diff --git a/torchmoji/lstm.py b/torchmoji/lstm.py index 67ed0e1..0ea7098 100644 --- a/torchmoji/lstm.py +++ b/torchmoji/lstm.py @@ -75,7 +75,8 @@ def reset_parameters(self): def forward(self, input, hx=None): is_packed = isinstance(input, PackedSequence) if is_packed: - input, batch_sizes = input + batch_sizes = input.batch_sizes + input = input.data max_batch_size = batch_sizes[0] else: batch_sizes = None @@ -337,11 +338,11 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): ingate = hard_sigmoid(ingate) forgetgate = hard_sigmoid(forgetgate) - cellgate = F.tanh(cellgate) + cellgate = torch.tanh(cellgate) outgate = hard_sigmoid(outgate) cy = (forgetgate * cx) + (ingate * cellgate) - hy = outgate * F.tanh(cy) + hy = outgate * torch.tanh(cy) return hy, cy diff --git a/torchmoji/model_def.py b/torchmoji/model_def.py index 9290371..410300b 100644 --- a/torchmoji/model_def.py +++ b/torchmoji/model_def.py @@ -144,7 +144,7 @@ def __init__(self, nb_classes, nb_tokens, feature_output=False, output_logits=Fa self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1))) else: self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1), - nn.Softmax() if self.nb_classes > 2 else nn.Sigmoid())) + nn.Softmax(dim=1) if self.nb_classes > 2 else nn.Sigmoid())) self.init_weights() # Put model in evaluation mode by default self.eval() @@ -156,15 +156,15 @@ def init_weights(self): ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) b = (param.data for name, param in self.named_parameters() if 'bias' in name) - nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5) + nn.init.uniform_(self.embed.weight.data, a=-0.5, b=0.5) for t in ih: - nn.init.xavier_uniform(t) + nn.init.xavier_uniform_(t) for t in hh: - nn.init.orthogonal(t) + nn.init.orthogonal_(t) for t in b: - nn.init.constant(t, 0) + nn.init.constant_(t, 0) if not self.feature_output: - nn.init.xavier_uniform(self.output_layer[0].weight.data) + nn.init.xavier_uniform_(self.output_layer[0].weight.data) def forward(self, input_seqs): """ Forward pass. @@ -177,10 +177,8 @@ def forward(self, input_seqs): """ # Check if we have Torch.LongTensor inputs or not Torch.Variable (assume Numpy array in this case), take note to return same format return_numpy = False - return_tensor = False if isinstance(input_seqs, (torch.LongTensor, torch.cuda.LongTensor)): input_seqs = Variable(input_seqs) - return_tensor = True elif not isinstance(input_seqs, Variable): input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long()) return_numpy = True @@ -246,8 +244,6 @@ def forward(self, input_seqs): outputs = reorered # Adapt return format if needed - if return_tensor: - outputs = outputs.data if return_numpy: outputs = outputs.data.numpy()