aboutsummaryrefslogtreecommitdiffstats
path: root/src/theming/lua-lib/expect.lua
blob: 20772b2586c8a470ff9f53bf83bfd2d0f71dd872 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
-- SPDX-FileCopyrightText: 2019 The CC: Tweaked Developers
--
-- SPDX-License-Identifier: MPL-2.0

--[[- The [`cc.expect`] library provides helper functions for verifying that
function arguments are well-formed and of the correct type.

@module cc.expect
@since 1.84.0
@changed 1.96.0 The module can now be called directly as a function, which wraps around `expect.expect`.
@usage Define a basic function and check it has the correct arguments.

    local expect = require "cc.expect"
    local expect, field = expect.expect, expect.field

    local function add_person(name, info)
        expect(1, name, "string")
        expect(2, info, "table", "nil")

        if info then
            print("Got age=", field(info, "age", "number"))
            print("Got gender=", field(info, "gender", "string", "nil"))
        end
    end

    add_person("Anastazja") -- `info' is optional
    add_person("Kion", { age = 23 }) -- `gender' is optional
    add_person("Caoimhin", { age = 23, gender = true }) -- error!
]]

local native_select, native_type = select, type

local function get_type_names(...)
    local types = table.pack(...)
    for i = types.n, 1, -1 do
        if types[i] == "nil" then table.remove(types, i) end
    end

    if #types <= 1 then
        return tostring(...)
    else
        return table.concat(types, ", ", 1, #types - 1) .. " or " .. types[#types]
    end
end


local function get_display_type(value, t)
    -- Lua is somewhat inconsistent in whether it obeys __name just for values which
    -- have a per-instance metatable (so tables/userdata) or for everything. We follow
    -- Cobalt and only read the metatable for tables/userdata.
    if t ~= "table" and t ~= "userdata" then return t end

    local metatable = getmetatable(value)
    if not metatable then return t end

    local name = rawget(metatable, "__name")
    if type(name) == "string" then return name else return t end
end

--- Expect an argument to have a specific type.
--
-- @tparam number index The 1-based argument index.
-- @param value The argument's value.
-- @tparam string ... The allowed types of the argument.
-- @return The given `value`.
-- @throws If the value is not one of the allowed types.
local function expect(index, value, ...)
    local t = native_type(value)
    for i = 1, native_select("#", ...) do
        if t == native_select(i, ...) then return value end
    end

    -- If we can determine the function name with a high level of confidence, try to include it.
    local name
    -- local ok, info = pcall(debug.getinfo, 3, "nS")
    -- if ok and info.name and info.name ~= "" and info.what ~= "C" then name = info.name end

    t = get_display_type(value, t)

    local type_names = get_type_names(...)
    if name then
        error(("bad argument #%d to '%s' (%s expected, got %s)"):format(index, name, type_names, t), 3)
    else
        error(("bad argument #%d (%s expected, got %s)"):format(index, type_names, t), 3)
    end
end

--- Expect an field to have a specific type.
--
-- @tparam table tbl The table to index.
-- @tparam string index The field name to check.
-- @tparam string ... The allowed types of the argument.
-- @return The contents of the given field.
-- @throws If the field is not one of the allowed types.
local function field(tbl, index, ...)
    expect(1, tbl, "table")
    expect(2, index, "string", "number")

    local value = tbl[index]
    local t = native_type(value)
    for i = 1, native_select("#", ...) do
        if t == native_select(i, ...) then return value end
    end

    t = get_display_type(value, t)

    if value == nil then
        error(("field '%s' missing from table"):format(index), 3)
    else
        error(("bad field '%s' (%s expected, got %s)"):format(index, get_type_names(...), t), 3)
    end
end

local function is_nan(num)
  return num ~= num
end

--- Expect a number to be within a specific range.
--
-- @tparam number num The value to check.
-- @tparam number min The minimum value, if nil then `-math.huge` is used.
-- @tparam number max The maximum value, if nil then `math.huge` is used.
-- @return The given `value`.
-- @throws If the value is outside of the allowed range.
-- @since 1.96.0
local function range(num, min, max)
  expect(1, num, "number")
  min = expect(2, min, "number", "nil") or -math.huge
  max = expect(3, max, "number", "nil") or math.huge
  if min > max then
      error("min must be less than or equal to max)", 2)
  end

  if is_nan(num) or num < min or num > max then
      error(("number outside of range (expected %s to be within %s and %s)"):format(num, min, max), 3)
  end

  return num
end

return setmetatable({
    expect = expect,
    field = field,
    range = range,
}, { __call = function(_, ...) return expect(...) end })