From b18021e22d5c192c88372889def02e6edb21b5be Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Thu, 10 Mar 2005 01:49:17 +0000 Subject: [PATCH] Added forward server. --- samples/forward.lua | 229 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 samples/forward.lua diff --git a/samples/forward.lua b/samples/forward.lua new file mode 100644 index 0000000..de651b4 --- /dev/null +++ b/samples/forward.lua @@ -0,0 +1,229 @@ +-- load our favourite library +local socket = require"socket" +-- timeout before an inactive thread is kicked +local TIMEOUT = 10 +-- local address to bind to +local ihost, iport = arg[1] or "localhost", arg[2] or 8080 +-- address to forward all data to +local ohost, oport = arg[3] or "localhost", arg[4] or 3128 + +-- creates a new set data structure +function newset() + local reverse = {} + local set = {} + return setmetatable(set, {__index = { + insert = function(set, value) + if not reverse[value] then + table.insert(set, value) + reverse[value] = table.getn(set) + end + end, + remove = function(set, value) + local index = reverse[value] + if index then + reverse[value] = nil + local top = table.remove(set) + if top ~= value then + reverse[top] = index + set[index] = top + end + end + end + }}) +end + +local receiving = newset() +local sending = newset() +local context = {} + +-- starts a non-blocking connect +function nconnect(host, port) + local peer, err = socket.tcp() + if not peer then return nil, err end + peer:settimeout(0) + local ret, err = peer:connect(host, port) + if ret then return peer end + if err ~= "timeout" then + peer:close() + return nil, err + end + return peer +end + +-- gets rid of a client +function kick(who) +if who == server then error("FUDEU") end + if context[who] then + sending:remove(who) + receiving:remove(who) + context[who] = nil + who:close() + end +end + +-- decides what to do with a thread based on coroutine return +function route(who, status, what) +print(who, status, what) + if status and what then + if what == "receiving" then receiving:insert(who) end + if what == "sending" then sending:insert(who) end + else kick(who) end +end + +-- loops accepting connections and creating new threads to deal with them +function accept(server) + while true do +print(server, "accepting a new client") + -- accept a new connection and start a new coroutine to deal with it + local client = server:accept() + if client then + -- start a new connection, non-blockingly, to the forwarding address + local peer = nconnect(ohost, oport) + if peer then + context[client] = { + last = socket.gettime(), + thread = coroutine.create(forward), + peer = peer, + } + -- make sure peer will be tested for writing in the next select + -- round, which means the connection attempt has finished + sending:insert(peer) + context[peer] = { + peer = client, + thread = coroutine.create(check), + last = socket.gettime() + } + -- put both in non-blocking mode + client:settimeout(0) + peer:settimeout(0) + else + -- otherwise just dump the client + client:close() + end + end + -- tell scheduler we are done for now + coroutine.yield("receiving") + end +end + +-- forwards all data arriving to the appropriate peer +function forward(who) + while true do +print(who, "getting data") + -- try to read as much as possible + local data, rec_err, partial = who:receive("*a") + -- if we had an error other than timeout, abort + if rec_err and rec_err ~= "timeout" then return error(rec_err) end + -- if we got a timeout, we probably have partial results to send + data = data or partial +print(who, " got ", string.len(data)) + -- renew our timestamp so scheduler sees we are active + context[who].last = socket.gettime() + -- forward what we got right away + local peer = context[who].peer + while true do + -- tell scheduler we need to wait until we can send something + coroutine.yield("sending") + local ret, snd_err + local start = 0 +print(who, "sending data") + ret, snd_err, start = peer:send(data, start+1) + if ret then break + elseif snd_err ~= "timeout" then return error(snd_err) end + -- renew our timestamp so scheduler sees we are active + context[who].last = socket.gettime() + end + -- if we are done receiving, we are done with this side of the + -- connection + if not rec_err then return nil end + -- otherwise tell schedule we have to wait for more data to arrive + coroutine.yield("receiving") + end +end + +-- checks if a connection completed successfully and if it did, starts +-- forwarding all data +function check(who) + local ret, err = who:connected() + if ret then +print(who, "connection completed") + receiving:insert(context[who].peer) + context[who].last = socket.gettime() +print(who, "yielding until there is input data") + coroutine.yield("receiving") + return forward(who) + else return error(err) end +end + +-- initializes the forward server +function init() + -- socket sets to test for events + -- create our server socket + server = assert(socket.bind(ihost, iport)) + server:settimeout(0.1) -- we don't want to be killed by bad luck + -- we initially + receiving:insert(server) + context[server] = { thread = coroutine.create(accept) } +end + +-- loop waiting until something happens, restarting the thread to deal with +-- what happened, and routing it to wait until something else happens +function go() + while true do +print("will select for readability") +for i,v in ipairs(receiving) do + print(i, v) +end +print("will select for writability") +for i,v in ipairs(sending) do + print(i, v) +end + -- check which sockets are interesting and act on them + readable, writable = socket.select(receiving, sending, 3) +print("returned as readable") +for i,v in ipairs(readable) do + print(i, v) +end +print("returned as writable") +for i,v in ipairs(writable) do + print(i, v) +end + -- for all readable connections, resume its thread and route it + for _, who in ipairs(readable) do + receiving:remove(who) + if context[who] then + route(who, coroutine.resume(context[who].thread, who)) + end + end + -- for all writable connections, do the same + for _, who in ipairs(writable) do + sending:remove(who) + if context[who] then + route(who, coroutine.resume(context[who].thread, who)) + end + end + -- put all inactive threads in death row + local now = socket.gettime() + local deathrow + for who, data in pairs(context) do + if data.last then +print("hung for" , now - data.last, who) + if now - data.last > TIMEOUT then + -- only create table if someone is doomed + deathrow = deathrow or {} + deathrow[who] = true + end + end + end + -- finally kick everyone in deathrow + if deathrow then +print("in death row") +for i,v in pairs(deathrow) do + print(i, v) +end + for who in pairs(deathrow) do kick(who) end + end + end +end + +go(init())