1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
|
-- SPDX-FileCopyrightText: 2019 The CC: Tweaked Developers
--
-- SPDX-License-Identifier: MPL-2.0
--[[- The [`cc.expect`] library provides helper functions for verifying that
function arguments are well-formed and of the correct type.
@module cc.expect
@since 1.84.0
@changed 1.96.0 The module can now be called directly as a function, which wraps around `expect.expect`.
@usage Define a basic function and check it has the correct arguments.
local expect = require "cc.expect"
local expect, field = expect.expect, expect.field
local function add_person(name, info)
expect(1, name, "string")
expect(2, info, "table", "nil")
if info then
print("Got age=", field(info, "age", "number"))
print("Got gender=", field(info, "gender", "string", "nil"))
end
end
add_person("Anastazja") -- `info' is optional
add_person("Kion", { age = 23 }) -- `gender' is optional
add_person("Caoimhin", { age = 23, gender = true }) -- error!
]]
local native_select, native_type = select, type
local function get_type_names(...)
local types = table.pack(...)
for i = types.n, 1, -1 do
if types[i] == "nil" then table.remove(types, i) end
end
if #types <= 1 then
return tostring(...)
else
return table.concat(types, ", ", 1, #types - 1) .. " or " .. types[#types]
end
end
local function get_display_type(value, t)
-- Lua is somewhat inconsistent in whether it obeys __name just for values which
-- have a per-instance metatable (so tables/userdata) or for everything. We follow
-- Cobalt and only read the metatable for tables/userdata.
if t ~= "table" and t ~= "userdata" then return t end
local metatable = getmetatable(value)
if not metatable then return t end
local name = rawget(metatable, "__name")
if type(name) == "string" then return name else return t end
end
--- Expect an argument to have a specific type.
--
-- @tparam number index The 1-based argument index.
-- @param value The argument's value.
-- @tparam string ... The allowed types of the argument.
-- @return The given `value`.
-- @throws If the value is not one of the allowed types.
local function expect(index, value, ...)
local t = native_type(value)
for i = 1, native_select("#", ...) do
if t == native_select(i, ...) then return value end
end
-- If we can determine the function name with a high level of confidence, try to include it.
local name
-- local ok, info = pcall(debug.getinfo, 3, "nS")
-- if ok and info.name and info.name ~= "" and info.what ~= "C" then name = info.name end
t = get_display_type(value, t)
local type_names = get_type_names(...)
if name then
error(("bad argument #%d to '%s' (%s expected, got %s)"):format(index, name, type_names, t), 3)
else
error(("bad argument #%d (%s expected, got %s)"):format(index, type_names, t), 3)
end
end
--- Expect an field to have a specific type.
--
-- @tparam table tbl The table to index.
-- @tparam string index The field name to check.
-- @tparam string ... The allowed types of the argument.
-- @return The contents of the given field.
-- @throws If the field is not one of the allowed types.
local function field(tbl, index, ...)
expect(1, tbl, "table")
expect(2, index, "string", "number")
local value = tbl[index]
local t = native_type(value)
for i = 1, native_select("#", ...) do
if t == native_select(i, ...) then return value end
end
t = get_display_type(value, t)
if value == nil then
error(("field '%s' missing from table"):format(index), 3)
else
error(("bad field '%s' (%s expected, got %s)"):format(index, get_type_names(...), t), 3)
end
end
local function is_nan(num)
return num ~= num
end
--- Expect a number to be within a specific range.
--
-- @tparam number num The value to check.
-- @tparam number min The minimum value, if nil then `-math.huge` is used.
-- @tparam number max The maximum value, if nil then `math.huge` is used.
-- @return The given `value`.
-- @throws If the value is outside of the allowed range.
-- @since 1.96.0
local function range(num, min, max)
expect(1, num, "number")
min = expect(2, min, "number", "nil") or -math.huge
max = expect(3, max, "number", "nil") or math.huge
if min > max then
error("min must be less than or equal to max)", 2)
end
if is_nan(num) or num < min or num > max then
error(("number outside of range (expected %s to be within %s and %s)"):format(num, min, max), 3)
end
return num
end
return setmetatable({
expect = expect,
field = field,
range = range,
}, { __call = function(_, ...) return expect(...) end })
|