luasocket/etc/check-links-nb.lua

263 lines
9.0 KiB
Lua
Raw Normal View History

-----------------------------------------------------------------------------
-- Little program that checks links in HTML files, using coroutines and
-- non-blocking I/O. Thus, faster than simpler version of same program
-- LuaSocket sample files
-- Author: Diego Nehab
-- RCS ID: $$
-----------------------------------------------------------------------------
local socket = require("socket")
TIMEOUT = 10
-- we need to yield across calls to protect, so we can't use pcall
-- we borrow and simplify code from coxpcall to reimplement socket.protect
-- before loading http
function socket.protect(f)
return function(...)
local co = coroutine.create(f)
while true do
local results = {coroutine.resume(co, unpack(arg))}
local status = results[1]
table.remove(results, 1)
if not status then
return nil, results[1][1]
end
if coroutine.status(co) == "suspended" then
arg = {coroutine.yield(unpack(results))}
else
return unpack(results)
end
end
end
end
local http = require("socket.http")
local url = require("socket.url")
-- creates a new set data structure
function newset()
local reverse = {}
local set = {}
return setmetatable(set, {__index = {
insert = function(set, value)
if not reverse[value] then
table.insert(set, value)
reverse[value] = table.getn(set)
end
end,
remove = function(set, value)
local index = reverse[value]
if index then
reverse[value] = nil
local top = table.remove(set)
if top ~= value then
reverse[top] = index
set[index] = top
end
end
end
}})
end
local context = {}
local sending = newset()
local receiving = newset()
local nthreads = 0
-- socket.tcp() replacement for non-blocking I/O
-- implements enough functionality to be used with http.request
-- in Lua 5.1, we have coroutine.running to simplify things...
function newcreate(thread)
return function()
-- try to create underlying socket
local tcp, error = socket.tcp()
if not tcp then return nil, error end
-- put it in non-blocking mode right away
tcp:settimeout(0)
local trap = {
-- we ignore settimeout to preserve our 0 timeout
settimeout = function(self, mode, value)
return 1
end,
-- send in non-blocking mode and yield on timeout
send = function(self, data, first, last)
first = (first or 1) - 1
local result, error
while true do
result, error, first = tcp:send(data, first+1, last)
if error == "timeout" then
-- tell dispatcher we want to keep sending
sending:insert(tcp)
-- mark time we started waiting
context[tcp].last = socket.gettime()
-- return control to dispatcher
if coroutine.yield() == "timeout" then
return nil, "timeout"
end
else return result, error, first end
end
end,
-- receive in non-blocking mode and yield on timeout
receive = function(self, pattern)
local error, partial = "timeout", ""
local value
while true do
value, error, partial = tcp:receive(pattern, partial)
if error == "timeout" then
-- tell dispatcher we want to keep receiving
receiving:insert(tcp)
-- mark time we started waiting
context[tcp].last = socket.gettime()
-- return control to dispatcher
if coroutine.yield() == "timeout" then
return nil, "timeout"
end
else return value, error, partial end
end
end,
-- connect in non-blocking mode and yield on timeout
connect = function(self, host, port)
local result, error = tcp:connect(host, port)
if error == "timeout" then
-- tell dispatcher we will be able to write uppon connection
sending:insert(tcp)
-- mark time we started waiting
context[tcp].last = socket.gettime()
-- return control to dispatcher
if coroutine.yield() == "timeout" then
return nil, "timeout"
end
-- when we come back, check if connection was successful
result, error = tcp:connect(host, port)
if result or error == "already connected" then return 1
else return nil, "non-blocking connect failed" end
else return result, error end
end,
close = function(self)
context[tcp] = nil
return tcp:close()
end
}
-- add newly created socket to context
context[tcp] = {
thread = thread,
trap = trap
}
return trap
end
end
-- get the status of a URL, non-blocking
function getstatus(from, link)
local parsed = url.parse(link, {scheme = "file"})
if parsed.scheme == "http" then
local thread = coroutine.create(function(thread, from, link)
local r, c, h, s = http.request{
method = "HEAD",
url = link,
create = newcreate(thread)
}
if c == 200 then io.write('\t', link, '\n')
else io.write('\t', link, ': ', c, '\n') end
nthreads = nthreads - 1
end)
nthreads = nthreads + 1
assert(coroutine.resume(thread, thread, from, link))
end
end
-- dispatch all threads until we are done
function dispatch()
while nthreads > 0 do
-- check which sockets are interesting and act on them
local readable, writable = socket.select(receiving, sending, 1)
-- for all readable connections, resume their threads
for _, who in ipairs(readable) do
if context[who] then
receiving:remove(who)
assert(coroutine.resume(context[who].thread))
end
end
-- for all writable connections, do the same
for _, who in ipairs(writable) do
if context[who] then
sending:remove(who)
assert(coroutine.resume(context[who].thread))
end
end
-- politely ask replacement I/O functions in idle threads to
-- return reporting a timeout
local now = socket.gettime()
for who, data in pairs(context) do
if data.last and now - data.last > TIMEOUT then
assert(coroutine.resume(context[who].thread, "timeout"))
end
end
end
end
function readfile(path)
path = url.unescape(path)
local file, error = io.open(path, "r")
if file then
local body = file:read("*a")
file:close()
return body
else return nil, error end
end
function retrieve(u)
local parsed = url.parse(u, { scheme = "file" })
local body, headers, code, error
local base = u
if parsed.scheme == "http" then
body, code, headers = http.request(u)
if code == 200 then
base = base or headers.location
end
if not body then
error = code
end
elseif parsed.scheme == "file" then
body, error = readfile(parsed.path)
else error = string.format("unhandled scheme '%s'", parsed.scheme) end
return base, body, error
end
function getlinks(body, base)
-- get rid of comments
body = string.gsub(body, "%<%!%-%-.-%-%-%>", "")
local links = {}
-- extract links
body = string.gsub(body, '[Hh][Rr][Ee][Ff]%s*=%s*"([^"]*)"', function(href)
table.insert(links, url.absolute(base, href))
end)
body = string.gsub(body, "[Hh][Rr][Ee][Ff]%s*=%s*'([^']*)'", function(href)
table.insert(links, url.absolute(base, href))
end)
string.gsub(body, "[Hh][Rr][Ee][Ff]%s*=%s*(.-)>", function(href)
table.insert(links, url.absolute(base, href))
end)
return links
end
function checklinks(from)
local base, body, error = retrieve(from)
if not body then print(error) return end
local links = getlinks(body, base)
for _, link in ipairs(links) do
getstatus(from, link)
end
end
arg = arg or {}
if table.getn(arg) < 1 then
print("Usage:\n luasocket check-links.lua {<url>}")
exit()
end
for _, a in ipairs(arg) do
print("Checking ", a)
checklinks(url.absolute("file:", a))
end
dispatch()