Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@parmeet
Copy link
Contributor

@parmeet parmeet commented Dec 4, 2021

Summary:

  • This PR updates the annotation types of transforms to Any
  • Add Truncate transform

Problem when transforms input and outputs are type annotated

The transforms and functionals support multiple styles of inputs (both batched and non-batched are supported) as well as sometimes they may support multiple types (refer to Truncate transform added in this PR that can work with both int and str primitive types). Composability of these transforms and support for torch scriptability may not go hand-in-hand. Below code snippets shows couple of these issues.

import torch
from typing import List, Union

#primitive transform
def foo(input: Union[int, List[int]])->Union[int, List[int]]:
   if torch.jit.isinstance(input, List[int]):
       return input
   elif torch.jit.isinstance(input, int):
      return input
   else:
       raise TypeError

#primitive transform
def bar(input: List[int])-> List[int]:
   if torch.jit.isinstance(input, List[int]):    
       return input
   else:
       raise TypeError

#composite transform
def goo(input: List[int]):
   x = bar(foo(input))
   return x

goo_jit = torch.jit.script(goo)

# gives following error:
RuntimeError: 

bar(int[] input) -> (int[]):
Expected a value of type 'int' for argument '<varargs>' but instead found type 'Union[List[int], int]'.
:
 File "<ipython-input-17-9a79ce8ea4b4>", line 13
def goo(input: List[int]):
   x = bar(foo(input))
       ~~~ <--- HERE
   return x


def foofoo(input: List[int], use_foo:bool = True):
   output = bar(input)
   if use_foo:
       output = foo(output)

   return output

foofoo_jit = torch.jit.script(foofoo)

#gives following error:
RuntimeError: 
Variable 'output' previously had type List[int] but is now being assigned to a value of type Union[List[int], int]
:
 File "<ipython-input-18-10b17e1c770a>", line 4
   output = bar(input)
   if use_foo:
       output = foo(output)
       ~~~~~~ <--- HERE

   return output

Solution

By changing the input and output type to Any and using type refinement (using torch.jit.instance) in primitive transforms, the above problem can be alleviated. As a nice side-benefit, this would force the developer to always check the input types while implementing primitive transforms.

import torch
from typing import Any, List

#primitive transform
def foo(input: Any)->Any:
    if torch.jit.isinstance(input, List[int]):
        return input
    elif torch.jit.isinstance(input, int):
       return input
    else:
        raise TypeError

#primitive transform
def bar(input: Any)->Any:
    if torch.jit.isinstance(input, List[int]):    
        return input
    else:
        raise TypeError

#composite transform
def goo(input: List[int]):
    x = bar(foo(input))
    return x

goo_jit = torch.jit.script(goo)

#composite transform
def foofoo(input: List[int], use_foo:bool = True):
    output = bar(input)
    if use_foo:
        output = foo(output)
 
    return output

foofoo_jit = torch.jit.script(foofoo)

Other concerns

  • As we remove the type annotations, we would add the doc-strings to ensure that the user is aware of supported input types.

@parmeet parmeet changed the title Add truncate transform [WIP] Add truncate transform Dec 6, 2021
@codecov
Copy link

codecov bot commented Dec 7, 2021

Codecov Report

Merging #1453 (79551db) into main (9f2fb3f) will decrease coverage by 0.67%.
The diff coverage is 60.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1453      +/-   ##
==========================================
- Coverage   86.35%   85.67%   -0.68%     
==========================================
  Files          58       58              
  Lines        2220     2262      +42     
==========================================
+ Hits         1917     1938      +21     
- Misses        303      324      +21     
Impacted Files Coverage Δ
torchtext/functional.py 70.17% <53.33%> (-20.45%) ⬇️
torchtext/models/roberta/transforms.py 79.16% <57.14%> (-18.14%) ⬇️
torchtext/transforms.py 95.71% <78.57%> (-4.29%) ⬇️
torchtext/vocab/vectors.py 89.69% <0.00%> (+3.03%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9f2fb3f...79551db. Read the comment docs.

@parmeet parmeet changed the title [WIP] Add truncate transform Update annotation types of transforms and add truncate transform Dec 7, 2021
@parmeet parmeet requested a review from mthrok December 7, 2021 17:59
Copy link
Contributor

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I learned that TS compiler can handle Any. Thanks.

.. automethod:: forward

Truncate
------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
------------
--------

ReST expects the length of line to match the title.

self.max_seq_len = max_seq_len

self.token_transform = transforms.SentencePieceTokenizer(spm_model_path)
self.tokenizer = transforms.SentencePieceTokenizer(spm_model_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically speaking, I think, this is BC-breaking change. The dict_state dumped from the previous version will stop working. Even if the BC-breaking is not an issue, as a good practice I recommend to split name change into a different PR, considering the possibility such as cherry-picking at minor release or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this up. I will keep change it back to original and do the naming change in separate PR.

return self._label_names


class Truncate(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation LGTM.

However, since this new transform is not used by the other change, I think that the addition of Truncate can be in a separate, self-contained PR. Splitting it will give better UX when users reading release note and looking into the commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again good suggestion. Actually this whole idea of changing to Any became suddenly apparent after I started adding this transform. And then I just kept all the changes :). I will do another PR to add Truncate and remove it here.

@parmeet
Copy link
Contributor Author

parmeet commented Dec 7, 2021

Looks good. I learned that TS compiler can handle Any. Thanks.

Apparently yes. I was also pleasantly surprise when I closely looked at their example https://pytorch.org/docs/stable/generated/torch.jit.isinstance.html :)

@parmeet
Copy link
Contributor Author

parmeet commented Dec 7, 2021

Thanks @mthrok for the quick review and valuable feedback. I am going to break this PR in few. More importantly, I will keep this PR to only update the annotation types of existing transforms and add doc-strings. I will update the summary accordingly.

@parmeet parmeet changed the title Update annotation types of transforms and add truncate transform Update annotation types of transforms and functionals Dec 8, 2021
@parmeet parmeet merged commit 3736f13 into pytorch:main Dec 8, 2021
@parmeet parmeet deleted the truncate_transform branch December 8, 2021 20:33
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants