Source code for dimarray.dataset

""" collection of base obeje
"""
from __future__ import absolute_import
from collections import OrderedDict as odict
import warnings, copy
from future.utils import string_types
import numpy as np

import dimarray as da  # for the doctest, so that they are testable via py.test
from dimarray.tools import format_doc, isscalar
from dimarray.config import get_option
from dimarray.compat.pycompat import dictkeys, dictvalues

from .core import DimArray, array, Axis, Axes
from .core import align as align_axes, stack, concatenate
from .core.align import _check_stack_args, _get_axes, stack, concatenate, _check_stack_axis, get_dims as _get_dims, reindex_like
from .core.transform import interp_like, _interp_internal_from_weight, _interp_internal_get_weights, _interp_internal_maybe_sort
from .core import pandas_obj
from .core.bases import AbstractDataset, GetSetDelAttrMixin, OpMixin
from .prettyprinting import repr_dataset

class DatasetAxes(Axes):
    """Dataset axes, overloaded to propagate modifications to the individual arrays
    """
    def __init__(self, ds):
        super(DatasetAxes, self).__init__()
        assert isinstance(ds, Dataset), "DatasetAxes can only be initialized empty with a Dataset"
        self._ds = ds  # attached dataset
    def __setitem__(self, key, item):
        super(DatasetAxes, self).__setitem__(key, item)
        # also apply the change to the contained DimArrays
        for k in self._ds.keys():
            dima = self._ds[k]
            if key not in dima.dims: 
                continue
            dima.axes[key] = self[key]

class Dataset(AbstractDataset, odict, OpMixin, GetSetDelAttrMixin):
# class Dataset(AbstractDataset, odict):
    """ Container for a set of aligned objects
    
    >>> ds = da.Dataset()
    >>> ds['a'] = np.arange(3) 
    >>> ds['b'] = [1,2,3]       
    >>> ds.set_axis(['a','b','c'], name='myaxis')
    >>> ds
    Dataset of 2 variables
    0 / myaxis (3): 'a' to 'c'
    a: ('myaxis',)
    b: ('myaxis',)
    """
    _constructor = DimArray

    def __init__(self, *args, **kwargs):
        """ initialize a dataset from a set of objects of varying dimensions

        data  : dict of DimArrays or list of named DimArrays or Axes object
        keys  : keys to order data if provided as dict, or to name data if list
        """
        assert not {'axes','keys'}.issubset(kwargs.keys()) # just to check bugs due to back-compat ==> TO BE REMOVED AFTER DEBUGGING

        # check input arguments: same init as odict
        data = odict(*args)
        for k in sorted(kwargs):
            data[k] = kwargs[k]  # sorted keys for consistent output

        # Basic initialization
        #self._axes = Axes()
        self._axes = DatasetAxes(self)
        self._attrs = odict()

        # initialize an ordered dictionary
        super(Dataset, self).__init__()
        #self.data = odict()

        values = dictvalues(data)
        keys = dictkeys(data)

        # Check everything is a DimArray
        #for key, value in zip(keys, values):
        for i, key in enumerate(keys):
            if not isinstance(values[i], DimArray):
                values[i] = self._constructor(values[i])

        # Align objects
        values = align_axes(values)

        # Append object (will automatically update self.axes)
        for key, value in zip(keys, values):
            self[key] = value

    @property
    def axes(self):
        return self._axes

    @axes.setter
    def axes(self, newaxes):
        for ax in newaxes:
            if ax.name in self.dims:
                self.axes[ax.name] = ax
            else:
                self.axes.append(ax)
        assert isinstance(self.axes, Axes)

    @property
    def dims(self):
        """ tuple of dimensions contained in the Dataset, consistently with DimArray's `dims`
        """
        return tuple([ax.name for ax in self._axes])

    @dims.setter
    def dims(self, newdims):
        """ rename all axis names at once
        """
        if not np.iterable(newdims): 
            raise TypeError("new dims must be iterable")
        if not len(newdims) == len(self.axes):
            raise ValueError("dimension mistmatch")

        # update every element's dimension
        for i, newname in enumerate(newdims):
            oldname = self.axes[i].name
            self.axes[i].name = newname

    @property
    def labels(self):
        """ tuple of axis values contained in the Dataset, consistently with DimArray's `labels`
        """
        return tuple([ax.values for ax in self.axes])

    # add pretty printer for ipython notebook 4, to avoid it just using OrderedDict methods...
    def _repr_pretty_(self, p, cycle):
        p.text(repr_dataset(self)) # seems to be enough...

    #
    # overload dictionary methods
    #
    def __delitem__(self, item):
        """ 
        """
        axes = self[item].axes
        #del self.data[item]
        super(Dataset, self).__delitem__(item)
        #del super(Dataset, self)[item]

        self._maybe_delete_axes(axes)

    def __getitem__(self, key):
        if key not in self.keys() and key in self.dims:
            # return an Axis as a DimArray variable if indexed with __getitem__
            return DimArray(self.axes[key], axes=[self.axes[key]], copy=True)
        else:
            return super(Dataset, self).__getitem__(key)

    def _maybe_delete_axes(self, axes):
        """ delete axes if not found in the dataset """
        # update axes
        for ax in axes:
            found = False
            for k in self:
                if ax.name in self[k].dims:
                    found = True
                    continue
            if not found:
                self.axes.remove(ax)

    def __setitem__(self, key, val):
        """ Make sure the object is a DimArray with appropriate axes

        Examples
        --------
        >>> ds = Dataset()
        >>> ds
        Dataset of 0 variables
        >>> a = DimArray([0, 1, 2], dims=('time',))
        >>> ds['yo'] = a 
        >>> ds['yo']
        dimarray: 3 non-null elements (0 null)
        0 / time (3): 0 to 2
        array([0, 1, 2])
        """
        if not isinstance(val, DimArray):
            val = self._constructor(val)

        # Remove any superfluous axis
        # first remove old element to make sure axes at not leftover
        if key in self.keys(): 
            _maybe_obsolete_axes = [ax for ax in self[key].axes if ax.name not in val.dims]
        else:
            _maybe_obsolete_axes = []

        # shallow copy of the DimArray so that its axes attribute can be 
        # modified without affecting the original array
        val = copy.copy(val)  
        val._axes = copy.deepcopy(val.axes)

        # Check dimensions
        # make sure axes match those of the dataset
        for i, newaxis in enumerate(val.axes):

            # Check dimensions if already existing axis
            if newaxis.name in [ax.name for ax in self.axes]:
                existing_axis = self.axes[newaxis.name]
                if not newaxis == existing_axis:
                    raise ValueError("axes values do not match, align data first.\
                            \nDataset: {}, \nGot: {}".format(existing_axis, newaxis))

                # assign the Dataset axis : they all must share the same axis
                val.axes[i] = existing_axis

            # Append new axis
            else:
                self.axes.append(newaxis)  

            assert val.axes[i] is self.axes[newaxis.name]

        super(Dataset, self).__setitem__(key, val)

        # Remove obsolete axes

        if len(_maybe_obsolete_axes) > 0:
            self._maybe_delete_axes(_maybe_obsolete_axes)

    def copy(self):
        ds2 = super(Dataset, self).copy() # odict method, copy axes but not metadata
        ds2.attrs.update(self.attrs)
        return ds2

    def __eq__(self, other):
        """ test equality but bypass numpy's __eq__ method
        """
        return isinstance(other, Dataset) and self.keys() == other.keys() \
                and self.axes == other.axes \
                and np.all([np.all(self[k] == other[k]) for k in self.keys()])

    #
    # Backends
    #
    def write_nc(self, f, mode='w', clobber=True, format=None, **kwargs):
        """ Write Dataset to netCDF file.

        Wrapper around DatasetOnDisk

        Parameters
        ----------
        f : file name
        mode, clobber, format : seel netCDF4-python doc
        **kwargs : passed to netCDF4.Dataset.createVAriable (compression)
        """
        from dimarray.io.nc import DatasetOnDisk, nc, _maybe_open_file
        f, close = _maybe_open_file(f, mode=mode, clobber=clobber,format=format)
        store = DatasetOnDisk(f)
        # store = DatasetOnDisk(f, mode=mode, clobber=clobber, format=format)
        # first write all axes
        for ax in self.axes:
            if ax.name not in store.dims:
                store.axes.append(ax, **kwargs)
        # then variables
        for name in self.keys():
            store.write(name, self[name], **kwargs)
        store.attrs.update(self.attrs) # attributes
        if isinstance(f, nc.Dataset): store.close() # do not close (deprecated)

    def write(self, *args, **kwargs):
        warnings.warn("Deprecated. Use write_nc.", FutureWarning)
        self.write_nc(*args, **kwargs)

    @classmethod
    def read_nc(cls, f, *args, **kwargs):
        """ Read dataset from netCDF file.
        """
        warnings.warn("Deprecated. Use dimarray.read_nc or dimarray.open_nc", FutureWarning)
        return da.io.nc.read_nc(f, *args, **kwargs)
    read = read_nc

    def to_array(self, axis=None, keys=None):
        """ Convert to DimArray

        axis  : axis name, by default "unnamed"
        """
        #if names is not None or dims is not None:
        #    return self.subset(names=names, dims=dims).to_array()

        if axis is None:
            axis = "unnamed"
            if axis in self.dims:
                i = 1
                while "unnamed_{}".format(i) in self.dims:
                    i+=1
                axis = "unnamed_{}".format(i)

        if axis in self.dims:
            raise ValueError("please provide an axis name which does not \
                    already exist in Dataset")

        if keys is None:
            keys = self.keys()

        # align all variables to the same dimensions
        data = odict()

        for k in keys:
            data[k] = self[k].reshape(self.dims).broadcast(self.axes)

        # make it a numpy array
        data = [data[k].values for k in keys]
        data = np.array(data)

        # determine axes
        axes = [Axis(keys, axis)] + self.axes 

        return self._constructor(data, axes)

    def take(self, names=None, indices=None, axis=0, indexing=None, tol=None, keepdims=False,
            raise_error=None, #
             ):
        """ Analogous to DimArray's take, but for each DimArray of the Dataset

        Parameters
        ----------
        names : list of variables to read, optional
        indices : int or list or slice (single-dimensional indices)
                   or a tuple of those (multi-dimensional)
                   or `dict` of { axis name : axis indices }
            Indices refer to Dataset axes. Any item that does not possess
            one of the dimensions will not be indexed along that dimension.
            For example, scalar items will be left unchanged whatever indices
            are provided.
        axis : None or int or str, optional
            if specified and indices is a slice, scalar or an array, assumes 
            indexing is along this axis.
        indexing : {'label', 'position'}, optional
            Indexing mode. 
            - "label": indexing on axis labels (default)
            - "position": use numpy-like position index
            Default value can be changed in dimarray.rcParams['indexing.by']
        tol : float, optional
            tolerance when looking for numerical values, e.g. to use nearest 
            neighbor search, default `None`.
        keepdims : bool, optional 
            keep singleton dimensions (default False)

        Returns
        -------
        Dataset

        See Also
        --------
        DimArrayOnDisk.read, DimArray.take

        Examples
        --------
        >>> a = DimArray([1,2,3], axes=('time', [1950, 1951, 1952]))
        >>> b = DimArray([11,22,33], axes=('time', [1951, 1952, 1953]))
        >>> ds = Dataset(a=a, b=b)
        >>> ds
        Dataset of 2 variables
        0 / time (4): 1950 to 1953
        a: ('time',)
        b: ('time',)
        >>> ds.take(indices=1951, axis='time')
        Dataset of 2 variables
        a: 2.0
        b: 11.0
        >>> ds.take(indices=0, axis='time', indexing='position')
        Dataset of 2 variables
        a: 1.0
        b: nan
        >>> ds['c'] = DimArray([[1,2],[11,22],[111,222],[3,4]], axes=[('time', [1950,1951,1952,1953]),('item',['a','b'])])
        >>> ds.take(indices={'time':1950})
        Dataset of 3 variables
        0 / item (2): 'a' to 'b'
        a: 1.0
        b: nan
        c: ('item',)
        >>> ds.take(indices={'time':1950})['c']
        dimarray: 2 non-null elements (0 null)
        0 / item (2): 'a' to 'b'
        array([1, 2])
        >>> ds.take(indices={'item':'b'})
        Dataset of 3 variables
        0 / time (4): 1950 to 1953
        a: ('time',)
        b: ('time',)
        c: ('time',)
        """
        if raise_error is not None:
            warnings.warn(FutureWarning("raise_error is deprecated: now always True"))

        # automatically read all variables to load (except for the dimensions)
        if names is None:
            names = self.keys()
        elif isinstance(names, string_types):
            raise TypeError("Please provide a sequence of variables to read.")

        tuple_indices = self._get_indices(indices, axis=axis, tol=tol, keepdims=keepdims, indexing=indexing)
        dict_indices = {dim:tuple_indices[i] for i, dim in enumerate(self.dims)}

        data = Dataset()
        # start with the axes, to make sure the ordering is maintained
        data.axes = self._getaxes_ortho(tuple_indices) 
        for nm in names:
            data[nm] = self[nm].take(indices={dim:dict_indices[dim] for dim in self[nm].dims}, indexing='position')
        data.attrs.update(self.attrs) # dataset's metadata
        return data

    def _apply_dimarray_axis(self, funcname, *args, **kwargs):
        """ Apply a function on every Dataset variable. 
        
        If the 'axis=' parameter is passed, only the variables with the required axis are called.
        """
        axis = kwargs.pop('axis',None)
        if axis is not None: axis = self.axes[axis].name
        kwargs['axis'] = axis

        d = self.to_odict()
        for k in self.keys():
            if axis is not None and axis not in self[k].dims: 
                continue
            #d[k] = self[k].apply(func, *args, **kwargs)
            d[k] = getattr(self[k], funcname)(*args, **kwargs)

        return Dataset(d)

    def mean(self, axis=0, **kwargs):
        """ Apply transformantion on every variable of the Dataset

        Examples
        --------
        >>> a = DimArray([1,2,3], axes=('time', [1950, 1951, 1952]))
        >>> b = DimArray([[11,22,33],[44,55,66]], axes=[('items',['a','b']), ('time', [1950, 1951, 1952])])
        >>> ds = Dataset(([('a',a),('b',b)]))  # keep order for doctest
        >>> ds.mean(axis='time')
        Dataset of 2 variables
        0 / items (2): 'a' to 'b'
        a: 2.0
        b: ('items',)
        >>> ds.mean(axis='items')
        Dataset of 2 variables
        0 / time (3): 1950 to 1952
        a: ('time',)
        b: ('time',)
        """
        return self._apply_dimarray_axis('mean', axis=axis, **kwargs)

    def std(self, axis=0, **kwargs): return self._apply_dimarray_axis('std', axis=axis, **kwargs)
    def var(self, axis=0, **kwargs): return self._apply_dimarray_axis('var', axis=axis, **kwargs)
    def median(self, axis=0, **kwargs): return self._apply_dimarray_axis('median', axis=axis, **kwargs)
    def sum(self, axis=0, **kwargs): return self._apply_dimarray_axis('sum', axis=axis, **kwargs)

    def to_dict(self):
        """ export to dict
        """
        return dict({nm:self[nm] for nm in self.keys()})

    def to_odict(self):
        """ export to ordered dict
        """
        return odict([(nm, self[nm]) for nm in self.keys()])

    def set_axis(self, values=None, axis=0, name=None, inplace=True, **kwargs):
        """ Set axis values, name and attributes of the Dataset
        
        Parameters
        ----------
        values : numpy array-like or mapper (callable or dict), optional
            - array-like : new axis values, must have exactly the same 
            length as original axis
            - dict : establish a map between original and new axis values
            - callable : transform each axis value into a new one
            - if None, axis values are left unchanged
            Default to None.
        axis : int or str, optional
            axis to be (re)set
        name : str, optional
            rename axis
        inplace : bool, optional
            modify dataset axis in-place (True) or return copy (False)? 
            (default True)
        **kwargs : key-word arguments
            Also reset other axis attributes, which can be single metadata
            or other axis attributes, via using `setattr`
            This includes special attribute `attrs` (which reset all attributes)

        Returns
        -------
        Dataset instance, or None if inplace is True

        Notes
        -----
        This affects all DimArray present in the Dataset, since they share the same
        axes.

        Examples
        --------
        >>> ds = Dataset()
        >>> ds['a'] = da.zeros(shape=(3,))  # some dimarray with dimension 'x0'
        >>> ds['b'] = da.zeros(shape=(3,4)) # dimensions 'x0', 'x1'
        >>> ds.set_axis(['a','b','c'], axis='x0')
        >>> ds
        Dataset of 2 variables
        0 / x0 (3): 'a' to 'c'
        1 / x1 (4): 0 to 3
        a: ('x0',)
        b: ('x0', 'x1')
        """
        if not inplace: self = self.copy()
        self.axes[axis].set(values=values, inplace=True, name=name, **kwargs)
        if not inplace: return self

    def reset(self, values=None, axis=0, name=None, **kwargs):
        "deprecated, see Dataset.set" 
        warnings.warn("Deprecated. Use Dataset.set", FutureWarning)
        if values is None: values = np.arange(self.size)
        if values is False: values = None
        return self.set(values, axis=axis, name=name, **kwargs)

    def rename_keys(self, mapper, inplace=True):
        """ Rename all variables in the Dataset

        Possible speedup compared to a classical dict-like operation 
        since an additional check on the axes is avoided.

        Parameters
        ----------
        mapper : dict-like or function to map oldname -> newname
        inplace : bool, optional
            if True, in-place modification, otherwise a copy with modified
            keys is returned (default: True)

        Returns
        -------
        None, or Dataset if inplace is False

        Examples
        --------
        >>> ds = da.Dataset(a=da.zeros(shape=(3,)), b=da.zeros(shape=(3,2)))
        >>> ds
        Dataset of 2 variables
        0 / x0 (3): 0 to 2
        1 / x1 (2): 0 to 1
        a: ('x0',)
        b: ('x0', 'x1')
        >>> ds.rename_keys({'b':'c'})
        >>> ds
        Dataset of 2 variables
        0 / x0 (3): 0 to 2
        1 / x1 (2): 0 to 1
        a: ('x0',)
        c: ('x0', 'x1')
        """
        if inplace:
            ds = self
        else:
            ds = self.copy()

        if isinstance(mapper, dict):
            iterkeys = mapper.items()
        else:
            if not callable(mapper):
                raise TypeError("mapper must be callable")
            iterkeys = [(old, mapper(old)) for old in ds.keys()]

        for old, new in iterkeys:
            val = super(Dataset, ds).__getitem__(old) # same as ds[old]
            super(Dataset, ds).__setitem__(new, val)
            if old != new:
                super(Dataset, ds).__delitem__(old)

        if not inplace:
            return ds

    def rename_axes(self, mapper, inplace=True):
        """ Rename axes, analogous to rename_keys for axis names
        """
        if inplace:
            ds = self
        else:
            ds = self.copy()

        if isinstance(mapper, dict):
            iterkeys = mapper.items()
        else:
            if not callable(mapper):
                raise TypeError("mapper must be callable")
            iterkeys = [(old, mapper(old)) for old in ds.dims]

        for old, new in iterkeys:
            ds.axes[old].name = new

        if not inplace:
            return ds

    def reduce_axis(self, func, axis=0, keepdims=False, keepattrs=False, **kwargs):
        """ reduce an axis in a Dataset

        Parameters
        ----------
        func : operation that can be applied on a numpy array, 
            which takes `axis` int argument
        keepdims : whether or not the axis is removed by the transformation
        **kwargs : passed to func
        """
        # prepare new axes
        pos, name = self._get_axis_info(axis)
        if keepdims:
            newaxes = [ax.copy() if ax.name != name else Axis(func(ax.values, axis=0, **kwargs), ax.name) for ax in self.axes]
        else:
            newaxes = [ax.copy() for ax in self.axes if ax.name != name ]
        newdims = [ax.name for ax in newaxes]

        # initialize dataset
        dataset = self.__class__()
        dataset.axes = newaxes

        # apply function to all elements
        for k in self.keys():
            item = self[k]
            # skip DimArrays without the dimension of interest
            try:
                pos, _ = item._get_axis_info(name)
            except:
                dataset[k] = item # no axis is present
                continue
            newval = func(item.values, axis=pos, **kwargs)
            dima = DimArray(newval, [newaxes[newdims.index(dim)] for dim in item.dims if dim in newdims])
            if keepattrs: 
                dima.attrs.update(item.attrs)
            # super(Dataset, dataset).__setitem__(k, dima)
            dataset[k] = dima # for now with check

        if keepattrs: 
            dataset.attrs.update(self.attrs) # keep metadata?
        return dataset

    _getitem = take # for ix, nloc, iloc, loc, sel, isel to work

    def take_axis(self, indices, axis=0, indexing=None, mode='raise'):
        """ Analogous to DimArray.take_axis
        """
        if not np.iterable(indices):
            raise TypeError("indices must be iterable")
        indexing = indexing or getattr(self, "_indexing", None) or get_option("indexing.by")
        if indexing == "label":
            indices = self.axes[axis].loc(indices, mode=mode)
        if mode not in ('raise', 'clip', 'wrap'):
            mode = 'raise'
        return self.reduce_axis(np.take, indices=indices, axis=axis, mode=mode, keepattrs=True, keepdims=True)

    def sort_axis(self, axis=0, kind='quicksort'):
        """Analogous to DimArray.sort_axis, for each element in a Dataset
        """
        index = self.axes[axis].values
        ii = index.argsort(kind=kind) # the default
        return self.take_axis(ii, axis=axis, indexing='position')

    def reindex_axis(self, values, axis=0, fill_value=np.nan, raise_error=False, method=None):
        """ analogous to DimArray.reindex_axis, but for a whole Dataset 

        See DimArray.reindex_axis for documention.
        """
        if isinstance(values, Axis):
            newaxis = values
            values = newaxis.values
            axis = newaxis.name
        elif isscalar(values) or type(values) is slice:
            raise TypeError("Please provide list, array-like or Axis object to perform re-indexing")
        else:
            values = np.asarray(values)

        # take axis, do not raise error
        dataset = self.take_axis(values, axis=axis, indexing='label', 
                                 mode='raise' if raise_error else 'clip')

        # Replace mismatch with missing values?
        newax = dataset.axes[axis]
        mask = newax.values != values
        any_nan = np.any(mask)

        if any_nan:
            # Make sure the axis values match the requested new axis
            dataset.axes[axis][mask] = values[mask]

            for k in dataset.keys():
                if method is None:
                    dataset[k].put(mask, fill_value, axis=axis, inplace=True, indexing="position", cast=True)

        return dataset

    def reindex_like(self, other, **kwargs):
        """Analogous to DimArray.reindex_like

        >>> ds1 = da.Dataset(a=da.DimArray(axes=[[1,2,3]]))
        >>> ds2 = da.Dataset(b=da.DimArray(axes=[[1.,3.],['a','b']]))
        >>> ds2.reindex_like(ds1)
        Dataset of 1 variable
        0 / x0 (3): 1.0 to 3.0
        1 / x1 (2): 'a' to 'b'
        b: ('x0', 'x1')
        """
        return reindex_like(self, other, **kwargs)

    def interp_axis(self, values, axis=0, left=np.nan, right=np.nan, issorted=None):
        """ Analogous to DimArray.interp_axis
        """
        # copy some of DimArray.interp_axis code to re-use the weights
        newaxis = Axis(values, self.axes[axis].name) # necessary array & type checks 

        # sort the axis if needed, to apply numpy interp
        obj = _interp_internal_maybe_sort(self, axis, issorted)
        curaxis = obj.axes[axis]

        kwargs = _interp_internal_get_weights(curaxis.values, newaxis.values)

        # loop over all dimarray
        return obj.reduce_axis(_interp_internal_from_weight, axis=axis, keepdims=True, keepattrs=True, left=left, right=right, **kwargs)

    def interp_like(self, other, **kwargs):
        """Analogous to DimArray.interp_like
        """
        return interp_like(self, other, **kwargs)

    #
    # Operations
    #
    def _binary_op(self, func, other):
        """ generalize DimArray operation to a Dataset, for each key

        In case the keys differ, returns the intersection of the two datasets

        Just for testing:
        >>> ds = Dataset(b=DimArray([[0.,1],[1,2]]))
        >>> -ds
        Dataset of 1 variable
        0 / x0 (2): 0 to 1
        1 / x1 (2): 0 to 1
        b: ('x0', 'x1')
        >>> -ds["b"]
        dimarray: 4 non-null elements (0 null)
        0 / x0 (2): 0 to 1
        1 / x1 (2): 0 to 1
        array([[-0., -1.],
               [-1., -2.]])
        >>> np.all(ds == ds)
        True
        >>> assert isinstance(-ds, Dataset)
        >>> assert isinstance(ds/0.5, Dataset)
        >>> assert isinstance(ds*0, Dataset)
        >>> (-ds -ds + ds/0.5 + ds*0+1)['b']
        dimarray: 4 non-null elements (0 null)
        0 / x0 (2): 0 to 1
        1 / x1 (2): 0 to 1
        array([[1., 1.],
               [1., 1.]])
        >>> ds += 1
        >>> ds['b']
        dimarray: 4 non-null elements (0 null)
        0 / x0 (2): 0 to 1
        1 / x1 (2): 0 to 1
        array([[1., 2.],
               [2., 3.]])
        """
        assert isinstance(other, Dataset) or isscalar(other), "can only combine Datasets objects (func={})".format(func.__name__)
        # align all axes first
        reindex = get_option("op.reindex")
        if reindex and hasattr(other, 'axes') and other.axes != self.axes:
            other.reindex_like(self)
        # now proceed to operation
        res = self.__class__()
        for k1 in self.keys():
            if hasattr(other, 'keys'):
                for k2 in other.keys():
                    if k1 == k2:
                        res[k1] = self[k1]._binary_op(func, other[k2])
            else:
                res[k1] = self[k1]._binary_op(func, other)
        return res

    def _unary_op(self, func):
        res = self.__class__()
        for k in self.keys():
            res[k] = self[k]._unary_op(func)
        return res


[docs]def stack_ds(datasets, axis, keys=None, align=False, **kwargs): """ stack dataset along a new dimension Parameters ---------- datasets: sequence or dict of datasets axis: str, new dimension along which to stack the dataset keys, optional: stack axis values, useful if dataset is a sequence, or a non-ordered dictionary align, optional: if True, align axes (via reindexing) *prior* to stacking **kwargs : optional key-word arguments passed to align, if align is True Returns ------- stacked dataset See Also -------- concatenate_ds, stack, sort_axis Examples -------- >>> a = DimArray([1,2,3], dims=('dima',)) >>> b = DimArray([11,22], dims=('dimb',)) >>> ds = Dataset(a=a,b=b) # dataset of 2 variables from an experiment >>> ds2 = Dataset(a=a*2,b=b*2) # dataset of 2 variables from a second experiment >>> stack_ds([ds, ds2], axis='stackdim', keys=['exp1','exp2']) Dataset of 2 variables 0 / stackdim (2): 'exp1' to 'exp2' 1 / dima (3): 0 to 2 2 / dimb (2): 0 to 1 a: ('stackdim', 'dima') b: ('stackdim', 'dimb') """ # make a sequence of datasets datasets, keys = _check_stack_args(datasets, keys) # make sure the stacking dimension is ok dims = _get_dims(*datasets) axis = _check_stack_axis(axis, dims) if align: datasets = da.align(datasets, strict=True, **kwargs) # find the list of variables common to all datasets variables = None for ds in datasets: # check that stack axis is not already present assert axis not in ds.dims, axis+" already exists in the dataset" # check that variables have the same variables if variables is None: variables = ds.keys() else: assert sorted(ds.keys()) == sorted(variables), "variables differ across datasets" # Compute stacked dataset dataset = Dataset() for v in variables: arrays = [ds[v] for ds in datasets] array = stack(arrays, axis=axis, keys=keys, align=False) dataset[v] = array return dataset
[docs]def concatenate_ds(datasets, axis=0, align=False, **kwargs): """ concatenate two datasets along an existing dimension Parameters ---------- datasets: sequence of datasets axis: axis along which to concatenate align, optional: if True, align secondary axes (via reindexing) prior to concatenating **kwargs : optional key-word arguments passed to align, if align is True Returns ------- joint Dataset along axis NOTE: will raise an error if variables are there which do not contain the required dimension See Also -------- stack_ds, concatenate, sort_axis Examples -------- >>> a = da.zeros(axes=[list('abc')], dims=('x0',)) # 1-D DimArray >>> b = da.zeros(axes=[list('abc'), [1,2]], dims=('x0','x1')) # 2-D DimArray >>> ds = Dataset(a=a,b=b) # dataset of 2 variables from an experiment >>> a2 = da.ones(axes=[list('def')], dims=('x0',)) >>> b2 = da.ones(axes=[list('def'), [1,2]], dims=('x0','x1')) # 2-D DimArray >>> ds2 = Dataset(a=a2,b=b2) # dataset of 2 variables from a second experiment >>> concatenate_ds([ds, ds2]) Dataset of 2 variables 0 / x0 (6): 'a' to 'f' 1 / x1 (2): 1 to 2 a: ('x0',) b: ('x0', 'x1') """ # find the list of variables common to all datasets variables = None for ds in datasets: # check that variables have the same variables if variables is None: variables = ds.keys() else: assert sorted(ds.keys()) == sorted(variables), "variables differ across datasets" if align: # all dataset axes axis_nm = datasets[0].axes[axis].name aligned_dims = [d for d in _get_dims(*datasets) if d != axis_nm] for d in aligned_dims: datasets = da.align(datasets, axis=d, strict=True, **kwargs) # Compute concatenated dataset dataset = Dataset() for v in variables: arrays = [ds[v] for ds in datasets] array = concatenate(arrays, axis=axis, align=False, _no_check=align) dataset[v] = array return dataset