-- load our favourite library local socket = require"socket" -- creates a new set data structure function newset() local reverse = {} local set = {} return setmetatable(set, {__index = { insert = function(set, value) if not reverse[value] then table.insert(set, value) reverse[value] = table.getn(set) end end, remove = function(set, value) local index = reverse[value] if index then reverse[value] = nil local top = table.remove(set) if top ~= value then reverse[top] = index set[index] = top end end end }}) end -- timeout before an inactive thread is kicked local TIMEOUT = 10 -- set of connections waiting to receive data local receiving = newset() -- set of sockets waiting to send data local sending = newset() -- context for connections and servers local context = {} function wait(who, what) if what == "input" then receiving:insert(who) else sending:insert(who) end context[who].last = socket.gettime() coroutine.yield() end -- initializes the forward server function init() if table.getn(arg) < 1 then print("Usage") print(" lua forward.lua ...") os.exit(1) end -- for each tunnel, start a new server socket for i, v in ipairs(arg) do -- capture forwarding parameters local iport, ohost, oport = socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)")) assert(iport, "invalid arguments") -- create our server socket local server = assert(socket.bind("*", iport)) server:settimeout(0) -- we don't want to be killed by bad luck -- make sure server is tested for readability receiving:insert(server) -- add server context context[server] = { thread = coroutine.create(accept), ohost = ohost, oport = oport } end end -- starts a connection in a non-blocking way function connect(who, host, port) who:settimeout(0) local ret, err = who:connect(host, port) if not ret and err == "timeout" then wait(who, "output") ret, err = who:connect(host, port) end if not ret then kick(who) kick(context[who].peer) else return forward(who) end end -- gets rid of a client function kick(who) if who and context[who] then sending:remove(who) receiving:remove(who) context[who] = nil who:close() end end -- loops accepting connections and creating new threads to deal with them function accept(server) while true do -- accept a new connection and start a new coroutine to deal with it local client = server:accept() if client then -- create contexts for client and peer. local peer, err = socket.tcp() if peer then context[client] = { last = socket.gettime(), -- client goes straight to forwarding loop thread = coroutine.create(forward), peer = peer, } context[peer] = { last = socket.gettime(), peer = client, -- peer first tries to connect to forwarding address thread = coroutine.create(connect), last = socket.gettime() } -- resume peer and client so they can do their thing local ohost = context[server].ohost local oport = context[server].oport coroutine.resume(context[peer].thread, peer, ohost, oport) coroutine.resume(context[client].thread, client) else print(err) client:close() end end -- tell scheduler we are done for now wait(server, "input") end end -- forwards all data arriving to the appropriate peer function forward(who) who:settimeout(0) while true do -- wait until we have something to read wait(who, "input") -- try to read as much as possible local data, rec_err, partial = who:receive("*a") -- if we had an error other than timeout, abort if rec_err and rec_err ~= "timeout" then return kick(who) end -- if we got a timeout, we probably have partial results to send data = data or partial -- forward what we got right away local peer = context[who].peer while true do -- tell scheduler we need to wait until we can send something wait(who, "output") local ret, snd_err local start = 0 ret, snd_err, start = peer:send(data, start+1) if ret then break elseif snd_err ~= "timeout" then return kick(who) end end -- if we are done receiving, we are done if not rec_err then kick(who) kick(peer) end end end -- loop waiting until something happens, restarting the thread to deal with -- what happened, and routing it to wait until something else happens function go() while true do -- check which sockets are interesting and act on them readable, writable = socket.select(receiving, sending) -- for all readable connections, resume its thread for _, who in ipairs(readable) do receiving:remove(who) coroutine.resume(context[who].thread, who) end -- for all writable connections, do the same for _, who in ipairs(writable) do sending:remove(who) coroutine.resume(context[who].thread, who) end -- put all inactive threads in death row local now = socket.gettime() local deathrow for who, data in pairs(context) do if data.peer then if now - data.last > TIMEOUT then -- only create table if at least one is doomed deathrow = deathrow or {} deathrow[who] = true end end end -- finally kick everyone in deathrow if deathrow then for who in pairs(deathrow) do kick(who) end end end end init() go()