@@ -760,3 +760,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
760760 return rc ;
761761}
762762#endif
763+
764+ #if HCOLL_API > HCOLL_VERSION (4 ,5 )
765+ int mca_coll_hcoll_reduce_scatter_block (const void * sbuf , void * rbuf , int rcount ,
766+ struct ompi_datatype_t * dtype ,
767+ struct ompi_op_t * op ,
768+ struct ompi_communicator_t * comm ,
769+ mca_coll_base_module_t * module ) {
770+ dte_data_representation_t Dtype ;
771+ hcoll_dte_op_t * Op ;
772+ int rc ;
773+ HCOL_VERBOSE (20 ,"RUNNING HCOL REDUCE SCATTER BLOCK" );
774+ mca_coll_hcoll_module_t * hcoll_module = (mca_coll_hcoll_module_t * )module ;
775+ Dtype = ompi_dtype_2_hcoll_dtype (dtype , NO_DERIVED );
776+ if (OPAL_UNLIKELY (HCOL_DTE_IS_ZERO (Dtype ))){
777+ /*If we are here then datatype is not simple predefined datatype */
778+ /*In future we need to add more complex mapping to the dte_data_representation_t */
779+ /* Now use fallback */
780+ HCOL_VERBOSE (20 ,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;" ,
781+ dtype -> super .name );
782+ goto fallback ;
783+ }
784+
785+ Op = ompi_op_2_hcolrte_op (op );
786+ if (OPAL_UNLIKELY (HCOL_DTE_OP_NULL == Op -> id )){
787+ /*If we are here then datatype is not simple predefined datatype */
788+ /*In future we need to add more complex mapping to the dte_data_representation_t */
789+ /* Now use fallback */
790+ HCOL_VERBOSE (20 ,"ompi_op_t is not supported: op = %s; calling fallback allreduce;" ,
791+ op -> o_name );
792+ goto fallback ;
793+ }
794+
795+ rc = hcoll_collectives .coll_reduce_scatter_block ((void * )sbuf ,rbuf ,rcount ,Dtype ,Op ,hcoll_module -> hcoll_context );
796+ if (HCOLL_SUCCESS != rc ){
797+ fallback :
798+ HCOL_VERBOSE (20 ,"RUNNING FALLBACK ALLREDUCE" );
799+ rc = hcoll_module -> previous_reduce_scatter_block (sbuf ,rbuf ,
800+ rcount ,dtype ,op ,
801+ comm , hcoll_module -> previous_allreduce_module );
802+ }
803+ return rc ;
804+ }
805+
806+ int mca_coll_hcoll_reduce_scatter (const void * sbuf , void * rbuf , const int * rcounts ,
807+ struct ompi_datatype_t * dtype ,
808+ struct ompi_op_t * op ,
809+ struct ompi_communicator_t * comm ,
810+ mca_coll_base_module_t * module ) {
811+ dte_data_representation_t Dtype ;
812+ hcoll_dte_op_t * Op ;
813+ int rc ;
814+ HCOL_VERBOSE (20 ,"RUNNING HCOL REDUCE SCATTER" );
815+ mca_coll_hcoll_module_t * hcoll_module = (mca_coll_hcoll_module_t * )module ;
816+ Dtype = ompi_dtype_2_hcoll_dtype (dtype , NO_DERIVED );
817+ if (OPAL_UNLIKELY (HCOL_DTE_IS_ZERO (Dtype ))){
818+ /*If we are here then datatype is not simple predefined datatype */
819+ /*In future we need to add more complex mapping to the dte_data_representation_t */
820+ /* Now use fallback */
821+ HCOL_VERBOSE (20 ,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;" ,
822+ dtype -> super .name );
823+ goto fallback ;
824+ }
825+
826+ Op = ompi_op_2_hcolrte_op (op );
827+ if (OPAL_UNLIKELY (HCOL_DTE_OP_NULL == Op -> id )){
828+ /*If we are here then datatype is not simple predefined datatype */
829+ /*In future we need to add more complex mapping to the dte_data_representation_t */
830+ /* Now use fallback */
831+ HCOL_VERBOSE (20 ,"ompi_op_t is not supported: op = %s; calling fallback allreduce;" ,
832+ op -> o_name );
833+ goto fallback ;
834+ }
835+
836+ rc = hcoll_collectives .coll_reduce_scatter ((void * )sbuf , rbuf , (int * )rcounts ,
837+ Dtype , Op , hcoll_module -> hcoll_context );
838+ if (HCOLL_SUCCESS != rc ){
839+ fallback :
840+ HCOL_VERBOSE (20 ,"RUNNING FALLBACK ALLREDUCE" );
841+ rc = hcoll_module -> previous_reduce_scatter (sbuf ,rbuf ,
842+ rcounts ,dtype ,op ,
843+ comm , hcoll_module -> previous_allreduce_module );
844+ }
845+ return rc ;
846+ }
847+ #endif
0 commit comments