Source code for dimarray.core.bases

""" Base classes
"""
from __future__ import print_function
from __future__ import absolute_import
from future.utils import string_types, PY2
import warnings
import copy
from collections import OrderedDict as odict
import numpy as np
from dimarray.config import get_option
from dimarray.tools import is_numeric
from dimarray.core.indexing import locate_one, locate_many, locate_slice, expanded_indexer
from dimarray.prettyprinting import repr_axis, repr_dataset, repr_axes, str_axes, str_dataset, str_dimarray

class GetSetDelAttrMixin(object):
    """ Class to overload the __getattr__, __setattr__, __delattr__
    functions for easy access to attrs and axes. It assumes the attributes
    are stored in an `attrs` dictionary attribute, and check for presence of axes
    in an `axes` attribute and `dims`.
    """
    __metadata_exclude__ = [] # do not add to attrs
    __metadata_include__ = [] 

    def __getattr__(self, name):
        if hasattr(self.__class__, name):
            return object.__getattribute__(self, name)
        elif name not in self.__metadata_include__ \
                and (name.startswith('_') or name in self.__metadata_exclude__):
            pass # raise error
        elif hasattr(type(self), 'dims') and name in self.dims:
            return self.axes[name].values # return axis values
        elif name in self.attrs.keys():
            return self.attrs[name]
        raise AttributeError("{} object has no attribute {}".format(self.__class__.__name__, name))

    def __setattr__(self, name, value):
        if name not in self.__metadata_include__ and \
                (name.startswith('_') \
                 or name in self.__metadata_exclude__ \
                 or hasattr(self.__class__, name)):
            object.__setattr__(self, name, value) # do nothing special
        elif hasattr(type(self), 'axes') and name in self.dims:
            self.axes[name][()] = value # modify axis values
        else:
            self.attrs[name] = value # add as metadata

    def __delattr__(self, name):
        if not name.startswith('_') \
                and name not in self.__metadata_exclude__ \
                and not hasattr(self.__class__, name) \
                and name in self.attrs.keys():
            del self.attrs[name]
        else:
            return object.__delattr__(self, name)



# class GetSetDelAttrMixin(object):
#     pass
#
class AbstractHasMetadata(object):
       
    @property
    def attrs(self):
        return self._attrs # only for the in-memory array

    @attrs.setter
    def attrs(self, value):
        del self.attrs
        self.attrs.update(value)
         
    @attrs.deleter
    def attrs(self):
        for k in list(self.attrs.keys()):
            del self.attrs[k]

    def _repr(self, metadata=False):
        return NotImplementedError()

    def __repr__(self):
        return self._repr()

    __str__ = str_dimarray

    def summary(self):
        print(self.summary_repr())

    def summary_repr(self):
        return self._repr(metadata=True)

    def _metadata(self, meta=None):
        " for back compatibility "
        if meta is None:
            return self.attrs
        else:
            self.attrs.update(meta)


class AbstractAxis(AbstractHasMetadata):

    _tol = None

    def loc(self, val, tol=None, issorted=False, mode='raise'):
        """ Locate one or several values, using numpy functions

        Parameters
        ----------
        val: scalar or array-like or slice, optional
            The value(s) to look for in the axis.
        tol: None or float, optional
            If different from None, search nearest neighbour, with a certain 
            tolerance (np.argmin is used). None by default. This option is 
            applicable for numerical axes only. It does not apply for slices.
        issorted: bool, optional
            If True, assume the axis is sorted with increasing values (faster search)
        mode: {'raise', 'clip'}, optional
            Only applicable if `val` is array-like ignored otherwise and mode == 'raise'.
            If `mode=='clip'`, any label not present in the axis is clipped to 
            the nearest values (see np.searchsorted).
            If mode == 'raise' (the default), a check is performed on the result to ensure that
            all values were present, and raise an IndexError exception otherwise.

        Returns
        -------
        matches: integer position(s) of val in the axis

        Notes
        -----
        For single values and slices, non-zero elements in the array `a == val` are 
        searched for, and the first element is returned.
        For array-like indexer, the axis is sorted and np.searchsorted is applied.
        If tol is provided, a `np.argmin(np.abs(a-val))` is used.

        Examples
        --------
        >>> from dimarray import Axis
        >>> ax = Axis([1.5, 2.5, 3.5], 'myaxis')
        >>> ax.loc(2.5)
        1
        >>> ax.loc(slice(2, 3))
        slice(1, 2, None)
        >>> ax.loc(slice(2, 3.5))
        slice(1, 3, None)
        """
        values = self.values[()]
        tol=tol or self._tol

        if tol is not None and not self.is_numeric():
            tol = None # ignore tol parameter for non-numeric axes (an error will be raised if element is not found)

        if type(val) is slice:
            istart, istop = locate_slice(values, val.start, val.stop, val.step, issorted=issorted)
            matches = slice(istart, istop, val.step)

        elif np.isscalar(val):
            matches = locate_one(values, val, tol=tol, issorted=issorted)

        elif val is None:
            matches = values.tolist().index(val)

        elif hasattr(val, 'dtype') and val.dtype.kind == 'b':
            matches = val  # boolean indexing, do nothing

        elif tol is not None: # no scalar, but tolerance parameter provided
            matches = [locate_one(values, v, tol=tol, issorted=issorted) for v in val]

        else:
            matches = locate_many(values, val, issorted=issorted)

            if mode != 'clip':
                test = values[matches] != val
                mismatch = np.asarray(val)[test]
                if np.any(test):
                    raise IndexError("Some values where not found in the axis ({}): {}.".format(self.name, mismatch))

        return matches

    # def _repr(self, metadata=False):
    #     return repr_axis(self, metadata=metadata)
    _repr = repr_axis

    def __str__(self):
        """ simple string representation
        """
        #return "{}={}:{}".format(self.name, self.values[0], self.values[-1])
        return "{}({})={}:{}".format(self.name, self.size, *self._bounds())

    @property
    def dtype(self):
        return self.values.dtype

    def is_numeric(self):
        return is_numeric(self.values)

class AbstractAxes(object):
    _Axis = AbstractAxis
    def __setattr__(self, name, value):
        if not name.startswith('_'):
            raise AttributeError("Cannot set attribute to an Axes object")
        object.__setattr__(self, name, value)

    __repr__ = repr_axes
    __str__ = str_axes

class Indexable(object):
    """Object to be indexed
    """
    def __init__(self, getitem, setitem, delitem, args=(), **kwargs):
        self.getitem = getitem
        self.setitem = setitem
        self.delitem = delitem
        self.args = args
        self.kwargs = kwargs
    def __getitem__(self, idx):
        return self.getitem(*self.args, indices=idx, **self.kwargs)
    def __setitem__(self, idx, item):
        return self.setitem(*self.args, indices=idx, values=item, **self.kwargs)
    def __delitem__(self, idx):
        return self.delitem(indices=idx, *self.args, **self.kwargs)

class AbstractHasAxes(AbstractHasMetadata):
    """ class to handle things related to axes, such as overloading __getattr__
    """
    _indexing = None
    _tol = None # define a tol attribute

    # indexing methods to be overloaded
    def _getitem(self, indices=None, **kwargs):
        raise NotImplementedError()
    def _setitem(self, indices=None, values=None, **kwargs):
        raise NotImplementedError()
    def _delitem(self, indices=None, **kwargs):
        raise NotImplementedError()

    @property
    def dims(self):
        return tuple([ax.name for ax in self.axes])

    @dims.setter
    def dims(self, newdims):
        self._set_dims(newdims)

    def _set_dims(self, newdims):
        if not np.iterable(newdims): 
            raise TypeError("new dims must be iterable")
        if not isinstance(newdims, dict):
            if len(newdims) != len(self.dims):
                raise ValueError("dimensions number mismatch")
            newdims = dict(zip(self.dims, newdims))
        for old in newdims.keys():
            self.axes[old].name = newdims[old]

    @property
    def axes(self):
        raise NotImplementedError('need to be overloaded !')

    @property
    def ndim(self):
        return len(self.axes)

    @property
    def labels(self):
        """ axis values 
        """
        return tuple([ax.values for ax in self.axes])

    @labels.setter
    def labels(self, newlabels):
        """ change all labels at once
        """
        if not np.iterable(newlabels): 
            raise TypeError("new labels must be iterable")
        if not len(newlabels) == self.ndim:
            raise ValueError("dimension mistmatch")
        for i, lab in enumerate(newlabels):
            self.axes[i][()] = lab

    def _get_indices(self, indices, axis=0, indexing=None, tol=None, keepdims=False):
        """ Return an n-D indexer  
        
        Parameters
        ----------
        **kwargs: same as DimArray.take or DimArrayOnDisk.read

        Returns
        -------
        indexer : tuple of numpy-compatible indices, of length equal to the number of 
            dimensions.
        """
        indexing = indexing or getattr(self,'_indexing',None) or get_option('indexing.by')
        dims = self.dims

        if indices is None:
            indices = ()

        if tol is None:
            tol = getattr(self, '_tol', None)

        #
        # Convert indices to tuple, from a variety of input formats
        #
        # special case: numpy like (idx, axis)
        if axis not in (0, None):
            indices = {axis:indices}

        # special case: Axes is provided as index
        elif isinstance(indices, AbstractAxes):
            indices = {ax.name:ax.values for ax in indices}

        # should always be a tuple
        if isinstance(indices, dict):
            # replace int dimensions with str dimensions
            for k in indices:
                if not isinstance(k, string_types):
                    indices[dims[k]] = indices[k]
                    del indices[k] 
                else:
                    if k not in dims:
                        raise ValueError("Dimension {} not found. Existing dimensions: {}".format(k, dims))
            indices = tuple(indices[d] if d in indices else slice(None) for d in dims)

        # expand to N-D tuple, and expands ellipsis
        indices = expanded_indexer(indices, self.ndim)

        # load each dimension as necessary
        indexer = ()
        for i, ix in enumerate(indices):
            dim = dims[i]

            if not np.isscalar(ix) and not isinstance(ix, slice):
                ix = np.asarray(ix)

            # boolean indices are fine
            if isinstance(ix, np.ndarray) and ix.dtype.kind == 'b':
                pass

            # in case of label-based indexing, need to read the whole dimension
            # and look for the appropriate values
            elif indexing != 'position' and not (type(ix) is slice and ix == slice(None)):
                # find the index corresponding to the required axis value
                lix = ix
                ix = self.axes[dim].loc(lix, tol=tol)

            # numpy rule: a singleton list does not collapse the axis
            if keepdims and np.isscalar(ix):
                ix = [ix]

            indexer += (ix,)

        return indexer

    def _getaxes_ortho(self, idx_tuple):
        " idx: tuple of position indices  of length = ndim (orthogonal indexing)"
        axes = []
        for i, ix in enumerate(idx_tuple):
            ax = self.axes[i][ix]
            if not np.isscalar(ax): # do not include scalar axes
                axes.append(ax)
        return axes

    #
    # returns axis position and name based on either of them
    #
    def _get_axis_info(self, axis):
        """ axis position and name

        Parameters
        ----------
        axis : `int` or `str` or None

        Returns
        -------
        idx : `int`, axis position
        name : `str` or None, axis name
        """
        if axis is None:
            return None, None

        if isinstance(axis, string_types):
            idx = self.dims.index(axis)

        elif type(axis) is int:
            idx = axis

        else:
            raise TypeError("axis must be int or str, got:"+repr(axis))

        name = self.axes[idx].name
        return idx, name

    def _get_axes_info(self, axes):
        """ return axis (dimension) positions AND names from a sequence of axis (dimension) positions OR names

        Parameters
        ----------
        axes : sequence of str or int, representing axis (dimension) 
            names or positions, possibly mixed up.

        Returns
        -------
        pos : list of `int` indicating dimension's rank in the array
        names : list of dimension names
        """
        pos, names = zip(*[self._get_axis_info(x) for x in axes])
        return pos, names

    @property
    def ix(self):
        # " toggle between position-based and label-based indexing "
        # newindexing = 'label' if self._indexing=='position' else 'position'
        # new = copy.copy(self) # shallow copy, not to verwrite _indexing
        # new._indexing = newindexing
        indexing = 'position' if self._indexing != 'position' else 'label'
        return Indexable(self._getitem, self._setitem, self._delitem, indexing=indexing)

    # after xray: add sel, isel, loc, iloc methods
    def sel(self, **indices):
        return self.loc[indices]

    def isel(self, **indices):
        return self.iloc[indices]

    @property
    def loc(self):
        return Indexable(self._getitem, self._setitem, self._delitem, indexing='label')

    @property
    def iloc(self):
        # return self if self._indexing == 'position' else self.ix
        return Indexable(self._getitem, self._setitem, self._delitem, indexing='position')

    @property
    def nloc(self):
        # nearest neighbor loc
        return Indexable(self._getitem, self._setitem, self._delitem, indexing='label', tol=np.inf)


class OpMixin(object):
    """ overload basic operations
    """
    def _unary_op(self, func):
        raise NotImplementedError()
    def _binary_op(self, func, other):
        raise NotImplementedError()
    def _rbinary_op(self, func, other):
        return other._binary_op(func, self) # default only
    # def _cmp(self, func, other):
    #     return self._binary_op(func, other) # default only

    def __neg__(self): return self._unary_op(np.ndarray.__neg__)
    def __pos__(self): return self._unary_op(np.ndarray.__pos__)
    def __sqrt__(self, other): return self._unary_op(np.sqrt)
    def __invert__(self): return self._unary_op(np.invert)

    def __add__(self, other): return self._binary_op(np.add, other)
    def __sub__(self, other): return self._binary_op(np.subtract, other)
    def __mul__(self, other): return self._binary_op(np.multiply, other)

    def __div__(self, other): return self._binary_op(np.true_divide, other) # TRUE DIVIDE
    def __truediv__(self, other): return self._binary_op(np.true_divide, other)
    def __floordiv__(self, other): return self._binary_op(np.floor_divide, other)

    def __pow__(self, other): return self._binary_op(np.power, other)

    # reverse order operation
    def __radd__(self, other): return self + other
    def __rmul__(self, other): return self * other
    def __rsub__(self, other): return self._rbinary_op(np.subtract, other)
    def __rdiv__(self, other): return self._rbinary_op(np.true_divide, other)
    def __rpow__(self, other): return self._rbinary_op(np.power, other)



class AbstractDimArray(AbstractHasAxes):

    # @property
    # def tol(self):
    #     raise NotImplementedError()
    @property
    def values(self):
        raise NotImplementedError()
    @property
    def shape(self):
        return self.values.shape
    @property
    def ndim(self):
        return self.values.ndim
    @property
    def size(self):
        return self.values.size
    @property
    def item(self):
        return self.values.item

    def is_numeric(self):
        return is_numeric(self.values)

    _broadcast = False

    # The indexing machinery in functional form, 
    # to be called by __getitem__ with default arguments
    def _getitem(self, indices=None, axis=0, indexing=None, tol=None, broadcast=None, keepdims=False,
                 broadcast_arrays=None,  # back-compatibility for broadcast
                 ):
        if indices is None:
            indices = ()

        if broadcast_arrays is not None:
            warnings.warn(FutureWarning("broadcast_arrays is deprecated, use broadcast instead"))
            broadcast = broadcast_arrays

        if broadcast is None: 
            if self._broadcast is None:
                broadcast = get_option('indexing.broadcast')
            else: 
                broadcast = self._broadcast

        # special-case: full-shape boolean indexing (will fail with netCDF4)
        if self._is_boolean_index_nd(indices):
            if hasattr(self, 'compress'):
                return self.compress(indices)
            else:
                raise TypeError("{} does not support boolean indexing".format(self.__class__.__name__))

        idx = self._get_indices(indices, axis=axis, indexing=indexing, tol=tol, keepdims=keepdims)

        # special case: broadcast arrays a la numpy
        if broadcast:
            axes = self._getaxes_broadcast(idx)
            values = self._getvalues_broadcast(idx)

        else:
            axes = self._getaxes_ortho(idx)
            values = self._getvalues_ortho(idx)

        if np.isscalar(values):
            return values
        # elif np.ndim(values) == 0:
        #     return np.asscalar(values)


        dima = self._constructor(values, axes) # initialize DimArray
        dima.attrs.update(self.attrs) # add attribute

        return dima

    def _setitem(self, indices, values, axis=0, indexing=None, tol=None, broadcast=None, cast=False, inplace=True):
        """
        See Also
        --------
        DimArray.read, DimArrayOnDisk.write
        """
        if broadcast is None: 
            if self._broadcast is None:
                broadcast = get_option('indexing.broadcast')
            else: 
                broadcast = self._broadcast

        if not inplace:
            self = self.copy()

        # special-case: full-shape boolean indexing (will fail with netCDF4)
        if self._is_boolean_index_nd(indices):
            self._setvalues_bool(indices, values, cast=cast)

        else:
            idx = self._get_indices(indices, tol=tol, indexing=indexing, axis=axis)

            if broadcast:
                self._setvalues_broadcast(idx, values, cast=cast)
            else:
                self._setvalues_ortho(idx, values, cast=cast)

        if not inplace:
            return self

    __getitem__ = _getitem 
    __setitem__ = _setitem

    # def _getitem_1d(self, indices, axis=0, **kwargs):
    #     # by default, call _getitem (could be overloaded for optimization)
    #     return self._getitem({axis:indices}, **kwargs)

    # orthogonal or broadcast indexing?
    def _setvalues_broadcast(self, idx_tuple, values, cast=False):
        raise NotImplementedError()

    def _getvalues_broadcast(self, idx_tuple):
        raise NotImplementedError()

    def _getaxes_broadcast(self, idx_tuple):
        raise NotImplementedError()

    def _setvalues_ortho(self, idx_tuple, values, cast=False):
        raise NotImplementedError()

    def _setvalues_bool(self, mask, values, cast=False):
        raise NotImplementedError("boolean indexing is not implemented")

    def _getvalues_ortho(self, idx_tuple):
        raise NotImplementedError()

    def _is_boolean_index_nd(self, idx):
        " check whether a[a > 2] kind of operation is intended, with a.ndim > 1 "
        return (hasattr(idx, 'dtype') and hasattr(idx, 'ndim')) \
            and idx.dtype.kind == 'b' and idx.ndim > 1

    def copy(self):
        raise NotImplementedError()

class AbstractDataset(AbstractHasAxes):

    def _getitems(self, indices=None, axis=0, indexing=None, tol=None, broadcast=None, keepdims=False):

        # first find the index for the shared axes
        tuple_indices = self._get_indices(indices, axis=axis, tol=tol, keepdims=keepdims, indexing=indexing)

        # then index all arrays, one after the other
        newdata = self.__class__()

        # then apply take in 'position' mode
        axes_dict = {ax.name:ax[ix] for ix, ax in zip(tuple_indices, self.axes) if not np.isscalar(ix)}
        indices_dict = {ax.name:ix for ix, ax in zip(tuple_indices, self.axes)}

        # loop over variables
        for k in self.keys():
            v = self[k]
            # loop over axes to index on
            for axis in kw_indices.keys():
                if np.ndim(v) == 0 or axis not in v.dims: 
                    if raise_error: 
                        raise ValueError("{} does not have dimension {} ==> set raise_error=False to keep this variable unchanged".format(k, axis))
                    else:
                        continue
                # slice along one axis
                v = v.take({axis:kw_indices[axis]}, indexing='position')
            newdata[k] = v

        return newdata
    _repr = repr_dataset
    __str__ = str_dataset

# Add docstrings