11#include " testing.hpp"
22
33#include < ttl/cuda_tensor>
4+ #include < ttl/experimental/copy>
45#include < ttl/range>
56#include < ttl/tensor>
67
@@ -23,11 +24,10 @@ TEST(cuda_tensor_test, test0)
2324{
2425 using R = float ;
2526 cuda_tensor<R, 0 > m0;
26-
2727 tensor<R, 0 > x;
2828
29- m0. from_host (x. data ( ));
30- m0. to_host (x. data ( ));
29+ ttl::copy ( ttl::ref (m0), ttl::view (x ));
30+ ttl::copy ( ttl::ref (x), ttl::view (m0 ));
3131}
3232
3333TEST (cuda_tensor_test, test1)
@@ -42,8 +42,8 @@ TEST(cuda_tensor_test, test2)
4242 cuda_tensor<R, 2 > m1 (10 , 100 );
4343 tensor<R, 2 > m2 (10 , 100 );
4444
45- m1. from_host (m2. data ( ));
46- m1. to_host (m2. data ( ));
45+ ttl::copy ( ttl::ref (m1), ttl::view (m2));
46+ ttl::copy ( ttl::ref (m2), ttl::view (m1 ));
4747
4848 m1.slice (1 , 2 );
4949 auto r = ref (m1);
@@ -58,14 +58,16 @@ TEST(cuda_tensor_test, test_3)
5858 cuda_tensor<R, 2 > m1 (ttl::make_shape (10 , 100 ));
5959}
6060
61- template <typename R, uint8_t r> void test_auto_ref ()
61+ template <typename R, uint8_t r>
62+ void test_auto_ref ()
6263{
6364 static_assert (
6465 std::is_convertible<cuda_tensor<R, r>, cuda_tensor_ref<R, r>>::value,
6566 " can't convert to ref" );
6667}
6768
68- template <typename R, uint8_t r> void test_auto_view ()
69+ template <typename R, uint8_t r>
70+ void test_auto_view ()
6971{
7072 static_assert (
7173 std::is_convertible<cuda_tensor<R, r>, cuda_tensor_view<R, r>>::value,
@@ -87,28 +89,30 @@ TEST(cuda_tensor_test, test_convert)
8789 test_auto_view<int , 2 >();
8890}
8991
90- template <typename R, uint8_t r> void test_copy (const ttl::shape<r> &shape)
92+ template <typename R, uint8_t r>
93+ void test_copy (const ttl::shape<r> &shape)
9194{
9295 tensor<R, r> x (shape);
9396 cuda_tensor<R, r> y (shape);
9497 tensor<R, r> z (shape);
9598
9699 std::iota (x.data (), x.data_end (), 1 );
97- y.from_host (x.data ());
98- y.to_host (z.data ());
100+
101+ ttl::copy (ttl::ref (y), ttl::view (x));
102+ ttl::copy (ttl::ref (z), ttl::view (y));
99103
100104 for (auto i : ttl::range (shape.size ())) {
101105 ASSERT_EQ (x.data ()[i], z.data ()[i]);
102106 }
103107
104108 {
105109 cuda_tensor_ref<R, r> ry = ref (y);
106- ry. from_host (x. data ( ));
107- ry. to_host (x. data ( ));
110+ ttl::copy (ry, ttl::view (x ));
111+ ttl::copy ( ttl::ref (z), ttl::view (ry ));
108112 }
109113 {
110114 cuda_tensor_view<R, r> vy = view (y);
111- vy. to_host (x. data () );
115+ ttl::copy ( ttl::ref (x), vy );
112116 }
113117}
114118
0 commit comments