diff --git a/python/pyarrow/compat.py b/python/pyarrow/compat.py index df5e4faadd4..f9c148b14e3 100644 --- a/python/pyarrow/compat.py +++ b/python/pyarrow/compat.py @@ -80,6 +80,7 @@ class Categorical(ClassPlaceholder): unicode_type = unicode lzip = zip zip = itertools.izip + zip_longest = itertools.izip_longest def dict_values(x): return x.values() @@ -108,6 +109,7 @@ def lzip(*x): return list(zip(*x)) long = int zip = zip + zip_longest = itertools.zip_longest def dict_values(x): return list(x.values()) from decimal import Decimal diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index be48aeb442d..141b33f119c 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -15,15 +15,18 @@ # specific language governing permissions and limitations # under the License. -import re +import ast +import collections import json +import re + import numpy as np import pandas as pd import six import pyarrow as pa -from pyarrow.compat import PY2 # noqa +from pyarrow.compat import PY2, zip_longest # noqa INDEX_LEVEL_NAME_REGEX = re.compile(r'^__index_level_\d+__$') @@ -89,6 +92,52 @@ def get_logical_type(arrow_type): raise NotImplementedError(str(arrow_type)) +_numpy_logical_type_map = { + np.bool_: 'bool', + np.int8: 'int8', + np.int16: 'int16', + np.int32: 'int32', + np.int64: 'int64', + np.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + np.float32: 'float32', + np.float64: 'float64', + 'datetime64[D]': 'date', + np.str_: 'unicode', + np.bytes_: 'bytes', +} + + +def get_logical_type_from_numpy(pandas_collection): + try: + return _numpy_logical_type_map[pandas_collection.dtype.type] + except KeyError: + if hasattr(pandas_collection.dtype, 'tz'): + return 'datetimetz' + return infer_dtype(pandas_collection) + + +def get_extension_dtype_info(column): + dtype = column.dtype + if str(dtype) == 'category': + cats = getattr(column, 'cat', column) + assert cats is not None + metadata = { + 'num_categories': len(cats.categories), + 'ordered': cats.ordered, + } + physical_dtype = 'object' + elif hasattr(dtype, 'tz'): + metadata = {'timezone': str(dtype.tz)} + physical_dtype = 'datetime64[ns]' + else: + metadata = None + physical_dtype = str(dtype) + return physical_dtype, metadata + + def get_column_metadata(column, name, arrow_type): """Construct the metadata for a given column @@ -102,25 +151,15 @@ def get_column_metadata(column, name, arrow_type): ------- dict """ - dtype = column.dtype logical_type = get_logical_type(arrow_type) - if hasattr(dtype, 'categories'): - assert logical_type == 'categorical' - extra_metadata = { - 'num_categories': len(column.cat.categories), - 'ordered': column.cat.ordered, - } - elif hasattr(dtype, 'tz'): - assert logical_type == 'datetimetz' - extra_metadata = {'timezone': str(dtype.tz)} - elif logical_type == 'decimal': + string_dtype, extra_metadata = get_extension_dtype_info(column) + if logical_type == 'decimal': extra_metadata = { 'precision': arrow_type.precision, 'scale': arrow_type.scale, } - else: - extra_metadata = None + string_dtype = 'object' if not isinstance(name, six.string_types): raise TypeError( @@ -132,7 +171,7 @@ def get_column_metadata(column, name, arrow_type): return { 'name': name, 'pandas_type': logical_type, - 'numpy_type': str(dtype), + 'numpy_type': string_dtype, 'metadata': extra_metadata, } @@ -188,21 +227,69 @@ def construct_metadata(df, column_names, index_levels, preserve_index, types): index_column_metadata = [ get_column_metadata(level, name=index_level_name(level, i), arrow_type=arrow_type) - for i, (level, arrow_type) in enumerate(zip(index_levels, - index_types)) + for i, (level, arrow_type) in enumerate( + zip(index_levels, index_types) + ) ] + + column_indexes = [] + + for level in getattr(df.columns, 'levels', [df.columns]): + string_dtype, extra_metadata = get_extension_dtype_info(level) + column_index = { + 'name': level.name, + 'pandas_type': get_logical_type_from_numpy(level), + 'numpy_type': string_dtype, + 'metadata': extra_metadata, + } + column_indexes.append(column_index) else: - index_column_names = index_column_metadata = [] + index_column_names = index_column_metadata = column_indexes = [] return { b'pandas': json.dumps({ 'index_columns': index_column_names, + 'column_indexes': column_indexes, 'columns': column_metadata + index_column_metadata, 'pandas_version': pd.__version__ }).encode('utf8') } +def _column_name_to_strings(name): + """Convert a column name (or level) to either a string or a recursive + collection of strings. + + Parameters + ---------- + name : str or tuple + + Returns + ------- + value : str or tuple + + Examples + -------- + >>> name = 'foo' + >>> _column_name_to_strings(name) + 'foo' + >>> name = ('foo', 'bar') + >>> _column_name_to_strings(name) + ('foo', 'bar') + >>> import pandas as pd + >>> name = (1, pd.Timestamp('2017-02-01 00:00:00')) + >>> _column_name_to_strings(name) + ('1', '2017-02-01 00:00:00') + """ + if isinstance(name, six.string_types): + return name + elif isinstance(name, tuple): + return tuple(map(_column_name_to_strings, name)) + elif isinstance(name, collections.Sequence): + raise TypeError("Unsupported type for MultiIndex level") + return str(name) + + def dataframe_to_arrays(df, schema, preserve_index): names = [] arrays = [] @@ -217,7 +304,7 @@ def dataframe_to_arrays(df, schema, preserve_index): for name in df.columns: col = df[name] if not isinstance(name, six.string_types): - name = str(name) + name = str(_column_name_to_strings(name)) if schema is not None: field = schema.field_by_name(name) @@ -267,25 +354,30 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1): import pyarrow.lib as lib index_columns = [] + column_indexes = [] index_arrays = [] index_names = [] schema = table.schema row_count = table.num_rows metadata = schema.metadata - if metadata is not None and b'pandas' in metadata: + has_pandas_metadata = metadata is not None and b'pandas' in metadata + + if has_pandas_metadata: pandas_metadata = json.loads(metadata[b'pandas'].decode('utf8')) index_columns = pandas_metadata['index_columns'] + column_indexes = pandas_metadata.get('column_indexes', []) table = _add_any_metadata(table, pandas_metadata) block_table = table + # Build up a list of index columns and names while removing those columns + # from the original table for name in index_columns: i = schema.get_field_index(name) if i != -1: col = table.column(i) - index_name = (None if is_unnamed_index_level(name) - else name) + index_name = None if is_unnamed_index_level(name) else name col_pandas = col.to_pandas() values = col_pandas.values if not values.flags.writeable: @@ -299,8 +391,12 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1): block_table.schema.get_field_index(name) ) + # Convert an arrow table to Block from the internal pandas API result = lib.table_to_blocks(options, block_table, nthreads, memory_pool) + # Construct the individual blocks converting dictionary types to pandas + # categorical types and Timestamps-with-timezones types to the proper + # pandas Blocks blocks = [] for item in result: block_arr = item['block'] @@ -321,6 +417,7 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1): block = _int.make_block(block_arr, placement=placement) blocks.append(block) + # Construct the row index if len(index_arrays) > 1: index = pd.MultiIndex.from_arrays(index_arrays, names=index_names) elif len(index_arrays) == 1: @@ -328,11 +425,51 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1): else: index = pd.RangeIndex(row_count) - axes = [ - [column.name for column in block_table.itercolumns()], - index - ] + column_strings = [x.name for x in block_table.itercolumns()] + + # If we're passed multiple column indexes then evaluate with + # ast.literal_eval, since the column index values show up as a list of + # tuples + to_pair = ast.literal_eval if len(column_indexes) > 1 else lambda x: (x,) + + # Create the column index + # Construct the base index + if not column_strings: + columns = pd.Index(column_strings) + else: + columns = pd.MultiIndex.from_tuples( + list(map(to_pair, column_strings)), + names=[col_index['name'] for col_index in column_indexes] or None, + ) + + # if we're reconstructing the index + if has_pandas_metadata: + + # Get levels and labels, and provide sane defaults if the index has a + # single level to avoid if/else spaghetti. + levels = getattr(columns, 'levels', None) or [columns] + labels = getattr(columns, 'labels', None) or [ + pd.RangeIndex(len(level)) for level in levels + ] + + # Convert each level to the dtype provided in the metadata + levels_dtypes = [ + (level, col_index.get('numpy_type', level.dtype)) + for level, col_index in zip_longest( + levels, column_indexes, fillvalue={} + ) + ] + new_levels = [ + level if level.dtype == dtype else level.astype(dtype) + for level, dtype in levels_dtypes + ] + columns = pd.MultiIndex( + levels=new_levels, + labels=labels, + names=columns.names + ) + axes = [columns, index] return _int.BlockManager(blocks, axes) diff --git a/python/pyarrow/tests/test_convert_pandas.py b/python/pyarrow/tests/test_convert_pandas.py index 182f3afc7e6..459e782d608 100644 --- a/python/pyarrow/tests/test_convert_pandas.py +++ b/python/pyarrow/tests/test_convert_pandas.py @@ -124,6 +124,74 @@ def test_non_string_columns(self): table = pa.Table.from_pandas(df) assert table.column(0).name == '0' + def test_column_index_names_are_preserved(self): + df = pd.DataFrame({'data': [1, 2, 3]}) + df.columns.names = ['a'] + self._check_pandas_roundtrip(df, check_index=True) + + def test_multiindex_columns(self): + columns = pd.MultiIndex.from_arrays([ + ['one', 'two'], ['X', 'Y'] + ]) + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns) + self._check_pandas_roundtrip(df, check_index=True) + + def test_multiindex_columns_with_dtypes(self): + columns = pd.MultiIndex.from_arrays( + [ + ['one', 'two'], + pd.DatetimeIndex(['2017-08-01', '2017-08-02']), + ], + names=['level_1', 'level_2'], + ) + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns) + self._check_pandas_roundtrip(df, check_index=True) + + def test_integer_index_column(self): + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')]) + self._check_pandas_roundtrip(df, check_index=True) + + def test_categorical_column_index(self): + # I *really* hope no one uses category dtypes for single level column + # indexes + df = pd.DataFrame( + [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)], + columns=pd.Index(list('def'), dtype='category') + ) + t = pa.Table.from_pandas(df, preserve_index=True) + raw_metadata = t.schema.metadata + js = json.loads(raw_metadata[b'pandas'].decode('utf8')) + + column_indexes, = js['column_indexes'] + assert column_indexes['name'] is None + assert column_indexes['pandas_type'] == 'categorical' + assert column_indexes['numpy_type'] == 'object' + + md = column_indexes['metadata'] + assert md['num_categories'] == 3 + assert md['ordered'] is False + + def test_datetimetz_column_index(self): + # I *really* hope no one uses category dtypes for single level column + # indexes + df = pd.DataFrame( + [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)], + columns=pd.date_range( + start='2017-01-01', periods=3, tz='America/New_York' + ) + ) + t = pa.Table.from_pandas(df, preserve_index=True) + raw_metadata = t.schema.metadata + js = json.loads(raw_metadata[b'pandas'].decode('utf8')) + + column_indexes, = js['column_indexes'] + assert column_indexes['name'] is None + assert column_indexes['pandas_type'] == 'datetimetz' + assert column_indexes['numpy_type'] == 'datetime64[ns]' + + md = column_indexes['metadata'] + assert md['timezone'] == 'America/New_York' + def test_float_no_nulls(self): data = {} fields = [] diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 7a771654e4d..6802c43de40 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -409,10 +409,6 @@ def test_serialize_pandas_empty_dataframe(): _check_serialize_pandas_round_trip(df) -@pytest.mark.xfail( - raises=AssertionError, - reason='Non string columns are not supported', -) def test_pandas_serialize_round_trip_not_string_columns(): df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc'))) buf = pa.serialize_pandas(df) diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py index d51b85d8ee4..deb4b3f35a6 100644 --- a/python/pyarrow/tests/test_parquet.py +++ b/python/pyarrow/tests/test_parquet.py @@ -161,6 +161,28 @@ def test_pandas_parquet_custom_metadata(tmpdir): assert js['index_columns'] == ['__index_level_0__'] +@parquet +def test_pandas_parquet_column_multiindex(tmpdir): + import pyarrow.parquet as pq + + df = alltypes_sample(size=10) + df.columns = pd.MultiIndex.from_tuples( + list(zip(df.columns, df.columns[::-1])), + names=['level_1', 'level_2'] + ) + + filename = tmpdir.join('pandas_rountrip.parquet') + arrow_table = pa.Table.from_pandas(df) + assert b'pandas' in arrow_table.schema.metadata + + _write_table(arrow_table, filename.strpath, version="2.0", + coerce_timestamps='ms') + + table_read = pq.read_pandas(filename.strpath) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + @parquet def test_pandas_parquet_2_0_rountrip_read_pandas_no_index_written(tmpdir): import pyarrow.parquet as pq