@@ -466,29 +466,54 @@ def _clamp_y_intercept(
466466 then applies various constraints to ensure the clamping conditions are respected.
467467 """
468468
469+ # Calculate slopes and y-intercepts for bounding boxes
469470 a , b = _get_slope_and_intercept (bounding_boxes )
470471 a1 , a2 , a3 , a4 = a .unbind (- 1 )
471472 b1 , b2 , b3 , b4 = b .unbind (- 1 )
472473
473- # Clamp y-intercepts (soft clamping)
474+ # Get y-intercepts from original bounding boxes
475+ _ , bm = _get_slope_and_intercept (original_bounding_boxes )
476+ b1m , b2m , b3m , b4m = bm .unbind (- 1 )
477+
478+ # Soft clamping: Clamp y-intercepts within canvas boundaries
474479 b1 = b2 .clamp (b1 , b3 ).clamp (0 , canvas_size [0 ])
475480 b4 = b3 .clamp (b2 , b4 ).clamp (0 , canvas_size [0 ])
476481
477482 if clamping_mode == "hard" :
478- # Get y-intercepts from original bounding boxes
479- _ , b = _get_slope_and_intercept (original_bounding_boxes )
480- _ , b2 , b3 , _ = b .unbind (- 1 )
481-
482- # Set b1 and b4 to the average of their clamped values
483- b1 = b4 = (b1 .clamp (0 , canvas_size [0 ]) + b4 .clamp (0 , canvas_size [0 ])) / 2
484-
485- # Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4
486- b2 .clamp_ (b1 * a2 / a1 , b4 ).clamp_ ((a1 - a2 ) * canvas_size [1 ] + b1 )
487- b2 .clamp_ (b3 * a2 / a3 , b4 ).clamp_ ((a3 - a2 ) * canvas_size [1 ] + b3 )
488- b3 .clamp_ (max = canvas_size [0 ] * (1 - a3 / a4 ) + b4 * a3 / a4 )
489- b3 .clamp_ (max = canvas_size [0 ] * (1 - a3 / a2 ) + b2 * a3 / a2 )
490- b3 .clamp_ (b1 , (a2 - a3 ) * canvas_size [1 ] + b2 )
491- b3 .clamp_ (b1 , (a4 - a3 ) * canvas_size [1 ] + b4 )
483+ # Hard clamping: Average b1 and b4, and adjust b2 and b3 for maximum area
484+ b1 = b4 = (b1 + b4 ) / 2
485+
486+ # Calculate candidate values for b2 based on geometric constraints
487+ b2_candidates = torch .stack (
488+ [
489+ b1 * a2 / a1 , # Constraint at y=0
490+ b3 * a2 / a3 , # Constraint at y=0
491+ (a1 - a2 ) * canvas_size [1 ] + b1 , # Constraint at x=canvas_width
492+ (a3 - a2 ) * canvas_size [1 ] + b3 , # Constraint at x=canvas_width
493+ ],
494+ dim = 1 ,
495+ )
496+ # Take maximum value that doesn't exceed original b2
497+ b2 = torch .max (b2_candidates , dim = 1 )[0 ].clamp (max = b2 )
498+
499+ # Calculate candidate values for b3 based on geometric constraints
500+ b3_candidates = torch .stack (
501+ [
502+ canvas_size [0 ] * (1 - a3 / a4 ) + b4 * a3 / a4 , # Constraint at y=canvas_height
503+ canvas_size [0 ] * (1 - a3 / a2 ) + b2 * a3 / a2 , # Constraint at y=canvas_height
504+ (a2 - a3 ) * canvas_size [1 ] + b2 , # Constraint at x=canvas_width
505+ (a4 - a3 ) * canvas_size [1 ] + b4 , # Constraint at x=canvas_width
506+ ],
507+ dim = 1 ,
508+ )
509+ # Take minimum value that doesn't go below original b3
510+ b3 = torch .min (b3_candidates , dim = 1 )[0 ].clamp (min = b3 )
511+
512+ # Final clamping to ensure y-intercepts are within original box bounds
513+ b1 .clamp_ (b1m , b3m )
514+ b3 .clamp_ (b1m , b3m )
515+ b2 .clamp_ (b2m , b4m )
516+ b4 .clamp_ (b2m , b4m )
492517
493518 return torch .stack ([b1 , b2 , b3 , b4 ], dim = - 1 )
494519
0 commit comments