@@ -290,6 +290,7 @@ impl Client {
290290        } 
291291
292292        let  stream_description = connection. stream_description ( ) ?; 
293+         let  is_sharded = stream_description. initial_server_type  == ServerType :: Mongos ; 
293294        let  mut  cmd = op. build ( stream_description) ?; 
294295        self . inner 
295296            . topology 
@@ -328,15 +329,22 @@ impl Client {
328329                        cmd. set_start_transaction ( ) ; 
329330                        cmd. set_autocommit ( ) ; 
330331                        cmd. set_txn_read_concern ( * session) ?; 
331-                         if  stream_description . initial_server_type  ==  ServerType :: Mongos  { 
332+                         if  is_sharded  { 
332333                            session. pin_mongos ( connection. address ( ) . clone ( ) ) ; 
333334                        } 
334335                        session. transaction . state  = TransactionState :: InProgress ; 
335336                    } 
336-                     TransactionState :: InProgress 
337-                     | TransactionState :: Committed  {  .. } 
338-                     | TransactionState :: Aborted  => { 
337+                     TransactionState :: InProgress  => cmd. set_autocommit ( ) , 
338+                     TransactionState :: Committed  {  .. }  | TransactionState :: Aborted  => { 
339339                        cmd. set_autocommit ( ) ; 
340+ 
341+                         // Append the recovery token to the command if we are committing or aborting 
342+                         // on a sharded transaction. 
343+                         if  is_sharded { 
344+                             if  let  Some ( ref  recovery_token)  = session. transaction . recovery_token  { 
345+                                 cmd. set_recovery_token ( recovery_token) ; 
346+                             } 
347+                         } 
340348                    } 
341349                    _ => { } 
342350                } 
@@ -403,6 +411,9 @@ impl Client {
403411                    Ok ( r)  => { 
404412                        self . update_cluster_time ( & r,  session) . await ; 
405413                        if  r. is_success ( )  { 
414+                             // Retrieve recovery token from successful response. 
415+                             Client :: update_recovery_token ( is_sharded,  & r,  session) . await ; 
416+ 
406417                            Ok ( CommandResult  { 
407418                                raw :  response, 
408419                                deserialized :  r. into_body ( ) , 
@@ -447,7 +458,15 @@ impl Client {
447458                                        } ) ) 
448459                                    } 
449460                                    // for ok: 1 just return the original deserialization error. 
450-                                     _ => Err ( deserialize_error) , 
461+                                     _ => { 
462+                                         Client :: update_recovery_token ( 
463+                                             is_sharded, 
464+                                             & error_response, 
465+                                             session, 
466+                                         ) 
467+                                         . await ; 
468+                                         Err ( deserialize_error) 
469+                                     } 
451470                                } 
452471                            } 
453472                            // We failed to deserialize even that, so just return the original 
@@ -626,6 +645,18 @@ impl Client {
626645            } 
627646        } 
628647    } 
648+ 
649+     async  fn  update_recovery_token < T :  Response > ( 
650+         is_sharded :  bool , 
651+         response :  & T , 
652+         session :  & mut  Option < & mut  ClientSession > , 
653+     )  { 
654+         if  let  Some ( ref  mut  session)  = session { 
655+             if  is_sharded && session. in_transaction ( )  { 
656+                 session. transaction . recovery_token  = response. recovery_token ( ) . cloned ( ) ; 
657+             } 
658+         } 
659+     } 
629660} 
630661
631662impl  Error  { 
0 commit comments