From ce72285c55f37e2209f01788bc5581b5ace3621a Mon Sep 17 00:00:00 2001 From: xhuan8 Date: Wed, 7 Dec 2022 23:26:05 +0800 Subject: [PATCH 1/4] #790 Add Ops\Boxes --- src/Native/LibTorchSharp/THSVision.cpp | 10 + src/Native/LibTorchSharp/THSVision.h | 4 +- src/Native/LibTorchSharp/Utils.h | 84 ++++- src/TorchSharp/Tensor/Tensor.Math.cs | 18 +- src/TorchSharp/Tensor/Tensor.cs | 2 +- src/TorchSharp/Tensor/Tensor.torch.cs | 8 + src/TorchSharp/Utils/DeconstructExtension.cs | 31 ++ src/TorchVision/Ops.cs | 2 +- src/TorchVision/Ops/BoxConvert.cs | 95 +++++ src/TorchVision/Ops/Boxes.cs | 362 +++++++++++++++++++ src/TorchVision/Ops/Utils.cs | 23 ++ test/TorchSharpTest/TestTorchVisionOps.cs | 6 +- 12 files changed, 624 insertions(+), 21 deletions(-) create mode 100644 src/TorchSharp/Utils/DeconstructExtension.cs create mode 100644 src/TorchVision/Ops/BoxConvert.cs create mode 100644 src/TorchVision/Ops/Boxes.cs create mode 100644 src/TorchVision/Ops/Utils.cs diff --git a/src/Native/LibTorchSharp/THSVision.cpp b/src/Native/LibTorchSharp/THSVision.cpp index 30e6b850f..e6e36c80a 100644 --- a/src/Native/LibTorchSharp/THSVision.cpp +++ b/src/Native/LibTorchSharp/THSVision.cpp @@ -358,4 +358,14 @@ void THSVision_RGB_BRGA(const uint8_t* inputBytes, uint8_t* outBytes, int64_t in } outBytes[outputAlpha + j] = inputHasAlpha ? inputBytes[inputBlue + i] : 255; } +} + +Tensor THSVision_nms(const Tensor dets, const Tensor scores, double iou_threshold) +{ + typedef at::Tensor (*TorchVisionFunc)(at::Tensor&, at::Tensor&, double); + auto nms = (TorchVisionFunc)LoadNativeSymbol("libtorchvision.dll", "?nms@ops@vision@@YA?AVTensor@at@@AEBV34@0N@Z"); + if (nms == NULL) + return NULL; + + CATCH_TENSOR(nms(*dets, *scores, iou_threshold)); } \ No newline at end of file diff --git a/src/Native/LibTorchSharp/THSVision.h b/src/Native/LibTorchSharp/THSVision.h index 02ad2a72e..00a73da47 100644 --- a/src/Native/LibTorchSharp/THSVision.h +++ b/src/Native/LibTorchSharp/THSVision.h @@ -23,4 +23,6 @@ EXPORT_API(void) THSVision_ComputeOutputSize(const float* matrix, const int64_t EXPORT_API(void) THSVision_BRGA_RGB(const uint8_t* inputBytes, uint8_t* redBytes, uint8_t* greenBytes, uint8_t* blueBytes, int64_t inputChannelCount, int64_t imageSize); EXPORT_API(void) THSVision_BRGA_RGBA(const uint8_t* inputBytes, uint8_t* redBytes, uint8_t* greenBytes, uint8_t* blueBytes, uint8_t* alphaBytes, int64_t inputChannelCount, int64_t imageSize); -EXPORT_API(void) THSVision_RGB_BRGA(const uint8_t* inputBytes, uint8_t* outBytes, int64_t inputChannelCount, int64_t imageSize); \ No newline at end of file +EXPORT_API(void) THSVision_RGB_BRGA(const uint8_t* inputBytes, uint8_t* outBytes, int64_t inputChannelCount, int64_t imageSize); + +EXPORT_API(Tensor) THSVision_nms(const Tensor dets, const Tensor scores, double iou_threshold); \ No newline at end of file diff --git a/src/Native/LibTorchSharp/Utils.h b/src/Native/LibTorchSharp/Utils.h index adefa5aca..d91635617 100644 --- a/src/Native/LibTorchSharp/Utils.h +++ b/src/Native/LibTorchSharp/Utils.h @@ -2,25 +2,38 @@ #pragma once #include +#include +#include #include "torch/torch.h" -extern thread_local char *torch_last_err; - -typedef torch::Tensor *Tensor; -typedef torch::Scalar *Scalar; +#if _WIN32 +#include +#include +#include +#define GetCurrentDir _getcwd +#else +#include +#include +#define GetCurrentDir getcwd +#endif + +extern thread_local char* torch_last_err; + +typedef torch::Tensor* Tensor; +typedef torch::Scalar* Scalar; typedef torch::Generator* Generator; typedef c10::Storage* Storage; typedef torch::nn::utils::rnn::PackedSequence* PackedSequence; -typedef std::shared_ptr * NNModule; -typedef std::shared_ptr * NNAnyModule; -typedef std::shared_ptr * Optimizer; -typedef std::shared_ptr * JITCompilationUnit; +typedef std::shared_ptr* NNModule; +typedef std::shared_ptr* NNAnyModule; +typedef std::shared_ptr* Optimizer; +typedef std::shared_ptr* JITCompilationUnit; typedef std::shared_ptr* JITModule; typedef std::shared_ptr* JITMethod; -typedef std::shared_ptr * JITFunction; -typedef std::shared_ptr * JITType; +typedef std::shared_ptr* JITFunction; +typedef std::shared_ptr* JITType; typedef std::shared_ptr* JITTensorType; //typedef std::shared_ptr* JITDimensionedTensorType; @@ -49,7 +62,7 @@ typedef std::shared_ptr* JITTensorType; #define CATCH_RETURN_Tensor(stmt) CATCH_RETURN_RES(Tensor, NULL, stmt) // Return undefined tensors as NULL to C# -inline Tensor ResultTensor(const at::Tensor & res) +inline Tensor ResultTensor(const at::Tensor& res) { if (res.defined()) return new torch::Tensor(res); @@ -82,11 +95,11 @@ inline Tensor ResultTensor(const at::Tensor & res) // Utility method used to built sharable strings. -const char * make_sharable_string(const std::string str); +const char* make_sharable_string(const std::string str); // Method concerting arrays of tensor pointers into arrays of tensors. template -std::vector toTensors(torch::Tensor ** tensorPtrs, const int length) +std::vector toTensors(torch::Tensor** tensorPtrs, const int length) { std::vector tensors; @@ -296,4 +309,49 @@ torch::nn::init::NonlinearityType get_nl_type(const int64_t nl) case 9: return torch::kReLU; case 10: return torch::kLeakyReLU; } +} + +inline +void* LoadNativeSymbol(const std::string libName, const std::string symbolName) +{ + void* lib = NULL; +#if _WIN32 +#ifdef UNICODE + auto fullName = libName; + std::wstring widestr = std::wstring(fullName.begin(), fullName.end()); + lib = LoadLibrary(widestr.c_str()); +#else + lib = LoadLibrary(libName.c_str()); +#endif // !UNICODE + +#else + lib = dlopen((libName + ".so").c_str(), RTLD_LAZY); +#endif + + if (lib == NULL) + { + char buff[FILENAME_MAX]; + GetCurrentDir(buff, FILENAME_MAX); + std::string current_working_dir(buff); + + torch_last_err = strdup(("Failed to load library: " + libName + " " + + std::to_string(GetLastError()) + " " + current_working_dir).c_str()); + return NULL; + } + + void* symbol = NULL; +#if _WIN32 + symbol = (void*)GetProcAddress((HMODULE)lib, symbolName.c_str()); +#else + symbol = dlsym(libHandle, symbolName.c_str()); +#endif + + if (symbol == NULL) + { + torch_last_err = strdup(("Cannot find symbol: " + symbolName + " " + + std::to_string(GetLastError())).c_str()); + return NULL; + } + + return symbol; } \ No newline at end of file diff --git a/src/TorchSharp/Tensor/Tensor.Math.cs b/src/TorchSharp/Tensor/Tensor.Math.cs index 47da602f9..d019207b6 100644 --- a/src/TorchSharp/Tensor/Tensor.Math.cs +++ b/src/TorchSharp/Tensor/Tensor.Math.cs @@ -2754,7 +2754,14 @@ public static Tensor einsum(string equation, params Tensor[] tensors) /// /// The first input tensor /// The second input tensor - public static Tensor maximum(Tensor input, Tensor other) => input.maximum(other); + public static Tensor max(Tensor input, Tensor other) => maximum(input, other); + + /// + /// Computes the element-wise maximum of input and other. + /// + /// The first input tensor + /// The second input tensor + static public Tensor maximum(Tensor input, Tensor other) => input.maximum(other); /// /// Returns a named tuple (values, indexes) where values is the maximum value of each row of the input tensor in the given dimension dim. @@ -2797,7 +2804,14 @@ public static Tensor einsum(string equation, params Tensor[] tensors) /// /// The first input tensor /// The second input tensor - public static Tensor minimum(Tensor input, Tensor other) => input.minimum(other); + public static Tensor min(Tensor input, Tensor other) => minimum(input, other); + + /// + /// Computes the element-wise minimum of input and other. + /// + /// The first input tensor + /// The second input tensor + static public Tensor minimum(Tensor input, Tensor other) => input.minimum(other); /// /// Returns a named tuple (values, indexes) where values is the minimum value of each row of the input tensor in the given dimension dim. diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index f6c8fec68..a1d69e4aa 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -36,7 +36,7 @@ public partial class Tensor : IDisposable internal DisposeScope? OwningDisposeScope { get; set; } - internal Tensor(IntPtr handle) + public Tensor(IntPtr handle) { this.handle = handle; System.Threading.Interlocked.Increment(ref _totalCount); diff --git a/src/TorchSharp/Tensor/Tensor.torch.cs b/src/TorchSharp/Tensor/Tensor.torch.cs index a473fad21..a0626b80e 100644 --- a/src/TorchSharp/Tensor/Tensor.torch.cs +++ b/src/TorchSharp/Tensor/Tensor.torch.cs @@ -678,5 +678,13 @@ public static Tensor _sample_dirichlet(Tensor input, torch.Generator generator = /// public static Tensor argsort(Tensor input, long dim = -1, bool descending = false) => input.argsort(dim, descending); + /// + /// Returns the unique elements of the input tensor. + /// + /// + public static Tensor unique(Tensor input) + { + return input.unique().output; + } } } diff --git a/src/TorchSharp/Utils/DeconstructExtension.cs b/src/TorchSharp/Utils/DeconstructExtension.cs new file mode 100644 index 000000000..240395091 --- /dev/null +++ b/src/TorchSharp/Utils/DeconstructExtension.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace TorchSharp.Utils +{ + /// + /// Converts IEnumerable to tuple. + /// + public static class DeconstructExtension + { + public static void Deconstruct(this IEnumerable seq, out T first, out IEnumerable rest) + { + first = seq.FirstOrDefault(); + rest = seq.Skip(1); + } + + public static void Deconstruct(this IEnumerable seq, out T first, out T second, out IEnumerable rest) + => (first, (second, rest)) = seq; + + public static void Deconstruct(this IEnumerable seq, out T first, out T second, out T third, out IEnumerable rest) + => (first, second, (third, rest)) = seq; + + public static void Deconstruct(this IEnumerable seq, out T first, out T second, out T third, out T fourth, out IEnumerable rest) + => (first, second, third, (fourth, rest)) = seq; + + public static void Deconstruct(this IEnumerable seq, out T first, out T second, out T third, out T fourth, out T fifth, out IEnumerable rest) + => (first, second, third, fourth, (fifth, rest)) = seq; + } +} diff --git a/src/TorchVision/Ops.cs b/src/TorchVision/Ops.cs index 987f022f7..d4498d26c 100644 --- a/src/TorchVision/Ops.cs +++ b/src/TorchVision/Ops.cs @@ -42,7 +42,7 @@ public static Tensor sigmoid_focal_loss(Tensor inputs, Tensor targets, float alp /// Scores (Tensor[N]) for each one of the boxes. /// Discards all overlapping boxes with IoU > iou_threshold. /// The indices (Tensor) of the elements that have been kept by NMS, sorted in decreasing order of scores. - public static Tensor nms(Tensor boxes, Tensor scores, double iou_threshold = 0.5) + public static Tensor nms_custom(Tensor boxes, Tensor scores, double iou_threshold = 0.5) { using (var _ = torch.NewDisposeScope()) { var x1 = boxes.select(1, 0); diff --git a/src/TorchVision/Ops/BoxConvert.cs b/src/TorchVision/Ops/BoxConvert.cs new file mode 100644 index 000000000..54ae8613a --- /dev/null +++ b/src/TorchVision/Ops/BoxConvert.cs @@ -0,0 +1,95 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +// A number of implementation details in this file have been translated from the Python version of torchvision, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/vision/blob/3d60f498e71ba63b428edb184c9ac38fa3737fa6/torchvision/ops/_box_convert.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/vision/blob/main/LICENSE +// + +using System; +using static TorchSharp.torch; +using TorchSharp.Utils; + +#nullable enable +namespace TorchSharp +{ + public static partial class torchvision + { + public static partial class ops + { + /// + /// Converts bounding boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format. + /// (cx, cy) refers to center of bounding box + /// (w, h) are width and height of bounding box + /// + /// boxes (Tensor[N, 4]): boxes in (cx, cy, w, h) format which will be converted. + /// boxes (Tensor(N, 4)): boxes in (x1, y1, x2, y2) format. + internal static Tensor _box_cxcywh_to_xyxy(Tensor boxes) + { + //# We need to change all 4 of them so some temporary variable is needed. + var (cx, cy, w, h, _) = boxes.unbind(-1); + var x1 = cx - 0.5 * w; + var y1 = cy - 0.5 * h; + var x2 = cx + 0.5 * w; + var y2 = cy + 0.5 * h; + + boxes = torch.stack(new Tensor[] { x1, y1, x2, y2 }, dim: -1); + return boxes; + } + + /// + /// Converts bounding boxes from (x1, y1, x2, y2) format to (cx, cy, w, h) format. + /// (x1, y1) refer to top left of bounding box + /// (x2, y2) refer to bottom right of bounding box + /// + /// boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format which will be converted. + /// boxes (Tensor(N, 4)): boxes in (cx, cy, w, h) format. + internal static Tensor _box_xyxy_to_cxcywh(Tensor boxes) + { + var (x1, y1, x2, y2, _) = boxes.unbind(-1); + var cx = (x1 + x2) / 2; + var cy = (y1 + y2) / 2; + var w = x2 - x1; + var h = y2 - y1; + + boxes = torch.stack(new Tensor[] { cx, cy, w, h }, dim: -1); + + return boxes; + } + + /// + /// Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format. + /// (x, y) refers to top left of bouding box. + /// (w, h) refers to width and height of box. + /// + /// boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted. + /// boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format. + internal static Tensor _box_xywh_to_xyxy(Tensor boxes) + { + var (x, y, w, h, _) = boxes.unbind(-1); + boxes = torch.stack(new Tensor[] { x, y, x + w, y + h }, dim: -1); + return boxes; + } + + /// + /// Converts bounding boxes from (x1, y1, x2, y2) format to (x, y, w, h) format. + /// (x1, y1) refer to top left of bounding box + /// (x2, y2) refer to bottom right of bounding box + /// + /// boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) which will be converted. + /// boxes (Tensor[N, 4]): boxes in (x, y, w, h) format. + internal static Tensor _box_xyxy_to_xywh(Tensor boxes) + { + var (x1, y1, x2, y2, _) = boxes.unbind(-1); + var w = x2 - x1; + var h = y2 - y1; + boxes = torch.stack(new Tensor[] { x1, y1, w, h }, dim: -1); + return boxes; + } + } + } +} diff --git a/src/TorchVision/Ops/Boxes.cs b/src/TorchVision/Ops/Boxes.cs new file mode 100644 index 000000000..bc543bb15 --- /dev/null +++ b/src/TorchVision/Ops/Boxes.cs @@ -0,0 +1,362 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +// A number of implementation details in this file have been translated from the Python version of torchvision, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/vision/blob/ea0be26b88778b1033d4a176be68bcdd008ff934/torchvision/ops/boxes.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/vision/blob/main/LICENSE +// + +using System; +using static TorchSharp.torch; +using TorchSharp.Utils; +using static TorchSharp.torchvision; +using System.Drawing; +using TorchSharp.Modules; +using System.Xml.Linq; +using static TorchSharp.torchvision.models; +using System.Reflection; +using System.Threading.Tasks; +using System.Runtime.InteropServices; +using System.Collections.Generic; + +#nullable enable +namespace TorchSharp +{ + public static partial class torchvision + { + public static partial class ops + { + [DllImport("LibTorchSharp")] + static extern IntPtr THSVision_nms(IntPtr dets, IntPtr scores, double iou_threshold); + + /// + /// Performs non-maximum suppression(NMS) on the boxes according + /// to their intersection-over-union(IoU). + /// NMS iteratively removes lower scoring boxes which have an + /// IoU greater than iou_threshold with another(higher scoring) + /// box. + /// If multiple boxes have the exact same score and satisfy the IoU + /// criterion with respect to a reference box, the selected box is + /// not guaranteed to be the same between CPU and GPU.This is similar + /// to the behavior of argsort in PyTorch when repeated values are present. + /// + /// boxes to perform NMS on. They are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1<x2`` and``0 <= y1<y2``. + /// scores (Tensor[N]): scores for each one of the boxes + /// discards all overlapping boxes with IoU > iou_threshold + /// int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores + public static Tensor nms(Tensor boxes, Tensor scores, double iou_threshold) + { + var res = THSVision_nms(boxes.Handle, scores.Handle, iou_threshold); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Performs non-maximum suppression in a batched fashion. + /// Each index value correspond to a category, and NMS + /// will not be applied between elements of different categories. + /// + /// (Tensor[N, 4]): boxes where NMS will be performed. They + /// are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and + /// ``0 <= y1 < y2``. + /// (Tensor[N]): scores for each one of the boxes + /// (Tensor[N]): indices of the categories for each one of the boxes. + /// discards all overlapping boxes with IoU > iou_threshold + /// Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted + /// in decreasing order of scores + public static Tensor batched_nms(Tensor boxes, Tensor scores, Tensor idxs, double iou_threshold) + { + // # Benchmarks that drove the following thresholds are at + // # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 + if (boxes.numel() > (boxes.device == torch.CPU ? 4000 : 20000)) + return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold); + else + return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold); + } + + private static Tensor _batched_nms_coordinate_trick(Tensor boxes, Tensor scores, Tensor idxs, double iou_threshold) + { + // # strategy: in order to perform NMS independently per class, + // # we add an offset to all the boxes. The offset is dependent + // # only on the class idx, and is large enough so that boxes + // # from different classes do not overlap + if (boxes.numel() == 0) + return torch.empty(new long[] { 0 }, dtype: torch.int64, device: boxes.device); + var max_coordinate = boxes.max(); + var offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)); + var boxes_for_nms = boxes + offsets[TensorIndex.Colon, TensorIndex.None]; + var keep = nms(boxes_for_nms, scores, iou_threshold); + return keep; + } + + private static Tensor _batched_nms_vanilla(Tensor boxes, Tensor scores, Tensor idxs, double iou_threshold) + { + // # Based on Detectron2 implementation, just manually call nms() on each class independently + var keep_mask = torch.zeros_like(scores, dtype: torch.@bool); + var unique = torch.unique(idxs); + for (int i = 0; i < unique.NumberOfElements; i++) { + var class_id = unique[i]; + var curr_indices = torch.where(idxs == class_id)[0]; + var curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold); + keep_mask[curr_indices[curr_keep_indices]] = true; + } + var keep_indices = torch.where(keep_mask)[0]; + return keep_indices[scores[keep_indices].sort(descending: true).Indices]; + } + + /// + /// Remove boxes which contains at least one side smaller than min_size. + /// + /// + /// + /// + public static Tensor remove_small_boxes(Tensor boxes, double min_size) + { + var ws = boxes[TensorIndex.Colon, 2] - boxes[TensorIndex.Colon, 0]; + var hs = boxes[TensorIndex.Colon, 3] - boxes[TensorIndex.Colon, 1]; + var keep = (ws >= min_size) & (hs >= min_size); + keep = torch.where(keep)[0]; + return keep; + } + + /// + /// Clip boxes so that they lie inside an image of size `size`. + /// + /// + /// + /// + public static Tensor clip_boxes_to_image(Tensor boxes, long[] size) + { + var dim = boxes.dim(); + var boxes_x = boxes[TensorIndex.Ellipsis, TensorIndex.Slice(start: 0, step: 2)]; + var boxes_y = boxes[TensorIndex.Ellipsis, TensorIndex.Slice(start: 1, step: 2)]; + var height = size[0]; + var width = size[1]; + + boxes_x = boxes_x.clamp(min: 0, max: width); + boxes_y = boxes_y.clamp(min: 0, max: height); + + var clipped_boxes = torch.stack(new Tensor[] { boxes_x, boxes_y }, dim: dim); + return clipped_boxes.reshape(boxes.shape); + } + + /// + /// Converts boxes from given in_fmt to out_fmt. + /// Supported in_fmt and out_fmt are: + /// 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. + /// This is the format that torchvision utilities expect. + /// 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. + /// 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h + /// being width and height. + /// + /// + /// + /// + /// + public static Tensor box_convert(Tensor boxes, string in_fmt, string out_fmt) + { + var allowed_fmts = new List { "xyxy", "xywh", "cxcywh" }; + if (!allowed_fmts.Contains(in_fmt) || !allowed_fmts.Contains(out_fmt)) + throw new ArgumentException("Unsupported Bounding Box Conversions for given in_fmt and out_fmt"); + + if (in_fmt == out_fmt) + return boxes.clone(); + + if (in_fmt != "xyxy" && out_fmt != "xyxy") { + //# convert to xyxy and change in_fmt xyxy + if (in_fmt == "xywh") + boxes = _box_xywh_to_xyxy(boxes); + else if (in_fmt == "cxcywh") + boxes = _box_cxcywh_to_xyxy(boxes); + in_fmt = "xyxy"; + } + + if (in_fmt == "xyxy") { + if (out_fmt == "xywh") + boxes = _box_xyxy_to_xywh(boxes); + else if (out_fmt == "cxcywh") + boxes = _box_xyxy_to_cxcywh(boxes); + } else if (out_fmt == "xyxy") { + if (in_fmt == "xywh") + boxes = _box_xywh_to_xyxy(boxes); + else if (in_fmt == "cxcywh") + boxes = _box_cxcywh_to_xyxy(boxes); + } + return boxes; + } + + /// + /// Computes the area of a set of bounding boxes, which are specified by their + /// (x1, y1, x2, y2) coordinates. + /// + /// + /// + public static Tensor box_area(Tensor boxes) + { + boxes = _upcast(boxes); + return (boxes[TensorIndex.Colon, 2] - boxes[TensorIndex.Colon, 0]) * (boxes[TensorIndex.Colon, 3] - boxes[TensorIndex.Colon, 1]); + } + + //# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py + //# with slight modifications + private static (Tensor, Tensor) _box_inter_union(Tensor boxes1, Tensor boxes2) + { + var area1 = box_area(boxes1); + var area2 = box_area(boxes2); + + var lt = torch.max(boxes1[TensorIndex.Colon, TensorIndex.None, TensorIndex.Slice(stop: 2)], boxes2[TensorIndex.Colon, TensorIndex.Slice(stop: 2)]); // [N,M,2]; + var rb = torch.min(boxes1[TensorIndex.Colon, TensorIndex.None, TensorIndex.Slice(start: 2)], boxes2[TensorIndex.Colon, TensorIndex.Slice(start: 2)]); // [N,M,2]; + + var wh = _upcast(rb - lt).clamp(min: 0); // [N,M,2]; + var inter = wh[TensorIndex.Colon, TensorIndex.Colon, 0] * wh[TensorIndex.Colon, TensorIndex.Colon, 1]; // [N,M]; + + var union = area1[TensorIndex.Colon, TensorIndex.None] + area2 - inter; + + return (inter, union); + } + + /// + /// Return intersection-over-union (Jaccard index) between two sets of boxes. + /// Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + /// ``0 <= x1<x2`` and ``0 <= y1<y2``. + /// + /// (Tensor[N, 4]): first set of boxes + /// (Tensor[M, 4]): second set of boxes + /// Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + public static Tensor box_iou(Tensor boxes1, Tensor boxes2) + { + var (inter, union) = _box_inter_union(boxes1, boxes2); + var iou = inter / union; + return iou; + } + + /// + /// Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py + /// Return generalized intersection-over-union (Jaccard index) between two sets of boxes. + /// Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + /// ``0 <= x1<x2`` and ``0 <= y1<y2``. + /// + /// (Tensor[N, 4]): first set of boxes + /// (Tensor[M, 4]): second set of boxes + /// Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values + /// for every element in boxes1 and boxes2 + public static Tensor generalized_box_iou(Tensor boxes1, Tensor boxes2) + { + var (inter, union) = _box_inter_union(boxes1, boxes2); + var iou = inter / union; + + var lti = torch.min(boxes1[TensorIndex.Colon, TensorIndex.None, TensorIndex.Slice(stop: 2)], boxes2[TensorIndex.Colon, TensorIndex.Slice(stop: 2)]); + var rbi = torch.max(boxes1[TensorIndex.Colon, TensorIndex.None, TensorIndex.Slice(start: 2)], boxes2[TensorIndex.Colon, TensorIndex.Slice(start: 2)]); + + var whi = _upcast(rbi - lti).clamp(min: 0);// # [N,M,2] + var areai = whi[TensorIndex.Colon, TensorIndex.Colon, 0] * whi[TensorIndex.Colon, TensorIndex.Colon, 1]; + + return iou - (areai - union) / areai; + } + + /// + /// Return complete intersection-over-union (Jaccard index) between two sets of boxes. + /// Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + /// ``0 <= x1<x2`` and ``0 <= y1<y2``. + /// + /// (Tensor[N, 4]): first set of boxes + /// (Tensor[M, 4]): second set of boxes + /// (float, optional): small number to prevent division by zero. Default: 1e-7 + /// Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values + /// for every element in boxes1 and boxes2 + public static Tensor complete_box_iou(Tensor boxes1, Tensor boxes2, float eps = 1e-7f) + { + boxes1 = _upcast(boxes1); + boxes2 = _upcast(boxes2); + + var (diou, iou) = _box_diou_iou(boxes1, boxes2, eps); + + var w_pred = boxes1[TensorIndex.Colon, TensorIndex.None, 2] - boxes1[TensorIndex.Colon, TensorIndex.None, 0]; + var h_pred = boxes1[TensorIndex.Colon, TensorIndex.None, 3] - boxes1[TensorIndex.Colon, TensorIndex.None, 1]; + + var w_gt = boxes2[TensorIndex.Colon, 2] - boxes2[TensorIndex.Colon, 0]; + var h_gt = boxes2[TensorIndex.Colon, 3] - boxes2[TensorIndex.Colon, 1]; + + var v = (4 / (torch.pow(Math.PI, 2))) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2); + Tensor alpha; + using (torch.no_grad()) + alpha = v / (1 - iou + v + eps); + return diou - alpha * v; + } + + + /// + /// Return distance intersection-over-union (Jaccard index) between two sets of boxes. + /// Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + /// ``0 <= x1<x2`` and ``0 <= y1<y2``. + /// + /// (Tensor[N, 4]): first set of boxes + /// (Tensor[M, 4]): second set of boxes + /// (float, optional): small number to prevent division by zero. Default: 1e-7 + /// Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values + /// for every element in boxes1 and boxes2 + public static Tensor distance_box_iou(Tensor boxes1, Tensor boxes2, float eps = 1e-7f) + { + boxes1 = _upcast(boxes1); + boxes2 = _upcast(boxes2); + var (diou, _) = _box_diou_iou(boxes1, boxes2, eps: eps); + return diou; + } + + private static (Tensor, Tensor) _box_diou_iou(Tensor boxes1, Tensor boxes2, float eps = 1e-7f) + { + var iou = box_iou(boxes1, boxes2); + var lti = torch.min(boxes1[TensorIndex.Colon, TensorIndex.None, TensorIndex.Slice(stop: 2)], boxes2[TensorIndex.Colon, TensorIndex.Slice(stop: 2)]); + var rbi = torch.max(boxes1[TensorIndex.Colon, TensorIndex.None, TensorIndex.Slice(start: 2)], boxes2[TensorIndex.Colon, TensorIndex.Slice(start: 2)]); + var whi = _upcast(rbi - lti).clamp(min: 0);// # [N,M,2] + var diagonal_distance_squared = (torch.pow(whi[TensorIndex.Colon, TensorIndex.Colon, 0], 2)) + (torch.pow(whi[TensorIndex.Colon, TensorIndex.Colon, 1], 2)) + eps; + //# centers of boxes + var x_p = (boxes1[TensorIndex.Colon, 0] + boxes1[TensorIndex.Colon, 2]) / 2; + var y_p = (boxes1[TensorIndex.Colon, 1] + boxes1[TensorIndex.Colon, 3]) / 2; + var x_g = (boxes2[TensorIndex.Colon, 0] + boxes2[TensorIndex.Colon, 2]) / 2; + var y_g = (boxes2[TensorIndex.Colon, 1] + boxes2[TensorIndex.Colon, 3]) / 2; + //# The distance between boxes' centers squared. + var centers_distance_squared = (torch.pow(_upcast((x_p[TensorIndex.Colon, TensorIndex.None] - x_g[TensorIndex.None, TensorIndex.Colon])), 2)) + ( + torch.pow(_upcast((y_p[TensorIndex.Colon, TensorIndex.None] - y_g[TensorIndex.None, TensorIndex.Colon])), 2)); + //# The distance IoU is the IoU penalized by a normalized + //# distance between boxes' centers squared. + return (iou - (centers_distance_squared / diagonal_distance_squared), iou); + } + + /// + /// Compute the bounding boxes around the provided masks. + /// Returns a[N, 4] tensor containing bounding boxes.The boxes are in ``(x1, y1, x2, y2)`` format with + /// ``0 <= x1<x2`` and ``0 <= y1<y2``. + /// + /// (Tensor[N, H, W]): masks to transform where N is the number of masks + /// and(H, W) are the spatial dimensions. + /// Tensor[N, 4]: bounding boxes + public static Tensor masks_to_boxes(Tensor masks) + { + if (masks.numel() == 0) + return torch.zeros(new long[] { 0, 4 }, device: masks.device, dtype: torch.@float); + + var n = masks.shape[0]; + + var bounding_boxes = torch.zeros(new long[] { n, 4 }, device: masks.device, dtype: torch.@float); + + for (int index = 0; index < masks.shape[0]; index++) { + var mask = masks[index]; + + var (y, x, _) = torch.where(mask != 0); + + bounding_boxes[index, 0] = torch.min(x); + bounding_boxes[index, 1] = torch.min(y); + bounding_boxes[index, 2] = torch.max(x); + bounding_boxes[index, 3] = torch.max(y); + } + + return bounding_boxes; + } + } + } +} diff --git a/src/TorchVision/Ops/Utils.cs b/src/TorchVision/Ops/Utils.cs new file mode 100644 index 000000000..79cfc8071 --- /dev/null +++ b/src/TorchVision/Ops/Utils.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using TorchSharp; +using static TorchSharp.torch; + +#nullable enable +namespace TorchSharp +{ + public static partial class torchvision + { + public static partial class ops + { + public static Tensor _upcast(Tensor t) + { + if (t.is_floating_point()) + return t.dtype == torch.float32 || t.dtype == torch.float64 ? t : t.@float(); + else + return t.dtype == torch.int32 || t.dtype == torch.int64 ? t : t.@int(); + } + } + } +} diff --git a/test/TorchSharpTest/TestTorchVisionOps.cs b/test/TorchSharpTest/TestTorchVisionOps.cs index 2cd774f53..9172e86e1 100644 --- a/test/TorchSharpTest/TestTorchVisionOps.cs +++ b/test/TorchSharpTest/TestTorchVisionOps.cs @@ -20,7 +20,7 @@ public void NMS_OneBox() }); var scores = torch.from_array(new[] { 0.9 }); - var nms_boxes = nms(boxes, scores); + var nms_boxes = nms_custom(boxes, scores); Assert.Multiple( () => Assert.Single(nms_boxes.shape), () => Assert.Equal(1, nms_boxes.shape[0]) @@ -41,7 +41,7 @@ public void NMS_MultipleBoxes() torch.Tensor nms_boxes = null; // Less than iou threshold. - nms_boxes = nms(boxes, scores, 0.6); + nms_boxes = nms_custom(boxes, scores, 0.6); Assert.Multiple( () => Assert.Single(nms_boxes.shape), () => Assert.Equal(2, nms_boxes.shape[0]), @@ -50,7 +50,7 @@ public void NMS_MultipleBoxes() ); // Larger than iou threshold. - nms_boxes = nms(boxes, scores, 0.3); + nms_boxes = nms_custom(boxes, scores, 0.3); Assert.Multiple( () => Assert.Single(nms_boxes.shape), () => Assert.Equal(1, nms_boxes.shape[0]), From 8b9e8c7303c6a10124648cc785e64d376db9161f Mon Sep 17 00:00:00 2001 From: xhuan8 Date: Thu, 8 Dec 2022 22:24:45 +0800 Subject: [PATCH 2/4] #867 move windows features inside if block --- src/Native/LibTorchSharp/Utils.h | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/Native/LibTorchSharp/Utils.h b/src/Native/LibTorchSharp/Utils.h index d91635617..6e24a6f74 100644 --- a/src/Native/LibTorchSharp/Utils.h +++ b/src/Native/LibTorchSharp/Utils.h @@ -323,11 +323,6 @@ void* LoadNativeSymbol(const std::string libName, const std::string symbolName) #else lib = LoadLibrary(libName.c_str()); #endif // !UNICODE - -#else - lib = dlopen((libName + ".so").c_str(), RTLD_LAZY); -#endif - if (lib == NULL) { char buff[FILENAME_MAX]; @@ -337,21 +332,23 @@ void* LoadNativeSymbol(const std::string libName, const std::string symbolName) torch_last_err = strdup(("Failed to load library: " + libName + " " + std::to_string(GetLastError()) + " " + current_working_dir).c_str()); return NULL; - } +} +#else + lib = dlopen((libName + ".so").c_str(), RTLD_LAZY); +#endif void* symbol = NULL; #if _WIN32 symbol = (void*)GetProcAddress((HMODULE)lib, symbolName.c_str()); -#else - symbol = dlsym(libHandle, symbolName.c_str()); -#endif - if (symbol == NULL) { torch_last_err = strdup(("Cannot find symbol: " + symbolName + " " + std::to_string(GetLastError())).c_str()); return NULL; } +#else + symbol = dlsym(lib, symbolName.c_str()); +#endif return symbol; } \ No newline at end of file From 6e0e8a9308f7fe2c5cee48394b4013e8b5d3ccaa Mon Sep 17 00:00:00 2001 From: xhuan8 Date: Mon, 12 Dec 2022 11:12:01 +0800 Subject: [PATCH 3/4] #867 move pinvoke to separate file --- src/TorchVision/LibTorchSharp.cs | 5 ++++- src/TorchVision/Ops/Boxes.cs | 5 +---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TorchVision/LibTorchSharp.cs b/src/TorchVision/LibTorchSharp.cs index d9c68db72..7ea2ac8f4 100644 --- a/src/TorchVision/LibTorchSharp.cs +++ b/src/TorchVision/LibTorchSharp.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. using System; using System.Runtime.InteropServices; @@ -41,5 +41,8 @@ internal static extern IntPtr THSVision_PerspectiveGrid(IntPtr coeffs, long coef [DllImport("LibTorchSharp")] internal static extern void THSVision_RGB_BRGA(IntPtr inputBytes, IntPtr outBytes, long inputChannelCount, long imageSize); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSVision_nms(IntPtr dets, IntPtr scores, double iou_threshold); } } \ No newline at end of file diff --git a/src/TorchVision/Ops/Boxes.cs b/src/TorchVision/Ops/Boxes.cs index bc543bb15..4b8963715 100644 --- a/src/TorchVision/Ops/Boxes.cs +++ b/src/TorchVision/Ops/Boxes.cs @@ -30,9 +30,6 @@ public static partial class torchvision { public static partial class ops { - [DllImport("LibTorchSharp")] - static extern IntPtr THSVision_nms(IntPtr dets, IntPtr scores, double iou_threshold); - /// /// Performs non-maximum suppression(NMS) on the boxes according /// to their intersection-over-union(IoU). @@ -50,7 +47,7 @@ public static partial class ops /// int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores public static Tensor nms(Tensor boxes, Tensor scores, double iou_threshold) { - var res = THSVision_nms(boxes.Handle, scores.Handle, iou_threshold); + var res = LibTorchSharp.THSVision_nms(boxes.Handle, scores.Handle, iou_threshold); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } From a9299c1721bc52cef6a98ad774272a099841858f Mon Sep 17 00:00:00 2001 From: xhuan8 Date: Mon, 12 Dec 2022 12:20:31 +0800 Subject: [PATCH 4/4] #867 add comments to public methods --- src/TorchSharp/Utils/DeconstructExtension.cs | 22 ++++++++++++++++++++ src/TorchVision/Ops/Utils.cs | 3 +++ 2 files changed, 25 insertions(+) diff --git a/src/TorchSharp/Utils/DeconstructExtension.cs b/src/TorchSharp/Utils/DeconstructExtension.cs index 240395091..c721c97e6 100644 --- a/src/TorchSharp/Utils/DeconstructExtension.cs +++ b/src/TorchSharp/Utils/DeconstructExtension.cs @@ -7,24 +7,46 @@ namespace TorchSharp.Utils { /// /// Converts IEnumerable to tuple. + /// Example: + /// int[] rect = new int[4]; + /// var (left, top, width, height, _) = rect; /// public static class DeconstructExtension { + /// + /// Deconstructs a sequence to the first element and rest of elements. + /// + /// + /// + /// + /// public static void Deconstruct(this IEnumerable seq, out T first, out IEnumerable rest) { first = seq.FirstOrDefault(); rest = seq.Skip(1); } + /// + /// Deconstrcts one element out of sequence. + /// public static void Deconstruct(this IEnumerable seq, out T first, out T second, out IEnumerable rest) => (first, (second, rest)) = seq; + /// + /// Deconstrcts two elements out of sequence. + /// public static void Deconstruct(this IEnumerable seq, out T first, out T second, out T third, out IEnumerable rest) => (first, second, (third, rest)) = seq; + /// + /// Deconstrcts three elements out of sequence. + /// public static void Deconstruct(this IEnumerable seq, out T first, out T second, out T third, out T fourth, out IEnumerable rest) => (first, second, third, (fourth, rest)) = seq; + /// + /// Deconstrcts four elements out of sequence. + /// public static void Deconstruct(this IEnumerable seq, out T first, out T second, out T third, out T fourth, out T fifth, out IEnumerable rest) => (first, second, third, fourth, (fifth, rest)) = seq; } diff --git a/src/TorchVision/Ops/Utils.cs b/src/TorchVision/Ops/Utils.cs index 79cfc8071..c85e7b719 100644 --- a/src/TorchVision/Ops/Utils.cs +++ b/src/TorchVision/Ops/Utils.cs @@ -11,6 +11,9 @@ public static partial class torchvision { public static partial class ops { + /// + /// Protects from numerical overflows in multiplications by upcasting to the equivalent higher type. + /// public static Tensor _upcast(Tensor t) { if (t.is_floating_point())