diff --git a/samples/forward.lua b/samples/forward.lua index de651b4..c3f0605 100644 --- a/samples/forward.lua +++ b/samples/forward.lua @@ -1,11 +1,5 @@ -- load our favourite library local socket = require"socket" --- timeout before an inactive thread is kicked -local TIMEOUT = 10 --- local address to bind to -local ihost, iport = arg[1] or "localhost", arg[2] or 8080 --- address to forward all data to -local ohost, oport = arg[3] or "localhost", arg[4] or 3128 -- creates a new set data structure function newset() @@ -32,12 +26,44 @@ function newset() }}) end +-- timeout before an inactive thread is kicked +local TIMEOUT = 10 +-- set of connections waiting to receive data local receiving = newset() +-- set of sockets waiting to send data local sending = newset() +-- context for connections and servers local context = {} --- starts a non-blocking connect -function nconnect(host, port) +-- initializes the forward server +function init() + if table.getn(arg) < 1 then + print("Usage") + print(" lua forward.lua ...") + os.exit(1) + end + -- for each tunnel, start a new server socket + for i, v in ipairs(arg) do + -- capture forwarding parameters + local iport, ohost, oport = + socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)")) + assert(iport, "invalid arguments") + -- create our server socket + local server = assert(socket.bind("*", iport)) + server:settimeout(0.1) -- we don't want to be killed by bad luck + -- make sure server is tested for readability + receiving:insert(server) + -- add server context + context[server] = { + thread = coroutine.create(accept), + ohost = ohost, + oport = oport + } + end +end + +-- starts a connection in a non-blocking way +function nbkcon(host, port) local peer, err = socket.tcp() if not peer then return nil, err end peer:settimeout(0) @@ -52,7 +78,6 @@ end -- gets rid of a client function kick(who) -if who == server then error("FUDEU") end if context[who] then sending:remove(who) receiving:remove(who) @@ -63,7 +88,6 @@ end -- decides what to do with a thread based on coroutine return function route(who, status, what) -print(who, status, what) if status and what then if what == "receiving" then receiving:insert(who) end if what == "sending" then sending:insert(who) end @@ -73,12 +97,13 @@ end -- loops accepting connections and creating new threads to deal with them function accept(server) while true do -print(server, "accepting a new client") -- accept a new connection and start a new coroutine to deal with it local client = server:accept() if client then -- start a new connection, non-blockingly, to the forwarding address - local peer = nconnect(ohost, oport) + local ohost = context[server].ohost + local oport = context[server].oport + local peer = nbkcon(ohost, oport) if peer then context[client] = { last = socket.gettime(), @@ -90,7 +115,7 @@ print(server, "accepting a new client") sending:insert(peer) context[peer] = { peer = client, - thread = coroutine.create(check), + thread = coroutine.create(chkcon), last = socket.gettime() } -- put both in non-blocking mode @@ -109,14 +134,12 @@ end -- forwards all data arriving to the appropriate peer function forward(who) while true do -print(who, "getting data") -- try to read as much as possible local data, rec_err, partial = who:receive("*a") -- if we had an error other than timeout, abort if rec_err and rec_err ~= "timeout" then return error(rec_err) end -- if we got a timeout, we probably have partial results to send data = data or partial -print(who, " got ", string.len(data)) -- renew our timestamp so scheduler sees we are active context[who].last = socket.gettime() -- forward what we got right away @@ -126,7 +149,6 @@ print(who, " got ", string.len(data)) coroutine.yield("sending") local ret, snd_err local start = 0 -print(who, "sending data") ret, snd_err, start = peer:send(data, start+1) if ret then break elseif snd_err ~= "timeout" then return error(snd_err) end @@ -143,51 +165,22 @@ end -- checks if a connection completed successfully and if it did, starts -- forwarding all data -function check(who) +function chkcon(who) local ret, err = who:connected() if ret then -print(who, "connection completed") receiving:insert(context[who].peer) context[who].last = socket.gettime() -print(who, "yielding until there is input data") coroutine.yield("receiving") return forward(who) else return error(err) end end --- initializes the forward server -function init() - -- socket sets to test for events - -- create our server socket - server = assert(socket.bind(ihost, iport)) - server:settimeout(0.1) -- we don't want to be killed by bad luck - -- we initially - receiving:insert(server) - context[server] = { thread = coroutine.create(accept) } -end - -- loop waiting until something happens, restarting the thread to deal with -- what happened, and routing it to wait until something else happens function go() while true do -print("will select for readability") -for i,v in ipairs(receiving) do - print(i, v) -end -print("will select for writability") -for i,v in ipairs(sending) do - print(i, v) -end -- check which sockets are interesting and act on them readable, writable = socket.select(receiving, sending, 3) -print("returned as readable") -for i,v in ipairs(readable) do - print(i, v) -end -print("returned as writable") -for i,v in ipairs(writable) do - print(i, v) -end -- for all readable connections, resume its thread and route it for _, who in ipairs(readable) do receiving:remove(who) @@ -207,7 +200,6 @@ end local deathrow for who, data in pairs(context) do if data.last then -print("hung for" , now - data.last, who) if now - data.last > TIMEOUT then -- only create table if someone is doomed deathrow = deathrow or {} @@ -217,13 +209,10 @@ print("hung for" , now - data.last, who) end -- finally kick everyone in deathrow if deathrow then -print("in death row") -for i,v in pairs(deathrow) do - print(i, v) -end for who in pairs(deathrow) do kick(who) end end end end -go(init()) +init() +go() diff --git a/src/http.lua b/src/http.lua index 1dff11a..38b93e2 100644 --- a/src/http.lua +++ b/src/http.lua @@ -32,13 +32,26 @@ USERAGENT = socket.VERSION ----------------------------------------------------------------------------- local metat = { __index = {} } -function open(host, port) - local c = socket.try(socket.tcp()) +-- default connect function, respecting the timeout +local function connect(host, port) + local c, e = socket.tcp() + if not c then return nil, e end + c:settimeout(TIMEOUT) + local r, e = c:connect(host, port or PORT) + if not r then + c:close() + return nil, e + end + return c +end + +function open(host, port, user) + -- create socket with user connect function, or with default + local c = socket.try((user or connect)(host, port)) + -- create our http request object, pointing to the socket local h = base.setmetatable({ c = c }, metat) - -- make sure the connection gets closed on exception + -- make sure the object close gets called on exception h.try = socket.newtry(function() h:close() end) - h.try(c:settimeout(TIMEOUT)) - h.try(c:connect(host, port or PORT)) return h end @@ -215,13 +228,14 @@ function tredirect(reqt, headers) sink = reqt.sink, headers = reqt.headers, proxy = reqt.proxy, - nredirects = (reqt.nredirects or 0) + 1 + nredirects = (reqt.nredirects or 0) + 1, + connect = reqt.connect } end function trequest(reqt) reqt = adjustrequest(reqt) - local h = open(reqt.host, reqt.port) + local h = open(reqt.host, reqt.port, reqt.connect) h:sendrequestline(reqt.method, reqt.uri) h:sendheaders(reqt.headers) h:sendbody(reqt.headers, reqt.source, reqt.step) diff --git a/test/httptest.lua b/test/httptest.lua index 2335fcb..8862ceb 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -23,7 +23,7 @@ http.TIMEOUT = 10 local t = socket.gettime() host = host or "diego.student.princeton.edu" -proxy = proxy or "http://localhost:3128" +proxy = proxy or "http://dell-diego:3128" prefix = prefix or "/luasocket-test" cgiprefix = cgiprefix or "/luasocket-test-cgi" index_file = "test/index.html"