Forward server works for multiple tunnels.

Http.lua has been patched to support non-blocking everything.
Makefile for linux has been updated with new names.
This commit is contained in:
Diego Nehab 2005-03-10 02:15:04 +00:00
parent b18021e22d
commit 63e3d7c5b0
3 changed files with 63 additions and 60 deletions

View File

@ -1,11 +1,5 @@
-- load our favourite library -- load our favourite library
local socket = require"socket" local socket = require"socket"
-- timeout before an inactive thread is kicked
local TIMEOUT = 10
-- local address to bind to
local ihost, iport = arg[1] or "localhost", arg[2] or 8080
-- address to forward all data to
local ohost, oport = arg[3] or "localhost", arg[4] or 3128
-- creates a new set data structure -- creates a new set data structure
function newset() function newset()
@ -32,12 +26,44 @@ function newset()
}}) }})
end end
-- timeout before an inactive thread is kicked
local TIMEOUT = 10
-- set of connections waiting to receive data
local receiving = newset() local receiving = newset()
-- set of sockets waiting to send data
local sending = newset() local sending = newset()
-- context for connections and servers
local context = {} local context = {}
-- starts a non-blocking connect -- initializes the forward server
function nconnect(host, port) function init()
if table.getn(arg) < 1 then
print("Usage")
print(" lua forward.lua <iport:ohost:oport> ...")
os.exit(1)
end
-- for each tunnel, start a new server socket
for i, v in ipairs(arg) do
-- capture forwarding parameters
local iport, ohost, oport =
socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)"))
assert(iport, "invalid arguments")
-- create our server socket
local server = assert(socket.bind("*", iport))
server:settimeout(0.1) -- we don't want to be killed by bad luck
-- make sure server is tested for readability
receiving:insert(server)
-- add server context
context[server] = {
thread = coroutine.create(accept),
ohost = ohost,
oport = oport
}
end
end
-- starts a connection in a non-blocking way
function nbkcon(host, port)
local peer, err = socket.tcp() local peer, err = socket.tcp()
if not peer then return nil, err end if not peer then return nil, err end
peer:settimeout(0) peer:settimeout(0)
@ -52,7 +78,6 @@ end
-- gets rid of a client -- gets rid of a client
function kick(who) function kick(who)
if who == server then error("FUDEU") end
if context[who] then if context[who] then
sending:remove(who) sending:remove(who)
receiving:remove(who) receiving:remove(who)
@ -63,7 +88,6 @@ end
-- decides what to do with a thread based on coroutine return -- decides what to do with a thread based on coroutine return
function route(who, status, what) function route(who, status, what)
print(who, status, what)
if status and what then if status and what then
if what == "receiving" then receiving:insert(who) end if what == "receiving" then receiving:insert(who) end
if what == "sending" then sending:insert(who) end if what == "sending" then sending:insert(who) end
@ -73,12 +97,13 @@ end
-- loops accepting connections and creating new threads to deal with them -- loops accepting connections and creating new threads to deal with them
function accept(server) function accept(server)
while true do while true do
print(server, "accepting a new client")
-- accept a new connection and start a new coroutine to deal with it -- accept a new connection and start a new coroutine to deal with it
local client = server:accept() local client = server:accept()
if client then if client then
-- start a new connection, non-blockingly, to the forwarding address -- start a new connection, non-blockingly, to the forwarding address
local peer = nconnect(ohost, oport) local ohost = context[server].ohost
local oport = context[server].oport
local peer = nbkcon(ohost, oport)
if peer then if peer then
context[client] = { context[client] = {
last = socket.gettime(), last = socket.gettime(),
@ -90,7 +115,7 @@ print(server, "accepting a new client")
sending:insert(peer) sending:insert(peer)
context[peer] = { context[peer] = {
peer = client, peer = client,
thread = coroutine.create(check), thread = coroutine.create(chkcon),
last = socket.gettime() last = socket.gettime()
} }
-- put both in non-blocking mode -- put both in non-blocking mode
@ -109,14 +134,12 @@ end
-- forwards all data arriving to the appropriate peer -- forwards all data arriving to the appropriate peer
function forward(who) function forward(who)
while true do while true do
print(who, "getting data")
-- try to read as much as possible -- try to read as much as possible
local data, rec_err, partial = who:receive("*a") local data, rec_err, partial = who:receive("*a")
-- if we had an error other than timeout, abort -- if we had an error other than timeout, abort
if rec_err and rec_err ~= "timeout" then return error(rec_err) end if rec_err and rec_err ~= "timeout" then return error(rec_err) end
-- if we got a timeout, we probably have partial results to send -- if we got a timeout, we probably have partial results to send
data = data or partial data = data or partial
print(who, " got ", string.len(data))
-- renew our timestamp so scheduler sees we are active -- renew our timestamp so scheduler sees we are active
context[who].last = socket.gettime() context[who].last = socket.gettime()
-- forward what we got right away -- forward what we got right away
@ -126,7 +149,6 @@ print(who, " got ", string.len(data))
coroutine.yield("sending") coroutine.yield("sending")
local ret, snd_err local ret, snd_err
local start = 0 local start = 0
print(who, "sending data")
ret, snd_err, start = peer:send(data, start+1) ret, snd_err, start = peer:send(data, start+1)
if ret then break if ret then break
elseif snd_err ~= "timeout" then return error(snd_err) end elseif snd_err ~= "timeout" then return error(snd_err) end
@ -143,51 +165,22 @@ end
-- checks if a connection completed successfully and if it did, starts -- checks if a connection completed successfully and if it did, starts
-- forwarding all data -- forwarding all data
function check(who) function chkcon(who)
local ret, err = who:connected() local ret, err = who:connected()
if ret then if ret then
print(who, "connection completed")
receiving:insert(context[who].peer) receiving:insert(context[who].peer)
context[who].last = socket.gettime() context[who].last = socket.gettime()
print(who, "yielding until there is input data")
coroutine.yield("receiving") coroutine.yield("receiving")
return forward(who) return forward(who)
else return error(err) end else return error(err) end
end end
-- initializes the forward server
function init()
-- socket sets to test for events
-- create our server socket
server = assert(socket.bind(ihost, iport))
server:settimeout(0.1) -- we don't want to be killed by bad luck
-- we initially
receiving:insert(server)
context[server] = { thread = coroutine.create(accept) }
end
-- loop waiting until something happens, restarting the thread to deal with -- loop waiting until something happens, restarting the thread to deal with
-- what happened, and routing it to wait until something else happens -- what happened, and routing it to wait until something else happens
function go() function go()
while true do while true do
print("will select for readability")
for i,v in ipairs(receiving) do
print(i, v)
end
print("will select for writability")
for i,v in ipairs(sending) do
print(i, v)
end
-- check which sockets are interesting and act on them -- check which sockets are interesting and act on them
readable, writable = socket.select(receiving, sending, 3) readable, writable = socket.select(receiving, sending, 3)
print("returned as readable")
for i,v in ipairs(readable) do
print(i, v)
end
print("returned as writable")
for i,v in ipairs(writable) do
print(i, v)
end
-- for all readable connections, resume its thread and route it -- for all readable connections, resume its thread and route it
for _, who in ipairs(readable) do for _, who in ipairs(readable) do
receiving:remove(who) receiving:remove(who)
@ -207,7 +200,6 @@ end
local deathrow local deathrow
for who, data in pairs(context) do for who, data in pairs(context) do
if data.last then if data.last then
print("hung for" , now - data.last, who)
if now - data.last > TIMEOUT then if now - data.last > TIMEOUT then
-- only create table if someone is doomed -- only create table if someone is doomed
deathrow = deathrow or {} deathrow = deathrow or {}
@ -217,13 +209,10 @@ print("hung for" , now - data.last, who)
end end
-- finally kick everyone in deathrow -- finally kick everyone in deathrow
if deathrow then if deathrow then
print("in death row")
for i,v in pairs(deathrow) do
print(i, v)
end
for who in pairs(deathrow) do kick(who) end for who in pairs(deathrow) do kick(who) end
end end
end end
end end
go(init()) init()
go()

View File

@ -32,13 +32,26 @@ USERAGENT = socket.VERSION
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
local metat = { __index = {} } local metat = { __index = {} }
function open(host, port) -- default connect function, respecting the timeout
local c = socket.try(socket.tcp()) local function connect(host, port)
local c, e = socket.tcp()
if not c then return nil, e end
c:settimeout(TIMEOUT)
local r, e = c:connect(host, port or PORT)
if not r then
c:close()
return nil, e
end
return c
end
function open(host, port, user)
-- create socket with user connect function, or with default
local c = socket.try((user or connect)(host, port))
-- create our http request object, pointing to the socket
local h = base.setmetatable({ c = c }, metat) local h = base.setmetatable({ c = c }, metat)
-- make sure the connection gets closed on exception -- make sure the object close gets called on exception
h.try = socket.newtry(function() h:close() end) h.try = socket.newtry(function() h:close() end)
h.try(c:settimeout(TIMEOUT))
h.try(c:connect(host, port or PORT))
return h return h
end end
@ -215,13 +228,14 @@ function tredirect(reqt, headers)
sink = reqt.sink, sink = reqt.sink,
headers = reqt.headers, headers = reqt.headers,
proxy = reqt.proxy, proxy = reqt.proxy,
nredirects = (reqt.nredirects or 0) + 1 nredirects = (reqt.nredirects or 0) + 1,
connect = reqt.connect
} }
end end
function trequest(reqt) function trequest(reqt)
reqt = adjustrequest(reqt) reqt = adjustrequest(reqt)
local h = open(reqt.host, reqt.port) local h = open(reqt.host, reqt.port, reqt.connect)
h:sendrequestline(reqt.method, reqt.uri) h:sendrequestline(reqt.method, reqt.uri)
h:sendheaders(reqt.headers) h:sendheaders(reqt.headers)
h:sendbody(reqt.headers, reqt.source, reqt.step) h:sendbody(reqt.headers, reqt.source, reqt.step)

View File

@ -23,7 +23,7 @@ http.TIMEOUT = 10
local t = socket.gettime() local t = socket.gettime()
host = host or "diego.student.princeton.edu" host = host or "diego.student.princeton.edu"
proxy = proxy or "http://localhost:3128" proxy = proxy or "http://dell-diego:3128"
prefix = prefix or "/luasocket-test" prefix = prefix or "/luasocket-test"
cgiprefix = cgiprefix or "/luasocket-test-cgi" cgiprefix = cgiprefix or "/luasocket-test-cgi"
index_file = "test/index.html" index_file = "test/index.html"