@@ -37,7 +37,10 @@ use rustc_middle::traits::ObligationCause;
3737use  rustc_middle:: ty:: error:: { ExpectedFound ,  TypeError } ; 
3838use  rustc_middle:: ty:: relate:: { self ,  Relate ,  RelateResult ,  TypeRelation } ; 
3939use  rustc_middle:: ty:: subst:: SubstsRef ; 
40- use  rustc_middle:: ty:: { self ,  InferConst ,  Ty ,  TyCtxt ,  TypeVisitable } ; 
40+ use  rustc_middle:: ty:: { 
41+     self ,  FallibleTypeFolder ,  InferConst ,  Ty ,  TyCtxt ,  TypeFoldable ,  TypeSuperFoldable , 
42+     TypeVisitable , 
43+ } ; 
4144use  rustc_middle:: ty:: { IntType ,  UintType } ; 
4245use  rustc_span:: { Span ,  DUMMY_SP } ; 
4346
@@ -140,8 +143,6 @@ impl<'tcx> InferCtxt<'tcx> {
140143        let  a = self . shallow_resolve ( a) ; 
141144        let  b = self . shallow_resolve ( b) ; 
142145
143-         let  a_is_expected = relation. a_is_expected ( ) ; 
144- 
145146        match  ( a. kind ( ) ,  b. kind ( ) )  { 
146147            ( 
147148                ty:: ConstKind :: Infer ( InferConst :: Var ( a_vid) ) , 
@@ -158,11 +159,11 @@ impl<'tcx> InferCtxt<'tcx> {
158159            } 
159160
160161            ( ty:: ConstKind :: Infer ( InferConst :: Var ( vid) ) ,  _)  => { 
161-                 return  self . unify_const_variable ( relation . param_env ( ) ,   vid,  b,  a_is_expected ) ; 
162+                 return  self . unify_const_variable ( vid,  b) ; 
162163            } 
163164
164165            ( _,  ty:: ConstKind :: Infer ( InferConst :: Var ( vid) ) )  => { 
165-                 return  self . unify_const_variable ( relation . param_env ( ) ,   vid,  a,  !a_is_expected ) ; 
166+                 return  self . unify_const_variable ( vid,  a) ; 
166167            } 
167168            ( ty:: ConstKind :: Unevaluated ( ..) ,  _)  if  self . tcx . lazy_normalization ( )  => { 
168169                // FIXME(#59490): Need to remove the leak check to accommodate 
@@ -223,10 +224,8 @@ impl<'tcx> InferCtxt<'tcx> {
223224     #[ instrument( level = "debug" ,  skip( self ) ) ]  
224225    fn  unify_const_variable ( 
225226        & self , 
226-         param_env :  ty:: ParamEnv < ' tcx > , 
227227        target_vid :  ty:: ConstVid < ' tcx > , 
228228        ct :  ty:: Const < ' tcx > , 
229-         vid_is_expected :  bool , 
230229    )  -> RelateResult < ' tcx ,  ty:: Const < ' tcx > >  { 
231230        let  ( for_universe,  span)  = { 
232231            let  mut  inner = self . inner . borrow_mut ( ) ; 
@@ -239,8 +238,12 @@ impl<'tcx> InferCtxt<'tcx> {
239238                ConstVariableValue :: Unknown  {  universe }  => ( universe,  var_value. origin . span ) , 
240239            } 
241240        } ; 
242-         let  value = ConstInferUnifier  {  infcx :  self ,  span,  param_env,  for_universe,  target_vid } 
243-             . relate ( ct,  ct) ?; 
241+         let  value = ct. try_fold_with ( & mut  ConstInferUnifier  { 
242+             infcx :  self , 
243+             span, 
244+             for_universe, 
245+             target_vid, 
246+         } ) ?; 
244247
245248        self . inner . borrow_mut ( ) . const_unification_table ( ) . union_value ( 
246249            target_vid, 
@@ -800,8 +803,6 @@ struct ConstInferUnifier<'cx, 'tcx> {
800803
801804    span :  Span , 
802805
803-     param_env :  ty:: ParamEnv < ' tcx > , 
804- 
805806    for_universe :  ty:: UniverseIndex , 
806807
807808    /// The vid of the const variable that is in the process of being 
@@ -810,69 +811,23 @@ struct ConstInferUnifier<'cx, 'tcx> {
810811     target_vid :  ty:: ConstVid < ' tcx > , 
811812} 
812813
813- // We use `TypeRelation` here to propagate `RelateResult` upwards. 
814- // 
815- // Both inputs are expected to be the same. 
816- impl < ' tcx >  TypeRelation < ' tcx >  for  ConstInferUnifier < ' _ ,  ' tcx >  { 
817-     fn  tcx ( & self )  -> TyCtxt < ' tcx >  { 
818-         self . infcx . tcx 
819-     } 
820- 
821-     fn  intercrate ( & self )  -> bool  { 
822-         assert ! ( !self . infcx. intercrate) ; 
823-         false 
824-     } 
825- 
826-     fn  param_env ( & self )  -> ty:: ParamEnv < ' tcx >  { 
827-         self . param_env 
828-     } 
829- 
830-     fn  tag ( & self )  -> & ' static  str  { 
831-         "ConstInferUnifier" 
832-     } 
833- 
834-     fn  a_is_expected ( & self )  -> bool  { 
835-         true 
836-     } 
837- 
838-     fn  mark_ambiguous ( & mut  self )  { 
839-         bug ! ( ) 
840-     } 
841- 
842-     fn  relate_with_variance < T :  Relate < ' tcx > > ( 
843-         & mut  self , 
844-         _variance :  ty:: Variance , 
845-         _info :  ty:: VarianceDiagInfo < ' tcx > , 
846-         a :  T , 
847-         b :  T , 
848-     )  -> RelateResult < ' tcx ,  T >  { 
849-         // We don't care about variance here. 
850-         self . relate ( a,  b) 
851-     } 
814+ impl < ' tcx >  FallibleTypeFolder < ' tcx >  for  ConstInferUnifier < ' _ ,  ' tcx >  { 
815+     type  Error  = TypeError < ' tcx > ; 
852816
853-     fn  binders < T > ( 
854-         & mut  self , 
855-         a :  ty:: Binder < ' tcx ,  T > , 
856-         b :  ty:: Binder < ' tcx ,  T > , 
857-     )  -> RelateResult < ' tcx ,  ty:: Binder < ' tcx ,  T > > 
858-     where 
859-         T :  Relate < ' tcx > , 
860-     { 
861-         Ok ( a. rebind ( self . relate ( a. skip_binder ( ) ,  b. skip_binder ( ) ) ?) ) 
817+     fn  tcx < ' a > ( & ' a  self )  -> TyCtxt < ' tcx >  { 
818+         self . infcx . tcx 
862819    } 
863820
864821    #[ instrument( level = "debug" ,  skip( self ) ,  ret) ]  
865-     fn  tys ( & mut  self ,  t :  Ty < ' tcx > ,  _t :  Ty < ' tcx > )  -> RelateResult < ' tcx ,  Ty < ' tcx > >  { 
866-         debug_assert_eq ! ( t,  _t) ; 
867- 
822+     fn  try_fold_ty ( & mut  self ,  t :  Ty < ' tcx > )  -> Result < Ty < ' tcx > ,  TypeError < ' tcx > >  { 
868823        match  t. kind ( )  { 
869824            & ty:: Infer ( ty:: TyVar ( vid) )  => { 
870825                let  vid = self . infcx . inner . borrow_mut ( ) . type_variables ( ) . root_var ( vid) ; 
871826                let  probe = self . infcx . inner . borrow_mut ( ) . type_variables ( ) . probe ( vid) ; 
872827                match  probe { 
873828                    TypeVariableValue :: Known  {  value :  u }  => { 
874829                        debug ! ( "ConstOccursChecker: known value {:?}" ,  u) ; 
875-                         self . tys ( u ,  u ) 
830+                         u . try_fold_with ( self ) 
876831                    } 
877832                    TypeVariableValue :: Unknown  {  universe }  => { 
878833                        if  self . for_universe . can_name ( universe)  { 
@@ -892,16 +847,15 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
892847                } 
893848            } 
894849            ty:: Infer ( ty:: IntVar ( _)  | ty:: FloatVar ( _) )  => Ok ( t) , 
895-             _ => relate :: super_relate_tys ( self ,  t ,  t ) , 
850+             _ => t . try_super_fold_with ( self ) , 
896851        } 
897852    } 
898853
899-     fn  regions ( 
854+     #[ instrument( level = "debug" ,  skip( self ) ,  ret) ]  
855+     fn  try_fold_region ( 
900856        & mut  self , 
901857        r :  ty:: Region < ' tcx > , 
902-         _r :  ty:: Region < ' tcx > , 
903-     )  -> RelateResult < ' tcx ,  ty:: Region < ' tcx > >  { 
904-         debug_assert_eq ! ( r,  _r) ; 
858+     )  -> Result < ty:: Region < ' tcx > ,  TypeError < ' tcx > >  { 
905859        debug ! ( "ConstInferUnifier: r={:?}" ,  r) ; 
906860
907861        match  * r { 
@@ -930,14 +884,8 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
930884        } 
931885    } 
932886
933-     #[ instrument( level = "debug" ,  skip( self ) ) ]  
934-     fn  consts ( 
935-         & mut  self , 
936-         c :  ty:: Const < ' tcx > , 
937-         _c :  ty:: Const < ' tcx > , 
938-     )  -> RelateResult < ' tcx ,  ty:: Const < ' tcx > >  { 
939-         debug_assert_eq ! ( c,  _c) ; 
940- 
887+     #[ instrument( level = "debug" ,  skip( self ) ,  ret) ]  
888+     fn  try_fold_const ( & mut  self ,  c :  ty:: Const < ' tcx > )  -> Result < ty:: Const < ' tcx > ,  TypeError < ' tcx > >  { 
941889        match  c. kind ( )  { 
942890            ty:: ConstKind :: Infer ( InferConst :: Var ( vid) )  => { 
943891                // Check if the current unification would end up 
@@ -958,7 +906,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
958906                let  var_value =
959907                    self . infcx . inner . borrow_mut ( ) . const_unification_table ( ) . probe_value ( vid) ; 
960908                match  var_value. val  { 
961-                     ConstVariableValue :: Known  {  value :  u }  => self . consts ( u ,  u ) , 
909+                     ConstVariableValue :: Known  {  value :  u }  => u . try_fold_with ( self ) , 
962910                    ConstVariableValue :: Unknown  {  universe }  => { 
963911                        if  self . for_universe . can_name ( universe)  { 
964912                            Ok ( c) 
@@ -977,17 +925,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
977925                    } 
978926                } 
979927            } 
980-             ty:: ConstKind :: Unevaluated ( ty:: UnevaluatedConst  {  def,  substs } )  => { 
981-                 let  substs = self . relate_with_variance ( 
982-                     ty:: Variance :: Invariant , 
983-                     ty:: VarianceDiagInfo :: default ( ) , 
984-                     substs, 
985-                     substs, 
986-                 ) ?; 
987- 
988-                 Ok ( self . tcx ( ) . mk_const ( ty:: UnevaluatedConst  {  def,  substs } ,  c. ty ( ) ) ) 
989-             } 
990-             _ => relate:: super_relate_consts ( self ,  c,  c) , 
928+             _ => c. try_super_fold_with ( self ) , 
991929        } 
992930    } 
993931} 
0 commit comments