Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 85 additions & 21 deletions stanza/models/common/packed_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@ def forward(self, input, lengths, hx=None):

class LSTMwRecDropout(nn.Module):
""" An LSTM implementation that supports recurrent dropout """
def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
def __init__(
self,
input_size,
hidden_size,
num_layers,
bias=True,
batch_first=False,
dropout=0,
bidirectional=False,
pad=False,
rec_dropout=0
):
super().__init__()
self.batch_first = batch_first
self.pad = pad
Expand All @@ -49,57 +60,110 @@ def forward(self, input, hx=None):
def rnn_loop(x, batch_sizes, cell, inits, reverse=False):
# RNN loop for one layer in one direction with recurrent dropout
# Assumes input is PackedSequence, returns PackedSequence as well

batch_size = batch_sizes[0].item()
states = [list(init.split([1] * batch_size)) for init in inits]

# Pre-calculate recurrent dropout mask only once per batch, not per time step
h_drop_mask = x.new_ones(batch_size, self.hidden_size)
h_drop_mask = self.rec_drop(h_drop_mask)

resh = []

# Preallocate tensor for output hidden states to minimize cat/appends
output_chunks = []
if not reverse:
st = 0
for bs in batch_sizes:
s1 = cell(x[st:st+bs], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
end = st + bs
# Efficient grouping/stacking of initial states for batch
h = torch.cat(states[0][:bs], 0)
c = torch.cat(states[1][:bs], 0)
# Mask only the hidden state with recurrent dropout
h = h * h_drop_mask[:bs]
s1 = cell(x[st:end], (h, c))
resh.append(s1[0])
# Update states in-place (avoid list append/insert ops)
h_split = s1[0].split(1, 0)
c_split = s1[1].split(1, 0)
for j in range(bs):
states[0][j] = s1[0][j].unsqueeze(0)
states[1][j] = s1[1][j].unsqueeze(0)
st += bs
states[0][j] = h_split[j]
states[1][j] = c_split[j]
st = end
else:
en = x.size(0)
for i in range(batch_sizes.size(0)-1, -1, -1):
for i in range(batch_sizes.size(0) - 1, -1, -1):
bs = batch_sizes[i]
s1 = cell(x[en-bs:en], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
st = en - bs
h = torch.cat(states[0][:bs], 0)
c = torch.cat(states[1][:bs], 0)
h = h * h_drop_mask[:bs]
s1 = cell(x[st:en], (h, c))
resh.append(s1[0])
h_split = s1[0].split(1, 0)
c_split = s1[1].split(1, 0)
for j in range(bs):
states[0][j] = s1[0][j].unsqueeze(0)
states[1][j] = s1[1][j].unsqueeze(0)
en -= bs
states[0][j] = h_split[j]
states[1][j] = c_split[j]
en = st
resh = list(reversed(resh))

return torch.cat(resh, 0), tuple(torch.cat(s, 0) for s in states)
# Avoid multiple torch.cat calls in loop, concat once at end
output = torch.cat(resh, 0)
# Stack states (hidden and cell) as contiguous tensor (minimize cat overhead)
final_states = tuple(torch.cat(s, 0) for s in states)
return output, final_states

all_states = [[], []]
inputdata, batch_sizes = input.data, input.batch_sizes
for l in range(self.num_layers):
new_input = []

if self.dropout > 0 and l > 0:
inputdata = self.drop(inputdata)
for d in range(self.num_directions):
idx = l * self.num_directions + d
cell = self.cells[idx]
out, states = rnn_loop(inputdata, batch_sizes, cell, (hx[i][idx] for i in range(2)) if hx is not None else (input.data.new_zeros(input.batch_sizes[0].item(), self.hidden_size, requires_grad=False) for _ in range(2)), reverse=(d == 1))
# Use tuple and list comprehension to reduce per layer looping overhead
# Loop unrolling for single direction
num_layers = self.num_layers
num_directions = self.num_directions
cells = self.cells
drop = self.drop
dropout = self.dropout
hidden_size = self.hidden_size

# Preallocate tensor for states, reduce attribute lookups in inner loops
hx_is_not_none = hx is not None
batch_size = batch_sizes[0].item()

for l in range(num_layers):
new_input = []

# Apply dropout only between layers if needed
if dropout > 0 and l > 0:
inputdata = drop(inputdata)
for d in range(num_directions):
idx = l * num_directions + d
cell = cells[idx]

if hx_is_not_none:
hx_gen = (hx[i][idx] for i in range(2))
else:
hx_gen = (
inputdata.new_zeros(batch_size, hidden_size, requires_grad=False)
for _ in range(2)
)

# Forward one direction at a time using efficient masking and batch updating
out, states = rnn_loop(
inputdata, batch_sizes, cell, hx_gen, reverse=(d == 1)
)
new_input.append(out)
# Use unsqueeze for correct stacking, reduce list overhead
all_states[0].append(states[0].unsqueeze(0))
all_states[1].append(states[1].unsqueeze(0))

if self.num_directions > 1:
# concatenate both directions
if num_directions > 1:
# Concatenate outputs from both directions once at the end
inputdata = torch.cat(new_input, 1)
else:
inputdata = new_input[0]

input = PackedSequence(inputdata, batch_sizes)

# Use tuple/list comprehension for final stacking of hidden/cell states
return input, tuple(torch.cat(x, 0) for x in all_states)