support multiple filters in ltn12.{sink,source}.chain()

This commit is contained in:
Fabien Fleutot 2013-06-18 11:01:46 +02:00
parent 22cd5833fc
commit 480a818bf0
2 changed files with 35 additions and 3 deletions

View File

@ -139,7 +139,9 @@ function source.rewind(src)
end
end
function source.chain(src, f)
-- chains a source with one or several filter(s)
function source.chain(src, f, ...)
if ... then f=filter.chain(f, ...) end
base.assert(src and f)
local last_in, last_out = "", ""
local state = "feeding"
@ -254,8 +256,13 @@ function sink.error(err)
end
end
-- chains a sink with a filter
function sink.chain(f, snk)
-- chains a sink with one or several filter(s)
function sink.chain(f, snk, ...)
if ... then
local args = { f, snk, ... }
snk = table.remove(args, #args)
f = filter.chain(unpack(args))
end
base.assert(f and snk)
return function(chunk, err)
if chunk ~= "" then

View File

@ -191,6 +191,21 @@ assert(table.concat(t) == s, "mismatch")
assert(filter(nil, 1), "filter not empty")
print("ok")
--------------------------------
io.write("testing source.chain (with several filters): ")
local function double(x) -- filter turning "ABC" into "AABBCC"
if not x then return end
local b={}
for k in x:gmatch'.' do table.insert(b, k..k) end
return table.concat(b)
end
source = ltn12.source.string(s)
source = ltn12.source.chain(source, double, double, double)
sink, t = ltn12.sink.table()
assert(ltn12.pump.all(source, sink), "returned error")
assert(table.concat(t) == double(double(double(s))), "mismatch")
print("ok")
--------------------------------
io.write("testing source.chain (with split) and sink.chain (with merge): ")
source = ltn12.source.string(s)
@ -205,6 +220,15 @@ assert(filter(nil, 1), "filter not empty")
assert(filter2(nil, 1), "filter2 not empty")
print("ok")
--------------------------------
io.write("testing sink.chain (with several filters): ")
source = ltn12.source.string(s)
sink, t = ltn12.sink.table()
sink = ltn12.sink.chain(double, double, double, sink)
assert(ltn12.pump.all(source, sink), "returned error")
assert(table.concat(t) == double(double(double(s))), "mismatch")
print("ok")
--------------------------------
io.write("testing filter.chain (and sink.chain, with split, merge): ")
source = ltn12.source.string(s)
@ -272,3 +296,4 @@ assert(filter3(nil, 1), "filter3 not empty")
assert(filter4(nil, 1), "filter4 not empty")
assert(filter5(nil, 1), "filter5 not empty")
print("ok")