Source code for xarray_validate.dataarray

from __future__ import annotations

from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    List,
    Literal,
    Mapping,
    Optional,
)

import attrs as _attrs
import xarray as xr

from . import _match
from .base import (
    BaseSchema,
    SchemaError,
    ValidationContext,
    ValidationMode,
    ValidationResult,
)
from .components import (
    ArrayTypeSchema,
    AttrsSchema,
    ChunksSchema,
    DimsSchema,
    DTypeSchema,
    NameSchema,
    ShapeSchema,
)


[docs] @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) class CoordsSchema(BaseSchema): r""" Schema container for Coordinates Parameters ---------- coords : dict Dict of coordinate keys and ``DataArraySchema`` objects. Keys can be either exact coordinate names or patterns: - Exact match: ``'time'`` matches only 'time' - Glob pattern: ``'x_*'`` matches x_0, x_1, x_foo, etc. - Regex pattern: ``'{x_\\d+}'`` matches x_0, x_1, but not x_foo require_all_keys : bool, default: True Whether to require to all coordinates included in ``coords``. Only applies to exact keys, not pattern keys. allow_extra_keys : bool, default: True Whether to allow coordinates not included in ``coords`` dict. Coordinates matching pattern keys are not considered "extra". """ coords: Dict[str, DataArraySchema] = _attrs.field() require_all_keys: bool = _attrs.field(default=True) allow_extra_keys: bool = _attrs.field(default=True)
[docs] def serialize(self) -> dict: obj = { "require_all_keys": self.require_all_keys, "allow_extra_keys": self.allow_extra_keys, "coords": {k: v.serialize() for k, v in self.coords.items()}, } return obj
[docs] @classmethod def deserialize(cls, obj: dict): if "coords" in obj: coords = obj["coords"] kwargs = {k: v for k, v in obj.items() if k != "coords"} else: coords = obj kwargs = {} coords = {k: DataArraySchema.convert(v) for k, v in list(coords.items())} return cls(coords=coords, **kwargs)
[docs] def validate( self, coords: Mapping[str, Any], context: ValidationContext | None = None ) -> None: # Inherit docstring # Separate exact keys from pattern keys and compile patterns exact_keys, pattern_keys, compiled_patterns = _match.separate_keys(self.coords) if self.require_all_keys: # Only check exact keys for require_all_keys missing_keys = set(exact_keys) - set(coords) if missing_keys: error = SchemaError(f"coords has missing keys: {missing_keys}") if context: context.handle_error(error) else: raise error if not self.allow_extra_keys: # Check that all coordinates match either exact or pattern keys matched_coords = _match.find_matched_keys( coords, exact_keys, compiled_patterns ) extra_keys = set(coords) - matched_coords if extra_keys: error = SchemaError(f"coords has extra keys: {extra_keys}") if context: context.handle_error(error) else: raise error # Validate coordinates matching exact keys for key, da_schema in exact_keys.items(): if key not in coords: error = SchemaError(f"key {key} not in coords") if context: context.handle_error(error) else: raise error else: child_context = context.push(f"coords.{key}") if context else None da_schema.validate(coords[key], child_context) # Validate coordinates matching pattern keys for pattern_key, da_schema in pattern_keys.items(): regex = compiled_patterns[pattern_key] for coord_name in coords: if regex.fullmatch(coord_name) and coord_name not in exact_keys: child_context = ( context.push(f"coords.{coord_name}") if context else None ) da_schema.validate(coords[coord_name], child_context)
[docs] @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) class DataArraySchema(BaseSchema): """ A lightweight xarray.DataArray validator. Parameters ---------- dtype : DTypeLike or str or DTypeSchema, optional Data type validation schema. If a string is specified, it must be a valid NumPy data type value. shape : ShapeT or tuple or ShapeSchema, optional Shape validation schema. dims : DimsT or list of str or DimsSchema, optional Dimensions validation schema. name : str, optional Name validation schema. coords : CoordsSchema, optional Coordinates validation schema. chunks : bool or dict or ChunksSchema, optional If bool, specifies whether the DataArray is chunked or not, agnostic to chunk sizes. If dict, includes the expected chunks for the DataArray. attrs : AttrsSchema, optional Attributes validation schema. array_type : type, optional Type of the underlying data in a DataArray (*e.g.* :class:`numpy.ndarray`). checks : list of callables, optional List of callables that will further validate the DataArray. """ _schema_slots: ClassVar = [ "dtype", "dims", "shape", "coords", "name", "chunks", "attrs", "array_type", ] dtype: Optional[DTypeSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(DTypeSchema.convert), ) shape: Optional[ShapeSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(ShapeSchema.convert), ) dims: Optional[DimsSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(DimsSchema.convert), ) name: Optional[NameSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(NameSchema.convert), ) coords: Optional[CoordsSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(CoordsSchema.convert), ) chunks: Optional[ChunksSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(ChunksSchema.convert), ) attrs: Optional[AttrsSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(AttrsSchema.convert), ) array_type: Optional[ArrayTypeSchema] = _attrs.field( default=None, converter=_attrs.converters.optional(ArrayTypeSchema.convert), ) checks: List[Callable] = _attrs.field( factory=list, validator=_attrs.validators.deep_iterable(_attrs.validators.is_callable()), )
[docs] def serialize(self) -> dict: obj = {} for slot in self._schema_slots: try: obj[slot] = getattr(self, slot).serialize() except AttributeError: pass return obj
[docs] @classmethod def deserialize(cls, obj: dict): kwargs = {} if "dtype" in obj: kwargs["dtype"] = DTypeSchema.convert(obj["dtype"]) if "shape" in obj: kwargs["shape"] = ShapeSchema.convert(obj["shape"]) if "dims" in obj: kwargs["dims"] = DimsSchema.convert(obj["dims"]) if "name" in obj: kwargs["name"] = NameSchema.convert(obj["name"]) if "coords" in obj: kwargs["coords"] = CoordsSchema.convert(obj["coords"]) if "chunks" in obj: kwargs["chunks"] = ChunksSchema.convert(obj["chunks"]) if "array_type" in obj: kwargs["array_type"] = ArrayTypeSchema.convert(obj["array_type"]) if "attrs" in obj: kwargs["attrs"] = AttrsSchema.convert(obj["attrs"]) return cls(**kwargs)
@classmethod def from_dataarray(cls, value: xr.DataArray): da_schema = value.to_dict(data=False) da_schema["coords"] = {"coords": da_schema["coords"]} da_schema["attrs"] = {"attrs": da_schema["attrs"]} return cls.deserialize(da_schema)
[docs] def validate( self, da: xr.DataArray, context: ValidationContext | None = None, mode: Literal["eager", "lazy"] | None = None, ) -> ValidationResult | None: """ Validate an xarray.DataArray against this schema. Parameters ---------- da : DataArray DataArray to validate. context : ValidationContext, optional Validation context for tracking tree traversal state. mode : {"eager", "lazy"}, optional Validation mode. If unset, the global default mode (eager) is used. Returns ------- ValidationResult or None In eager mode, this method returns ``None``. In lazy mode, it returns a :class:`ValidationResult` object. """ if mode is None: mode = "eager" if context is None: context = ValidationContext(mode=mode) if not isinstance(da, xr.DataArray): raise ValueError("Input must be an xarray.DataArray") if context is None: context = ValidationContext() if self.dtype is not None: dtype_context = context.push("dtype") self.dtype.validate(da.dtype, dtype_context) if self.name is not None: name_context = context.push("name") self.name.validate(da.name, name_context) if self.dims is not None: dims_context = context.push("dims") self.dims.validate(da.dims, dims_context) if self.shape is not None: shape_context = context.push("shape") self.shape.validate(da.shape, shape_context) if self.coords is not None: coords_context = context.push("coords") self.coords.validate(da.coords, coords_context) if self.chunks is not None: chunks_context = context.push("chunks") self.chunks.validate(da.chunks, da.dims, da.shape, chunks_context) if self.attrs: attrs_context = context.push("attrs") self.attrs.validate(da.attrs, attrs_context) if self.array_type is not None: array_type_context = context.push("array_type") self.array_type.validate(da.data, array_type_context) for check in self.checks: check(da) return None if context.mode is ValidationMode.EAGER else context.result