-
-
Notifications
You must be signed in to change notification settings - Fork 261
Closed
Description
What happened:
Using the regression metrics like dask_ml.metrics.regression.mean_squared_error, I expected that passing y_true and y_pred with shape (n_samples, n_outputs) and multoutput="raw_values" would return:
- a Dask Array with shape
(n_outputs,), ifcompute=False - a
numpyarray with shape(n_outputs,), ifcompute=True
What you expected to happen:
Using dask_ml.metrics.regression.mean_squared_error() or dask_ml.merics.regression.mean_absolute_error() with multioutput="raw_values" always returns a Dask Array.
Minimal Complete Verifiable Example:
import dask.array as da
from dask_ml.metrics.regression import mean_squared_error
a = da.random.uniform(size=(100, 3))
b = da.random.uniform(size=(100, 3))
raw_output = mean_squared_error(a, b, multioutput="raw_values", compute=True)
type(raw_output)
# dask.array.core.ArrayAnything else we need to know?:
I started looking into this while working on #756
Environment:
- Dask version:
output of 'conda list | grep -E "dask|distributed"'
dask 2021.3.0 pyhd3eb1b0_0
dask-core 2021.3.0 pyhd3eb1b0_0
dask-glm 0.2.0 pypi_0 pypi
dask-saturn 0.2.2 pypi_0 pypi
dask-sphinx-theme 1.3.5 pypi_0 pypi
distributed 2021.3.0 py37h06a4308_0
- Python version, Operating System
output of 'conda info'
active environment : None
user config file : /home/jlamb/.condarc
populated config files :
conda version : 4.9.2
conda-build version : not installed
python version : 3.7.6.final.0
virtual packages : __glibc=2.27=0
__unix=0=0
__archspec=1=x86_64
base environment : /home/jlamb/miniconda3 (writable)
channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
https://repo.anaconda.com/pkgs/main/noarch
https://repo.anaconda.com/pkgs/r/linux-64
https://repo.anaconda.com/pkgs/r/noarch
package cache : /home/jlamb/miniconda3/pkgs
/home/jlamb/.conda/pkgs
envs directories : /home/jlamb/miniconda3/envs
/home/jlamb/.conda/envs
platform : linux-64
user-agent : conda/4.9.2 requests/2.22.0 CPython/3.7.6 Linux/5.4.0-70-generic ubuntu/18.04.4 glibc/2.27
UID:GID : 1000:1000
netrc file : None
offline mode : False
- Install method (conda, pip, source): installed from source as of f5e5bb4, using
python setup.py install.
Metadata
Metadata
Assignees
Labels
No labels