在 api 请求的的参数校验中,使用 marshmallow 非常方便
如果要针对大量json数据做检验,推荐使用 python-jsonschema
1. 定义Schema
# 定义模型
import datetime as dt
class User:
def __init__(self, name, email):
self.name = name
self.email = email
self.created_at = dt.datetime.now()
def __repr__(self):
return "".format(self=self)
# 定义Schema
class UserSchema(Schema):
name = fields.String()
email = fields.Email()
create_at = fields.DateTime()
2. 使用动态Schema
from marshmallow import Schema, fields
UserSchema = Schema.from_dict(
{"name": fields.Str(), "email": fields.Email(), "created_at": fields.DateTime()}
)
3. dump(obj -> dict)
from pprint import pprint
user = User(name="Monty", email="monty@python.org")
schema = UserSchema()
result = schema.dump(user)
pprint(result)
#{'create_at': '2024-06-16T21:37:04.711994',
# 'email': 'monty@python.org',
# 'name': 'Monty'}
print(result["name"]) # Monty
4. dumps(obj -> str)
json_result = schema.dumps(user)
pprint(json_result)
# '{"name": "Monty", "email": "monty@python.org", "created_at": "2014-08-17T14:54:16.049594+00:00"}'*
输出过滤
only
summary_schema = UserSchema(only=("name", "email"))
summary_schema.dump(user)
# {"name": "Monty", "email": "monty@python.org"}
exculde
summary_schema = UserSchema(exclude=("email",))
s = summary_schema.dump(user)
print(s)
# {'name': 'Monty', 'create_at': '2024-06-16T21:43:28.035327'}
5. load
Deserializing - > dict (校验)
fields.DateTime()
对象会被序列化为python日期对象
user_data = {
"created_at": "2014-08-11T05:26:03.869245",
"email": "ken@yahoo.com",
"name": "Ken",
}
schema = UserSchema()
result = schema.load(user_data)
pprint(result)
#{'created_at': datetime.datetime(2014, 8, 11, 5, 26, 3, 869245),
# 'email': 'ken@yahoo.com',
# 'name': 'Ken'}
print(type(result))
Deserializing -> obj
为了将数据反序列化为对象,请为模式定义一个方法,并用 post_load 对其进行装饰。该方法接收反序列化数据的字典。
from marshmallow import Schema, fields, post_load
class UserSchema(Schema):
name = fields.Str()
email = fields.Email()
created_at = fields.DateTime()
@post_load
def make_user(self, data, **kwargs):
return User(**data)
使用load方法返回一个User实例
user_data = {
"created_at": "2014-08-11T05:26:03.869245",
"email": "ken@yahoo.com",
"name": "Ken",
}
schema = UserSchema()
result = schema.load(user_data)
pprint(result) #
print(result.email) # ken@yahoo.com
6. 处理集合对象
当处理可迭代的集合对象时,设置`many=True``
dump
user1 = User(name="Mick", email="mick@tones.com")
user2 = User(name="Keith", email="keith@stones.com")
users = [user1, user2]
schema = UserSchema(many=True)
result = schema.dump(users) # OR UserSchema().dump(users, many=True)
pprint(result)
#[{'created_at': '2024-06-18T21:29:14.942278',
# 'email': 'mick@tones.com',
# 'name': 'Mick'},
# {'created_at': '2024-06-18T21:29:14.942287',
# 'email': 'keith@stones.com',
# 'name': 'Keith'}]
load
s = [{'created_at': '2024-06-18T21:29:14.942278',
'email': 'mick@tones.com',
'name': 'Mick'},
{'created_at': '2024-06-18T21:29:14.942287',
'email': 'keith@stones.com',
'name': 'Keith'}]
schema = UserSchema(many=True)
result = schema.load(s) # UserSchema().load(s, many=True)
print(result)
7. Validation校验
单个校验
try:
result = UserSchema().load({"name": "John", "email": "foo"})
except ValidationError as err:
print(err.messages) # => {"email": ['"foo" is not a valid email address.']}
print(err.valid_data) # => {"name": "John"}
校验集合
from marshmallow import Schema, fields, post_load, ValidationError
class UserSchema(Schema):
name = fields.String(required=True)
email = fields.Email()
user_data = [
{"email": "mick@stones.com", "name": "Mick"},
{"email": "invalid", "name": "Invalid"}, # invalid email
{"email": "keith@stones.com", "name": "Keith"},
{"email": "charlie@stones.com"}, # missing "name"
]
try:
UserSchema(many=True).load(user_data)
except ValidationError as err:
pprint(err.messages)
# {1: {'email': ['Not a valid email address.']},
# 3: {'name': ['Missing data for required field.']}}
8. 添加校验参数
from pprint import pprint
from marshmallow import Schema, fields, validate, ValidationError
class UserSchema(Schema):
name = fields.Str(validate=validate.Length(min=1))
permission = fields.Str(validate=validate.OneOf(["read", "write", "admin"]))
age = fields.Int(validate=validate.Range(min=18, max=40))
in_data = {"name": "", "permission": "invalid", "age": 71}
try:
UserSchema().load(in_data)
except ValidationError as err:
pprint(err.messages)
# {'age': ['Must be greater than or equal to 18 and less than or equal to 40.'],
# 'name': ['Shorter than minimum length 1.'],
# 'permission': ['Must be one of: read, write, admin.']}
9. 自定义校验规则
校验函数接收一个单一的参数,校验失败应该raise一个ValidationError或者返回False
使用函数
from marshmallow import Schema, fields, ValidationError
def validate_quantity(n):
if n < 0:
raise ValidationError("Quantity must be greater than 0.")
if 30 < n <= 60:
raise ValidationError("Quantity must not be greater than 30.")
if n > 60:
return False
class ItemSchema(Schema):
quantity = fields.Integer(validate=validate_quantity)
in_data = {"quantity": 50}
try:
result = ItemSchema().load(in_data)
except ValidationError as err:
print(err.messages)
也可以给定一个校验函数的集合(list, tuple, generator)
def validate_quantity_2(n):
if n%2 ==0:
raise ValidationError("The quantity must not be even number")
class ItemSchema(Schema):
quantity = fields.Integer(validate=[validate_quantity, validate_quantity_2])
# {'quantity': ['Quantity must not be greater than 30.', 'The quantity must not be even number']}
使用校验装饰器
class UserSchema(Schema):
email = fields.Str(required=True)
age = fields.Integer(required=True)
@post_load
def lowerstrip_email(self, item, many, **kwargs):
item["email"] = item["email"].lower().strip()
return item
@pre_load(pass_many=True)
def remove_envelope(self, data, many, **kwargs):
namespace = "results" if many else "result"
return data[namespace]
@post_dump(pass_many=True)
def add_envelope(self, data, many, **kwargs):
namespace = "results" if many else "result"
return {namespace: data}
@validates_schema
def validate_email(self, data, **kwargs):
if len(data["email"]) < 3:
raise ValidationError("Email must be more than 3 characters", "email")
@validates("age")
def validate_age(self, data, **kwargs):
if data < 14:
raise ValidationError("Too young!")
注意
- 验证在反序列化时进行,但不在序列化时进行。为提高序列化性能,传递给 Schema.dump() 的数据被视为有效数据。
10. 全局校验
有时候我们需要使用传入的数据做一次整体的校验,这时候需要使用到装饰器 validates_schema
class CreateInstanceSchema(Ec2InstanceSchema, AccountIdSchema):
count = fields.Int(required=True, validate=validate.Range(min=1, max=10))
associate_public_ip = fields.Bool(required=False)
public_ip = fields.Str(required=False)
private_ip = fields.Str(required=False)
@validates_schema
def mutex(self, data, **kwargs):
count = data.get("count")
private_ip = data.get("private_ip")
public_ip = data.get("public_ip")
associate_public_ip = data.get("associate_public_ip")
if count != 1 and (private_ip or public_ip):
raise ValidationError(
"You cannot provide private_ip or public_ip for multiple machines at the same time"
)
if not associate_public_ip and public_ip:
raise ValidationError(
"You cannot provide associate_public_ip=False and public_ip at the same time"
)
11. 动态字段
一般动态字段
from marshmallow import Schema, fields
PersonSchema = Schema.from_dict({"name": fields.Str()})
print(PersonSchema().load({"name": "David"})) # => {'name': 'David'}
key是动态字段
import yaml
from marshmallow import Schema, fields, ValidationError
class PortSchema(Schema):
firstPorts = fields.Integer()
secondPorts = fields.Integer()
class RegionSchema(Schema):
@classmethod
def from_regions(cls, regions):
schema_dict = {}
for region in regions:
schema_dict[region] = fields.Nested(PortSchema())
new_class = cls.from_dict(schema_dict)
# 将新类的属性添加到现有类上
for attr_name, attr_value in new_class.__dict__.items():
setattr(cls, attr_name, attr_value)
return new_class
def validate_yaml(yaml_content):
try:
data = yaml.safe_load(yaml_content)
regions = list(data['ports'].keys())
new_schema = RegionSchema.from_regions(regions)
print("new_schema", new_schema.__dict__)
print("RegionSchema", RegionSchema.__dict__)
class RootSchema(Schema):
ports = fields.Nested(RegionSchema())
root_schema = RootSchema()
result = root_schema.load(data)
return result
except ValidationError as err:
print(f"验证错误: {err.messages}")
raise err
except yaml.YAMLError as err:
print(f"YAML解析错误: {err}")
raise err
if __name__ == "__main__":
yaml_data = """
ports:
us-west-2:
firstPorts: 2
secondPorts: 3
us-east-2:
firstPorts: 1
secondPorts: 2
"""
s = validate_yaml(yaml_data)
print(s)
...