在 api 请求的的参数校验中,使用 marshmallow 非常方便

如果要针对大量json数据做检验,推荐使用 python-jsonschema

1. 定义Schema

  # 定义模型
import datetime as dt


class User:
    def __init__(self, name, email):
        self.name = name
        self.email = email
        self.created_at = dt.datetime.now()

    def __repr__(self):
        return "".format(self=self)

# 定义Schema
class UserSchema(Schema):
    name = fields.String()
    email = fields.Email()
    create_at = fields.DateTime()

2. 使用动态Schema

from marshmallow import Schema, fields

UserSchema = Schema.from_dict(
  {"name": fields.Str(), "email": fields.Email(), "created_at": fields.DateTime()}
)

3. dump(obj -> dict)

from pprint import pprint
    user = User(name="Monty", email="monty@python.org")
    schema = UserSchema()
    result = schema.dump(user)
    pprint(result)
    #{'create_at': '2024-06-16T21:37:04.711994',
    # 'email': 'monty@python.org',
    # 'name': 'Monty'}

    print(result["name"]) # Monty

4. dumps(obj -> str)

json_result = schema.dumps(user)
    pprint(json_result)
    # '{"name": "Monty", "email": "monty@python.org", "created_at": "2014-08-17T14:54:16.049594+00:00"}'*

输出过滤

only

summary_schema = UserSchema(only=("name", "email"))
    summary_schema.dump(user)
    # {"name": "Monty", "email": "monty@python.org"}

exculde

summary_schema = UserSchema(exclude=("email",))
    s = summary_schema.dump(user)
    print(s)
    # {'name': 'Monty', 'create_at': '2024-06-16T21:43:28.035327'}

5. load

Deserializing - > dict (校验)

fields.DateTime()对象会被序列化为python日期对象

user_data = {
        "created_at": "2014-08-11T05:26:03.869245",
        "email": "ken@yahoo.com",
        "name": "Ken",
}
schema = UserSchema()
result = schema.load(user_data)
pprint(result)
#{'created_at': datetime.datetime(2014, 8, 11, 5, 26, 3, 869245),
# 'email': 'ken@yahoo.com',
# 'name': 'Ken'}

print(type(result))

Deserializing -> obj

为了将数据反序列化为对象,请为模式定义一个方法,并用 post_load 对其进行装饰。该方法接收反序列化数据的字典。

from marshmallow import Schema, fields, post_load


class UserSchema(Schema):
    name = fields.Str()
    email = fields.Email()
    created_at = fields.DateTime()

    @post_load
    def make_user(self, data, **kwargs):
        return User(**data)

使用load方法返回一个User实例

  user_data = {
      "created_at": "2014-08-11T05:26:03.869245",
      "email": "ken@yahoo.com",
      "name": "Ken",
  }
  schema = UserSchema()
  result = schema.load(user_data)
  pprint(result)  #
  print(result.email) # ken@yahoo.com

6. 处理集合对象

当处理可迭代的集合对象时,设置`many=True``

dump

user1 = User(name="Mick", email="mick@tones.com")
user2 = User(name="Keith", email="keith@stones.com")
users = [user1, user2]
schema = UserSchema(many=True) 
result = schema.dump(users)  # OR UserSchema().dump(users, many=True)
pprint(result)

#[{'created_at': '2024-06-18T21:29:14.942278',
#  'email': 'mick@tones.com',
#  'name': 'Mick'},
# {'created_at': '2024-06-18T21:29:14.942287',
#  'email': 'keith@stones.com',
#  'name': 'Keith'}]

load

s = [{'created_at': '2024-06-18T21:29:14.942278',
  'email': 'mick@tones.com',
  'name': 'Mick'},
 	{'created_at': '2024-06-18T21:29:14.942287',
  'email': 'keith@stones.com',
  'name': 'Keith'}]

schema = UserSchema(many=True)
result = schema.load(s)   # UserSchema().load(s, many=True)
print(result)

7. Validation校验

单个校验

try:
    result = UserSchema().load({"name": "John", "email": "foo"})
except ValidationError as err:
    print(err.messages)  # => {"email": ['"foo" is not a valid email address.']}
    print(err.valid_data)  # => {"name": "John"}

校验集合

from marshmallow import Schema, fields, post_load, ValidationError

class UserSchema(Schema):
    name = fields.String(required=True)
    email = fields.Email()

user_data = [
    {"email": "mick@stones.com", "name": "Mick"},
    {"email": "invalid", "name": "Invalid"},  # invalid email
    {"email": "keith@stones.com", "name": "Keith"},
    {"email": "charlie@stones.com"},  # missing "name"
]

try:
    UserSchema(many=True).load(user_data)
except ValidationError as err:
    pprint(err.messages)
    # {1: {'email': ['Not a valid email address.']},
    #  3: {'name': ['Missing data for required field.']}}

8. 添加校验参数

from pprint import pprint
from marshmallow import Schema, fields, validate, ValidationError

class UserSchema(Schema):
    name = fields.Str(validate=validate.Length(min=1))
    permission = fields.Str(validate=validate.OneOf(["read", "write", "admin"]))
    age = fields.Int(validate=validate.Range(min=18, max=40))

in_data = {"name": "", "permission": "invalid", "age": 71}
try:
    UserSchema().load(in_data)
except ValidationError as err:
    pprint(err.messages)
    # {'age': ['Must be greater than or equal to 18 and less than or equal to 40.'],
    #  'name': ['Shorter than minimum length 1.'],
    #  'permission': ['Must be one of: read, write, admin.']}

9. 自定义校验规则

校验函数接收一个单一的参数,校验失败应该raise一个ValidationError或者返回False

使用函数

from marshmallow import Schema, fields, ValidationError
  			  
def validate_quantity(n):
    if n < 0:
        raise ValidationError("Quantity must be greater than 0.")
    if  30 < n <= 60:
        raise ValidationError("Quantity must not be greater than 30.")
    if n > 60:
        return False

class ItemSchema(Schema):
    quantity = fields.Integer(validate=validate_quantity)

in_data = {"quantity": 50}
try:
    result = ItemSchema().load(in_data)
except ValidationError as err:
    print(err.messages)

也可以给定一个校验函数的集合(list, tuple, generator)

def validate_quantity_2(n):
    if n%2 ==0:
        raise ValidationError("The quantity must not be even number")

class ItemSchema(Schema):
    quantity = fields.Integer(validate=[validate_quantity, validate_quantity_2])

# {'quantity': ['Quantity must not be greater than 30.', 'The quantity must not be even number']}
  			  

使用校验装饰器

class UserSchema(Schema):
    email = fields.Str(required=True)
    age = fields.Integer(required=True)

    @post_load
    def lowerstrip_email(self, item, many, **kwargs):
        item["email"] = item["email"].lower().strip()
        return item

    @pre_load(pass_many=True)
    def remove_envelope(self, data, many, **kwargs):
        namespace = "results" if many else "result"
        return data[namespace]

    @post_dump(pass_many=True)
    def add_envelope(self, data, many, **kwargs):
        namespace = "results" if many else "result"
        return {namespace: data}

    @validates_schema
    def validate_email(self, data, **kwargs):
        if len(data["email"]) < 3:
            raise ValidationError("Email must be more than 3 characters", "email")

    @validates("age")
    def validate_age(self, data, **kwargs):
        if data < 14:
            raise ValidationError("Too young!")

注意

  • 验证在反序列化时进行,但不在序列化时进行。为提高序列化性能,传递给 Schema.dump() 的数据被视为有效数据。

10. 全局校验

有时候我们需要使用传入的数据做一次整体的校验,这时候需要使用到装饰器 validates_schema

class CreateInstanceSchema(Ec2InstanceSchema, AccountIdSchema):
    count = fields.Int(required=True, validate=validate.Range(min=1, max=10))
    associate_public_ip = fields.Bool(required=False)
    public_ip = fields.Str(required=False)
    private_ip = fields.Str(required=False)

    @validates_schema
    def mutex(self, data, **kwargs):
        count = data.get("count")
        private_ip = data.get("private_ip")
        public_ip = data.get("public_ip")
        associate_public_ip = data.get("associate_public_ip")
        if count != 1 and (private_ip or public_ip):
            raise ValidationError(
                "You cannot provide private_ip or public_ip for multiple machines at the same time"
            )
        if not associate_public_ip and public_ip:
            raise ValidationError(
                "You cannot provide associate_public_ip=False and public_ip at the same time"
            )

11. 动态字段

一般动态字段

from marshmallow import Schema, fields

PersonSchema = Schema.from_dict({"name": fields.Str()})
print(PersonSchema().load({"name": "David"}))  # => {'name': 'David'}

key是动态字段

import yaml
from marshmallow import Schema, fields, ValidationError

class PortSchema(Schema):
    firstPorts = fields.Integer()
    secondPorts = fields.Integer()

class RegionSchema(Schema):
    @classmethod
    def from_regions(cls, regions):
        schema_dict = {}
        for region in regions:
            schema_dict[region] = fields.Nested(PortSchema())

        new_class = cls.from_dict(schema_dict)

        # 将新类的属性添加到现有类上
        for attr_name, attr_value in new_class.__dict__.items():
            setattr(cls, attr_name, attr_value)

        return new_class

def validate_yaml(yaml_content):
    try:
        data = yaml.safe_load(yaml_content)
        regions = list(data['ports'].keys())
        new_schema = RegionSchema.from_regions(regions)
        
        print("new_schema", new_schema.__dict__)
        print("RegionSchema", RegionSchema.__dict__)

        class RootSchema(Schema):
            ports = fields.Nested(RegionSchema())

        root_schema = RootSchema()
        result = root_schema.load(data)
        return result
    except ValidationError as err:
        print(f"验证错误: {err.messages}")
        raise err
    except yaml.YAMLError as err:
        print(f"YAML解析错误: {err}")
        raise err


if __name__ == "__main__":
    yaml_data = """
ports:
  us-west-2:
    firstPorts: 2
    secondPorts: 3
  us-east-2:
    firstPorts: 1
    secondPorts: 2
    """
    s = validate_yaml(yaml_data)
    print(s)