|
94 | 94 | train_datapipe = SST2(split="train") |
95 | 95 | dev_datapipe = SST2(split="dev") |
96 | 96 |
|
| 97 | + |
97 | 98 | # Transform the raw dataset using non-batched API (i.e apply transformation line by line) |
98 | | -train_datapipe = train_datapipe.map(lambda x: (text_transform(x[0]), x[1])) |
| 99 | +def apply_transform(x): |
| 100 | + return text_transform(x[0]), x[1] |
| 101 | + |
| 102 | + |
| 103 | +train_datapipe = train_datapipe.map(apply_transform) |
99 | 104 | train_datapipe = train_datapipe.batch(batch_size) |
100 | 105 | train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"]) |
101 | 106 | train_dataloader = DataLoader(train_datapipe, batch_size=None) |
102 | 107 |
|
103 | | -dev_datapipe = dev_datapipe.map(lambda x: (text_transform(x[0]), x[1])) |
| 108 | +dev_datapipe = dev_datapipe.map(apply_transform) |
104 | 109 | dev_datapipe = dev_datapipe.batch(batch_size) |
105 | 110 | dev_datapipe = dev_datapipe.rows2columnar(["token_ids", "target"]) |
106 | 111 | dev_dataloader = DataLoader(dev_datapipe, batch_size=None) |
|
111 | 116 | # |
112 | 117 | # :: |
113 | 118 | # |
114 | | -# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
115 | | -# train_datapipe = train_datapipe.map(lambda x: {"token_ids": text_transform(x["text"]), "target": label_transform(x["label"])}) |
116 | | -# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
117 | | -# dev_datapipe = dev_datapipe.map(lambda x: {"token_ids": text_transform(x["text"]), "target": label_transform(x["label"])}) |
| 119 | +# def batch_transform(x): |
| 120 | +# return {"token_ids": text_transform(x["text"]), "target": x["label"]} |
| 121 | +# |
| 122 | +# |
| 123 | +# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
| 124 | +# train_datapipe = train_datapipe.map(lambda x: batch_transform) |
| 125 | +# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
| 126 | +# dev_datapipe = dev_datapipe.map(lambda x: batch_transform) |
118 | 127 | # |
119 | 128 |
|
120 | 129 | ###################################################################### |
|
0 commit comments