|
12 | 12 | import torch |
13 | 13 |
|
14 | 14 |
|
15 | | -def _get_rotation_to_best_fit_xy( |
16 | | - points: torch.Tensor, centroid: torch.Tensor |
| 15 | +def get_rotation_to_best_fit_xy( |
| 16 | + points: torch.Tensor, centroid: Optional[torch.Tensor] = None |
17 | 17 | ) -> torch.Tensor: |
18 | 18 | """ |
19 | | - Returns a rotation r such that points @ r has a best fit plane |
| 19 | + Returns a rotation R such that `points @ R` has a best fit plane |
20 | 20 | parallel to the xy plane |
21 | 21 |
|
22 | 22 | Args: |
23 | | - points: (N, 3) tensor of points in 3D |
24 | | - centroid: (3,) their centroid |
| 23 | + points: (*, N, 3) tensor of points in 3D |
| 24 | + centroid: (*, 1, 3), (3,) or scalar: their centroid |
25 | 25 |
|
26 | 26 | Returns: |
27 | | - (3,3) tensor rotation matrix |
| 27 | + (*, 3, 3) tensor rotation matrix |
28 | 28 | """ |
29 | | - points_centered = points - centroid[None] |
30 | | - return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]] |
| 29 | + if centroid is None: |
| 30 | + centroid = points.mean(dim=-2, keepdim=True) |
| 31 | + |
| 32 | + points_centered = points - centroid |
| 33 | + _, evec = torch.linalg.eigh(points_centered.transpose(-1, -2) @ points_centered) |
| 34 | + # in general, evec can form either right- or left-handed basis, |
| 35 | + # but we need the former to have a proper rotation (not reflection) |
| 36 | + return torch.cat( |
| 37 | + (evec[..., 1:], torch.cross(evec[..., 1], evec[..., 2])[..., None]), dim=-1 |
| 38 | + ) |
31 | 39 |
|
32 | 40 |
|
33 | 41 | def _signed_area(path: torch.Tensor) -> torch.Tensor: |
@@ -191,7 +199,7 @@ def fit_circle_in_3d( |
191 | 199 | Circle3D object |
192 | 200 | """ |
193 | 201 | centroid = points.mean(0) |
194 | | - r = _get_rotation_to_best_fit_xy(points, centroid) |
| 202 | + r = get_rotation_to_best_fit_xy(points, centroid) |
195 | 203 | normal = r[:, 2] |
196 | 204 | rotated_points = (points - centroid) @ r |
197 | 205 | result_2d = fit_circle_in_2d( |
|
0 commit comments