diff --git a/TODO b/TODO index 7dadfd9..dfb9178 100644 --- a/TODO +++ b/TODO @@ -1,14 +1,12 @@ - +use wim's filter.chain or something better make sure standard libraries are "required" by modules before use. eliminate globals from namespaces created by module(). ftp.send/recv return bytes transfered? new scheme to choose family/protocol of object to create change ltn13 to make sure drawbacks are obvious - check discussion -make sure errors not thrown by try() are not caught by protect() use mike's "don't set to blocking before closing unless needed" patch? take a look at DB's smtp patch (add "extra argument" table) -move wsocket.c:sock_send kludge to buffer.c:sendraw (probably)? optmize aux_getgroupudata (Mike idea) make aux_newclass receive upvalues @@ -25,7 +23,10 @@ testar os options! - proteger ou atomizar o conjunto (timedout, receive), (timedout, send) - inet_ntoa também é uma merda. -*use wim's filter.chain or something better +*close wasn't returning 1 +*make sure errors not thrown by try() are not caught by protect() +*move wsocket.c:sock_send kludge to buffer.c:sendraw? +*bug on UDP sendto. *fix PROXY in http.lua *use new distribution scheme *create the getstats method. diff --git a/doc/ftp.html b/doc/ftp.html index 158b402..0b52007 100644 --- a/doc/ftp.html +++ b/doc/ftp.html @@ -163,7 +163,7 @@ local ltn12 = require("ltn12") local url = require("url") -- a function that returns a directory listing -function ls(u) +function nlst(u) local t = {} local p = url.parse(u) p.command = "nlst" diff --git a/etc/dict.lua b/etc/dict.lua index 8c197f5..716b8db 100644 --- a/etc/dict.lua +++ b/etc/dict.lua @@ -8,11 +8,14 @@ ----------------------------------------------------------------------------- -- Load required modules ----------------------------------------------------------------------------- +local base = require("base") +local string = require("string") +local table = require("table") local socket = require("socket") local url = require("socket.url") local tp = require("socket.tp") -module("socket.dict") +local dict = module("socket.dict") ----------------------------------------------------------------------------- -- Globals @@ -28,7 +31,7 @@ local metat = { __index = {} } function open(host, port) local tp = socket.try(tp.connect(host or HOST, port or PORT, TIMEOUT)) - return setmetatable({tp = tp}, metat) + return base.setmetatable({tp = tp}, metat) end function metat.__index:greet() @@ -37,7 +40,8 @@ end function metat.__index:check(ok) local code, status = socket.try(self.tp:check(ok)) - return code, tonumber(socket.skip(2, string.find(status, "^%d%d%d (%d*)"))) + return code, + base.tonumber(socket.skip(2, string.find(status, "^%d%d%d (%d*)"))) end function metat.__index:getdef() @@ -116,7 +120,7 @@ local function parse(u) if cmd == "m" then arg = string.gsub(arg, "^:([^:]*)", function(f) t.strat = there(f) end) end - string.gsub(arg, ":([^:]*)$", function(f) t.n = tonumber(f) end) + string.gsub(arg, ":([^:]*)$", function(f) t.n = base.tonumber(f) end) return t end @@ -143,6 +147,8 @@ local function sget(u) end get = socket.protect(function(gett) - if type(gett) == "string" then return sget(gett) + if base.type(gett) == "string" then return sget(gett) else return tget(gett) end end) + +base.setmetatable(dict, nil) diff --git a/etc/lp.lua b/etc/lp.lua index b69cc02..a5327d1 100644 --- a/etc/lp.lua +++ b/etc/lp.lua @@ -9,9 +9,12 @@ if you have any questions: RFC 1179 ]] -- make sure LuaSocket is loaded +local io = require("io") +local base = require("base") +local string = require("string") local socket = require("socket") local ltn12 = require("ltn12") -local test = socket.try +local lp = module("socket.lp") -- default port PORT = 515 @@ -28,7 +31,7 @@ local function connect(localhost, option) local localport = 721 local done, err repeat - skt = test(socket.tcp()) + skt = socket.try(socket.tcp()) try(skt:settimeout(30)) done, err = skt:bind(localhost, localport) if not done then @@ -37,8 +40,8 @@ local function connect(localhost, option) skt = nil else break end until localport > 731 - test(skt, err) - else skt = test(socket.tcp()) end + socket.try(skt, err) + else skt = socket.try(socket.tcp()) end try(skt:connect(host, port)) return { skt = skt, try = try } end @@ -241,9 +244,9 @@ local format_codes = { -- lp.send send = socket.protect(function(file, option) - test(file, "invalid file name") - test(option and type(option) == "table", "invalid options") - local fh = test(io.open(file,"rb")) + socket.try(file, "invalid file name") + socket.try(option and base.type(option) == "table", "invalid options") + local fh = socket.try(io.open(file,"rb")) local datafile_size = fh:seek("end") -- get total size fh:seek("set") -- go back to start of file local localhost = socket.dns.gethostname() or os.getenv("COMPUTERNAME") @@ -270,11 +273,11 @@ send = socket.protect(function(file, option) lpfile, ctlfn); -- mandatory part of ctl file if (option.banner) then cfile = cfile .. 'L'..user..'\10' end - if (option.indent) then cfile = cfile .. 'I'..tonumber(option.indent)..'\10' end + if (option.indent) then cfile = cfile .. 'I'..base.tonumber(option.indent)..'\10' end if (option.mail) then cfile = cfile .. 'M'..string.sub((option.mail),1,128)..'\10' end if (fmt == 'p' and option.title) then cfile = cfile .. 'T'..string.sub((option.title),1,79)..'\10' end if ((fmt == 'p' or fmt == 'l' or fmt == 'f') and option.width) then - cfile = cfile .. 'W'..tonumber(option,width)..'\10' + cfile = cfile .. 'W'..base.tonumber(option,width)..'\10' end con.skt:settimeout(option.timeout or 65) @@ -314,3 +317,5 @@ query = socket.protect(function(p) con.skt:close() return data end) + +base.setmetatable(lp, nil) diff --git a/etc/tftp.lua b/etc/tftp.lua index f4af8bc..83a08b9 100644 --- a/etc/tftp.lua +++ b/etc/tftp.lua @@ -8,11 +8,14 @@ ----------------------------------------------------------------------------- -- Load required files ----------------------------------------------------------------------------- +local base = require("base") +local table = require("table") +local math = require("math") +local string = require("string") local socket = require("socket") local ltn12 = require("ltn12") local url = require("socket.url") - -module("socket.tftp") +local tftp = module("socket.tftp") ----------------------------------------------------------------------------- -- Program constants @@ -73,16 +76,18 @@ end local function tget(gett) local retries, dgram, sent, datahost, dataport, code local last = 0 + socket.try(gett.host, "missing host") local con = socket.try(socket.udp()) local try = socket.newtry(function() con:close() end) -- convert from name to ip if needed gett.host = try(socket.dns.toip(gett.host)) con:settimeout(1) -- first packet gives data host/port to be used for data transfers + local path = string.gsub(gett.path or "", "^/", "") + path = url.unescape(path) retries = 0 repeat - sent = try(con:sendto(RRQ(gett.path, "octet"), - gett.host, gett.port)) + sent = try(con:sendto(RRQ(path, "octet"), gett.host, gett.port)) dgram, datahost, dataport = con:receivefrom() retries = retries + 1 until dgram or datahost ~= "timeout" or retries > 5 @@ -144,6 +149,8 @@ local function sget(u) end get = socket.protect(function(gett) - if type(gett) == "string" then return sget(gett) + if base.type(gett) == "string" then return sget(gett) else return tget(gett) end end) + +base.setmetatable(tftp, nil) diff --git a/luasocket.vcproj b/luasocket.vcproj index 6e2da8a..4699498 100644 --- a/luasocket.vcproj +++ b/luasocket.vcproj @@ -12,17 +12,18 @@ + RelativePath="..\..\lib\liblua.lib"> + RelativePath="..\..\lib\liblualib.lib"> diff --git a/mime.vcproj b/mime.vcproj index 51ce05e..9fe7aa8 100644 --- a/mime.vcproj +++ b/mime.vcproj @@ -12,17 +12,18 @@ + RelativePath="..\..\lib\liblua.lib"> + RelativePath="..\..\lib\liblualib.lib"> diff --git a/src/buffer.c b/src/buffer.c index dbd5d2c..1b1b791 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -158,6 +158,7 @@ int buf_isempty(p_buf buf) { /*-------------------------------------------------------------------------*\ * Sends a block of data (unbuffered) \*-------------------------------------------------------------------------*/ +#define STEPSIZE 8192 static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent) { p_io io = buf->io; p_tm tm = buf->tm; @@ -165,7 +166,8 @@ static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent) { int err = IO_DONE; while (total < count && err == IO_DONE) { size_t done; - err = io->send(io->ctx, data+total, count-total, &done, tm); + size_t step = (count-total <= STEPSIZE)? count-total: STEPSIZE; + err = io->send(io->ctx, data+total, step, &done, tm); total += done; } *sent = total; diff --git a/src/except.c b/src/except.c index 80d7e5d..dabaf19 100644 --- a/src/except.c +++ b/src/except.c @@ -29,11 +29,21 @@ static luaL_reg func[] = { /*-------------------------------------------------------------------------*\ * Try factory \*-------------------------------------------------------------------------*/ +static void wrap(lua_State *L) { + lua_newtable(L); + lua_pushnumber(L, 1); + lua_pushvalue(L, -3); + lua_settable(L, -3); + lua_insert(L, -2); + lua_pop(L, 1); +} + static int finalize(lua_State *L) { if (!lua_toboolean(L, 1)) { lua_pushvalue(L, lua_upvalueindex(1)); lua_pcall(L, 0, 0, 0); lua_settop(L, 2); + wrap(L); lua_error(L); return 0; } else return lua_gettop(L); @@ -54,13 +64,23 @@ static int global_newtry(lua_State *L) { /*-------------------------------------------------------------------------*\ * Protect factory \*-------------------------------------------------------------------------*/ +static int unwrap(lua_State *L) { + if (lua_istable(L, -1)) { + lua_pushnumber(L, 1); + lua_gettable(L, -2); + lua_pushnil(L); + lua_insert(L, -2); + return 1; + } else return 0; +} + static int protected_(lua_State *L) { lua_pushvalue(L, lua_upvalueindex(1)); lua_insert(L, 1); if (lua_pcall(L, lua_gettop(L) - 1, LUA_MULTRET, 0) != 0) { - lua_pushnil(L); - lua_insert(L, 1); - return 2; + if (unwrap(L)) return 2; + else lua_error(L); + return 0; } else return lua_gettop(L); } diff --git a/src/ftp.lua b/src/ftp.lua index 9902c88..4529acd 100644 --- a/src/ftp.lua +++ b/src/ftp.lua @@ -8,13 +8,15 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +local base = require("base") +local table = require("table") +local string = require("string") +local math = require("math") local socket = require("socket") local url = require("socket.url") local tp = require("socket.tp") - local ltn12 = require("ltn12") - -module("socket.ftp") +local ftp = module("socket.ftp") ----------------------------------------------------------------------------- -- Program constants @@ -35,7 +37,7 @@ local metat = { __index = {} } function open(server, port) local tp = socket.try(tp.connect(server, port or PORT, TIMEOUT)) - local f = setmetatable({ tp = tp }, metat) + local f = base.setmetatable({ tp = tp }, metat) -- make sure everything gets closed in an exception f.try = socket.newtry(function() f:close() end) return f @@ -102,7 +104,8 @@ function metat.__index:send(sendt) -- we just get the data connection into self.data if self.pasvt then self:pasvconnect() end -- get the transfer argument and command - local argument = sendt.argument or string.gsub(sendt.path, "^/", "") + local argument = sendt.argument or + url.unescape(string.gsub(sendt.path or "", "^/", "")) if argument == "" then argument = nil end local command = sendt.command or "stor" -- send the transfer command and check the reply @@ -134,7 +137,8 @@ end function metat.__index:receive(recvt) self.try(self.pasvt or self.server, "need port or pasv first") if self.pasvt then self:pasvconnect() end - local argument = recvt.argument or string.gsub(recvt.path, "^/", "") + local argument = recvt.argument or + url.unescape(string.gsub(recvt.path or "", "^/", "")) if argument == "" then argument = nil end local command = recvt.command or "retr" self.try(self.tp:command(command, argument)) @@ -182,7 +186,19 @@ end ----------------------------------------------------------------------------- -- High level FTP API ----------------------------------------------------------------------------- +function override(t) + if t.url then + u = url.parse(t.url) + for i,v in base.pairs(t) do + u[i] = v + end + return u + else return t end +end + local function tput(putt) + putt = override(putt) + socket.try(putt.host, "missing hostname") local f = open(putt.host, putt.port) f:greet() f:login(putt.user, putt.password) @@ -201,8 +217,8 @@ local default = { local function parse(u) local t = socket.try(url.parse(u, default)) - socket.try(t.scheme == "ftp", "invalid scheme '" .. t.scheme .. "'") - socket.try(t.host, "invalid host") + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") local pat = "^type=(.)$" if t.params then t.type = socket.skip(2, string.find(t.params, pat)) @@ -219,11 +235,13 @@ local function sput(u, body) end put = socket.protect(function(putt, body) - if type(putt) == "string" then return sput(putt, body) + if base.type(putt) == "string" then return sput(putt, body) else return tput(putt) end end) local function tget(gett) + gett = override(gett) + socket.try(gett.host, "missing hostname") local f = open(gett.host, gett.port) f:greet() f:login(gett.user, gett.password) @@ -242,7 +260,22 @@ local function sget(u) return table.concat(t) end +command = socket.protect(function(cmdt) + cmdt = override(cmdt) + socket.try(cmdt.host, "missing hostname") + socket.try(cmdt.command, "missing command") + local f = open(cmdt.host, cmdt.port) + f:greet() + f:login(cmdt.user, cmdt.password) + f.try(f.tp:command(cmdt.command, cmdt.argument)) + if cmdt.check then f.try(f.tp:check(cmdt.check)) end + f:quit() + return f:close() +end) + get = socket.protect(function(gett) - if type(gett) == "string" then return sget(gett) + if base.type(gett) == "string" then return sget(gett) else return tget(gett) end end) + +base.setmetatable(ftp, nil) diff --git a/src/http.lua b/src/http.lua index b265650..a15ea69 100644 --- a/src/http.lua +++ b/src/http.lua @@ -12,8 +12,10 @@ local socket = require("socket") local url = require("socket.url") local ltn12 = require("ltn12") local mime = require("mime") - -module("socket.http") +local string = require("string") +local base = require("base") +local table = require("table") +local http = module("socket.http") ----------------------------------------------------------------------------- -- Program constants @@ -32,7 +34,7 @@ local metat = { __index = {} } function open(host, port) local c = socket.try(socket.tcp()) - local h = setmetatable({ c = c }, metat) + local h = base.setmetatable({ c = c }, metat) -- make sure the connection gets closed on exception h.try = socket.newtry(function() h:close() end) h.try(c:settimeout(TIMEOUT)) @@ -46,7 +48,7 @@ function metat.__index:sendrequestline(method, uri) end function metat.__index:sendheaders(headers) - for i, v in pairs(headers) do + for i, v in base.pairs(headers) do self.try(self.c:send(i .. ": " .. v .. "\r\n")) end -- mark end of request headers @@ -66,7 +68,7 @@ end function metat.__index:receivestatusline() local status = self.try(self.c:receive()) local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) - return self.try(tonumber(code), status) + return self.try(base.tonumber(code), status) end function metat.__index:receiveheaders() @@ -97,11 +99,11 @@ end function metat.__index:receivebody(headers, sink, step) sink = sink or ltn12.sink.null() step = step or ltn12.pump.step - local length = tonumber(headers["content-length"]) + local length = base.tonumber(headers["content-length"]) local TE = headers["transfer-encoding"] local mode = "default" -- connection close if TE and TE ~= "identity" then mode = "http-chunked" - elseif tonumber(headers["content-length"]) then mode = "by-length" end + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end return self.try(ltn12.pump.all(socket.source(mode, self.c, length), sink, step)) end @@ -159,9 +161,10 @@ local default = { local function adjustrequest(reqt) -- parse url if provided local nreqt = reqt.url and url.parse(reqt.url, default) or {} + local t = url.parse(reqt.url, default) -- explicit components override url for i,v in reqt do nreqt[i] = reqt[i] end - socket.try(nreqt.host, "invalid host '" .. tostring(nreqt.host) .. "'") + socket.try(nreqt.host, "invalid host '" .. base.tostring(nreqt.host) .. "'") -- compute uri if user hasn't overriden nreqt.uri = reqt.uri or adjusturi(nreqt) -- ajust host and port if there is a proxy @@ -253,6 +256,8 @@ local function srequest(u, body) end request = socket.protect(function(reqt, body) - if type(reqt) == "string" then return srequest(reqt, body) + if base.type(reqt) == "string" then return srequest(reqt, body) else return trequest(reqt) end end) + +base.setmetatable(http, nil) diff --git a/src/ltn12.lua b/src/ltn12.lua index ed39ec8..43c2755 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -8,7 +8,11 @@ ----------------------------------------------------------------------------- -- Declare module ----------------------------------------------------------------------------- -module("ltn12") +local string = require("string") +local table = require("table") +local base = require("base") +local coroutine = require("coroutine") +local ltn12 = module("ltn12") filter = {} source = {} @@ -23,7 +27,7 @@ BLOCKSIZE = 2048 ----------------------------------------------------------------------------- -- returns a high level filter that cycles a low-level filter function filter.cycle(low, ctx, extra) - assert(low) + base.assert(low) return function(chunk) local ret ret, ctx = low(ctx, chunk, extra) @@ -121,7 +125,7 @@ end -- turns a fancy source into a simple source function source.simplify(src) - assert(src) + base.assert(src) return function() local chunk, err_or_new = src() src = err_or_new or src @@ -145,7 +149,7 @@ end -- creates rewindable source function source.rewind(src) - assert(src) + base.assert(src) local t = {} return function(chunk) if not chunk then @@ -160,7 +164,7 @@ end -- chains a source with a filter function source.chain(src, f) - assert(src and f) + base.assert(src and f) local co = coroutine.create(function() while true do local chunk, err = src() @@ -215,7 +219,7 @@ end -- turns a fancy sink into a simple sink function sink.simplify(snk) - assert(snk) + base.assert(snk) return function(chunk, err) local ret, err_or_new = snk(chunk, err) if not ret then return nil, err_or_new end @@ -254,7 +258,7 @@ end -- chains a sink with a filter function sink.chain(f, snk) - assert(f and snk) + base.assert(f and snk) return function(chunk, err) local filtered = f(chunk) local done = chunk and "" @@ -279,10 +283,12 @@ end -- pumps all data from a source to a sink, using a step function function pump.all(src, snk, step) - assert(src and snk) + base.assert(src and snk) step = step or pump.step while true do local ret, err = step(src, snk) if not ret then return not err, err end end end + +base.setmetatable(ltn12, nil) diff --git a/src/mime.lua b/src/mime.lua index 3dbcf79..712600c 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -8,9 +8,10 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- -module("mime") -local mime = require("lmime") +local base = require("base") local ltn12 = require("ltn12") +local mime = require("lmime") +module("mime") -- encode, decode and wrap algorithm tables mime.encodet = {} @@ -20,11 +21,11 @@ mime.wrapt = {} -- creates a function that chooses a filter by name from a given table local function choose(table) return function(name, opt1, opt2) - if type(name) ~= "string" then + if base.type(name) ~= "string" then name, opt1, opt2 = "default", name, opt1 end local f = table[name or "nil"] - if not f then error("unknown key (" .. tostring(name) .. ")", 3) + if not f then error("unknown key (" .. base.tostring(name) .. ")", 3) else return f(opt1, opt2) end end end @@ -74,3 +75,5 @@ end function mime.stuff() return ltn12.filter.cycle(dot, 2) end + +base.setmetatable(mime, nil) diff --git a/src/smtp.lua b/src/smtp.lua index 974d222..9d49178 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -8,13 +8,16 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +local base = require("base") +local coroutine = require("coroutine") +local string = require("string") +local math = require("math") +local os = require("os") local socket = require("socket") local tp = require("socket.tp") - local ltn12 = require("ltn12") local mime = require("mime") - -module("socket.smtp") +local smtp = module("socket.smtp") ----------------------------------------------------------------------------- -- Program constants @@ -98,8 +101,8 @@ end -- send message or throw an exception function metat.__index:send(mailt) self:mail(mailt.from) - if type(mailt.rcpt) == "table" then - for i,v in ipairs(mailt.rcpt) do + if base.type(mailt.rcpt) == "table" then + for i,v in base.ipairs(mailt.rcpt) do self:rcpt(v) end else @@ -110,7 +113,7 @@ end function open(server, port) local tp = socket.try(tp.connect(server or SERVER, port or PORT, TIMEOUT)) - local s = setmetatable({tp = tp}, metat) + local s = base.setmetatable({tp = tp}, metat) -- make sure tp is closed if we get an exception s.try = socket.newtry(function() if s.tp:command("QUIT") then s.tp:check("2..") end @@ -145,7 +148,7 @@ local function send_multipart(mesgt) coroutine.yield("\r\n") end -- send each part separated by a boundary - for i, m in ipairs(mesgt.body) do + for i, m in base.ipairs(mesgt.body) do coroutine.yield("\r\n--" .. bd .. "\r\n") send_message(m) end @@ -191,7 +194,7 @@ end -- yield the headers one by one local function send_headers(mesgt) if mesgt.headers then - for i,v in pairs(mesgt.headers) do + for i,v in base.pairs(mesgt.headers) do coroutine.yield(i .. ':' .. v .. "\r\n") end end @@ -200,8 +203,8 @@ end -- message source function send_message(mesgt) send_headers(mesgt) - if type(mesgt.body) == "table" then send_multipart(mesgt) - elseif type(mesgt.body) == "function" then send_source(mesgt) + if base.type(mesgt.body) == "table" then send_multipart(mesgt) + elseif base.type(mesgt.body) == "function" then send_source(mesgt) else send_string(mesgt) end end @@ -241,3 +244,5 @@ send = socket.protect(function(mailt) s:quit() return s:close() end) + +base.setmetatable(smtp, nil) diff --git a/src/tcp.c b/src/tcp.c index 746c4b6..618f4ce 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -233,7 +233,8 @@ static int meth_close(lua_State *L) { p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); sock_destroy(&tcp->sock); - return 0; + lua_pushnumber(L, 1); + return 1; } /*-------------------------------------------------------------------------*\ diff --git a/src/tp.lua b/src/tp.lua index ada00d2..0a671fb 100644 --- a/src/tp.lua +++ b/src/tp.lua @@ -8,10 +8,12 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +local base = require("base") +local string = require("string") local socket = require("socket") local ltn12 = require("ltn12") -module("socket.tp") +local tp = module("socket.tp") ----------------------------------------------------------------------------- -- Program constants @@ -47,22 +49,27 @@ local metat = { __index = {} } function metat.__index:check(ok) local code, reply = get_reply(self.c) if not code then return nil, reply end - if type(ok) ~= "function" then - if type(ok) == "table" then - for i, v in ipairs(ok) do - if string.find(code, v) then return tonumber(code), reply end + if base.type(ok) ~= "function" then + if base.type(ok) == "table" then + for i, v in base.ipairs(ok) do + if string.find(code, v) then + return base.tonumber(code), reply + end end return nil, reply else - if string.find(code, ok) then return tonumber(code), reply + if string.find(code, ok) then return base.tonumber(code), reply else return nil, reply end end - else return ok(tonumber(code), reply) end + else return ok(base.tonumber(code), reply) end end function metat.__index:command(cmd, arg) - if arg then return self.c:send(cmd .. " " .. arg.. "\r\n") - else return self.c:send(cmd .. "\r\n") end + if arg then + return self.c:send(cmd .. " " .. arg.. "\r\n") + else + return self.c:send(cmd .. "\r\n") + end end function metat.__index:sink(snk, pat) @@ -111,5 +118,7 @@ function connect(host, port, timeout) c:close() return nil, e end - return setmetatable({c = c}, metat) + return base.setmetatable({c = c}, metat) end + +base.setmetatable(tp, nil) diff --git a/src/udp.c b/src/udp.c index 97a6169..7a60080 100644 --- a/src/udp.c +++ b/src/udp.c @@ -288,7 +288,8 @@ static int meth_setpeername(lua_State *L) { static int meth_close(lua_State *L) { p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1); sock_destroy(&udp->sock); - return 0; + lua_pushnumber(L, 1); + return 1; } /*-------------------------------------------------------------------------*\ diff --git a/src/url.lua b/src/url.lua index efe7254..08081f0 100644 --- a/src/url.lua +++ b/src/url.lua @@ -8,7 +8,10 @@ ----------------------------------------------------------------------------- -- Declare module ----------------------------------------------------------------------------- -module("socket.url") +local string = require("string") +local base = require("base") +local table = require("table") +local url = module("socket.url") ----------------------------------------------------------------------------- -- Encodes a string into its escaped hexadecimal representation @@ -18,7 +21,7 @@ module("socket.url") -- escaped representation of string binary ----------------------------------------------------------------------------- function escape(s) - return string.gsub(s, "(.)", function(c) + return string.gsub(s, "([^A-Za-z0-9_])", function(c) return string.format("%%%02x", string.byte(c)) end) end @@ -33,7 +36,7 @@ end ----------------------------------------------------------------------------- local function make_set(t) local s = {} - for i = 1, table.getn(t) do + for i,v in base.ipairs(t) do s[t[i]] = 1 end return s @@ -62,7 +65,7 @@ end ----------------------------------------------------------------------------- function unescape(s) return string.gsub(s, "%%(%x%x)", function(hex) - return string.char(tonumber(hex, 16)) + return string.char(base.tonumber(hex, 16)) end) end @@ -191,7 +194,7 @@ end -- corresponding absolute url ----------------------------------------------------------------------------- function absolute(base_url, relative_url) - local base = type(base_url) == "table" and base_url or parse(base_url) + local base = base.type(base_url) == "table" and base_url or parse(base_url) local relative = parse(relative_url) if not base then return relative_url elseif not relative then return base_url @@ -269,3 +272,5 @@ function build_path(parsed, unsafe) if parsed.is_absolute then path = "/" .. path end return path end + +base.setmetatable(url, nil) diff --git a/src/wsocket.c b/src/wsocket.c index 1b169ed..0294dce 100644 --- a/src/wsocket.c +++ b/src/wsocket.c @@ -180,9 +180,10 @@ int sock_accept(p_sock ps, p_sock pa, SA *addr, socklen_t *len, p_tm tm) { /*-------------------------------------------------------------------------*\ * Send with timeout +* On windows, if you try to send 10MB, the OS will buffer EVERYTHING +* this can take an awful lot of time and we will end up blocked. +* Therefore, whoever calls this function should not pass a huge buffer. \*-------------------------------------------------------------------------*/ -/* has to be larger than UDP_DATAGRAMSIZE !!!*/ -#define MAXCHUNK (64*1024) int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, p_tm tm) { int err; @@ -192,9 +193,7 @@ int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, p_tm tm) *sent = 0; for ( ;; ) { /* try to send something */ - /* on windows, if you try to send 10MB, the OS will buffer EVERYTHING - * this can take an awful lot of time and we will end up blocked. */ - int put = send(*ps, data, (count < MAXCHUNK)? (int)count: MAXCHUNK, 0); + int put = send(*ps, data, count, 0); /* if we sent something, we are done */ if (put > 0) { *sent = put; @@ -221,7 +220,7 @@ int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, if (*ps == SOCK_INVALID) return IO_CLOSED; *sent = 0; for ( ;; ) { - int put = send(*ps, data, (int) count, 0); + int put = sendto(*ps, data, (int) count, 0, addr, len); if (put > 0) { *sent = put; return IO_DONE; @@ -298,13 +297,13 @@ void sock_setnonblocking(p_sock ps) { int sock_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) { *hp = gethostbyaddr(addr, len, AF_INET); if (*hp) return IO_DONE; - else return h_errno; + else return WSAGetLastError(); } int sock_gethostbyname(const char *addr, struct hostent **hp) { *hp = gethostbyname(addr); if (*hp) return IO_DONE; - else return h_errno; + else return WSAGetLastError(); } /*-------------------------------------------------------------------------*\ diff --git a/test/dicttest.lua b/test/dicttest.lua index fcdd61f..a37ec8d 100644 --- a/test/dicttest.lua +++ b/test/dicttest.lua @@ -1,3 +1,3 @@ local dict = require"socket.dict" -for i,v in dict.get("dict://localhost/d:banana") do print(v) end +for i,v in dict.get("dict://dell-diego/d:banana") do print(v) end diff --git a/test/ftptest.lua b/test/ftptest.lua index f578c82..ef1bf0f 100644 --- a/test/ftptest.lua +++ b/test/ftptest.lua @@ -1,103 +1,110 @@ -local similar = function(s1, s2) - return - string.lower(string.gsub(s1, "%s", "")) == - string.lower(string.gsub(s2, "%s", "")) +local socket = require("socket") +local ftp = require("socket.ftp") +local url = require("socket.url") +local ltn12 = require("ltn12") + +-- override protection to make sure we see all errors +--socket.protect = function(s) return s end + +dofile("testsupport.lua") + +local host, port, index_file, index, back, err, ret + +local t = socket.gettime() + +host = host or "diego.student.princeton.edu" +index_file = "test/index.html" + + +-- a function that returns a directory listing +local function nlst(u) + local t = {} + local p = url.parse(u) + p.command = "nlst" + p.sink = ltn12.sink.table(t) + local r, e = ftp.get(p) + return r and table.concat(t), e end -local readfile = function(name) - local f = io.open(name, "r") - if not f then return nil end - local s = f:read("*a") - f:close() - return s +-- function that removes a remote file +local function dele(u) + local p = url.parse(u) + p.command = "dele" + p.argument = string.gsub(p.path, "^/", "") + if p.argumet == "" then p.argument = nil end + p.check = 250 + return ftp.command(p) end -local capture = function(cmd) - local f = io.popen(cmd) - if not f then return nil end - local s = f:read("*a") - f:close() - return s -end - -local check = function(v, e, o) - e = e or "failed!" - o = o or "ok" - if v then print(o) - else print(e) os.exit() end -end - --- needs an account luasocket:password --- and some directories and files in ~ftp - -local index, err, saved, back, expected - -local t = socket.time() - -index = readfile("test/index.html") +-- read index with CRLF convention +index = readfile(index_file) io.write("testing wrong scheme: ") -back, err = socket.ftp.get("wrong://banana.com/lixo") -check(not back and err == "unknown scheme 'wrong'", err) +back, err = ftp.get("wrong://banana.com/lixo") +assert(not back and err == "wrong scheme 'wrong'", err) +print("ok") io.write("testing invalid url: ") -back, err = socket.ftp.get("localhost/dir1/index.html;type=i") -local c, e = socket.connect("", 21) -check(not back and err == e, err) - -io.write("testing anonymous file upload: ") -os.remove("/var/ftp/pub/index.up.html") -ret, err = socket.ftp.put("ftp://localhost/pub/index.up.html;type=i", index) -saved = readfile("/var/ftp/pub/index.up.html") -check(ret and not err and saved == index, err) +back, err = ftp.get("localhost/dir1/index.html;type=i") +assert(not back and err) +print("ok") io.write("testing anonymous file download: ") -back, err = socket.ftp.get("ftp://localhost/pub/index.up.html;type=i") -check(not err and back == index, err) +back, err = socket.ftp.get("ftp://" .. host .. "/pub/index.html;type=i") +assert(not err and back == index, err) +print("ok") -io.write("testing no directory changes: ") -back, err = socket.ftp.get("ftp://localhost/index.html;type=i") -check(not err and back == index, err) +io.write("erasing before upload: ") +ret, err = dele("ftp://luasocket:password@" .. host .. "/index.up.html") +if not ret then print(err) +else print("ok") end -io.write("testing multiple directory changes: ") -back, err = socket.ftp.get("ftp://localhost/pub/dir1/dir2/dir3/dir4/dir5/index.html;type=i") -check(not err and back == index, err) +io.write("testing upload: ") +ret, err = ftp.put("ftp://luasocket:password@" .. host .. "/index.up.html;type=i", index) +assert(ret and not err, err) +print("ok") -io.write("testing authenticated upload: ") -os.remove("/home/luasocket/index.up.html") -ret, err = socket.ftp.put("ftp://luasocket:password@localhost/index.up.html;type=i", index) -saved = readfile("/home/luasocket/index.up.html") -check(ret and not err and saved == index, err) +io.write("downloading uploaded file: ") +back, err = ftp.get("ftp://luasocket:password@" .. host .. "/index.up.html;type=i") +assert(ret and not err and index == back, err) +print("ok") -io.write("testing authenticated download: ") -back, err = socket.ftp.get("ftp://luasocket:password@localhost/index.up.html;type=i") -check(not err and back == index, err) +io.write("erasing after upload/download: ") +ret, err = dele("ftp://luasocket:password@" .. host .. "/index.up.html") +assert(ret and not err, err) +print("ok") io.write("testing weird-character translation: ") -back, err = socket.ftp.get("ftp://luasocket:password@localhost/%2fvar/ftp/pub/index.html;type=i") -check(not err and back == index, err) +back, err = ftp.get("ftp://luasocket:password@" .. host .. "/%23%3f;type=i") +assert(not err and back == index, err) +print("ok") io.write("testing parameter overriding: ") -back, err = socket.ftp.get { - url = "//stupid:mistake@localhost/index.html", +local back = {} +ret, err = ftp.get{ + url = "//stupid:mistake@" .. host .. "/index.html", user = "luasocket", password = "password", - type = "i" + type = "i", + sink = ltn12.sink.table(back) } -check(not err and back == index, err) +assert(ret and not err and table.concat(back) == index, err) +print("ok") io.write("testing upload denial: ") -ret, err = socket.ftp.put("ftp://localhost/index.up.html;type=a", index) -check(err, err) +ret, err = ftp.put("ftp://" .. host .. "/index.up.html;type=a", index) +assert(not ret and err, "should have failed") +print(err) io.write("testing authentication failure: ") -ret, err = socket.ftp.put("ftp://luasocket:wrong@localhost/index.html;type=a", index) +ret, err = ftp.get("ftp://luasocket:wrong@".. host .. "/index.html;type=a") +assert(not ret and err, "should have failed") print(err) -check(not ret and err, err) io.write("testing wrong file: ") -back, err = socket.ftp.get("ftp://localhost/index.wrong.html;type=a") -check(err, err) +back, err = ftp.get("ftp://".. host .. "/index.wrong.html;type=a") +assert(not back and err, "should have failed") +print(err) print("passed all tests") -print(string.format("done in %.2fs", socket.time() - t)) +print(string.format("done in %.2fs", socket.gettime() - t)) diff --git a/test/httptest.lua b/test/httptest.lua index 71021a4..cade837 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -23,7 +23,7 @@ http.TIMEOUT = 10 local t = socket.gettime() host = host or "diego.student.princeton.edu" -proxy = proxy or "http://localhost:3128" +proxy = proxy or "http://dell-diego.cs.princeton.edu:3128" prefix = prefix or "/luasocket-test" cgiprefix = cgiprefix or "/luasocket-test-cgi" index_file = "test/index.html" diff --git a/test/testclnt.lua b/test/testclnt.lua index f838533..c2c782c 100644 --- a/test/testclnt.lua +++ b/test/testclnt.lua @@ -1,7 +1,7 @@ local socket = require"socket" host = host or "localhost" -port = port or "8080" +port = port or "8383" function pass(...) local s = string.format(unpack(arg)) @@ -590,7 +590,6 @@ test_mixed(200) test_mixed(17) test_mixed(1) - test("binary line") test_rawline(1) test_rawline(17) @@ -630,14 +629,12 @@ test_nonblocking(200) test_nonblocking(17) test_nonblocking(1) - test("total timeout on send") test_totaltimeoutsend(800091, 1, 3) test_totaltimeoutsend(800091, 2, 3) test_totaltimeoutsend(800091, 5, 2) test_totaltimeoutsend(800091, 3, 1) - test("total timeout on receive") test_totaltimeoutreceive(800091, 1, 3) test_totaltimeoutreceive(800091, 2, 3) diff --git a/test/testmesg.lua b/test/testmesg.lua index e29b3cb..5350921 100644 --- a/test/testmesg.lua +++ b/test/testmesg.lua @@ -1,5 +1,5 @@ -- load the smtp support and its friends -local smtp = require("smtp") +local smtp = require("socket.smtp") local mime = require("mime") local ltn12 = require("ltn12") @@ -48,6 +48,11 @@ source = smtp.message{ } } +--[[ +sink = ltn12.sink.file(io.stdout) +ltn12.pump.all(source, sink) +]] + -- finally send it r, e = smtp.send{ rcpt = {"", diff --git a/test/testsrvr.lua b/test/testsrvr.lua index 23e3850..2408e83 100644 --- a/test/testsrvr.lua +++ b/test/testsrvr.lua @@ -1,6 +1,6 @@ socket = require("socket"); host = host or "localhost"; -port = port or "8080"; +port = port or "8383"; server = assert(socket.bind(host, port)); ack = "\n"; while 1 do diff --git a/test/tftptest.lua b/test/tftptest.lua index edb6484..35078e8 100644 --- a/test/tftptest.lua +++ b/test/tftptest.lua @@ -12,7 +12,7 @@ function readfile(file) return a end -host = host or "localhost" +host = host or "diego.student.princeton.edu" retrieved, err = tftp.get("tftp://" .. host .."/index.html") assert(not err, err) original = readfile("test/index.html")