-
Notifications
You must be signed in to change notification settings - Fork 193
Dev/shining #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Dev/shining #75
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
ac6765e
add text matching compression & delete regression
5cba830
modify the Tutorial.md
7876d64
modify the Tutorial.md
2e7982c
add teacher model name
3be3052
modify tutorial.md
3a3ff7c
Merge branch 'master' into dev/quanjia
fcb20ac
fix transform params2tensors problem
efb2270
Merge branch 'master' into dev/quanjia
3778da7
add softmax output layer for slot tagging
1da9927
add slot_tagging metrics
b42c747
Merge branch 'master' into dev/quanjia
44bfb43
Merge branch 'master' into dev/quanjia
53e7233
modify make word emb matrix
309921b
Delete dev.tsv
adolphk-yk cfad91e
Delete test.tsv
adolphk-yk de9b3b3
Delete train.tsv
adolphk-yk b8a34a1
delate conll data
385ec7a
Merge branch 'dev/quanjia' of https://github.com/Microsoft/NeuronBloc…
cf57b98
Update Contributing.md
boshining b18fe12
Update tools
ShiningBo 0957603
Merge branch 'dev/shining' of github.com:microsoft/NeuronBlocks into …
ShiningBo 3d9b3a0
Update README.md
boshining 2897020
Update Contributing.md
boshining 4fdcf3e
Update README.md
boshining 861e6bd
Update autotest.sh
ShiningBo 4d3c70f
update get_results.py
ShiningBo 5c5f841
fix sequence tagging workflow
f7e122f
Merge branch 'master' into dev/quanjia
f1daf76
add model type judgement for optimizer
a0fd463
delete full atis dataset and unuseful config filee
a085554
add slot_tagging sample data
5579545
fix load embedding slow problem
9f98dd3
fix Char embedding CNN problem
a4e2644
Merge branch 'dev/quanjia' into dev/shining
ShiningBo 9acd52d
add lower token when load embedding matrix
d22ed5b
Merge branch 'master' into dev/quanjia
31bca9e
add word level length for char emb
3143519
Update Conv
ShiningBo b430e9f
Merge branch 'dev/quanjia' into dev/shining
ShiningBo 92cd783
merge quanjia
ShiningBo 992070b
Merge branch 'master' into dev/shining
b715945
Add ARCI & ARCII module and Modify Conv block
10c8d8f
Merge branch 'master' into dev/shining
8ce3e7e
Update to the same as master
30ce22e
update Linear layer
6380f32
Add block - Calculate Distance of Two Vectors
c8e8d47
Merge branch 'master' into dev/shining
883785c
update tutorial_zh_CN
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT license. | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| import numpy as np | ||
|
|
||
| from block_zoo.BaseLayer import BaseLayer, BaseConf | ||
| from utils.DocInherit import DocInherit | ||
|
|
||
|
|
||
| class Pooling1DConf(BaseConf): | ||
| """ | ||
|
|
||
| Args: | ||
| pool_type (str): 'max' or 'mean', default is 'max'. | ||
| stride (int): which axis to conduct pooling, default is 1. | ||
| padding (int): implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0 | ||
| window_size (int): the size of the pooling | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super(Pooling1DConf, self).__init__(**kwargs) | ||
|
|
||
| @DocInherit | ||
| def default(self): | ||
| self.pool_type = 'max' # Supported: ['max', mean'] | ||
| self.stride = 1 | ||
| self.padding = 0 | ||
| self.window_size = 3 | ||
|
|
||
| @DocInherit | ||
| def declare(self): | ||
| self.num_of_inputs = 1 | ||
| self.input_ranks = [3] | ||
|
|
||
|
|
||
| @DocInherit | ||
| def inference(self): | ||
|
|
||
| self.output_dim = [self.input_dims[0][0]] | ||
| if self.input_dims[0][1] != -1: | ||
| self.output_dim.append( | ||
| (self.input_dims[0][1] + 2 * self.padding - self.window_size) // self.stride + 1) | ||
| else: | ||
| self.output_dim.append(-1) | ||
|
|
||
| self.output_dim.append(self.input_dims[0][-1]) | ||
| # DON'T MODIFY THIS | ||
| self.output_rank = len(self.output_dim) | ||
|
|
||
| @DocInherit | ||
| def verify(self): | ||
| super(Pooling1DConf, self).verify() | ||
|
|
||
| necessary_attrs_for_user = ['pool_type'] | ||
| for attr in necessary_attrs_for_user: | ||
| self.add_attr_exist_assertion_for_user(attr) | ||
|
|
||
| self.add_attr_value_assertion('pool_type', ['max', 'mean']) | ||
|
|
||
| assert self.output_dim[ | ||
| -1] != -1, "The shape of input is %s , and the input channel number of pooling should not be -1." % ( | ||
| str(self.input_dims[0])) | ||
|
|
||
|
|
||
| class Pooling1D(BaseLayer): | ||
| """ Pooling layer | ||
|
|
||
| Args: | ||
| layer_conf (PoolingConf): configuration of a layer | ||
| """ | ||
|
|
||
| def __init__(self, layer_conf): | ||
| super(Pooling1D, self).__init__(layer_conf) | ||
| self.pool = None | ||
| if layer_conf.pool_type == "max": | ||
| self.pool = nn.MaxPool1d(kernel_size=layer_conf.window_size, stride=layer_conf.stride, | ||
| padding=layer_conf.padding) | ||
| elif layer_conf.pool_type == "mean": | ||
| self.pool = nn.AvgPool1d(kernel_size=layer_conf.window_size, stride=layer_conf.stride, | ||
| padding=layer_conf.padding) | ||
|
|
||
| def forward(self, string, string_len=None): | ||
| """ process inputs | ||
|
|
||
| Args: | ||
| string (Tensor): tensor with shape: [batch_size, length, feature_dim] | ||
| string_len (Tensor): [batch_size], default is None. | ||
|
|
||
| Returns: | ||
| Tensor: Pooling result of string | ||
|
|
||
| """ | ||
|
|
||
| string = string.permute([0, 2, 1]).contiguous() | ||
| string = self.pool(string) | ||
| string = string.permute([0, 2, 1]).contiguous() | ||
| return string, string_len | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT license. | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import logging | ||
|
|
||
| from ..BaseLayer import BaseConf, BaseLayer | ||
| from utils.DocInherit import DocInherit | ||
| from utils.exceptions import ConfigurationError | ||
| import copy | ||
|
|
||
|
|
||
| class CalculateDistanceConf(BaseConf): | ||
| """ Configuration of CalculateDistance Layer | ||
|
|
||
| Args: | ||
| operations (list): a subset of ["cos", "euclidean", "manhattan", "chebyshev"]. | ||
| """ | ||
|
|
||
| # init the args | ||
| def __init__(self, **kwargs): | ||
| super(CalculateDistanceConf, self).__init__(**kwargs) | ||
|
|
||
| # set default params | ||
| @DocInherit | ||
| def default(self): | ||
| self.operations = ["cos", "euclidean", "manhattan", "chebyshev"] | ||
|
|
||
| @DocInherit | ||
| def declare(self): | ||
| self.num_of_inputs = 2 | ||
| self.input_ranks = [2] | ||
|
|
||
| @DocInherit | ||
| def inference(self): | ||
| self.output_dim = copy.deepcopy(self.input_dims[0]) | ||
| self.output_dim[-1] = 1 | ||
|
|
||
| super(CalculateDistanceConf, self).inference() | ||
|
|
||
| @DocInherit | ||
| def verify(self): | ||
| super(CalculateDistanceConf, self).verify() | ||
|
|
||
| assert len(self.input_dims) == 2, "Operation requires that there should be two inputs" | ||
|
|
||
| # to check if the ranks of all the inputs are equal | ||
| rank_equal_flag = True | ||
| for i in range(len(self.input_ranks)): | ||
| if self.input_ranks[i] != self.input_ranks[0] or self.input_ranks[i] != 2: | ||
| rank_equal_flag = False | ||
| break | ||
| if rank_equal_flag == False: | ||
| raise ConfigurationError("For layer CalculateDistance, the ranks of each inputs should be equal and 2!") | ||
|
|
||
|
|
||
| class CalculateDistance(BaseLayer): | ||
| """ CalculateDistance layer to calculate the distance of sequences(2D representation) | ||
|
|
||
| Args: | ||
| layer_conf (CalculateDistanceConf): configuration of a layer | ||
| """ | ||
|
|
||
| def __init__(self, layer_conf): | ||
| super(CalculateDistance, self).__init__(layer_conf) | ||
| self.layer_conf = layer_conf | ||
|
|
||
|
|
||
| def forward(self, x, x_len, y, y_len): | ||
| """ | ||
|
|
||
| Args: | ||
| x: [batch_size, dim] | ||
| x_len: [batch_size] | ||
| y: [batch_size, dim] | ||
| y_len: [batch_size] | ||
| Returns: | ||
| Tensor: [batch_size, 1], None | ||
|
|
||
| """ | ||
|
|
||
| batch_size = x.size()[0] | ||
| if "cos" in self.layer_conf.operations: | ||
| result = F.cosine_similarity(x , y) | ||
| elif "euclidean" in self.layer_conf.operations: | ||
| result = torch.sqrt(torch.sum((x-y)**2, dim=1)) | ||
| elif "manhattan" in self.layer_conf.operations: | ||
| result = torch.sum(torch.abs((x - y)), dim=1) | ||
| elif "chebyshev" in self.layer_conf.operations: | ||
| result = torch.abs((x - y)).max(dim=1) | ||
| else: | ||
| raise ConfigurationError("This operation is not supported!") | ||
|
|
||
| result = result.view(batch_size, 1) | ||
| return result, None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT license. | ||
|
|
||
| # Come from http://www.hangli-hl.com/uploads/3/1/6/8/3168008/hu-etal-nips2014.pdf [ARC-II] | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import copy | ||
|
|
||
| from block_zoo.BaseLayer import BaseLayer, BaseConf | ||
| from utils.DocInherit import DocInherit | ||
| from utils.exceptions import ConfigurationError | ||
|
|
||
| class Expand_plusConf(BaseConf): | ||
| """Configuration for Expand_plus layer | ||
|
|
||
| """ | ||
| def __init__(self, **kwargs): | ||
| super(Expand_plusConf, self).__init__(**kwargs) | ||
|
|
||
| @DocInherit | ||
| def default(self): | ||
| self.operation = 'Plus' | ||
|
|
||
| @DocInherit | ||
| def declare(self): | ||
| self.num_of_inputs = 2 | ||
| self.input_ranks = [3, 3] | ||
|
|
||
| @DocInherit | ||
| def inference(self): | ||
| self.output_dim = copy.deepcopy(self.input_dims[0]) | ||
| if self.input_dims[0][1] == -1 or self.input_dims[1][1] == -1: | ||
| raise ConfigurationError("For Expand_plus layer, the sequence length should be fixed") | ||
| self.output_dim.insert(2, self.input_dims[1][1]) # y_len | ||
| super(Expand_plusConf, self).inference() # PUT THIS LINE AT THE END OF inference() | ||
|
|
||
| @DocInherit | ||
| def verify(self): | ||
| super(Expand_plusConf, self).verify() | ||
|
|
||
|
|
||
| class Expand_plus(BaseLayer): | ||
| """ Expand_plus layer | ||
| Given sequences X and Y, put X and Y expand_dim, and then add. | ||
|
|
||
| Args: | ||
| layer_conf (Expand_plusConf): configuration of a layer | ||
|
|
||
| """ | ||
| def __init__(self, layer_conf): | ||
|
|
||
| super(Expand_plus, self).__init__(layer_conf) | ||
| assert layer_conf.input_dims[0][-1] == layer_conf.input_dims[1][-1] | ||
|
|
||
|
|
||
| def forward(self, x, x_len, y, y_len): | ||
| """ | ||
|
|
||
| Args: | ||
| x: [batch_size, x_max_len, dim]. | ||
| x_len: [batch_size], default is None. | ||
| y: [batch_size, y_max_len, dim]. | ||
| y_len: [batch_size], default is None. | ||
|
|
||
| Returns: | ||
| output: batch_size, x_max_len, y_max_len, dim]. | ||
|
|
||
| """ | ||
|
|
||
| x_new = torch.stack([x]*y.size()[1], 2) # [batch_size, x_max_len, y_max_len, dim] | ||
| y_new = torch.stack([y]*x.size()[1], 1) # [batch_size, x_max_len, y_max_len, dim] | ||
|
|
||
| return x_new + y_new, None | ||
|
|
||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about this?