aboutsummaryrefslogtreecommitdiffstats
path: root/compat53/init.lua
blob: a7f0c807eb5bd4e9f9806d8919beceec47fa2b36 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
local lua_version = _VERSION:sub(-3)


if lua_version < "5.3" then

   local _G, pairs, require, select, type =
         _G, pairs, require, select, type
   local debug, io = debug, io
   local unpack = lua_version == "5.1" and unpack or table.unpack

   local M = require("compat53.module")

   -- select the most powerful getmetatable function available
   local gmt = type(debug) == "table" and debug.getmetatable or
               getmetatable or function() return false end
   -- metatable for file objects from Lua's standard io library
   local file_meta = gmt(io.stdout)


   -- make '*' optional for file:read and file:lines
   if type(file_meta) == "table" and type(file_meta.__index) == "table" then

      local function addasterisk(fmt)
         if type(fmt) == "string" and fmt:sub(1, 1) ~= "*" then
            return "*"..fmt
         else
            return fmt
         end
      end

      local file_lines = file_meta.__index.lines
      file_meta.__index.lines = function(self, ...)
         local n = select('#', ...)
         for i = 1, n do
            local a = select(i, ...)
            local b = addasterisk(a)
            -- as an optimization we only allocate a table for the
            -- modified format arguments when we have a '*' somewhere
            if a ~= b then
               local args = { ... }
               args[i] = b
               for j = i+1, n do
                  args[j] = addasterisk(args[j])
               end
               return file_lines(self, unpack(args, 1, n))
            end
         end
         return file_lines(self, ...)
      end

      local file_read = file_meta.__index.read
      file_meta.__index.read = function(self, ...)
         local n = select('#', ...)
         for i = 1, n do
            local a = select(i, ...)
            local b = addasterisk(a)
            -- as an optimization we only allocate a table for the
            -- modified format arguments when we have a '*' somewhere
            if a ~= b then
               local args = { ... }
               args[i] = b
               for j = i+1, n do
                  args[j] = addasterisk(args[j])
               end
               return file_read(self, unpack(args, 1, n))
            end
         end
         return file_read(self, ...)
      end

   end -- got a valid metatable for file objects


   -- changes for Lua 5.1 only
   if lua_version == "5.1" then

      -- cache globals
      local error, pcall, rawset, setmetatable, tostring, xpcall =
            error, pcall, rawset, setmetatable, tostring, xpcall
      local coroutine, package, string = coroutine, package, string
      local coroutine_resume = coroutine.resume
      local coroutine_running = coroutine.running
      local coroutine_status = coroutine.status
      local coroutine_yield = coroutine.yield
      local io_type = io.type


      -- detect LuaJIT (including LUAJIT_ENABLE_LUA52COMPAT compilation flag)
      local is_luajit = (string.dump(function() end) or ""):sub(1, 3) == "\027LJ"
      local is_luajit52 = is_luajit and
        #setmetatable({}, { __len = function() return 1 end }) == 1


      -- make package.searchers available as an alias for package.loaders
      local p_index = { searchers = package.loaders }
      setmetatable(package, {
         __index = p_index,
         __newindex = function(p, k, v)
            if k == "searchers" then
               rawset(p, "loaders", v)
               p_index.searchers = v
            else
               rawset(p, k, v)
            end
         end
      })


      if type(file_meta) == "table" and type(file_meta.__index) == "table" then
         if not is_luajit then
            local function helper(_, var_1, ...)
               if var_1 == nil then
                  if (...) ~= nil then
                     error((...), 2)
                  end
               end
               return var_1, ...
            end

            local function lines_iterator(st)
               return helper(st, st.f:read(unpack(st, 1, st.n)))
            end

            local file_write = file_meta.__index.write
            file_meta.__index.write = function(self, ...)
               local res, msg, errno = file_write(self, ...)
               if res then
                  return self
               else
                  return nil, msg, errno
               end
            end

            file_meta.__index.lines = function(self, ...)
               if io_type(self) == "closed file" then
                  error("attempt to use a closed file", 2)
               end
               local st = { f=self, n=select('#', ...), ... }
               for i = 1, st.n do
                  local t = type(st[i])
                  if t == "string" then
                     local fmt = st[i]:match("^*?([aln])")
                     if not fmt then
                        error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2)
                     end
                     st[i] = "*"..fmt
                  elseif t ~= "number" then
                     error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2)
                  end
               end
               return lines_iterator, st
            end
         end -- not luajit
      end -- file_meta valid


      -- the (x)pcall implementations start a new coroutine internally
      -- to allow yielding even in Lua 5.1. to allow for accurate
      -- stack traces we keep track of the nested coroutine activations
      -- in the weak tables below:
      local weak_meta = { __mode = "kv" }
      -- maps the internal pcall coroutines to the user coroutine that
      -- *should* be running if pcall didn't use coroutines internally
      local pcall_mainOf = setmetatable({}, weak_meta)
      -- table that maps each running coroutine started by pcall to
      -- the coroutine that resumed it (user coroutine *or* pcall
      -- coroutine!)
      local pcall_previous = setmetatable({}, weak_meta)
      -- reverse of `pcall_mainOf`. maps a user coroutine to the
      -- currently active pcall coroutine started within it
      local pcall_callOf = setmetatable({}, weak_meta)
      -- similar to `pcall_mainOf` but is used only while executing
      -- the error handler of xpcall (thus no nesting is necessary!)
      local xpcall_running = setmetatable({}, weak_meta)

      -- handle debug functions
      if type(debug) == "table" then
         local debug_getinfo = debug.getinfo
         local debug_traceback = debug.traceback

         if not is_luajit then
            local function calculate_trace_level(co, level)
               if level ~= nil then
                  for out = 1, 1/0 do
                     local info = (co==nil) and debug_getinfo(out, "") or debug_getinfo(co, out, "")
                     if info == nil then
                        local max = out-1
                        if level <= max then
                           return level
                        end
                        return nil, level-max
                     end
                  end
               end
               return 1
            end

            local stack_pattern = "\nstack traceback:"
            local stack_replace = ""
            function debug.traceback(co, msg, level)
               local lvl
               local nilmsg
               if type(co) ~= "thread" then
                  co, msg, level = coroutine_running(), co, msg
               end
               if msg == nil then
                  msg = ""
                  nilmsg = true
               elseif type(msg) ~= "string" then
                  return msg
               end
               if co == nil then
                  msg = debug_traceback(msg, level or 1)
               else
                  local xpco = xpcall_running[co]
                  if xpco ~= nil then
                     lvl, level = calculate_trace_level(xpco, level)
                     if lvl then
                        msg = debug_traceback(xpco, msg, lvl)
                     else
                        msg = msg..stack_pattern
                     end
                     lvl, level = calculate_trace_level(co, level)
                     if lvl then
                        local trace = debug_traceback(co, "", lvl)
                        msg = msg..trace:gsub(stack_pattern, stack_replace)
                     end
                  else
                     co = pcall_callOf[co] or co
                     lvl, level = calculate_trace_level(co, level)
                     if lvl then
                        msg = debug_traceback(co, msg, lvl)
                     else
                        msg = msg..stack_pattern
                     end
                  end
                  co = pcall_previous[co]
                  while co ~= nil do
                     lvl, level = calculate_trace_level(co, level)
                     if lvl then
                        local trace = debug_traceback(co, "", lvl)
                        msg = msg..trace:gsub(stack_pattern, stack_replace)
                     end
                     co = pcall_previous[co]
                  end
               end
               if nilmsg then
                  msg = msg:gsub("^\n", "")
               end
               msg = msg:gsub("\n\t%(tail call%): %?", "\000")
               msg = msg:gsub("\n\t%.%.%.\n", "\001\n")
               msg = msg:gsub("\n\t%.%.%.$", "\001")
               msg = msg:gsub("(%z+)\001(%z+)", function(some, other)
                  return "\n\t(..."..#some+#other.."+ tail call(s)...)"
               end)
               msg = msg:gsub("\001(%z+)", function(zeros)
                  return "\n\t(..."..#zeros.."+ tail call(s)...)"
               end)
               msg = msg:gsub("(%z+)\001", function(zeros)
                  return "\n\t(..."..#zeros.."+ tail call(s)...)"
               end)
               msg = msg:gsub("%z+", function(zeros)
                  return "\n\t(..."..#zeros.." tail call(s)...)"
               end)
               msg = msg:gsub("\001", function()
                  return "\n\t..."
               end)
               return msg
            end
         end -- is not luajit
      end -- debug table available


      if not is_luajit52 then
         local coroutine_running52 = M.coroutine.running
         function M.coroutine.running()
            local co, ismain = coroutine_running52()
            if ismain then
               return co, true
            else
               return pcall_mainOf[co] or co, false
            end
         end
      end

      if not is_luajit then
         local function pcall_results(current, call, success, ...)
            if coroutine_status(call) == "suspended" then
               return pcall_results(current, call, coroutine_resume(call, coroutine_yield(...)))
            end
            if pcall_previous then
               pcall_previous[call] = nil
               local main = pcall_mainOf[call]
               if main == current then current = nil end
               pcall_callOf[main] = current
            end
            pcall_mainOf[call] = nil
            return success, ...
         end

         local function pcall_exec(current, call, ...)
            local main = pcall_mainOf[current] or current
            pcall_mainOf[call] = main
            if pcall_previous then
               pcall_previous[call] = current
               pcall_callOf[main] = call
            end
            return pcall_results(current, call, coroutine_resume(call, ...))
         end

         local coroutine_create52 = M.coroutine.create

         local function pcall_coroutine(func)
            if type(func) ~= "function" then
               local callable = func
               func = function (...) return callable(...) end
            end
            return coroutine_create52(func)
         end

         function M.pcall(func, ...)
            local current = coroutine_running()
            if not current then return pcall(func, ...) end
            return pcall_exec(current, pcall_coroutine(func), ...)
         end

         local function xpcall_catch(current, call, msgh, success, ...)
            if not success then
               xpcall_running[current] = call
               local ok, result = pcall(msgh, ...)
               xpcall_running[current] = nil
               if not ok then
                  return false, "error in error handling ("..tostring(result)..")"
               end
               return false, result
            end
            return true, ...
         end

         function M.xpcall(f, msgh, ...)
            local current = coroutine_running()
            if not current then
               local args, n = { ... }, select('#', ...)
               return xpcall(function() return f(unpack(args, 1, n)) end, msgh)
            end
            local call = pcall_coroutine(f)
            return xpcall_catch(current, call, msgh, pcall_exec(current, call, ...))
         end
      end -- not luajit

   end -- lua 5.1


   -- handle exporting to global scope
   local function extend_table(from, to)
      if from ~= to then
         for k,v in pairs(from) do
            if type(v) == "table" and
               type(to[k]) == "table" and
               v ~= to[k] then
               extend_table(v, to[k])
            else
               to[k] = v
            end
         end
      end
   end

   extend_table(M, _G)

end -- lua < 5.3

-- vi: set expandtab softtabstop=3 shiftwidth=3 :