import re
import dataclasses as dc
import importlib
import copy
from yankee.util import is_valid, AttrDict, clean_whitespace, unzip_records, import_class
from yankee.data import Row, AttrDict
from .deserializer import Deserializer, DefaultMeta
from .accessor import python_accessor
from yankee.data.collection import ListCollection, Collection
[docs]class Schema(Deserializer):
output_type = dict
__model__ = None
def __init__(
self,
*args,
flatten=False,
prefix=False,
**kwargs
):
self.flatten = flatten
self.prefix = prefix
super().__init__(*args, **kwargs)
self.bind()
[docs] def bind(self, name=None, parent=None, meta=None):
super().bind(name, parent)
# Make sure that fields are grabbed from superclasses as well
self.fields = self.get_fields()
self.bind_fields()
self.get_model()
[docs] def bind_fields(self, meta=None):
for name, field in self.fields.items():
if self.prefix:
field.bind(f"{self.name}_{name}", self, meta)
else:
field.bind(name, self, meta)
[docs] def get_fields(self):
class_fields = list()
for c in reversed(self.__class__.mro()):
class_fields += [
(k, v) for k, v in c.__dict__.items() if isinstance(v, Deserializer)
]
return dict(class_fields)
[docs] def get_model(self):
# Model class is expressly there
if not isinstance(self.__model__, str) and self.__model__ is not None:
return
# Model should be obtained from a string
model_str = getattr(self, "__model_name__", False) or self.__class__.__name__.replace("Schema", "")
*module, _model = model_str.rsplit(".", 1)
module = self.__class__.__module__.replace(".schema", ".model")
try:
self.__model__ = getattr(importlib.import_module(module), _model)
except (ImportError, AttributeError):
self.__model__ = self.make_dataclass()
[docs] def deserialize(self, obj) -> "Dict":
output = AttrDict()
obj = self.accessor(obj)
for key, field in self.fields.items():
value = field.load(obj)
# If there is no value, don't include anything in the output dictionary
if not is_valid(value):
continue
# If the value isn't a dict, or there's not flatten directive, add and continue
if not isinstance(value, dict) or not getattr(field, "flatten", False):
output[field.output_name] = value
continue
# Merge in flattened fields
output.update(value)
return output
[docs] def load_model(self, obj):
return self.__model__(**obj)
[docs] def make_field(self, t):
if t == list:
return dc.field(default_factory=ListCollection)
else:
return dc.field(default=None)
[docs] def make_dataclass(self):
fields = list(
(f.output_name, f.output_type, self.make_field(f.output_type))
for f in self.fields.values()
)
dataclass = dc.make_dataclass(
cls_name=self.__class__.__name__.replace("Schema", ""),
fields=fields,
bases=(Row,)
)
return dataclass
[docs] def load_batch(self, objs):
return ListCollection(self.load(o) for o in objs)
[docs]class PolymorphicSchema(Schema):
[docs] def bind(self, name=None, parent=None, meta=None):
super().bind(self, name)
for schema in self.schemas:
schema.bind(name)
[docs] def choose_schema(self, obj):
raise NotImplementedError("Must be implemented in subclass!")
[docs] def deserialize(self, raw_obj) -> "Dict":
# Get the key one time only, rather than
# on both deserializing the selector obj
# and the final output
obj = super().deserialize(obj)
schema = self.choose_schema(obj)
return schema.deserialize(raw_obj)
[docs]class RegexSchema(Schema):
"""
This schema type allows for using a regex to pull data
out of a string, and then treat it like a schema
"""
__regex__ = None
def __init__(self, *args, **kwargs):
self._regex = re.compile(self.__regex__)
super().__init__(*args, **kwargs)
[docs] def deserialize(self, obj):
obj = self.accessor(obj)
if obj is None:
return dict()
text = clean_whitespace(self.to_string(obj))
match = self._regex.search(text)
if match is None:
return dict()
obj = self.convert_groupdict(match.groupdict())
output = AttrDict()
for key, field in self.fields.items():
value = field.load(obj)
# If there is no value, don't include anything in the output dictionary
if not is_valid(value):
continue
# If the value isn't a dict, or there's not flatten directive, add and continue
if not isinstance(value, dict) or not getattr(field, "flatten", False):
output[field.output_name] = value
continue
# Merge in flattened fields
output.update(value)
return output
[docs] def bind_fields(self, meta=None):
return super().bind_fields(DefaultMeta)
[docs] def convert_groupdict(self, obj):
return obj
[docs] def to_string(self, elem):
return str(elem)
[docs]class ZipSchema(Schema):
"""
This schema type allows fields that produce multiple values to be
zipped together into records.
"""
def __init__(self, *args, **kwargs):
self.list_field = import_class(self.list_field)
super().__init__(*args, **kwargs)
[docs] def bind(self, name=None, parent=None, meta=None):
super().bind(name, parent)
new_fields = dict()
for name, field in self.fields.items():
f_copy = copy.deepcopy(field)
list_field = self.list_field(f_copy, field.data_key)
f_copy.data_key = False
f_copy.make_accessor()
list_field.output_name = field.output_name
new_fields[name] = list_field
self.fields = new_fields
[docs] def deserialize(self, obj) -> "Dict":
result = unzip_records(super().deserialize(obj))
return result
[docs] def load_model(self, obj):
return [self.__model__(**o) for o in obj]