diff --git a/Tutorial_zh_CN.md b/Tutorial_zh_CN.md index c07feb8..7d93db9 100644 --- a/Tutorial_zh_CN.md +++ b/Tutorial_zh_CN.md @@ -289,6 +289,8 @@ Question answer matching is a crucial subtask of the question answering problem, CNN (NeuronBlocks) | 0.747 BiLSTM (NeuronBlocks) | 0.767 BiLSTM+Attn (NeuronBlocks) | 0.754 + [ARC-I](https://arxiv.org/abs/1503.03244) (NeuronBlocks) | 0.7508 + [ARC-II](https://arxiv.org/abs/1503.03244) (NeuronBlocks) | 0.7612 [MatchPyramid](https://arxiv.org/abs/1602.06359) (NeuronBlocks) | 0.763 BiLSTM+Match Attention (NeuronBlocks) | 0.786 diff --git a/block_zoo/Conv.py b/block_zoo/Conv.py index 4e4fd54..d97c45f 100644 --- a/block_zoo/Conv.py +++ b/block_zoo/Conv.py @@ -52,7 +52,10 @@ def inference(self): self.output_dim = [-1] if self.input_dims[0][1] != -1: - self.output_dim.append((self.input_dims[0][1] - self.window_size) // self.stride + 1) + if self.padding_type == 'SAME': + self.output_dim.append(self.input_dims[0][1]) + else: + self.output_dim.append((self.input_dims[0][1] - self.window_size) // self.stride + 1) else: self.output_dim.append(-1) self.output_dim.append(self.output_channel_num) diff --git a/block_zoo/Linear.py b/block_zoo/Linear.py index ecdfa5b..6867bd1 100644 --- a/block_zoo/Linear.py +++ b/block_zoo/Linear.py @@ -32,6 +32,7 @@ def default(self): self.activation = 'PReLU' self.last_hidden_activation = True self.last_hidden_softmax = False + self.keep_dim = True # for exmaple if the output shape is [?, len, 1]. you want to squeeze it, set keep_dim=False, the the output shape is [?, len] @DocInherit def declare(self): @@ -42,10 +43,16 @@ def declare(self): def inference(self): if isinstance(self.hidden_dim, int): self.output_dim = copy.deepcopy(self.input_dims[0]) - self.output_dim[-1] = self.hidden_dim + if not self.keep_dim and self.hidden_dim == 1: + self.output_dim.pop() + else: + self.output_dim[-1] = self.hidden_dim elif isinstance(self.hidden_dim, list): self.output_dim = copy.deepcopy(self.input_dims[0]) - self.output_dim[-1] = self.hidden_dim[-1] + if not self.keep_dim and self.hidden_dim[-1] == 1: + self.output_dim.pop() + else: + self.output_dim[-1] = self.hidden_dim[-1] super(LinearConf, self).inference() # PUT THIS LINE AT THE END OF inference() @@ -87,6 +94,7 @@ class Linear(BaseLayer): def __init__(self, layer_conf): super(Linear, self).__init__(layer_conf) + self.layer_conf = layer_conf if layer_conf.input_ranks[0] == 3 and layer_conf.batch_norm is True: layer_conf.batch_norm = False @@ -139,6 +147,8 @@ def forward(self, string, string_len=None): masks = masks.to(device) string = string * masks string_out = self.linear(string.float()) + if not self.layer_conf.keep_dim: + string_out = torch.squeeze(string_out, -1) return string_out, string_len diff --git a/block_zoo/Pooling1D.py b/block_zoo/Pooling1D.py new file mode 100644 index 0000000..236e440 --- /dev/null +++ b/block_zoo/Pooling1D.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from block_zoo.BaseLayer import BaseLayer, BaseConf +from utils.DocInherit import DocInherit + + +class Pooling1DConf(BaseConf): + """ + + Args: + pool_type (str): 'max' or 'mean', default is 'max'. + stride (int): which axis to conduct pooling, default is 1. + padding (int): implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0 + window_size (int): the size of the pooling + + """ + + def __init__(self, **kwargs): + super(Pooling1DConf, self).__init__(**kwargs) + + @DocInherit + def default(self): + self.pool_type = 'max' # Supported: ['max', mean'] + self.stride = 1 + self.padding = 0 + self.window_size = 3 + + @DocInherit + def declare(self): + self.num_of_inputs = 1 + self.input_ranks = [3] + + + @DocInherit + def inference(self): + + self.output_dim = [self.input_dims[0][0]] + if self.input_dims[0][1] != -1: + self.output_dim.append( + (self.input_dims[0][1] + 2 * self.padding - self.window_size) // self.stride + 1) + else: + self.output_dim.append(-1) + + self.output_dim.append(self.input_dims[0][-1]) + # DON'T MODIFY THIS + self.output_rank = len(self.output_dim) + + @DocInherit + def verify(self): + super(Pooling1DConf, self).verify() + + necessary_attrs_for_user = ['pool_type'] + for attr in necessary_attrs_for_user: + self.add_attr_exist_assertion_for_user(attr) + + self.add_attr_value_assertion('pool_type', ['max', 'mean']) + + assert self.output_dim[ + -1] != -1, "The shape of input is %s , and the input channel number of pooling should not be -1." % ( + str(self.input_dims[0])) + + +class Pooling1D(BaseLayer): + """ Pooling layer + + Args: + layer_conf (PoolingConf): configuration of a layer + """ + + def __init__(self, layer_conf): + super(Pooling1D, self).__init__(layer_conf) + self.pool = None + if layer_conf.pool_type == "max": + self.pool = nn.MaxPool1d(kernel_size=layer_conf.window_size, stride=layer_conf.stride, + padding=layer_conf.padding) + elif layer_conf.pool_type == "mean": + self.pool = nn.AvgPool1d(kernel_size=layer_conf.window_size, stride=layer_conf.stride, + padding=layer_conf.padding) + + def forward(self, string, string_len=None): + """ process inputs + + Args: + string (Tensor): tensor with shape: [batch_size, length, feature_dim] + string_len (Tensor): [batch_size], default is None. + + Returns: + Tensor: Pooling result of string + + """ + + string = string.permute([0, 2, 1]).contiguous() + string = self.pool(string) + string = string.permute([0, 2, 1]).contiguous() + return string, string_len + + diff --git a/block_zoo/Pooling2D.py b/block_zoo/Pooling2D.py index 5c94a8b..46ed2b7 100644 --- a/block_zoo/Pooling2D.py +++ b/block_zoo/Pooling2D.py @@ -19,7 +19,6 @@ class Pooling2DConf(BaseConf): stride (int): which axis to conduct pooling, default is 1. padding (int): implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0 window_size (int): the size of the pooling - activation (string): activation functions, e.g. ReLU """ def __init__(self, **kwargs): diff --git a/block_zoo/__init__.py b/block_zoo/__init__.py index 79b9522..721a3f2 100644 --- a/block_zoo/__init__.py +++ b/block_zoo/__init__.py @@ -16,9 +16,11 @@ from .Dropout import Dropout, DropoutConf from .Conv2D import Conv2D, Conv2DConf +from .Pooling1D import Pooling1D, Pooling1DConf from .Pooling2D import Pooling2D, Pooling2DConf from .embedding import CNNCharEmbedding, CNNCharEmbeddingConf +from .embedding import LSTMCharEmbedding, LSTMCharEmbeddingConf from .CRF import CRFConf, CRF @@ -51,4 +53,7 @@ from .normalizations import LayerNorm, LayerNormConf -from .HighwayLinear import HighwayLinear, HighwayLinearConf \ No newline at end of file +from .HighwayLinear import HighwayLinear, HighwayLinearConf + +from .Gating import Gating, GatingConf +from .HistogramMapping import HistogramMapping, HistogramMappingConf \ No newline at end of file diff --git a/block_zoo/embedding/LSTMCharEmbedding.py b/block_zoo/embedding/LSTMCharEmbedding.py index 23b8126..1fc18a8 100644 --- a/block_zoo/embedding/LSTMCharEmbedding.py +++ b/block_zoo/embedding/LSTMCharEmbedding.py @@ -118,7 +118,7 @@ def forward(self, string): 'input_ranks': [3], 'use_gpu': True } - layer_conf = CNNCharEmbeddingConf(**conf) + layer_conf = LSTMCharEmbeddingConf(**conf) # make a fake input: [bs, seq_len, char num in words] # assume in this batch, the padded sentence length is 3 and the each word has 5 chars, including padding 0. @@ -135,4 +135,3 @@ def forward(self, string): print(output) - diff --git a/block_zoo/op/CalculateDistance.py b/block_zoo/op/CalculateDistance.py new file mode 100644 index 0000000..823d0f9 --- /dev/null +++ b/block_zoo/op/CalculateDistance.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import logging + +from ..BaseLayer import BaseConf, BaseLayer +from utils.DocInherit import DocInherit +from utils.exceptions import ConfigurationError +import copy + + +class CalculateDistanceConf(BaseConf): + """ Configuration of CalculateDistance Layer + + Args: + operations (list): a subset of ["cos", "euclidean", "manhattan", "chebyshev"]. + """ + + # init the args + def __init__(self, **kwargs): + super(CalculateDistanceConf, self).__init__(**kwargs) + + # set default params + @DocInherit + def default(self): + self.operations = ["cos", "euclidean", "manhattan", "chebyshev"] + + @DocInherit + def declare(self): + self.num_of_inputs = 2 + self.input_ranks = [2] + + @DocInherit + def inference(self): + self.output_dim = copy.deepcopy(self.input_dims[0]) + self.output_dim[-1] = 1 + + super(CalculateDistanceConf, self).inference() + + @DocInherit + def verify(self): + super(CalculateDistanceConf, self).verify() + + assert len(self.input_dims) == 2, "Operation requires that there should be two inputs" + + # to check if the ranks of all the inputs are equal + rank_equal_flag = True + for i in range(len(self.input_ranks)): + if self.input_ranks[i] != self.input_ranks[0] or self.input_ranks[i] != 2: + rank_equal_flag = False + break + if rank_equal_flag == False: + raise ConfigurationError("For layer CalculateDistance, the ranks of each inputs should be equal and 2!") + + +class CalculateDistance(BaseLayer): + """ CalculateDistance layer to calculate the distance of sequences(2D representation) + + Args: + layer_conf (CalculateDistanceConf): configuration of a layer + """ + + def __init__(self, layer_conf): + super(CalculateDistance, self).__init__(layer_conf) + self.layer_conf = layer_conf + + + def forward(self, x, x_len, y, y_len): + """ + + Args: + x: [batch_size, dim] + x_len: [batch_size] + y: [batch_size, dim] + y_len: [batch_size] + Returns: + Tensor: [batch_size, 1], None + + """ + + batch_size = x.size()[0] + if "cos" in self.layer_conf.operations: + result = F.cosine_similarity(x , y) + elif "euclidean" in self.layer_conf.operations: + result = torch.sqrt(torch.sum((x-y)**2, dim=1)) + elif "manhattan" in self.layer_conf.operations: + result = torch.sum(torch.abs((x - y)), dim=1) + elif "chebyshev" in self.layer_conf.operations: + result = torch.abs((x - y)).max(dim=1) + else: + raise ConfigurationError("This operation is not supported!") + + result = result.view(batch_size, 1) + return result, None diff --git a/block_zoo/op/Combination.py b/block_zoo/op/Combination.py index eeec7f5..329d262 100644 --- a/block_zoo/op/Combination.py +++ b/block_zoo/op/Combination.py @@ -47,7 +47,6 @@ def inference(self): self.output_dim[-1] += int(np.mean([input_dim[-1] for input_dim in self.input_dims])) # difference operation requires dimension of all the inputs should be equal if "dot_multiply" in self.operations: self.output_dim[-1] += int(np.mean([input_dim[-1] for input_dim in self.input_dims])) # dot_multiply operation requires dimension of all the inputs should be equal - super(CombinationConf, self).inference() @DocInherit diff --git a/block_zoo/op/Expand_plus.py b/block_zoo/op/Expand_plus.py new file mode 100644 index 0000000..17ebb47 --- /dev/null +++ b/block_zoo/op/Expand_plus.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +# Come from http://www.hangli-hl.com/uploads/3/1/6/8/3168008/hu-etal-nips2014.pdf [ARC-II] + +import torch +import torch.nn as nn +import copy + +from block_zoo.BaseLayer import BaseLayer, BaseConf +from utils.DocInherit import DocInherit +from utils.exceptions import ConfigurationError + +class Expand_plusConf(BaseConf): + """Configuration for Expand_plus layer + + """ + def __init__(self, **kwargs): + super(Expand_plusConf, self).__init__(**kwargs) + + @DocInherit + def default(self): + self.operation = 'Plus' + + @DocInherit + def declare(self): + self.num_of_inputs = 2 + self.input_ranks = [3, 3] + + @DocInherit + def inference(self): + self.output_dim = copy.deepcopy(self.input_dims[0]) + if self.input_dims[0][1] == -1 or self.input_dims[1][1] == -1: + raise ConfigurationError("For Expand_plus layer, the sequence length should be fixed") + self.output_dim.insert(2, self.input_dims[1][1]) # y_len + super(Expand_plusConf, self).inference() # PUT THIS LINE AT THE END OF inference() + + @DocInherit + def verify(self): + super(Expand_plusConf, self).verify() + + +class Expand_plus(BaseLayer): + """ Expand_plus layer + Given sequences X and Y, put X and Y expand_dim, and then add. + + Args: + layer_conf (Expand_plusConf): configuration of a layer + + """ + def __init__(self, layer_conf): + + super(Expand_plus, self).__init__(layer_conf) + assert layer_conf.input_dims[0][-1] == layer_conf.input_dims[1][-1] + + + def forward(self, x, x_len, y, y_len): + """ + + Args: + x: [batch_size, x_max_len, dim]. + x_len: [batch_size], default is None. + y: [batch_size, y_max_len, dim]. + y_len: [batch_size], default is None. + + Returns: + output: batch_size, x_max_len, y_max_len, dim]. + + """ + + x_new = torch.stack([x]*y.size()[1], 2) # [batch_size, x_max_len, y_max_len, dim] + y_new = torch.stack([y]*x.size()[1], 1) # [batch_size, x_max_len, y_max_len, dim] + + return x_new + y_new, None + + diff --git a/block_zoo/op/__init__.py b/block_zoo/op/__init__.py index 0be67bb..896cef6 100644 --- a/block_zoo/op/__init__.py +++ b/block_zoo/op/__init__.py @@ -4,4 +4,6 @@ from .Concat3D import Concat3D, Concat3DConf from .Combination import Combination, CombinationConf from .Match import Match, MatchConf -from .Flatten import Flatten, FlattenConf \ No newline at end of file +from .Flatten import Flatten, FlattenConf +from .Expand_plus import Expand_plus, Expand_plusConf +from .CalculateDistance import CalculateDistance, CalculateDistanceConf \ No newline at end of file diff --git a/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arci.json b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arci.json new file mode 100644 index 0000000..31854b9 --- /dev/null +++ b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arci.json @@ -0,0 +1,241 @@ +{ + "license": "Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license.", + "tool_version": "1.1.0", + "model_description": "This model is used for question answer matching task, and it achieved auc: 0.7508 in WikiQACorpus test set", + "language": "English", + "inputs": { + "use_cache": true, + "dataset_type": "classification", + "data_paths": { + "train_data_path": "./dataset/WikiQACorpus/WikiQA-train.tsv", + "valid_data_path": "./dataset/WikiQACorpus/WikiQA-dev.tsv", + "test_data_path": "./dataset/WikiQACorpus/WikiQA-test.tsv", + "pre_trained_emb": "./dataset/Glove/glove.840B.300d.txt" + }, + "file_with_col_header": true, + "add_start_end_for_seq": true, + "file_header": { + "question_id": 0, + "question_text": 1, + "document_id": 2, + "document_title": 3, + "passage_id": 4, + "passage_text": 5, + "label": 6 + }, + "model_inputs": { + "question": ["question_text"], + "passage": ["passage_text"] + }, + "target": ["label"] + }, + "outputs":{ + "save_base_dir": "./models/wikiqa_arci/", + "model_name": "model.nb", + "train_log_name": "train.log", + "test_log_name": "test.log", + "predict_log_name": "predict.log", + "predict_fields": ["prediction"], + "predict_output_name": "predict.tsv", + "cache_dir": ".cache.wikiqa_arci/" + }, + "training_params": { + "vocabulary": { + "min_word_frequency": 1 + }, + "optimizer": { + "name": "Adam", + "params": { + "lr": 0.001 + } + }, + "fixed_lengths": { + "question": 200, + "passage": 200 + }, + "lr_decay": 0.90, + "minimum_lr": 0.00005, + "epoch_start_lr_decay": 20, + "use_gpu": true, + "cpu_num_workers": 1, + "batch_size": 64, + "batch_num_to_show_results": 500, + "max_epoch": 10, + "valid_times_per_epoch": 2 + }, + "architecture":[ + { + "layer": "Embedding", + "conf": { + "word": { + "cols": ["question_text", "passage_text"], + "dim": 300, + "fix_weight": false + } + } + }, + { + "layer_id": "s1_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["question"] + }, + { + "layer_id": "s2_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["passage"] + }, + { + "layer_id": "s1_conv_1", + "layer": "Conv", + "conf": { + "window_size": 3, + "output_channel_num": 32, + "padding_type": "SAME", + "remind_lengths": false + }, + "inputs": ["s1_dropout"] + }, + { + "layer_id": "s1_pool_1", + "layer": "Pooling1D", + "conf": { + "stride": 1, + "window_size": 2 + }, + "inputs": ["s1_conv_1"] + }, + { + "layer_id": "s1_conv_2", + "layer": "Conv", + "conf": { + "window_size": 3, + "output_channel_num": 32, + "padding_type": "SAME" + }, + "inputs": ["s1_pool_1"] + }, + { + "layer_id": "s1_pool_2", + "layer": "Pooling1D", + "conf": { + "stride": 1, + "window_size": 2 + }, + "inputs": ["s1_conv_2"] + }, + { + "layer_id": "s1_flatten", + "layer": "Flatten", + "conf": { + + }, + "inputs": ["s1_pool_2"] + }, + { + "layer_id": "s2_conv_1", + "layer": "Conv", + "conf": { + "window_size": 3, + "output_channel_num": 32, + "padding_type": "SAME", + "remind_lengths": false + }, + "inputs": ["s2_dropout"] + }, + { + "layer_id": "s2_pool_1", + "layer": "Pooling1D", + "conf": { + "stride": 1, + "window_size": 2 + }, + "inputs": ["s2_conv_1"] + }, + { + "layer_id": "s2_conv_2", + "layer": "Conv", + "conf": { + "window_size": 3, + "output_channel_num": 32, + "padding_type": "SAME" + }, + "inputs": ["s2_pool_1"] + }, + { + "layer_id": "s2_pool_2", + "layer": "Pooling1D", + "conf": { + "stride": 1, + "window_size": 2 + }, + "inputs": ["s2_conv_2"] + }, + { + "layer_id": "s2_flatten", + "layer": "Flatten", + "conf": { + + }, + "inputs": ["s2_pool_2"] + }, + { + "layer_id": "comb", + "layer": "Combination", + "conf": { + "operations": ["origin"] + }, + "inputs": ["s1_flatten", "s2_flatten"] + }, + { + "layer_id": "comb_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["comb"] + }, + { + "layer_id": "mlp", + "layer": "Linear", + "conf": { + "hidden_dim": [64, 32], + "activation": "ReLU", + "batch_norm": true, + "last_hidden_activation": true + }, + "inputs": ["comb_dropout"] + }, + { + "output_layer_flag": true, + "layer_id": "output", + "layer": "Linear", + "conf": { + "hidden_dim": [-1], + "activation": "ReLU", + "batch_norm": true, + "last_hidden_activation": false, + "last_hidden_softmax": true + }, + "inputs": ["mlp"] + } + ], + "loss": { + "losses": [ + { + "type": "CrossEntropyLoss", + "conf": { + "weight": [0.1,0.9], + "size_average": true + }, + "inputs": ["output","label"] + } + ] + }, + "metrics": ["auc", "accuracy"] +} \ No newline at end of file diff --git a/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json new file mode 100644 index 0000000..e9bf0d9 --- /dev/null +++ b/model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_arcii.json @@ -0,0 +1,212 @@ +{ + "license": "Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license.", + "tool_version": "1.1.0", + "model_description": "This model is used for question answer matching task, and it achieved auc: 0.7612 in WikiQACorpus test set", + "language": "English", + "inputs": { + "use_cache": true, + "dataset_type": "classification", + "data_paths": { + "train_data_path": "./dataset/WikiQACorpus/WikiQA-train.tsv", + "valid_data_path": "./dataset/WikiQACorpus/WikiQA-dev.tsv", + "test_data_path": "./dataset/WikiQACorpus/WikiQA-test.tsv", + "pre_trained_emb": "./dataset/Glove/glove.840B.300d.txt" + }, + "file_with_col_header": true, + "add_start_end_for_seq": true, + "file_header": { + "question_id": 0, + "question_text": 1, + "document_id": 2, + "document_title": 3, + "passage_id": 4, + "passage_text": 5, + "label": 6 + }, + "model_inputs": { + "question": ["question_text"], + "passage": ["passage_text"] + }, + "target": ["label"] + }, + "outputs":{ + "save_base_dir": "./models/wikiqa_arcii/", + "model_name": "model.nb", + "train_log_name": "train.log", + "test_log_name": "test.log", + "predict_log_name": "predict.log", + "predict_fields": ["prediction"], + "predict_output_name": "predict.tsv", + "cache_dir": ".cache.wikiqa_arcii/" + }, + "training_params": { + "vocabulary": { + "min_word_frequency": 1 + }, + "optimizer": { + "name": "Adam", + "params": { + "lr": 0.001 + } + }, + "fixed_lengths": { + "question": 200, + "passage": 200 + }, + "lr_decay": 0.9, + "minimum_lr": 0.00005, + "epoch_start_lr_decay": 20, + "use_gpu": true, + "cpu_num_workers": 1, + "batch_size": 64, + "batch_num_to_show_results": 500, + "max_epoch": 10, + "valid_times_per_epoch": 1 + }, + "architecture":[ + { + "layer": "Embedding", + "conf": { + "word": { + "cols": ["question_text", "passage_text"], + "dim": 300, + "fix_weight": true + } + } + }, + { + "layer_id": "s1_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["question"] + }, + { + "layer_id": "s2_dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["passage"] + }, + { + "layer_id": "s1_conv_1", + "layer": "Conv", + "conf": { + "window_size": 3, + "output_channel_num": 32, + "padding_type": "SAME", + "remind_lengths": false + }, + "inputs": ["s1_dropout"] + }, + { + "layer_id": "s2_conv_1", + "layer": "Conv", + "conf": { + "window_size": 3, + "output_channel_num": 32, + "padding_type": "SAME", + "remind_lengths": false + }, + "inputs": ["s2_dropout"] + }, + { + "layer_id": "match", + "layer": "Expand_plus", + "conf": { + }, + "inputs": ["s1_conv_1", "s2_conv_1"] + }, + { + "layer_id": "conv2D_1", + "layer": "Conv2D", + "conf": { + "window_size": [3,3], + "output_channel_num": 32, + "padding_type": "SAME" + }, + "inputs": ["match"] + }, + { + "layer_id": "pool2D_1", + "layer": "Pooling2D", + "conf": { + "window_size": [2,2] + }, + "inputs": ["conv2D_1"] + }, + { + "layer_id": "conv2D_2", + "layer": "Conv2D", + "conf": { + "window_size": [3,3], + "output_channel_num": 32, + "padding_type": "SAME" + }, + "inputs": ["pool2D_1"] + }, + { + "layer_id": "pool2D_2", + "layer": "Pooling2D", + "conf": { + "window_size": [2,2] + }, + "inputs": ["conv2D_2"] + }, + { + "layer_id": "flatten", + "layer": "Flatten", + "conf": { + + }, + "inputs": ["pool2D_2"] + }, + { + "layer_id": "dropout", + "layer": "Dropout", + "conf": { + "dropout": 0.5 + }, + "inputs": ["flatten"] + }, + { + "layer_id": "mlp", + "layer": "Linear", + "conf": { + "hidden_dim": [64, 32], + "activation": "ReLU", + "batch_norm": true, + "last_hidden_activation": true + }, + "inputs": ["dropout"] + }, + { + "output_layer_flag": true, + "layer_id": "output", + "layer": "Linear", + "conf": { + "hidden_dim": [-1], + "activation": "ReLU", + "batch_norm": true, + "last_hidden_activation": false, + "last_hidden_softmax": true + }, + "inputs": ["mlp"] + } + ], + "loss": { + "losses": [ + { + "type": "CrossEntropyLoss", + "conf": { + "weight": [0.1,0.9], + "size_average": true + }, + "inputs": ["output","label"] + } + ] + }, + "metrics": ["auc", "accuracy"] +} \ No newline at end of file