-
-
Notifications
You must be signed in to change notification settings - Fork 261
Description
What happened:
Using multioutput=None with the regression metrics in dask_ml.metrics.regression results in an error.
What you expected to happen:
I expected the behavior to be the same as the equivalent scikit-learn metrics functions, where multioutput=None means "all elements have the same weight".
Minimal Complete Verifiable Example:
In scikit-learn, the value of multioutput is passed through to np.average().
np.average() treats the value None as "equally-weighted", which is the same as passing multioutput = "uniform_average". From https://numpy.org/doc/stable/reference/generated/numpy.average.html
If weights=None, then all data in a are assumed to have a weight equal to one.
import numpy as np
from sklearn.metrics import mean_squared_error
a = np.random.uniform(size=(100, 3))
b = np.random.uniform(size=(100, 3))
raw_output = mean_squared_error(a, b, multioutput="raw_values")
print(raw_output)
# [0.17118814 0.15964742 0.13095381]
mean_squared_error(a, b, multioutput=None)
# 0.15392978995168105In dask-ml, passing multioutput=None results in an error.
import dask.array as da
from dask_ml.metrics.regression import mean_squared_error
a = da.random.uniform(size=(100, 3))
b = da.random.uniform(size=(100, 3))
raw_output = mean_squared_error(a, b, multioutput="raw_values")
print(raw_output)
# dask.array<mean_agg-aggregate, shape=(3,), dtype=float64, chunksize=(3,), chunktype=numpy.ndarray>
mean_squared_error(a, b, multioutput=None)
# ValueError: Weighted 'multioutput' not supported.Anything else we need to know?:
Environment:
- Dask version:
output of 'conda list | grep -E "dask|distributed"'
dask 2021.3.0 pyhd3eb1b0_0
dask-core 2021.3.0 pyhd3eb1b0_0
dask-glm 0.2.0 pypi_0 pypi
dask-saturn 0.2.2 pypi_0 pypi
dask-sphinx-theme 1.3.5 pypi_0 pypi
distributed 2021.3.0 py37h06a4308_0
- Python version, Operating System
output of 'conda info'
active environment : None
user config file : /home/jlamb/.condarc
populated config files :
conda version : 4.9.2
conda-build version : not installed
python version : 3.7.6.final.0
virtual packages : __glibc=2.27=0
__unix=0=0
__archspec=1=x86_64
base environment : /home/jlamb/miniconda3 (writable)
channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
https://repo.anaconda.com/pkgs/main/noarch
https://repo.anaconda.com/pkgs/r/linux-64
https://repo.anaconda.com/pkgs/r/noarch
package cache : /home/jlamb/miniconda3/pkgs
/home/jlamb/.conda/pkgs
envs directories : /home/jlamb/miniconda3/envs
/home/jlamb/.conda/envs
platform : linux-64
user-agent : conda/4.9.2 requests/2.22.0 CPython/3.7.6 Linux/5.4.0-70-generic ubuntu/18.04.4 glibc/2.27
UID:GID : 1000:1000
netrc file : None
offline mode : False
- Install method (conda, pip, source): installed from source as of f5e5bb4, using
python setup.py install.