@@ -438,6 +438,43 @@ Tensor softmax(const Tensor& input_, const int64_t dim_, c10::optional<ScalarTyp
438438 return result;
439439}
440440
441+ Tensor& softmax_out (
442+ const Tensor& input_,
443+ const int64_t dim_,
444+ c10::optional<ScalarType> dtype,
445+ Tensor& output_) {
446+ Tensor output_temp;
447+ if (input_.is_cuda () && input_.scalar_type () == ScalarType::Half &&
448+ dtype == ScalarType::Float) {
449+ if (!output_.is_contiguous ()) {
450+ auto options =
451+ TensorOptions ().dtype (output_.dtype ()).device (output_.device ());
452+ output_temp = at::empty (output_.sizes (), options);
453+ at::_softmax_out (output_temp, input_, dim_, true );
454+ } else {
455+ at::_softmax_out (output_, input_, dim_, true );
456+ }
457+ } else {
458+ Tensor converted =
459+ dtype.has_value () ? input_.toType (dtype.value ()) : input_;
460+ if (!output_.is_contiguous ()) {
461+ auto options =
462+ TensorOptions ().dtype (output_.dtype ()).device (output_.device ());
463+ output_temp = at::empty (output_.sizes (), options);
464+ at::_softmax_out (output_temp, converted, dim_, false );
465+ } else {
466+ at::_softmax_out (output_, converted, dim_, false );
467+ }
468+ }
469+
470+ if (!output_.is_contiguous ()) {
471+ output_.resize_ (output_temp.sizes ());
472+ output_.copy_ (output_temp);
473+ }
474+
475+ return output_;
476+ }
477+
441478// special_softmax, alias for softmax
442479Tensor special_softmax (const Tensor& input_, const int64_t dim_, c10::optional<ScalarType> dtype) {
443480 return at::softmax (input_, dim_, dtype);
@@ -466,6 +503,43 @@ Tensor log_softmax(const Tensor& input_, const int64_t dim_, c10::optional<Scala
466503 return result;
467504}
468505
506+ Tensor& log_softmax_out (
507+ const Tensor& input_,
508+ const int64_t dim_,
509+ c10::optional<ScalarType> dtype,
510+ Tensor& output_) {
511+ Tensor output_temp;
512+ if (input_.is_cuda () && input_.scalar_type () == ScalarType::Half &&
513+ dtype == ScalarType::Float) {
514+ if (!output_.is_contiguous ()) {
515+ auto options =
516+ TensorOptions ().dtype (output_.dtype ()).device (output_.device ());
517+ output_temp = at::empty (output_.sizes (), options);
518+ at::_log_softmax_out (output_temp, input_, dim_, true );
519+ } else {
520+ at::_log_softmax_out (output_, input_, dim_, true );
521+ }
522+ } else {
523+ Tensor converted =
524+ dtype.has_value () ? input_.toType (dtype.value ()) : input_;
525+ if (!output_.is_contiguous ()) {
526+ auto options =
527+ TensorOptions ().dtype (output_.dtype ()).device (output_.device ());
528+ output_temp = at::empty (output_.sizes (), options);
529+ at::_log_softmax_out (output_temp, converted, dim_, false );
530+ } else {
531+ at::_log_softmax_out (output_, converted, dim_, false );
532+ }
533+ }
534+
535+ if (!output_.is_contiguous ()) {
536+ output_.resize_ (output_temp.sizes ());
537+ output_.copy_ (output_temp);
538+ }
539+
540+ return output_;
541+ }
542+
469543Tensor special_log_softmax (const Tensor& input, const int64_t dim, c10::optional<ScalarType> dtype) {
470544 return at::log_softmax (input, dim, dtype);
471545}
0 commit comments