----------------------------------------------------------------------------- -- A hacked dispatcher module -- LuaSocket sample files -- Author: Diego Nehab -- RCS ID: $$ ----------------------------------------------------------------------------- local base = _G local socket = require("socket") local coroutine = require("coroutine") module("dispatch") -- if too much time goes by without any activity in one of our sockets, we -- just kill it TIMEOUT = 10 ----------------------------------------------------------------------------- -- Mega hack. Don't try to do this at home. ----------------------------------------------------------------------------- -- Lua 5.1 has coroutine.running(). We need it here, so we use this terrible -- hack to emulate it in Lua itself -- This is very inefficient, but is very good for debugging. local running local resume = coroutine.resume function coroutine.resume(co, ...) running = co return resume(co, unpack(arg)) end function coroutine.running() return running end ----------------------------------------------------------------------------- -- Mega hack. Don't try to do this at home. ----------------------------------------------------------------------------- -- we can't yield across calls to protect, so we rewrite it with coxpcall -- make sure you don't require any module that uses socket.protect before -- loading our hack function socket.protect(f) return f end function socket.protect(f) return function(...) local co = coroutine.create(f) while true do local results = {resume(co, unpack(arg))} local status = table.remove(results, 1) if not status then if type(results[1]) == 'table' then return nil, results[1][1] else error(results[1]) end end if coroutine.status(co) == "suspended" then arg = {coroutine.yield(unpack(results))} else return unpack(results) end end end end ----------------------------------------------------------------------------- -- socket.tcp() replacement for non-blocking I/O ----------------------------------------------------------------------------- local function newtrap(dispatcher) -- try to create underlying socket local tcp, error = socket.tcp() if not tcp then return nil, error end -- put it in non-blocking mode right away tcp:settimeout(0) -- metatable for trap produces new methods on demand for those that we -- don't override explicitly. local metat = { __index = function(table, key) table[key] = function(...) return tcp[key](tcp, unpack(arg)) end end} -- does user want to do his own non-blocking I/O? local zero = false -- create a trap object that will behave just like a real socket object local trap = { } -- we ignore settimeout to preserve our 0 timeout, but record whether -- the user wants to do his own non-blocking I/O function trap:settimeout(mode, value) if value == 0 then zero = true else zero = false end return 1 end -- send in non-blocking mode and yield on timeout function trap:send(data, first, last) first = (first or 1) - 1 local result, error while true do -- tell dispatcher we want to keep sending before we yield dispatcher.sending:insert(tcp) -- mark time we started waiting dispatcher.context[tcp].last = socket.gettime() -- return control to dispatcher -- if upon return the dispatcher tells us we timed out, -- return an error to whoever called us if coroutine.yield() == "timeout" then return nil, "timeout" end -- try sending result, error, first = tcp:send(data, first+1, last) -- if we are done, or there was an unexpected error, -- break away from loop if error ~= "timeout" then return result, error, first end end end -- receive in non-blocking mode and yield on timeout -- or simply return partial read, if user requested timeout = 0 function trap:receive(pattern, partial) local error = "timeout" local value while true do -- tell dispatcher we want to keep receiving before we yield dispatcher.receiving:insert(tcp) -- mark time we started waiting dispatcher.context[tcp].last = socket.gettime() -- return control to dispatcher -- if upon return the dispatcher tells us we timed out, -- return an error to whoever called us if coroutine.yield() == "timeout" then return nil, "timeout" end -- try receiving value, error, partial = tcp:receive(pattern, partial) -- if we are done, or there was an unexpected error, -- break away from loop if (error ~= "timeout") or zero then return value, error, partial end end end -- connect in non-blocking mode and yield on timeout function trap:connect(host, port) local result, error = tcp:connect(host, port) -- mark time we started waiting dispatcher.context[tcp].last = socket.gettime() if error == "timeout" then -- tell dispatcher we will be able to write uppon connection dispatcher.sending:insert(tcp) -- return control to dispatcher -- if upon return the dispatcher tells us we have a -- timeout, just abort if coroutine.yield() == "timeout" then return nil, "timeout" end -- when we come back, check if connection was successful result, error = tcp:connect(host, port) if result or error == "already connected" then return 1 else return nil, "non-blocking connect failed" end else return result, error end end -- accept in non-blocking mode and yield on timeout function trap:accept() local result, error = tcp:accept() while error == "timeout" do -- mark time we started waiting dispatcher.context[tcp].last = socket.gettime() -- tell dispatcher we will be able to read uppon connection dispatcher.receiving:insert(tcp) -- return control to dispatcher -- if upon return the dispatcher tells us we have a -- timeout, just abort if coroutine.yield() == "timeout" then return nil, "timeout" end end return result, error end -- remove thread from context function trap:close() dispatcher.context[tcp] = nil return tcp:close() end -- add newly created socket to context dispatcher.context[tcp] = { thread = coroutine.running() } return setmetatable(trap, metat) end ----------------------------------------------------------------------------- -- Our set data structure ----------------------------------------------------------------------------- local function newset() local reverse = {} local set = {} return setmetatable(set, {__index = { insert = function(set, value) if not reverse[value] then table.insert(set, value) reverse[value] = table.getn(set) end end, remove = function(set, value) local index = reverse[value] if index then reverse[value] = nil local top = table.remove(set) if top ~= value then reverse[top] = index set[index] = top end end end }}) end ----------------------------------------------------------------------------- -- Our dispatcher API. ----------------------------------------------------------------------------- local metat = { __index = {} } function metat.__index:start(func) local co = coroutine.create(func) assert(coroutine.resume(co)) end function newhandler() local dispatcher = { context = {}, sending = newset(), receiving = newset() } function dispatcher.tcp() return newtrap(dispatcher) end return setmetatable(dispatcher, metat) end -- step through all active threads function metat.__index:step() -- check which sockets are interesting and act on them local readable, writable = socket.select(self.receiving, self.sending, 1) -- for all readable connections, resume their threads for _, who in ipairs(readable) do if self.context[who] then self.receiving:remove(who) assert(coroutine.resume(self.context[who].thread)) end end -- for all writable connections, do the same for _, who in ipairs(writable) do if self.context[who] then self.sending:remove(who) assert(coroutine.resume(self.context[who].thread)) end end -- politely ask replacement I/O functions in idle threads to -- return reporting a timeout local now = socket.gettime() for who, data in pairs(self.context) do if data.last and now - data.last > TIMEOUT then self.sending:remove(who) self.receiving:remove(who) assert(coroutine.resume(self.context[who].thread, "timeout")) end end end