模块:Schema

来自维阿百科
跳转至: 导航搜索

此模块用于验证数据结构。契机是去年的棒球赛契机是模块:IconLink/data被填错了,过了一个多月才发现,便想弄点校验。

基础

local schema = require('Module:Schema')  -- 随便起个名,再比如scm,再比如s

local valid, msg = schema.Number:test(1)
--> true, nil
local valid, msg = schema.Number:test('啊啊啊啊啊啊啊啊啊啊可爱的字符串')
--> false, '报错信息,我忘了是啥了,等它出来了自己看'
schema.Number:assert(1)  --> true
schema.Number:assert('OwO')  --> 抛出错误
-- 后面讲的几个类型都有test、assert这两个方法。

Any

schema.Any  -- 任何值都能通过测试,是很随性的孩子

-- 约束选项有validator:
schema.Any{validator=func}
schema.Any{validator={func1, func2, ...}}

-- validator的函数为(any) -> truthy | falsy,比如判断是否为非字符串:
local function is_string(v)
	return type(v) == 'string'
end
local function is_not_empty_string(v)
	return v ~= ''
end
local NonEmptyString = schema.Any{validator={is_string, is_not_empty_string}}
NonEmptyString:test('只有非空字符串能通过')  --> true
-- 仅作示例,实际使用String更便捷。

-- 可以在已有schema的基础上继续增加约束,这对于Number、String等其他支持约束的类型也适用。
-- 还是拿NonEmptyString举例
local String = schema.Any{validator=is_string}
local NonEmptyString = String{validator=is_not_empty_string}

Nil、Boolean

schema.Nil  -- nil类型的值可以通过
schema.Boolean  -- boolean类型的值可以通过
-- 这两个无约束选项,因为Nil只能是nil,Boolean只能是true或false。
-- 如果需要“只能是true”或“只能是false”,请看Const。

Number

schema.Number:test(0721)  --> true

-- 约束选项:
-- gt(大于)、lt(小于)、ge/min(大于等于)、le/max(小于等于)、ne(不等于),
-- 以及validator(自定义约束)
-- 例如:
schema.Number{gt=0, le=100}:test(22/7)  --> true

String

schema.String:test('Ciallo ~')  --> true

-- 约束选项:
-- min_len(最短长度)、max_len(最长长度)、pattern(正则),
-- 以及validator(自定义约束)。
-- 例如,Any中的例子可以改为:
local NonEmptyString = schema.String{min_len=1}
-- 例二:
local HttpUrl = schema.String{pattern='^https?://'}

Function

schema.Function:test(function() return 42 end)
-- 约束选项只有validator

Table

schema.Table  -- table类型的值可以通过

-- 示例
local CharacterInfo = schema.Table{
	name = schema.String{min_len=1},
	age = schema.Number{min=0},
	sex = 'female',  -- 等同于 sex = schema.Const('female')
}
local info1 = {
	name = '缠流子',
	age = '17',
	sex = 'female',
}
local info2 = {
	name = '满舰饰真子',
	age = 16,
	sex = 'female',
}
CharacterInfo:test(info1)  --> false, '好像会说什么age应该是数字而不是字符串'
CharacterInfo:test(info2)  --> true


-- 'validator'字段被自定义校验器占用了,
-- 若你的结构中包含这两个字段,请使用schema.Const('validator')代替。
-- 作用对比如下:
local function check_children_num(t)
	return #t > 5
end
local array_with_6_children = {1, 2, 3, 4, 5, 6}
local table_with_validator_field = {
	validator = check_children_num
}

-- 这是为表格添加自定义校验:
schema.Table{
	validator = check_children_num
}:test(array_with_6_children)  --> true

-- 而这是设定表格validator字段的类型:
schema.Table{
	[schema.Const('validator')] = schema.Function,
}:test(table_with_validator_field)  --> true


-- 可以使用schema作为键来匹配多个字段:
local hanzi_number_conversion = {
	'一', '二', '三',
	['一'] = 1, ['二'] = 2, ['三'] = 3,
}
schema.Table{
	[schema.String] = schema.Number,
	[schema.Number] = schema.String,
}:test(hanzi_number_conversion)  --> true
-- 这个例子中,只有数字键对应的值为字符串、字符串键对应的值为数字的表才能通过测试。

Const

schema.Const(val)
-- Const是一个函数,本身不可以用作校验,它接收一个参数并返回一个schema,
-- 只有等于这个参数的值才能通过校验。
schema.Const('?'):test('!')  --> false, '好像是说这个不等于那个'
schema.Const('?'):test('?')  --> true

-- 当你传入相同的对象时,返回的schema也是同一个对象
local t = {}
local A, B = schema.Const(t), schema.Const(t)
rawequal(A, B)  --> true

Union

schema.Union(scm1, scm2, ...)
-- 这个并集,被测值符合参数的其中一种就行。
-- 与Const一样,Union本身不可以用作校验,调用后返回的对象才可。
local QAQ = schema.Union('来测', '求你了,来测吧', schema.Boolean)
QAQ:test('来测')  --> true
QAQ:test('求你了,来测吧')  --> true
QAQ:test(false)  --> true
QAQ:test(6)  --> false, '说是union里的类型都不符合'

-- 有更方便的写法:
local WhatName = schema.String / schema.Number / nil
-- 等同于schema.Union(schema.String, schema.Number, nil)。

其他

schema.Integer  -- 用法同schema.Number,只不过这个只有整数能通过测试
schema.Callable  -- 可以调用的值,包含function和设置了__call元方法的表
schema.Truthy  -- 所有不是nil或false的值
schema.Falsy  -- nil或false

local fmt = string.format
local next = next
local type = type
local error = error
local pairs = pairs
local assert = assert
local ipairs = ipairs
local getmt = getmetatable
local setmt = setmetatable


local schema = {}


--- 值的表示形式
---@param v any
---@return string
local function repr(v)
	local typ = type(v)
	if typ == 'string' then
		return fmt('%q', v)
	elseif typ == 'table' then
		local s = tostring(v)
		if not s:match('^table') then
			s = 'table: '..s
		end
		return s
	end
	return tostring(v)
end


local function indent_new_ln(str)
	return str:gsub('(\r?\n)', '%1  ')
end


local function map(array, func)
	local t = {}
	for i, v in ipairs(array) do
		t[i] = func(v)
	end
	return t
end


local schema_mts = setmt({}, {__mode = 'k'})  ---@type {[metatable]: true}

---@param name string
---@param super_mt metatable?
---@param without_override boolean?
---@return metatable
local function reg_mt(name, super_mt, without_override)
	local index = super_mt and super_mt.__index or {}
	if not without_override then
		index = setmt({}, {__index = index})
	end

	local mt = {
		__name = name,
		__index = index,
	}
	schema_mts[mt] = true
	return mt
end

---@param v any
---@return string | nil
local function get_scm_type(v)
	local mt =  getmt(v)
	if not schema_mts[mt] then return nil end
	return mt.__name
end


local function is_callable(v)
	if type(v) == 'function' then return true end
	local mt = getmt(v)
	local call = mt and rawget(mt, '__call')
	if not call then return false end
	return is_callable(call)
end


---@param constraints table
---@return table?
local function get_validators_from_constraints(constraints)
	local inputs = constraints.validator
	if not inputs then return nil end
	if type(inputs) ~= 'table' or getmt(inputs) then
		inputs = {inputs}
	end
	local validators = {}
	for i, v in ipairs(inputs) do
		if not is_callable(v) then
			error(fmt("%s isn't callable", repr(v)), 3)
		end
		validators[i] = v
	end
	return validators
end


local SUPER = '__s__'
local VALIDATORS = '__v__'
local always_true = function() return true end

local Any_mt = reg_mt('Any', nil)
schema.Any = setmt({
	test = always_true,
}, Any_mt)

function Any_mt:__call(constraints)
	return setmt({
		[SUPER] = self,
		[VALIDATORS] = get_validators_from_constraints(constraints),
	}, getmt(self))
end

Any_mt.__index._test = always_true

function Any_mt.__index:test(testee)
	if self[SUPER] then
		local valid, msg = self[SUPER]:test(testee)
		if not valid then
			return false, msg
		end
	end

	local valid, msg = self:_test(testee)
	if not valid then
		return false, msg
	end

	if not self[VALIDATORS] then return true end
	for _, validator in ipairs(self[VALIDATORS]) do
		local valid, msg = validator(testee)
		if not valid then
			return false, msg and 'custom validation failed: '..msg or 'custom validation failed'
		end
	end
	return true
end

function Any_mt.__index:assert(testee)
	if self[SUPER] then
		self[SUPER]:assert(testee)
	end

	assert(self:_test(testee))

	if not self[VALIDATORS] then return true end
	for _, validator in ipairs(self[VALIDATORS]) do
		assert(validator(testee))
	end
	return true
end


---@param typ string 类型
---@param a string? 冠词
---@return function
local function TypeChecker(typ, a)
	local fmt_str = "%s (type: %s) isn't "..(a and a..' ' or '')..typ
	return function(_self, testee)
		if type(testee) == typ then
			return true
		end
		return false, fmt(fmt_str, repr(testee), type(testee))
	end
end


local Nil_mt = reg_mt('Nil', Any_mt, true)
schema.Nil = setmt({
	_test = TypeChecker('nil'),
}, Nil_mt)


local Boolean_mt = reg_mt('Boolean', Any_mt, true)
schema.Boolean = setmt({
	_test = TypeChecker('boolean', 'a'),
}, Boolean_mt)


local Number_mt = reg_mt('Number', Any_mt)
schema.Number = setmt({
	_test = TypeChecker('number', 'a'),
}, Number_mt)

---@type {[string]: fun(testee: number, n: number): boolean, string?}
local num_cmps = {
	lt = function(testee, n)
		if testee < n then return true end
		return false, fmt("%s isn't < %s", testee, n)
	end,
	gt = function(testee, n)
		if testee > n then return true end
		return false, fmt("%s isn't > %s", testee, n)
	end,
	le = function(testee, n)
		if testee <= n then return true end
		return false, fmt("%s isn't <= %s", testee, n)
	end,
	ge = function(testee, n)
		if testee >= n then return true end
		return false, fmt("%s isn't >= %s", testee, n)
	end,
	ne = function(testee, n)
		if testee ~= n then return true end
		return false, fmt('testee equals %s', n)
	end,
}

function Number_mt:__call(constraints)
	return setmt({
		[SUPER] = self,
		cmp = {
			lt = constraints.lt,
			gt = constraints.gt,
			le = constraints.le or constraints.max,
			ge = constraints.ge or constraints.min,
			ne = constraints.ne,
		},
		[VALIDATORS] = get_validators_from_constraints(constraints),
	}, Number_mt)
end

function Number_mt.__index:_test(testee)
	for method_name, n in pairs(self.cmp) do
		local valid, msg = num_cmps[method_name](testee, n)
		if not valid then
			return false, msg
		end
	end
	return true
end


local String_mt = reg_mt('String', Any_mt)
schema.String = setmt({
	_test = TypeChecker('string', 'a')
}, String_mt)

function String_mt:__call(constraints)
	return setmt({
		[SUPER] = self,
		max_len = constraints.max_len,
		min_len = constraints.min_len,
		pattern = constraints.pattern,
		[VALIDATORS] = get_validators_from_constraints(constraints),
	}, String_mt)
end

function String_mt.__index:_test(testee)
	if self.max_len and #testee > self.max_len then
		return false, fmt("the length of %q (%d) exceeds %s", testee, #testee, self.max_len)
	end
	if self.min_len and #testee < self.min_len then
		return false, fmt("the length of %q (%d) is under %s", testee, #testee, self.min_len)
	end
	if self.pattern and not testee:match(self.pattern) then
		return false, fmt("%q doesn't match the pattern %q", testee, self.pattern)
	end
	return true
end


local Function_mt = reg_mt('Function', Any_mt)
schema.Function = setmt({
	_test = TypeChecker('function', 'a'),
}, Function_mt)

Function_mt.__call = Any_mt.__call


local Table_mt = reg_mt('Table', Any_mt)
schema.Table = setmt({
	_test = TypeChecker('table', 'a')
}, Table_mt)

function Table_mt:__call(constraints)
	local specific = {}
	local generic = {}
	for k, v in pairs(constraints) do
		if not get_scm_type(v) then
			v = schema.Const(v)
		end
		local scm_type = get_scm_type(k)
		if scm_type then
			if scm_type == 'Const' then
				specific[k[1]] = v
			else
				generic[k] = v
			end
		elseif k ~= 'validator' then
			specific[k] = v
		end
	end
	return setmt({
		[SUPER] = self,
		specific = specific,
		generic = generic,
		[VALIDATORS] = get_validators_from_constraints(constraints)
	}, Table_mt)
end

function Table_mt.__index:_test(testee)
	for key_scm, val_scm in pairs(self.generic) do
		for testee_key, testee_val in pairs(testee) do
			if key_scm:test(testee_key) then
				local valid, msg = val_scm:test(testee_val)
				if not valid then
					return false, fmt(
						'in %s, field %s:\n- %s',
						repr(testee), repr(testee_key), indent_new_ln(msg or 'no message provided')
					)
				end
			end
		end
	end
	for key, val_scm in pairs(self.specific) do
		local valid, msg = val_scm:test(testee[key])
		if not valid then
			return false, fmt(
				'in %s, field %s:\n- %s',
				repr(testee), repr(key), indent_new_ln(msg or 'no message provided')
			)
		end
	end
	return true
end


local Const_mt = reg_mt('Const', Any_mt)
local existing_const_scms = setmt({}, {__mode = 'kv'})

--- 获得一个Const实例,以相同参数多次调用将会返回同一对象
function schema.Const(val)
	if val == nil then
		return schema.Nil
	elseif existing_const_scms[val] then
		return existing_const_scms[val]
	end
	local obj = setmt({val}, Const_mt)
	existing_const_scms[val] = obj
	return obj
end

function Const_mt.__index:_test(testee)
	if self[1] == testee then return true end
	return false, fmt("%s doesn't equals %s", repr(testee), repr(self[1]))
end


local Union_mt = reg_mt('Union', Any_mt)

function schema.Union(...)
	local union = {}
	for i = 1, select('#', ...) do
		local sub_scm = select(i, ...)
		if sub_scm == nil then
			union[schema.Nil] = true
		elseif get_scm_type(sub_scm) == 'Union' then
			for scm_in_union in next, sub_scm do
				union[scm_in_union] = true
			end
		elseif get_scm_type(sub_scm) then
			union[sub_scm] = true
		else
			union[schema.Const(sub_scm)] = true
		end
	end
	return setmt(union, Union_mt)
end

function Union_mt.__index:_test(testee)
	local msgs = {}
	for allowed_scm in next, self do
		local valid, msg = allowed_scm:test(testee)
		if valid then return true end
		msgs[#msgs+1] = msg or 'no message provided'
	end
	return false, fmt(
		'%s fails to match any value in the union:\n- %s',
		repr(testee),
		table.concat(map(msgs, indent_new_ln), '\n- ')
	)
end

-- 必须放在所有reg_mt()之后
for mt in next, schema_mts do
	mt.__bor = schema.Union
	mt.__div = schema.Union
end
reg_mt = nil  -- 防止后续意外调用


schema.Integer = schema.Number{validator=function(v)
	if math.fmod(v, 1) == 0 then return true end
	return false, fmt("%s isn't an integer", v)
end}

schema.Callable = schema.Any{validator=function(v)
	if is_callable(v) then return true end
	return false, fmt("%s isn't callable", repr(v))
end}

schema.Truthy = schema.Any{validator=function(v)
	if v then return true end
	return false, fmt("%s isn't truthy", v)
end}

schema.Falsy = schema.Any{validator=function(v)
	if not v then return true end
	return false, fmt("%s isn't falsy", repr(v))
end}


return schema