Skip to content

Commit a672b1a

Browse files
committed
refactor so UnytScalar defers to a dtype for the unit
1 parent 48152a5 commit a672b1a

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

unytdtype/tests/test_unytdtype.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ def test_dtype_creation():
1313
assert dtype == dtype2
1414

1515

16+
def test_scalar_creation():
17+
dtype = UnytDType("m")
18+
unit = unyt.Unit("m")
19+
unit_s = "m"
20+
21+
s_1 = UnytScalar(1, dtype)
22+
s_2 = UnytScalar(1, unit)
23+
s_3 = UnytScalar(1, unit_s)
24+
25+
assert s_1 == s_2 == s_3
26+
27+
1628
def test_creation_from_zeros():
1729
dtype = UnytDType("m")
1830
arr = np.zeros(3, dtype=dtype)

unytdtype/unytdtype/scalar.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,24 @@
55

66
class UnytScalar:
77
def __init__(self, value, unit):
8+
from . import UnytDType
89
self.value = value
9-
if isinstance(unit, str):
10-
self.unit = Unit(unit)
11-
elif isinstance(unit, Unit):
12-
self.unit = unit
10+
if isinstance(unit, (str, Unit)):
11+
self.dtype = UnytDType(unit)
12+
elif isinstance(unit, UnytDType):
13+
self.dtype = unit
1314
else:
1415
raise RuntimeError
1516

17+
@property
18+
def unit(self):
19+
return self.dtype.unit
20+
1621
def __repr__(self):
17-
return f"{self.value} {self.unit}"
22+
return f"{self.value} {self.dtype.unit}"
1823

1924
def __rmul__(self, other):
20-
return UnytScalar(self.value * other, self.unit)
25+
return UnytScalar(self.value * other, self.dtype.unit)
26+
27+
def __eq__(self, other):
28+
return self.value == other.value and self.dtype == other.dtype

unytdtype/unytdtype/src/dtype.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ get_unit(PyObject *scalar, UnytDTypeObject *descr)
4343
return descr->unit;
4444
}
4545

46-
PyObject *unit = PyObject_GetAttrString(scalar, "unit");
46+
PyObject *dtype = PyObject_GetAttrString(scalar, "dtype");
47+
if (dtype == NULL) {
48+
return NULL;
49+
}
50+
PyObject *unit = PyObject_GetAttrString(dtype, "unit");
51+
Py_DECREF(dtype);
4752
if (unit == NULL) {
4853
return NULL;
4954
}

0 commit comments

Comments
 (0)