diff --git a/src/main/java/org/cobbzilla/util/io/multi/MultiStream.java b/src/main/java/org/cobbzilla/util/io/multi/MultiStream.java index 74ea54b..a22845b 100644 --- a/src/main/java/org/cobbzilla/util/io/multi/MultiStream.java +++ b/src/main/java/org/cobbzilla/util/io/multi/MultiStream.java @@ -15,6 +15,10 @@ public class MultiStream extends InputStream { private InputStream currentStream; private int streamIndex = 0; private boolean endOfStreams = false; + + private Integer markedStreamIndex = null; + private int markReadLimit = 0; + @Getter private final MultiUnderflowHandler underflow = new MultiUnderflowHandler(); public MultiStream (InputStream r, boolean last) { this(r, last, "no-name"); } @@ -34,6 +38,26 @@ public class MultiStream extends InputStream { public MultiStream (InputStream r) { this(r, false); } + @Override public boolean markSupported() { return currentStream.markSupported(); } + + @Override public synchronized void mark(int readlimit) { + this.markReadLimit = readlimit; + currentStream.mark(readlimit); + markedStreamIndex = streamIndex; + } + + @Override public synchronized void reset() throws IOException { + if (markedStreamIndex == null) throw new IOException("cannot reset stream that was never marked"); + int marked = streamIndex; + while (marked >= markedStreamIndex) { + streams.get(marked).reset(); + marked--; + } + streamIndex = markedStreamIndex; + currentStream = streams.get(streamIndex); + markedStreamIndex = null; + } + public int pendingStreamCount () { return streams.size() - streamIndex; } public MultiStream setUnderflowTimeout(long timeout) { getUnderflow().setUnderflowTimeout(timeout); return this; } @@ -68,9 +92,12 @@ public class MultiStream extends InputStream { underflow.handleUnderflow(); return 0; } - currentStream.close(); + if (markedStreamIndex == null) { + currentStream.close(); + } streamIndex++; currentStream = streams.get(streamIndex); + if (markedStreamIndex != null) currentStream.mark(markReadLimit); if (log.isTraceEnabled()) log.trace(logPrefix()+"read(byte): end of all stream, advanced to next stream ("+currentStream.getClass().getSimpleName()+")"); return read(); @@ -92,9 +119,12 @@ public class MultiStream extends InputStream { underflow.handleUnderflow(); return 0; } - currentStream.close(); + if (markedStreamIndex == null) { + currentStream.close(); + } streamIndex++; currentStream = streams.get(streamIndex); + if (markedStreamIndex != null) currentStream.mark(markReadLimit); if (log.isTraceEnabled()) log.trace(logPrefix()+"read(byte[]): end of all stream, advanced to next stream ("+currentStream.getClass().getSimpleName()+")"); return read(buf, off, len); diff --git a/src/main/java/org/cobbzilla/util/io/regex/RegexFilterReader.java b/src/main/java/org/cobbzilla/util/io/regex/RegexFilterReader.java index 39a64ac..2f5dbbb 100644 --- a/src/main/java/org/cobbzilla/util/io/regex/RegexFilterReader.java +++ b/src/main/java/org/cobbzilla/util/io/regex/RegexFilterReader.java @@ -8,6 +8,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.cobbzilla.util.system.Bytes; import java.io.*; +import java.nio.charset.Charset; import java.util.concurrent.atomic.AtomicReference; import static org.apache.commons.lang3.ArrayUtils.addAll; @@ -36,7 +37,11 @@ public class RegexFilterReader extends BufferedReader { public RegexFilterReader(InputStream in, RegexStreamFilter filter) { this(in, DEFAULT_BUFFER_SIZE, filter); } public RegexFilterReader(InputStream in, int bufsiz, RegexStreamFilter filter) { - super(new InputStreamReader(in, UTF8cs), bufsiz); + this(in, UTF8cs, bufsiz, filter); + } + + public RegexFilterReader(InputStream in, Charset charset, int bufsiz, RegexStreamFilter filter) { + super(new InputStreamReader(in, charset), bufsiz); this.bufsiz = bufsiz; this.filter = filter; } diff --git a/src/test/java/org/cobbzilla/util/io/regex/RegexFilterReaderTest.java b/src/test/java/org/cobbzilla/util/io/regex/RegexFilterReaderTest.java index 9d84a07..8fc8437 100644 --- a/src/test/java/org/cobbzilla/util/io/regex/RegexFilterReaderTest.java +++ b/src/test/java/org/cobbzilla/util/io/regex/RegexFilterReaderTest.java @@ -5,6 +5,7 @@ import org.apache.commons.io.IOUtils; import org.cobbzilla.util.io.BlockedInputStream; import org.cobbzilla.util.io.multi.MultiReader; import org.cobbzilla.util.io.multi.MultiStream; +import org.cobbzilla.util.system.Bytes; import org.junit.Test; import java.io.*; @@ -14,6 +15,7 @@ import static org.cobbzilla.util.daemon.ZillaRuntime.background; import static org.cobbzilla.util.daemon.ZillaRuntime.die; import static org.cobbzilla.util.io.multi.MultiUnderflowHandlerMonitor.DEFAULT_UNDERFLOW_MONITOR; import static org.cobbzilla.util.io.regex.RegexReplacementFilter.DEFAULT_PREFIX_REPLACEMENT_WITH_MATCH; +import static org.cobbzilla.util.string.StringUtil.UTF8cs; import static org.cobbzilla.util.system.Sleep.sleep; import static org.junit.Assert.*; @@ -223,4 +225,142 @@ public class RegexFilterReaderTest { assertTrue("expected multi stream failed to get data2 output", result.toString().contains(" bogus data2 ")); } + @Test public void testSimpleMultiStreamMark() throws Exception { + final String data1 = "dat1\n".repeat(1024); + final InputStream stream1 = new ByteArrayInputStream(data1.getBytes()); + final MultiStream multiStream = new MultiStream(stream1, true); + + multiStream.mark(data1.length()); + final byte[] buffer = new byte[(int) (Bytes.KB)]; + final String initialData = readStream(multiStream, buffer, buffer.length); + assertTrue(data1.startsWith(initialData)); + + multiStream.reset(); + final ByteArrayOutputStream out = new ByteArrayOutputStream(data1.length()); + IOUtils.copyLarge(multiStream, out); + assertEquals("expected output == data1", data1, out.toString()); + } + + @Test public void testMultiStreamMark() throws Exception { + final String data1 = "dt1\n".repeat(1024); + final String data2 = "dt2\n".repeat(1024); + final String allData = data1 + data2; + final InputStream stream1 = new ByteArrayInputStream(data1.getBytes()); + final InputStream stream2 = new ByteArrayInputStream(data2.getBytes()); + + final MultiStream multiStream = new MultiStream(stream1); + multiStream.addLastStream(stream2); + final byte[] buffer = new byte[(int) (2 * Bytes.KB)]; + + // read 5K of data + final int initialReadSize = (int) (5 * Bytes.KB); + final String initialData = readStream(multiStream, buffer, initialReadSize); + assertEquals(initialReadSize, initialData.length()); + assertTrue("expected initial read to start with dt1", initialData.startsWith(data1)); + assertTrue("expected initial read to contain some of dt2", initialData.contains("dt2\n")); + + // then mark + multiStream.mark(allData.length()); + + // then read some more + final String moreData = readStream(multiStream, buffer, buffer.length); + assertEquals(buffer.length, moreData.length()); + + // verify what we read was the remainder of data2 + assertTrue("expected remainder read to contain dt2", moreData.contains("dt2\n")); + assertFalse("expected remainder read to NOT contain dt1", moreData.contains("dt1")); + + // reset the stream + multiStream.reset(); + + // now read the remainder + final ByteArrayOutputStream out = new ByteArrayOutputStream(allData.length()); + IOUtils.copyLarge(multiStream, out); + final String remainderData = out.toString(UTF8cs); + + assertEquals( "expected initial + remainder == all", allData, initialData + remainderData); + } + + @Test public void testMultiStreamExtendedMark() throws Exception { + final String data1 = "dt1\n".repeat(1024); + final String data2 = "dt2\n".repeat(1024); + final String data3 = "dt3\n".repeat(1024); + final String data4 = "dt4\n".repeat(1024); + final String allData = data1 + data2 + data3 + data4; + final InputStream stream1 = new ByteArrayInputStream(data1.getBytes()); + final InputStream stream2 = new ByteArrayInputStream(data2.getBytes()); + final InputStream stream3 = new ByteArrayInputStream(data3.getBytes()); + final InputStream stream4 = new ByteArrayInputStream(data4.getBytes()); + + final MultiStream multiStream = new MultiStream(stream1); + multiStream.addStream(stream2); + multiStream.addStream(stream3); + multiStream.addLastStream(stream4); + final byte[] buffer = new byte[(int) (2 * Bytes.KB)]; + + // read 5K of data + final int initialReadSize = (int) (5 * Bytes.KB); + final String initialData = readStream(multiStream, buffer, initialReadSize); + assertEquals(initialReadSize, initialData.length()); + assertTrue("expected initial read to start with dt1", initialData.startsWith(data1)); + assertTrue("expected initial read to contain some of dt2", initialData.contains("dt2\n")); + + // then mark + multiStream.mark(allData.length()); + + // then read some more + final String moreData = readStream(multiStream, buffer, data2.length()); + assertEquals(data2.length(), moreData.length()); + + // verify what we read was the remainder of data2 + assertTrue("expected remainder read to contain dt2", moreData.contains("dt2\n")); + assertTrue("expected remainder read to contain dt3", moreData.contains("dt3\n")); + assertFalse("expected remainder read to NOT contain dt4", moreData.contains("dt4")); + assertFalse("expected remainder read to NOT contain dt1", moreData.contains("dt1")); + + // reset the stream and re-mark + multiStream.reset(); + multiStream.mark(allData.length()); + + // do the same read again, should get the same data + final String moreData2 = readStream(multiStream, buffer, data2.length()); + assertEquals(data2.length(), moreData.length()); + + // verify what we read was the remainder of data2 + assertTrue("expected remainder read to contain dt2", moreData2.contains("dt2\n")); + assertTrue("expected remainder read to contain dt3", moreData2.contains("dt3\n")); + assertFalse("expected remainder read to NOT contain dt4", moreData2.contains("dt4")); + assertFalse("expected remainder read to NOT contain dt1", moreData2.contains("dt1")); + + // verify reset read was the same + assertEquals("expected to read the same data in moreData and moreData2", moreData, moreData2); + + // reset the stream again + multiStream.reset(); + + // now read the all the remainder + final ByteArrayOutputStream out = new ByteArrayOutputStream(allData.length()); + IOUtils.copyLarge(multiStream, out); + final String remainderData = out.toString(UTF8cs); + + assertEquals( "expected initial + remainder == all", allData, initialData + remainderData); + } + + public String readStream(MultiStream multiStream, byte[] buffer, int size) throws IOException { + final StringBuilder b = new StringBuilder(size); + int bytesRead = 0; + int count; + while ((bytesRead < size + && ((count = multiStream.read(buffer, 0, readSize(buffer, size, bytesRead))) != -1))) { + bytesRead += count; + b.append(new String(buffer, 0, count)); + } + return b.toString(); + } + + private int readSize(byte[] buffer, int initialReadSize, int bytesRead) { + if (bytesRead + buffer.length <= initialReadSize) return buffer.length; + return buffer.length - (initialReadSize - bytesRead); + } + }