Source code for xarray_validate.dataset

from __future__ import annotations

from typing import Callable, Dict, Iterable, Literal, Optional, Union

import attrs as _attrs
import xarray as xr

from . import _match
from .base import (
    BaseSchema,
    SchemaError,
    ValidationContext,
    ValidationMode,
    ValidationResult,
)
from .components import AttrsSchema
from .dataarray import CoordsSchema, DataArraySchema


[docs] @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) class DatasetSchema(BaseSchema): r""" A lightweight xarray.Dataset validator. Parameters ---------- data_vars : dict, optional Per-variable :class:`.DataArraySchema`\ s. Keys can be either exact variable names or patterns: - Exact match: ``'temperature'`` matches only 'temperature' - Glob pattern: ``'x_*'`` matches x_0, x_1, x_foo, etc. - Regex pattern: ``'{x_\\d+}'`` matches x_0, x_1, but not x_foo require_all_keys : bool, default: True Whether to require all data variables included in ``data_vars``. Only applies to exact keys, not pattern keys. allow_extra_keys : bool, default: True Whether to allow data variables not included in ``data_vars`` dict. Variables matching pattern keys are not considered "extra". coords : CoordsSchema, optional Coordinate validation schema. attrs : AttrsSchema, optional Attributes value validation schema. checks : list of callables, optional List of callables that will further validate the Dataset. """ data_vars: Optional[Dict[str, Optional[DataArraySchema]]] = _attrs.field( default=None, converter=_attrs.converters.optional( lambda x: { k: v if isinstance(v, DataArraySchema) else DataArraySchema(**v) for k, v in x.items() } ), ) require_all_keys: bool = _attrs.field(default=True) allow_extra_keys: bool = _attrs.field(default=True) coords: Union[CoordsSchema, None] = _attrs.field( default=None, converter=_attrs.converters.optional(CoordsSchema.convert) ) attrs: Union[AttrsSchema, None] = _attrs.field( default=None, converter=_attrs.converters.optional( lambda x: x if isinstance(x, AttrsSchema) else AttrsSchema(x) ), ) checks: Iterable[Callable] = _attrs.field( factory=list, validator=_attrs.validators.deep_iterable(_attrs.validators.is_callable()), )
[docs] def serialize(self): obj = { "require_all_keys": self.require_all_keys, "allow_extra_keys": self.allow_extra_keys, "data_vars": {}, "attrs": self.attrs.serialize() if self.attrs is not None else {}, } if self.data_vars: for key, var in self.data_vars.items(): obj["data_vars"][key] = var.serialize() if self.coords: obj["coords"] = self.coords.serialize() return obj
[docs] @classmethod def deserialize(cls, obj: dict): kwargs = {} if "require_all_keys" in obj: kwargs["require_all_keys"] = obj["require_all_keys"] if "allow_extra_keys" in obj: kwargs["allow_extra_keys"] = obj["allow_extra_keys"] if "data_vars" in obj: kwargs["data_vars"] = { k: DataArraySchema.convert(v) for k, v in obj["data_vars"].items() } if "coords" in obj: kwargs["coords"] = CoordsSchema.convert(obj["coords"]) if "attrs" in obj: kwargs["attrs"] = AttrsSchema.convert(obj["attrs"]) return cls(**kwargs)
@classmethod def from_dataset(cls, value): ds_schema = value.to_dict(data=False) return cls.deserialize(ds_schema)
[docs] def validate( self, ds: xr.Dataset, context: ValidationContext | None = None, mode: Literal["eager", "lazy"] | None = None, ) -> ValidationResult | None: """ Validate an xarray.DataArray against this schema. Parameters ---------- ds : Dataset Dataset to validate. context : ValidationContext, optional Validation context for tracking tree traversal state. mode : {"eager", "lazy"}, optional Validation mode. If unset, the global default mode (eager) is used. Returns ------- ValidationResult or None In eager mode, this method returns ``None``. In lazy mode, it returns a :class:`ValidationResult` object. """ if mode is None: mode = "eager" if context is None: context = ValidationContext(mode=mode) if self.data_vars is not None: # Separate exact keys from pattern keys and compile patterns exact_keys, pattern_keys, compiled_patterns = _match.separate_keys( self.data_vars ) if self.require_all_keys: # Only check exact keys for require_all_keys missing_keys = set(exact_keys.keys()) - set(ds.data_vars.keys()) if missing_keys: error = SchemaError(f"data_vars has missing keys: {missing_keys}") if context: context.handle_error(error) else: raise error if not self.allow_extra_keys: # Check that all dataset variables match either exact or pattern keys matched_vars = _match.find_matched_keys( ds.data_vars, exact_keys, compiled_patterns ) extra_keys = set(ds.data_vars.keys()) - matched_vars if extra_keys: error = SchemaError(f"data_vars has extra keys: {extra_keys}") if context: context.handle_error(error) else: raise error # Validate variables matching exact keys for key, da_schema in exact_keys.items(): if da_schema is not None and key in ds.data_vars: data_var_context = context.push(f"data_vars.{key}") da_schema.validate(ds.data_vars[key], data_var_context) # Validate variables matching pattern keys for pattern_key, da_schema in pattern_keys.items(): if da_schema is None: continue regex = compiled_patterns[pattern_key] for var_name in ds.data_vars.keys(): if regex.fullmatch(var_name) and var_name not in exact_keys: data_var_context = context.push(f"data_vars.{var_name}") da_schema.validate(ds.data_vars[var_name], data_var_context) if self.coords is not None: # pragma: no cover coords_context = context.push("coords") self.coords.validate(ds.coords, coords_context) if self.attrs: attrs_context = context.push("attrs") self.attrs.validate(ds.attrs, attrs_context) if self.checks: for check in self.checks: check(ds) return None if context.mode is ValidationMode.EAGER else context.result