summaryrefslogtreecommitdiff
path: root/src/luarocks/type_check.lua
blob: 21085ef9a8897eb38c42891fda5918d044f70b19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

local type_check = {}

local cfg = require("luarocks.core.cfg")
local fun = require("luarocks.fun")
local util = require("luarocks.util")
local vers = require("luarocks.core.vers")
--------------------------------------------------------------------------------

-- A magic constant that is not used anywhere in a schema definition
-- and retains equality when the table is deep-copied.
type_check.MAGIC_PLATFORMS = 0xEBABEFAC

do
   local function fill_in_version(tbl, version)
      for _, v in pairs(tbl) do
         if type(v) == "table" then
            if v._version == nil then
               v._version = version
            end
            fill_in_version(v)
         end
      end
   end

   local function expand_magic_platforms(tbl)
      for k,v in pairs(tbl) do
         if v == type_check.MAGIC_PLATFORMS then
            tbl[k] = {
               _any = util.deep_copy(tbl)
            }
            tbl[k]._any[k] = nil
         elseif type(v) == "table" then
            expand_magic_platforms(v)
         end
      end
   end

   -- Build a table of schemas.
   -- @param versions a table where each key is a version number as a string,
   -- and the value is a schema specification. Schema versions are considered
   -- incremental: version "2.0" only needs to specify what's new/changed from
   -- version "1.0".
   function type_check.declare_schemas(inputs)
      local schemas = {}
      local parent_version

      local versions = fun.reverse_in(fun.sort_in(util.keys(inputs), vers.compare_versions))

      for _, version in ipairs(versions) do
         local schema = inputs[version]
         if parent_version ~= nil then
            local copy = util.deep_copy(schemas[parent_version])
            util.deep_merge(copy, schema)
            schema = copy
         end
         fill_in_version(schema, version)
         expand_magic_platforms(schema)
         parent_version = version
         schemas[version] = schema
      end

      return schemas, versions
   end
end

--------------------------------------------------------------------------------

local function check_version(version, typetbl, context)
   local typetbl_version = typetbl._version or "1.0"
   if vers.compare_versions(typetbl_version, version) then
      if context == "" then
         return nil, "Invalid rockspec_format version number in rockspec? Please fix rockspec accordingly."
      else
         return nil, context.." is not supported in rockspec format "..version.." (requires version "..typetbl_version.."), please fix the rockspec_format field accordingly."
      end
   end
   return true
end

--- Type check an object.
-- The object is compared against an archetypical value
-- matching the expected type -- the actual values don't matter,
-- only their types. Tables are type checked recursively.
-- @param version string: The version of the item.
-- @param item any: The object being checked.
-- @param typetbl any: The type-checking table for the object.
-- @param context string: A string indicating the "context" where the
-- error occurred (the full table path), for error messages.
-- @return boolean or (nil, string): true if type checking
-- succeeded, or nil and an error message if it failed.
-- @see type_check_table
local function type_check_item(version, item, typetbl, context)
   assert(type(version) == "string")

   if typetbl._version and typetbl._version ~= "1.0" then
      local ok, err = check_version(version, typetbl, context)
      if not ok then
         return nil, err
      end
   end

   local item_type = type(item) or "nil"
   local expected_type = typetbl._type or "table"

   if expected_type == "number" then
      if not tonumber(item) then
         return nil, "Type mismatch on field "..context..": expected a number"
      end
   elseif expected_type == "string" then
      if item_type ~= "string" then
         return nil, "Type mismatch on field "..context..": expected a string, got "..item_type
      end
      local pattern = typetbl._pattern
      if pattern then
         if not item:match("^"..pattern.."$") then
            local what = typetbl._name or ("'"..pattern.."'")
            return nil, "Type mismatch on field "..context..": invalid value '"..item.."' does not match " .. what
         end
      end
   elseif expected_type == "table" then
      if item_type ~= expected_type then
         return nil, "Type mismatch on field "..context..": expected a table"
      else
         return type_check.type_check_table(version, item, typetbl, context)
      end
   elseif item_type ~= expected_type then
      return nil, "Type mismatch on field "..context..": expected "..expected_type
   end
   return true
end

local function mkfield(context, field)
   if context == "" then
      return tostring(field)
   elseif type(field) == "string" then
      return context.."."..field
   else
      return context.."["..tostring(field).."]"
   end
end

--- Type check the contents of a table.
-- The table's contents are compared against a reference table,
-- which contains the recognized fields, with archetypical values
-- matching the expected types -- the actual values of items in the
-- reference table don't matter, only their types (ie, for field x
-- in tbl that is correctly typed, type(tbl.x) == type(types.x)).
-- If the reference table contains a field called MORE, then
-- unknown fields in the checked table are accepted.
-- If it contains a field called ANY, then its type will be
-- used to check any unknown fields. If a field is prefixed
-- with MUST_, it is mandatory; its absence from the table is
-- a type error.
-- Tables are type checked recursively.
-- @param version string: The version of tbl.
-- @param tbl table: The table to be type checked.
-- @param typetbl table: The type-checking table, containing
-- values for recognized fields in the checked table.
-- @param context string: A string indicating the "context" where the
-- error occurred (such as the name of the table the item is a part of),
-- to be used by error messages.
-- @return boolean or (nil, string): true if type checking
-- succeeded, or nil and an error message if it failed.
function type_check.type_check_table(version, tbl, typetbl, context)
   assert(type(version) == "string")
   assert(type(tbl) == "table")
   assert(type(typetbl) == "table")

   local ok, err = check_version(version, typetbl, context)
   if not ok then
      return nil, err
   end

   for k, v in pairs(tbl) do
      local t = typetbl[k] or typetbl._any
      if t then
         local ok, err = type_check_item(version, v, t, mkfield(context, k))
         if not ok then return nil, err end
      elseif typetbl._more then
         -- Accept unknown field
      else
         if not cfg.accept_unknown_fields then
            return nil, "Unknown field "..k
         end
      end
   end
   for k, v in pairs(typetbl) do
      if k:sub(1,1) ~= "_" and v._mandatory then
         if not tbl[k] then
            return nil, "Mandatory field "..mkfield(context, k).." is missing."
         end
      end
   end
   return true
end

function type_check.check_undeclared_globals(globals, typetbl)
   local undeclared = {}
   for glob, _ in pairs(globals) do
      if not (typetbl[glob] or typetbl["MUST_"..glob]) then
         table.insert(undeclared, glob)
      end
   end
   if #undeclared == 1 then
      return nil, "Unknown variable: "..undeclared[1]
   elseif #undeclared > 1 then
      return nil, "Unknown variables: "..table.concat(undeclared, ", ")
   end
   return true
end

return type_check