在 api 请求的的参数校验中,使用 marshmallow 非常方便

如果要针对大量json数据做检验,推荐使用 python-jsonschema

1. 定义Schema

  # 定义模型
import datetime as dt


class User:
    def __init__(self, name, email):
        self.name = name
        self.email = email
        self.created_at = dt.datetime.now()

    def __repr__(self):
        return "".format(self=self)

# 定义Schema
class UserSchema(Schema):
    name = fields.String()
    email = fields.Email()
    create_at = fields.DateTime()

2. 使用动态Schema

from marshmallow import Schema, fields

UserSchema = Schema.from_dict(
  {"name": fields.Str(), "email": fields.Email(), "created_at": fields.DateTime()}
)

3. dump(obj -> dict)

from pprint import pprint
    user = User(name="Monty", email="monty@python.org")
    schema = UserSchema()
    result = schema.dump(user)
    pprint(result)
    #{'create_at': '2024-06-16T21:37:04.711994',
    # 'email': 'monty@python.org',
    # 'name': 'Monty'}

    print(result["name"]) # Monty

4. dumps(obj -> str)

json_result = schema.dumps(user)
    pprint(json_result)
    # '{"name": "Monty", "email": "monty@python.org", "created_at": "2014-08-17T14:54:16.049594+00:00"}'*

输出过滤

only

summary_schema = UserSchema(only=("name", "email"))
    summary_schema.dump(user)
    # {"name": "Monty", "email": "monty@python.org"}

exculde

summary_schema = UserSchema(exclude=("email",))
    s = summary_schema.dump(user)
    print(s)
    # {'name': 'Monty', 'create_at': '2024-06-16T21:43:28.035327'}

5. load

Deserializing - > dict (校验)

fields.DateTime()对象会被序列化为python日期对象

user_data = {
        "created_at": "2014-08-11T05:26:03.869245",
        "email": "ken@yahoo.com",
        "name": "Ken",
}
schema = UserSchema()
result = schema.load(user_data)
pprint(result)
#{'created_at': datetime.datetime(2014, 8, 11, 5, 26, 3, 869245),
# 'email': 'ken@yahoo.com',
# 'name': 'Ken'}

print(type(result))

Deserializing -> obj

为了将数据反序列化为对象,请为模式定义一个方法,并用 post_load 对其进行装饰。该方法接收反序列化数据的字典。

from marshmallow import Schema, fields, post_load


class UserSchema(Schema):
    name = fields.Str()
    email = fields.Email()
    created_at = fields.DateTime()

    @post_load
    def make_user(self, data, **kwargs):
        return User(**data)

使用load方法返回一个User实例

  user_data = {
      "created_at": "2014-08-11T05:26:03.869245",
      "email": "ken@yahoo.com",
      "name": "Ken",
  }
  schema = UserSchema()
  result = schema.load(user_data)
  pprint(result)  #
  print(result.email) # ken@yahoo.com

6. 处理集合对象

当处理可迭代的集合对象时,设置`many=True``

dump

user1 = User(name="Mick", email="mick@tones.com")
user2 = User(name="Keith", email="keith@stones.com")
users = [user1, user2]
schema = UserSchema(many=True) 
result = schema.dump(users)  # OR UserSchema().dump(users, many=True)
pprint(result)

#[{'created_at': '2024-06-18T21:29:14.942278',
#  'email': 'mick@tones.com',
#  'name': 'Mick'},
# {'created_at': '2024-06-18T21:29:14.942287',
#  'email': 'keith@stones.com',
#  'name': 'Keith'}]

load

s = [{'created_at': '2024-06-18T21:29:14.942278',
  'email': 'mick@tones.com',
  'name': 'Mick'},
 	{'created_at': '2024-06-18T21:29:14.942287',
  'email': 'keith@stones.com',
  'name': 'Keith'}]

schema = UserSchema(many=True)
result = schema.load(s)   # UserSchema().load(s, many=True)
print(result)

7. Validation校验

单个校验

try:
    result = UserSchema().load({"name": "John", "email": "foo"})
except ValidationError as err:
    print(err.messages)  # => {"email": ['"foo" is not a valid email address.']}
    print(err.valid_data)  # => {"name": "John"}

校验集合

from marshmallow import Schema, fields, post_load, ValidationError

class UserSchema(Schema):
    name = fields.String(required=True)
    email = fields.Email()

user_data = [
    {"email": "mick@stones.com", "name": "Mick"},
    {"email": "invalid", "name": "Invalid"},  # invalid email
    {"email": "keith@stones.com", "name": "Keith"},
    {"email": "charlie@stones.com"},  # missing "name"
]

try:
    UserSchema(many=True).load(user_data)
except ValidationError as err:
    pprint(err.messages)
    # {1: {'email': ['Not a valid email address.']},
    #  3: {'name': ['Missing data for required field.']}}

8. 添加校验参数

from pprint import pprint
from marshmallow import Schema, fields, validate, ValidationError

class UserSchema(Schema):
    name = fields.Str(validate=validate.Length(min=1))
    permission = fields.Str(validate=validate.OneOf(["read", "write", "admin"]))
    age = fields.Int(validate=validate.Range(min=18, max=40))

in_data = {"name": "", "permission": "invalid", "age": 71}
try:
    UserSchema().load(in_data)
except ValidationError as err:
    pprint(err.messages)
    # {'age': ['Must be greater than or equal to 18 and less than or equal to 40.'],
    #  'name': ['Shorter than minimum length 1.'],
    #  'permission': ['Must be one of: read, write, admin.']}

9. 自定义校验规则

校验函数接收一个单一的参数,校验失败应该raise一个ValidationError或者返回False

from marshmallow import Schema, fields, ValidationError
  			  
def validate_quantity(n):
    if n < 0:
        raise ValidationError("Quantity must be greater than 0.")
    if  30 < n <= 60:
        raise ValidationError("Quantity must not be greater than 30.")
    if n > 60:
        return False

class ItemSchema(Schema):
    quantity = fields.Integer(validate=validate_quantity)

in_data = {"quantity": 50}
try:
    result = ItemSchema().load(in_data)
except ValidationError as err:
    print(err.messages)

也可以给定一个校验函数的集合(list, tuple, generator)

def validate_quantity_2(n):
    if n%2 ==0:
        raise ValidationError("The quantity must not be even number")

class ItemSchema(Schema):
    quantity = fields.Integer(validate=[validate_quantity, validate_quantity_2])

# {'quantity': ['Quantity must not be greater than 30.', 'The quantity must not be even number']}
  			  

注意

  • 验证在反序列化时进行,但不在序列化时进行。为提高序列化性能,传递给 Schema.dump() 的数据被视为有效数据。