@@ -54,34 +54,34 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
5454
5555 Workspace<float > workspace (
5656 /* options=*/ options,
57- /* dtype_data=*/ float_workspace.data <float >(),
57+ /* dtype_data=*/ float_workspace.data_ptr <float >(),
5858 /* dtype_size=*/ float_workspace.numel (),
59- /* int_data=*/ int_workspace.data <int >(),
59+ /* int_data=*/ int_workspace.data_ptr <int >(),
6060 /* int_size=*/ int_workspace.numel ());
6161
62- switch (logits.type (). scalarType ()) {
62+ switch (logits.scalar_type ()) {
6363 case torch::ScalarType::Float: {
6464 Compute</* DTYPE=*/ float , /* CAST_DTYPE=*/ float >(
6565 /* workspace=*/ workspace,
66- /* logits=*/ logits.data <float >(),
67- /* targets=*/ targets.data <int >(),
68- /* src_lengths=*/ src_lengths.data <int >(),
69- /* tgt_lengths=*/ tgt_lengths.data <int >(),
70- /* costs=*/ costs.data <float >(),
66+ /* logits=*/ logits.data_ptr <float >(),
67+ /* targets=*/ targets.data_ptr <int >(),
68+ /* src_lengths=*/ src_lengths.data_ptr <int >(),
69+ /* tgt_lengths=*/ tgt_lengths.data_ptr <int >(),
70+ /* costs=*/ costs.data_ptr <float >(),
7171 /* gradients=*/
72- (gradients == c10::nullopt ) ? nullptr : gradients->data <float >());
72+ (gradients == c10::nullopt ) ? nullptr : gradients->data_ptr <float >());
7373 break ;
7474 }
7575 case torch::ScalarType::Half: {
7676 Compute</* DTYPE=*/ c10::Half, /* CAST_DTYPE=*/ float >(
7777 /* workspace=*/ workspace,
78- /* logits=*/ logits.data <c10::Half>(),
79- /* targets=*/ targets.data <int >(),
80- /* src_lengths=*/ src_lengths.data <int >(),
81- /* tgt_lengths=*/ tgt_lengths.data <int >(),
82- /* costs=*/ costs.data <c10::Half>(),
78+ /* logits=*/ logits.data_ptr <c10::Half>(),
79+ /* targets=*/ targets.data_ptr <int >(),
80+ /* src_lengths=*/ src_lengths.data_ptr <int >(),
81+ /* tgt_lengths=*/ tgt_lengths.data_ptr <int >(),
82+ /* costs=*/ costs.data_ptr <c10::Half>(),
8383 /* gradients=*/
84- (gradients == c10::nullopt ) ? nullptr : gradients->data <c10::Half>());
84+ (gradients == c10::nullopt ) ? nullptr : gradients->data_ptr <c10::Half>());
8585 break ;
8686 }
8787 default : {
0 commit comments