y_pred = []
y_true = []
for inputs, labels in tqdm(data_loader, total=len(data_loader)):
with torch.no_grad():
inputs, labels = inputs.cuda(), labels.cuda()
output = bert_punc(inputs)
y_pred += list(output.argmax(dim=1).cpu().data.numpy().flatten())
y_true += list(labels.cpu().data.numpy().flatten())
return y_pred, y_true
Где:
data_set = TensorDataset(torch.from_numpy(X).long(), torch.from_numpy(np.array(y)).long())
data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=shuffle)
И всё летало без проблем. Не знаешь, почему здесь не так?
То что из него выходит это inputs, labels в данном случае
Обсуждают сегодня