From 480a818bf0ef6de32527ba14fc2bb27e754d0612 Mon Sep 17 00:00:00 2001 From: Fabien Fleutot Date: Tue, 18 Jun 2013 11:01:46 +0200 Subject: [PATCH] support multiple filters in ltn12.{sink,source}.chain() --- src/ltn12.lua | 13 ++++++++++--- test/ltn12test.lua | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/ltn12.lua b/src/ltn12.lua index 5b10f56..1014de2 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -139,7 +139,9 @@ function source.rewind(src) end end -function source.chain(src, f) +-- chains a source with one or several filter(s) +function source.chain(src, f, ...) + if ... then f=filter.chain(f, ...) end base.assert(src and f) local last_in, last_out = "", "" local state = "feeding" @@ -254,8 +256,13 @@ function sink.error(err) end end --- chains a sink with a filter -function sink.chain(f, snk) +-- chains a sink with one or several filter(s) +function sink.chain(f, snk, ...) + if ... then + local args = { f, snk, ... } + snk = table.remove(args, #args) + f = filter.chain(unpack(args)) + end base.assert(f and snk) return function(chunk, err) if chunk ~= "" then diff --git a/test/ltn12test.lua b/test/ltn12test.lua index 74a45e8..e3f85fb 100644 --- a/test/ltn12test.lua +++ b/test/ltn12test.lua @@ -191,6 +191,21 @@ assert(table.concat(t) == s, "mismatch") assert(filter(nil, 1), "filter not empty") print("ok") +-------------------------------- +io.write("testing source.chain (with several filters): ") +local function double(x) -- filter turning "ABC" into "AABBCC" + if not x then return end + local b={} + for k in x:gmatch'.' do table.insert(b, k..k) end + return table.concat(b) +end +source = ltn12.source.string(s) +source = ltn12.source.chain(source, double, double, double) +sink, t = ltn12.sink.table() +assert(ltn12.pump.all(source, sink), "returned error") +assert(table.concat(t) == double(double(double(s))), "mismatch") +print("ok") + -------------------------------- io.write("testing source.chain (with split) and sink.chain (with merge): ") source = ltn12.source.string(s) @@ -205,6 +220,15 @@ assert(filter(nil, 1), "filter not empty") assert(filter2(nil, 1), "filter2 not empty") print("ok") +-------------------------------- +io.write("testing sink.chain (with several filters): ") +source = ltn12.source.string(s) +sink, t = ltn12.sink.table() +sink = ltn12.sink.chain(double, double, double, sink) +assert(ltn12.pump.all(source, sink), "returned error") +assert(table.concat(t) == double(double(double(s))), "mismatch") +print("ok") + -------------------------------- io.write("testing filter.chain (and sink.chain, with split, merge): ") source = ltn12.source.string(s) @@ -272,3 +296,4 @@ assert(filter3(nil, 1), "filter3 not empty") assert(filter4(nil, 1), "filter4 not empty") assert(filter5(nil, 1), "filter5 not empty") print("ok") +