From 3d74934b6f057212174ee64a7216266373c93626 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Wed, 26 Jun 2013 23:57:21 +0200 Subject: [PATCH] Add a length-prefixed protobuf connection server/client. This forms the base for payment channel network connections, with a simple single-threaded server which accepts new connections, a simple single-threaded client which makes a single connection to a server, and a ProtobufParser which handles data generated by a connection, splits them into Protobufs and provides a reasonable interface to users who wish to create/accept protobuf-based connections. --- .../niowrapper/MessageWriteTarget.java | 25 + .../protocols/niowrapper/ProtobufClient.java | 123 +++++ .../protocols/niowrapper/ProtobufParser.java | 240 ++++++++ .../niowrapper/ProtobufParserFactory.java | 32 ++ .../protocols/niowrapper/ProtobufServer.java | 217 ++++++++ .../protocols/niowrapper/NioWrapperTest.java | 511 ++++++++++++++++++ 6 files changed, 1148 insertions(+) create mode 100644 core/src/main/java/com/google/bitcoin/protocols/niowrapper/MessageWriteTarget.java create mode 100644 core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufClient.java create mode 100644 core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParser.java create mode 100644 core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParserFactory.java create mode 100644 core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java create mode 100644 core/src/test/java/com/google/bitcoin/protocols/niowrapper/NioWrapperTest.java diff --git a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/MessageWriteTarget.java b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/MessageWriteTarget.java new file mode 100644 index 00000000..b3bcba46 --- /dev/null +++ b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/MessageWriteTarget.java @@ -0,0 +1,25 @@ +/* + * Copyright 2013 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.bitcoin.protocols.niowrapper; + +/** + * A target to which messages can be written/connection can be closed + */ +abstract class MessageWriteTarget { + abstract void writeBytes(byte[] message); + abstract void closeConnection(); +} diff --git a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufClient.java b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufClient.java new file mode 100644 index 00000000..d1c412b2 --- /dev/null +++ b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufClient.java @@ -0,0 +1,123 @@ +/* + * Copyright 2013 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.bitcoin.protocols.niowrapper; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousCloseException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; +import javax.annotation.Nonnull; + +import org.slf4j.LoggerFactory; + +import static com.google.common.base.Preconditions.checkState; + +/** + * Creates a simple connection to a server using a {@link ProtobufParser} to process data. + */ +public class ProtobufClient extends MessageWriteTarget { + private static final org.slf4j.Logger log = LoggerFactory.getLogger(ProtobufClient.class); + + private static final int BUFFER_SIZE_LOWER_BOUND = 4096; + private static final int BUFFER_SIZE_UPPER_BOUND = 65536; + + @Nonnull private final ByteBuffer dbuf; + @Nonnull private final SocketChannel sc; + + /** + *

Creates a new client to the given server address using the given {@link ProtobufParser} to decode the data. + * The given parser MUST be unique to this object. This does not block while waiting for the connection to + * open, but will call either the {@link ProtobufParser#connectionOpen()} or {@link ProtobufParser#connectionClosed()} + * callback on the created network event processing thread.

+ * + * @param connectTimeoutMillis The connect timeout set on the connection (in milliseconds). 0 is interpreted as no + * timeout. + */ + public ProtobufClient(final InetSocketAddress serverAddress, final ProtobufParser parser, + final int connectTimeoutMillis) throws IOException { + // Try to fit at least one message in the network buffer, but place an upper and lower limit on its size to make + // sure it doesnt get too large or have to call read too often. + dbuf = ByteBuffer.allocate(Math.min(Math.max(parser.maxMessageSize, BUFFER_SIZE_LOWER_BOUND), BUFFER_SIZE_UPPER_BOUND)); + parser.setWriteTarget(this); + sc = SocketChannel.open(); + + new Thread() { + @Override + public void run() { + try { + sc.socket().connect(serverAddress, connectTimeoutMillis); + parser.connectionOpen(); + + while (true) { + int read = sc.read(dbuf); + if (read == 0) + continue; + else if (read == -1) + return; + // "flip" the buffer - setting the limit to the current position and setting position to 0 + dbuf.flip(); + // Use parser.receive's return value as a double-check that it stopped reading at the right + // location + int bytesConsumed = parser.receive(dbuf); + checkState(dbuf.position() == bytesConsumed); + // Now drop the bytes which were read by compacting dbuf (resetting limit and keeping relative + // position) + dbuf.compact(); + } + } catch (AsynchronousCloseException e) {// Expected if the connection is closed + } catch (ClosedChannelException e) { // Expected if the connection is closed + } catch (Exception e) { + log.error("Error trying to open/read from connection", e); + } finally { + try { + sc.close(); + } catch (IOException e1) { + // At this point there isn't much we can do, and we can probably assume the channel is closed + } + parser.connectionClosed(); + } + } + }.start(); + } + + /** + * Closes the connection to the server, triggering the {@link ProtobufParser#connectionClosed()} + * event on the network-handling thread where all callbacks occur. + */ + public void closeConnection() { + // Closes the channel, triggering an exception in the network-handling thread triggering connectionClosed() + try { + sc.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + // Writes raw bytes to the channel (used by the write method in ProtobufParser) + @Override + synchronized void writeBytes(byte[] message) { + try { + if (sc.write(ByteBuffer.wrap(message)) != message.length) + throw new IOException("Couldn't write all of message to socket"); + } catch (IOException e) { + log.error("Error writing message to connection, closing connection", e); + closeConnection(); + } + } +} diff --git a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParser.java b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParser.java new file mode 100644 index 00000000..6f252232 --- /dev/null +++ b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParser.java @@ -0,0 +1,240 @@ +/* + * Copyright 2013 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.bitcoin.protocols.niowrapper; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Timer; +import java.util.TimerTask; + +import com.google.bitcoin.core.Utils; +import com.google.protobuf.ByteString; +import com.google.protobuf.MessageLite; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +/** + *

A handler which is used in {@link ProtobufServer} and {@link ProtobufClient} to split up incoming data streams + * into protobufs and provide an interface for writing protobufs to the connections.

+ * + *

Messages are encoded with a 4-byte signed integer (big endian) prefix to indicate their length followed by the + * serialized protobuf

+ */ +public class ProtobufParser { + /** + * An interface which can be implemented to handle callbacks as new messages are generated and socket events occur. + * @param The protobuf type which is used on this socket. + * This MUST match the MessageType used in the parent {@link ProtobufParser} + */ + public interface Listener { + /** Called when a new protobuf is received from the remote side. */ + public void messageReceived(ProtobufParser handler, MessageType msg); + /** Called when the connection is opened and available for writing data to. */ + public void connectionOpen(ProtobufParser handler); + /** Called when the connection is closed and no more data should be provided. */ + public void connectionClosed(ProtobufParser handler); + } + + // The callback listener + private final Listener handler; + // The prototype which is used to deserialize messages + private final MessageLite prototype; + + // The maximum message size (NOT INCLUDING LENGTH PREFIX) + final int maxMessageSize; + + // A temporary buffer used when the message size is larger than the buffer being used by the network code + // Because the networking code uses a constant size buffer and we want to allow for very large message sizes, we use + // a smaller network buffer per client and only allocate more memory when we need it to deserialize large messages. + // Though this is not in of itself a DoS protection, it allows for handling more legitimate clients per server and + // attacking clients can be made to timeout/get blocked if they are sending crap to fill buffers. + private int messageBytesOffset = 0; + private byte[] messageBytes; + + private MessageWriteTarget writeTarget; + + // TimerTask and timeout value which are added to a timer to kill the connection on timeout + private TimerTask timeoutTask; + private long timeoutMillis; + + // A timer which manages expiring channels as their timeouts occur + private static final Timer timeoutTimer = new Timer(); + + /** + * Creates a new protobuf handler. + * + * @param handler The callback listener + * @param prototype The default instance of the message type used in both directions of this channel. + * This should be the return value from {@link MessageType#getDefaultInstanceForType()} + * @param maxMessageSize The maximum message size (not including the 4-byte length prefix). + * Note that this has an upper bound of {@link Integer#MAX_VALUE} - 4 + * @param timeoutMillis The timeout between messages before the connection is automatically closed. Only enabled + * after the connection is established. + */ + public ProtobufParser(Listener handler, MessageType prototype, int maxMessageSize, int timeoutMillis) { + this.handler = handler; + this.prototype = prototype; + this.timeoutMillis = timeoutMillis; + this.maxMessageSize = Math.min(maxMessageSize, Integer.MAX_VALUE - 4); + } + + // Sets the upstream write channel + synchronized void setWriteTarget(MessageWriteTarget writeTarget) { + checkState(this.writeTarget == null); + this.writeTarget = checkNotNull(writeTarget); + } + + /** + * Closes this connection, eventually triggering a {@link ProtobufParser.Listener#connectionClosed()} event. + */ + public synchronized void closeConnection() { + this.writeTarget.closeConnection(); + } + + // Deserializes and provides a listener event (buff must not have the length prefix in it) + // Does set the buffers's position to its limit + private void deserializeMessage(ByteBuffer buff) throws Exception { + MessageType msg = (MessageType) prototype.newBuilderForType().mergeFrom(ByteString.copyFrom(buff)).build(); + resetTimeout(); + handler.messageReceived(this, msg); + } + + /** + * Called when new bytes are available from the remote end. + * * buff will start with its limit set to the position we can read to and its position set to the location we will + * start reading at + * * May read more than one message (recursively) if there are enough bytes available + * * Uses messageBytes/messageBytesOffset to store message which are larger (incl their length prefix) than buff's + * capacity(), ie it is up to this method to ensure we dont run out of buffer space to decode the next message. + * * buff will end with its limit the same as it was previously, and its position set to the position up to which + * bytes have been read (the same as its return value) + * @return The amount of bytes consumed which should not be provided again + */ + synchronized int receive(ByteBuffer buff) throws Exception { + if (messageBytes != null) { + // Just keep filling up the currently being worked on message + int bytesToGet = Math.min(messageBytes.length - messageBytesOffset, buff.remaining()); + buff.get(messageBytes, messageBytesOffset, bytesToGet); + messageBytesOffset += bytesToGet; + if (messageBytesOffset == messageBytes.length) { + // Filled up our buffer, decode the message + deserializeMessage(ByteBuffer.wrap(messageBytes)); + messageBytes = null; + if (buff.hasRemaining()) + return bytesToGet + receive(buff); + } + return bytesToGet; + } + + // If we cant read the length prefix yet, give up + if (buff.remaining() < 4) + return 0; + + // Read one integer in big endian + buff.order(ByteOrder.BIG_ENDIAN); + final int len = buff.getInt(); + + // If length is larger than the maximum message size (or is negative/overflows) throw an exception and close the + // connection + if (len > maxMessageSize || len + 4 < 4) + throw new IllegalStateException("Message too large or length underflowed"); + + // If the buffer's capacity is less than the next messages length + 4 (length prefix), we must use messageBytes + // as a temporary buffer to store the message + if (buff.capacity() < len + 4) { + messageBytes = new byte[len]; + // Now copy all remaining bytes into the new buffer, set messageBytesOffset and tell the caller how many + // bytes we consumed + int bytesToRead = buff.remaining(); + buff.get(messageBytes, 0, bytesToRead); + messageBytesOffset = bytesToRead; + return bytesToRead + 4; + } + + if (buff.remaining() < len) { + // Wait until the whole message is available in the buffer + buff.position(buff.position() - 4); // Make sure the buffer's position is right at the end + return 0; + } + + // Temporarily limit the buffer to the size of the message so that the protobuf decode doesn't get messed up + int limit = buff.limit(); + buff.limit(buff.position() + len); + deserializeMessage(buff); + checkState(buff.remaining() == 0); + buff.limit(limit); // Reset the limit in case we have to recurse + + // If there are still bytes remaining, see if we can pull out another message since we won't get called again + if (buff.hasRemaining()) + return len + 4 + receive(buff); + else + return len + 4; + } + + /** Called by the upstream connection manager if this connection closes */ + void connectionClosed() { + handler.connectionClosed(this); + } + + /** Called by the upstream connection manager when this connection is open */ + void connectionOpen() { + resetTimeout(); + handler.connectionOpen(this); + } + + /** + *

Writes the given message to the other side of the connection, prefixing it with the proper 4-byte prefix.

+ * + *

Provides a write-order guarantee.

+ * + * @throws IllegalStateException If the encoded message is larger than the maximum message size. + */ + public synchronized void write(MessageType msg) throws IllegalStateException { + byte[] messageBytes = msg.toByteArray(); + checkState(messageBytes.length <= maxMessageSize); + byte[] messageLength = new byte[4]; + Utils.uint32ToByteArrayBE(messageBytes.length, messageLength, 0); + writeTarget.writeBytes(messageLength); + writeTarget.writeBytes(messageBytes); + } + + /** + *

Sets the receive timeout to the given number of milliseconds, automatically killing the connection if no + * messages are received for this long

+ * + *

A timeout of 0 is interpreted as no timeout

+ */ + public synchronized void setSocketTimeout(int timeoutMillis) { + this.timeoutMillis = timeoutMillis; + resetTimeout(); + } + + private synchronized void resetTimeout() { + if (timeoutTask != null) + timeoutTask.cancel(); + if (timeoutMillis == 0) + return; + timeoutTask = new TimerTask() { + @Override + public void run() { + closeConnection(); + } + }; + timeoutTimer.schedule(timeoutTask, timeoutMillis); + } +} diff --git a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParserFactory.java b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParserFactory.java new file mode 100644 index 00000000..c2a115ef --- /dev/null +++ b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufParserFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright 2013 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.bitcoin.protocols.niowrapper; + +import java.net.InetAddress; +import javax.annotation.Nullable; + +/** + * A factory which generates new {@link ProtobufParser}s when a new connection is opened. + */ +public interface ProtobufParserFactory { + /** + * Returns a new handler or null to have the connection close. + * @param inetAddress The client's (IP) address + * @param port The remote port on the client side + */ + @Nullable public ProtobufParser getNewParser(InetAddress inetAddress, int port); +} diff --git a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java new file mode 100644 index 00000000..400d2f11 --- /dev/null +++ b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java @@ -0,0 +1,217 @@ +/* + * Copyright 2013 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.bitcoin.protocols.niowrapper; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.Iterator; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.LoggerFactory; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +/** + * Creates a simple server listener which listens for incoming client connections and uses a {@link ProtobufParser} to + * process data. + */ +public class ProtobufServer { + private static final org.slf4j.Logger log = LoggerFactory.getLogger(ProtobufServer.class); + + private final ProtobufParserFactory parserFactory; + + @VisibleForTesting final Thread handlerThread; + private final ServerSocketChannel sc; + + private static final int BUFFER_SIZE_LOWER_BOUND = 4096; + private static final int BUFFER_SIZE_UPPER_BOUND = 65536; + + private class ConnectionHandler extends MessageWriteTarget { + private final ByteBuffer dbuf; + private final SocketChannel channel; + private final ProtobufParser parser; + private boolean closeCalled = false; + + ConnectionHandler(SocketChannel channel) throws IOException { + this.channel = checkNotNull(channel); + ProtobufParser newParser = parserFactory.getNewParser(channel.socket().getInetAddress(), channel.socket().getPort()); + if (newParser == null) { + closeConnection(); + throw new IOException("Parser factory.getNewParser returned null"); + } + this.parser = newParser; + dbuf = ByteBuffer.allocate(Math.min(Math.max(newParser.maxMessageSize, BUFFER_SIZE_LOWER_BOUND), BUFFER_SIZE_UPPER_BOUND)); + newParser.setWriteTarget(this); + } + + @Override + synchronized void writeBytes(byte[] message) { + try { + if (channel.write(ByteBuffer.wrap(message)) != message.length) + throw new IOException("Couldn't write all of message to socket"); + } catch (IOException e) { + log.error("Error writing message to connection, closing connection", e); + closeConnection(); + } + } + + @Override + void closeConnection() { + try { + channel.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + connectionClosed(); + } + + private synchronized void connectionClosed() { + if (!closeCalled) + parser.connectionClosed(); + closeCalled = true; + } + } + + // Handle a SelectionKey which was selected + private void handleKey(Selector selector, SelectionKey key) throws IOException { + if (key.isValid() && key.isAcceptable()) { + // Accept a new connection, give it a parser as an attachment + SocketChannel newChannel = sc.accept(); + newChannel.configureBlocking(false); + ConnectionHandler handler = new ConnectionHandler(newChannel); + newChannel.register(selector, SelectionKey.OP_READ).attach(handler); + handler.parser.connectionOpen(); + } else { // Got a closing channel or a channel to a client connection + ConnectionHandler handler = ((ConnectionHandler)key.attachment()); + try { + if (!key.isValid() && handler != null) + handler.closeConnection(); // Key has been cancelled, make sure the socket gets closed + else if (handler != null && key.isReadable()) { + // Do a socket read and invoke the parser's receive message + int read = handler.channel.read(handler.dbuf); + if (read == 0) + return; // Should probably never happen, but just in case it actually can just return 0 + else if (read == -1) { // Socket was closed + key.cancel(); + handler.closeConnection(); + return; + } + // "flip" the buffer - setting the limit to the current position and setting position to 0 + handler.dbuf.flip(); + // Use parser.receive's return value as a double-check that it stopped reading at the right location + int bytesConsumed = handler.parser.receive(handler.dbuf); + checkState(handler.dbuf.position() == bytesConsumed); + // Now drop the bytes which were read by compacting dbuf (resetting limit and keeping relative + // position) + handler.dbuf.compact(); + } + } catch (Exception e) { + // This can happen eg if the channel closes while the tread is about to get killed + // (ClosedByInterruptException), or if parser.parser.receive throws something + log.error("Error handling SelectionKey", e); + if (handler != null) + handler.closeConnection(); + } + } + } + + /** + * Creates a new server which is capable of listening for incoming connections and processing client provided data + * using {@link ProtobufParser}s created by the given {@link ProtobufParserFactory} + * + * @throws IOException If there is an issue opening the server socket (note that we don't bind yet) + */ + public ProtobufServer(final ProtobufParserFactory parserFactory) throws IOException { + this.parserFactory = parserFactory; + + sc = ServerSocketChannel.open(); + sc.configureBlocking(false); + final Selector selector = SelectorProvider.provider().openSelector(); + + handlerThread = new Thread() { + @Override + public void run() { + try { + sc.register(selector, SelectionKey.OP_ACCEPT); + + while (selector.select() > 0) { // Will get 0 on stop() due to thread interrupt + Iterator keyIterator = selector.selectedKeys().iterator(); + while (keyIterator.hasNext()) { + SelectionKey key = keyIterator.next(); + keyIterator.remove(); + + handleKey(selector, key); + } + } + } catch (Exception e) { + log.error("Error trying to open/read from connection: {}", e); + } finally { + // Go through and close everything, without letting IOExceptions getting in our way + for (SelectionKey key : selector.keys()) { + try { + key.channel().close(); + } catch (IOException e) { + log.error("Error closing channel", e); + } + try { + key.cancel(); + handleKey(selector, key); + } catch (IOException e) { + log.error("Error closing selection key", e); + } + } + try { + selector.close(); + } catch (IOException e) { + log.error("Error closing server selector", e); + } + try { + sc.close(); + } catch (IOException e) { + log.error("Error closing server channel", e); + } + } + } + }; + } + + /** + * Starts the server by binding to the given address and starting the connection handling thread. + * + * @throws IOException If binding fails for some reason. + */ + public void start(InetSocketAddress bindAddress) throws IOException { + sc.socket().bind(bindAddress); + handlerThread.start(); + } + + /** + * Attempts to gracefully close all open connections, calling their connectionClosed() events. + * @throws InterruptedException If we are interrupted while waiting for the process to finish + */ + public void stop() throws InterruptedException { + handlerThread.interrupt(); + handlerThread.join(); + } +} diff --git a/core/src/test/java/com/google/bitcoin/protocols/niowrapper/NioWrapperTest.java b/core/src/test/java/com/google/bitcoin/protocols/niowrapper/NioWrapperTest.java new file mode 100644 index 00000000..d54a251d --- /dev/null +++ b/core/src/test/java/com/google/bitcoin/protocols/niowrapper/NioWrapperTest.java @@ -0,0 +1,511 @@ +/* + * Copyright 2013 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.bitcoin.protocols.niowrapper; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.bitcoin.core.Utils; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.ByteString; +import org.bitcoin.paymentchannel.Protos; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static com.google.common.base.Preconditions.checkState; +import static org.junit.Assert.*; + +public class NioWrapperTest { + private AtomicBoolean fail; + + @Before + public void setUp() { + fail = new AtomicBoolean(false); + } + + @After + public void checkFail() { + assertFalse(fail.get()); + } + + @Test + public void basicClientServerTest() throws Exception { + // Tests creating a basic server, opening a client connection and sending a few messages + + final SettableFuture serverConnectionOpen = SettableFuture.create(); + final SettableFuture clientConnectionOpen = SettableFuture.create(); + final SettableFuture serverConnectionClosed = SettableFuture.create(); + final SettableFuture clientConnectionClosed = SettableFuture.create(); + final SettableFuture clientMessage1Received = SettableFuture.create(); + final SettableFuture clientMessage2Received = SettableFuture.create(); + ProtobufServer server = new ProtobufServer(new ProtobufParserFactory() { + @Override + public ProtobufParser getNewParser(InetAddress inetAddress, int port) { + return new ProtobufParser(new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + handler.write(msg); + handler.write(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + serverConnectionOpen.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + serverConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + } + }); + server.start(new InetSocketAddress("localhost", 4243)); + + ProtobufParser clientHandler = new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public synchronized void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + if (clientMessage1Received.isDone()) + clientMessage2Received.set(msg); + else + clientMessage1Received.set(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + clientConnectionOpen.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + clientConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + + ProtobufClient client = new ProtobufClient(new InetSocketAddress("localhost", 4243), clientHandler, 0); + + clientConnectionOpen.get(); + serverConnectionOpen.get(); + + Protos.TwoWayChannelMessage msg = Protos.TwoWayChannelMessage.newBuilder().setType(Protos.TwoWayChannelMessage.MessageType.CHANNEL_OPEN).build(); + clientHandler.write(msg); + assertEquals(msg, clientMessage1Received.get()); + assertEquals(msg, clientMessage2Received.get()); + + client.closeConnection(); + serverConnectionClosed.get(); + clientConnectionClosed.get(); + + server.stop(); + } + + @Test + public void basicTimeoutTest() throws Exception { + // Tests various timeout scenarios + + final SettableFuture serverConnection1Open = SettableFuture.create(); + final SettableFuture clientConnection1Open = SettableFuture.create(); + final SettableFuture serverConnection1Closed = SettableFuture.create(); + final SettableFuture clientConnection1Closed = SettableFuture.create(); + + final SettableFuture serverConnection2Open = SettableFuture.create(); + final SettableFuture clientConnection2Open = SettableFuture.create(); + final SettableFuture serverConnection2Closed = SettableFuture.create(); + final SettableFuture clientConnection2Closed = SettableFuture.create(); + ProtobufServer server = new ProtobufServer(new ProtobufParserFactory() { + @Override + public ProtobufParser getNewParser(InetAddress inetAddress, int port) { + return new ProtobufParser(new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + fail.set(true); + } + + @Override + public synchronized void connectionOpen(ProtobufParser handler) { + if (serverConnection1Open.isDone()) { + handler.setSocketTimeout(0); + serverConnection2Open.set(null); + } else + serverConnection1Open.set(null); + } + + @Override + public synchronized void connectionClosed(ProtobufParser handler) { + if (serverConnection1Closed.isDone()) { + serverConnection2Closed.set(null); + } else + serverConnection1Closed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 10); + } + }); + server.start(new InetSocketAddress("localhost", 4243)); + + new ProtobufClient(new InetSocketAddress("localhost", 4243), new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + fail.set(true); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + clientConnection1Open.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + clientConnection1Closed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0), 0); + + clientConnection1Open.get(); + serverConnection1Open.get(); + Thread.sleep(15); + assertTrue(clientConnection1Closed.isDone() && serverConnection1Closed.isDone()); + + ProtobufParser client2Handler = new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + fail.set(true); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + clientConnection2Open.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + clientConnection2Closed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + ProtobufClient client2 = new ProtobufClient(new InetSocketAddress("localhost", 4243), client2Handler, 0); + + clientConnection2Open.get(); + serverConnection2Open.get(); + Thread.sleep(15); + assertFalse(clientConnection2Closed.isDone() || serverConnection2Closed.isDone()); + + client2Handler.setSocketTimeout(10); + Thread.sleep(15); + assertTrue(clientConnection2Closed.isDone() && serverConnection2Closed.isDone()); + + server.stop(); + } + + @Test + public void largeDataTest() throws Exception { + /** Test various large-data handling, essentially testing {@link ProtobufParser#receive(java.nio.ByteBuffer)} */ + final SettableFuture serverConnectionOpen = SettableFuture.create(); + final SettableFuture clientConnectionOpen = SettableFuture.create(); + final SettableFuture serverConnectionClosed = SettableFuture.create(); + final SettableFuture clientConnectionClosed = SettableFuture.create(); + final SettableFuture clientMessage1Received = SettableFuture.create(); + final SettableFuture clientMessage2Received = SettableFuture.create(); + final SettableFuture clientMessage3Received = SettableFuture.create(); + final SettableFuture clientMessage4Received = SettableFuture.create(); + ProtobufServer server = new ProtobufServer(new ProtobufParserFactory() { + @Override + public ProtobufParser getNewParser(InetAddress inetAddress, int port) { + return new ProtobufParser(new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + handler.write(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + serverConnectionOpen.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + serverConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 0x10000, 0); + } + }); + server.start(new InetSocketAddress("localhost", 4243)); + + ProtobufParser clientHandler = new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public synchronized void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + if (clientMessage1Received.isDone()) { + if (clientMessage2Received.isDone()) { + if (clientMessage3Received.isDone()) { + if (clientMessage4Received.isDone()) + fail.set(true); + clientMessage4Received.set(msg); + } else + clientMessage3Received.set(msg); + } else + clientMessage2Received.set(msg); + } else + clientMessage1Received.set(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + clientConnectionOpen.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + clientConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 0x10000, 0); + + ProtobufClient client = new ProtobufClient(new InetSocketAddress("localhost", 4243), clientHandler, 0); + + clientConnectionOpen.get(); + serverConnectionOpen.get(); + + // Large message that is larger than buffer and equal to maximum message size + Protos.TwoWayChannelMessage msg = Protos.TwoWayChannelMessage.newBuilder() + .setType(Protos.TwoWayChannelMessage.MessageType.CHANNEL_OPEN) + .setClientVersion(Protos.ClientVersion.newBuilder() + .setMajor(1) + .setPreviousChannelContractHash(ByteString.copyFrom(new byte[0x10000 - 12]))) + .build(); + // Small message that fits in the buffer + Protos.TwoWayChannelMessage msg2 = Protos.TwoWayChannelMessage.newBuilder() + .setType(Protos.TwoWayChannelMessage.MessageType.CHANNEL_OPEN) + .setClientVersion(Protos.ClientVersion.newBuilder() + .setMajor(1) + .setPreviousChannelContractHash(ByteString.copyFrom(new byte[1]))) + .build(); + // Break up the message into chunks to simulate packet network (with strange MTUs...) + byte[] messageBytes = msg.toByteArray(); + byte[] messageLength = new byte[4]; + Utils.uint32ToByteArrayBE(messageBytes.length, messageLength, 0); + client.writeBytes(new byte[]{messageLength[0], messageLength[1]}); + Thread.sleep(10); + client.writeBytes(new byte[]{messageLength[2], messageLength[3]}); + Thread.sleep(10); + client.writeBytes(new byte[]{messageBytes[0], messageBytes[1]}); + Thread.sleep(10); + client.writeBytes(Arrays.copyOfRange(messageBytes, 2, messageBytes.length - 1)); + Thread.sleep(10); + + // Now send the end of msg + msg2 + msg3 all at once + byte[] messageBytes2 = msg2.toByteArray(); + byte[] messageLength2 = new byte[4]; + Utils.uint32ToByteArrayBE(messageBytes2.length, messageLength2, 0); + byte[] sendBytes = Arrays.copyOf(new byte[] {messageBytes[messageBytes.length-1]}, 1 + messageBytes2.length*2 + messageLength2.length*2); + System.arraycopy(messageLength2, 0, sendBytes, 1, 4); + System.arraycopy(messageBytes2, 0, sendBytes, 5, messageBytes2.length); + System.arraycopy(messageLength2, 0, sendBytes, 5 + messageBytes2.length, 4); + System.arraycopy(messageBytes2, 0, sendBytes, 9 + messageBytes2.length, messageBytes2.length); + client.writeBytes(sendBytes); + assertEquals(msg, clientMessage1Received.get()); + assertEquals(msg2, clientMessage2Received.get()); + assertEquals(msg2, clientMessage3Received.get()); + + // Now resent msg2 in chunks, by itself + Utils.uint32ToByteArrayBE(messageBytes2.length, messageLength2, 0); + client.writeBytes(new byte[]{messageLength2[0], messageLength2[1]}); + Thread.sleep(10); + client.writeBytes(new byte[]{messageLength2[2], messageLength2[3]}); + Thread.sleep(10); + client.writeBytes(new byte[]{messageBytes2[0], messageBytes2[1]}); + Thread.sleep(10); + client.writeBytes(new byte[]{messageBytes2[2], messageBytes2[3]}); + Thread.sleep(10); + client.writeBytes(Arrays.copyOfRange(messageBytes2, 4, messageBytes2.length)); + assertEquals(msg2, clientMessage4Received.get()); + + Protos.TwoWayChannelMessage msg5 = Protos.TwoWayChannelMessage.newBuilder() + .setType(Protos.TwoWayChannelMessage.MessageType.CHANNEL_OPEN) + .setClientVersion(Protos.ClientVersion.newBuilder() + .setMajor(1) + .setPreviousChannelContractHash(ByteString.copyFrom(new byte[0x10000 - 11]))) + .build(); + try { + clientHandler.write(msg5); + } catch (IllegalStateException e) {} + + // Override max size and make sure the server drops our connection + byte[] messageBytes5 = msg5.toByteArray(); + byte[] messageLength5 = new byte[4]; + Utils.uint32ToByteArrayBE(messageBytes5.length, messageLength5, 0); + client.writeBytes(messageBytes5); + client.writeBytes(messageLength5); + + serverConnectionClosed.get(); + clientConnectionClosed.get(); + + server.stop(); + } + + @Test + public void testConnectionEventHandlers() throws Exception { + final SettableFuture serverConnection1Open = SettableFuture.create(); + final SettableFuture serverConnection2Open = SettableFuture.create(); + final SettableFuture serverConnection3Open = SettableFuture.create(); + final SettableFuture client1ConnectionOpen = SettableFuture.create(); + final SettableFuture client2ConnectionOpen = SettableFuture.create(); + final SettableFuture client3ConnectionOpen = SettableFuture.create(); + final SettableFuture serverConnectionClosed1 = SettableFuture.create(); + final SettableFuture serverConnectionClosed2 = SettableFuture.create(); + final SettableFuture serverConnectionClosed3 = SettableFuture.create(); + final SettableFuture client1ConnectionClosed = SettableFuture.create(); + final SettableFuture client2ConnectionClosed = SettableFuture.create(); + final SettableFuture client3ConnectionClosed = SettableFuture.create(); + final SettableFuture client1MessageReceived = SettableFuture.create(); + final SettableFuture client2MessageReceived = SettableFuture.create(); + final SettableFuture client3MessageReceived = SettableFuture.create(); + ProtobufServer server = new ProtobufServer(new ProtobufParserFactory() { + @Override + public ProtobufParser getNewParser(InetAddress inetAddress, int port) { + return new ProtobufParser(new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + handler.write(msg); + } + + @Override + public synchronized void connectionOpen(ProtobufParser handler) { + if (serverConnection1Open.isDone()) { + if (serverConnection2Open.isDone()) + serverConnection3Open.set(null); + else + serverConnection2Open.set(null); + } else + serverConnection1Open.set(null); + } + + @Override + public synchronized void connectionClosed(ProtobufParser handler) { + if (serverConnectionClosed1.isDone()) { + if (serverConnectionClosed2.isDone()) { + checkState(!serverConnectionClosed3.isDone()); + serverConnectionClosed3.set(null); + } else + serverConnectionClosed2.set(null); + } else + serverConnectionClosed1.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + } + }); + server.start(new InetSocketAddress("localhost", 4243)); + + ProtobufParser client1Handler = new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + client1MessageReceived.set(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + client1ConnectionOpen.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + client1ConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + ProtobufClient client1 = new ProtobufClient(new InetSocketAddress("localhost", 4243), client1Handler, 0); + + client1ConnectionOpen.get(); + serverConnection1Open.get(); + + ProtobufParser client2Handler = new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + client2MessageReceived.set(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + client2ConnectionOpen.set(null); + } + + @Override + public void connectionClosed(ProtobufParser handler) { + client2ConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + ProtobufClient client2 = new ProtobufClient(new InetSocketAddress("localhost", 4243), client2Handler, 0); + + client2ConnectionOpen.get(); + serverConnection2Open.get(); + + ProtobufParser client3Handler = new ProtobufParser( + new ProtobufParser.Listener() { + @Override + public void messageReceived(ProtobufParser handler, Protos.TwoWayChannelMessage msg) { + client3MessageReceived.set(msg); + } + + @Override + public void connectionOpen(ProtobufParser handler) { + client3ConnectionOpen.set(null); + } + + @Override + public synchronized void connectionClosed(ProtobufParser handler) { + checkState(!client3ConnectionClosed.isDone()); + client3ConnectionClosed.set(null); + } + }, Protos.TwoWayChannelMessage.getDefaultInstance(), 1000, 0); + ProtobufClient client3 = new ProtobufClient(new InetSocketAddress("localhost", 4243), client3Handler, 0); + + client3ConnectionOpen.get(); + serverConnection3Open.get(); + + Protos.TwoWayChannelMessage msg = Protos.TwoWayChannelMessage.newBuilder().setType(Protos.TwoWayChannelMessage.MessageType.CHANNEL_OPEN).build(); + client1Handler.write(msg); + assertEquals(msg, client1MessageReceived.get()); + + Protos.TwoWayChannelMessage msg2 = Protos.TwoWayChannelMessage.newBuilder().setType(Protos.TwoWayChannelMessage.MessageType.INITIATE).build(); + client2Handler.write(msg2); + assertEquals(msg2, client2MessageReceived.get()); + + client1.closeConnection(); + serverConnectionClosed1.get(); + client1ConnectionClosed.get(); + + Protos.TwoWayChannelMessage msg3 = Protos.TwoWayChannelMessage.newBuilder().setType(Protos.TwoWayChannelMessage.MessageType.CLOSE).build(); + client3Handler.write(msg3); + assertEquals(msg3, client3MessageReceived.get()); + + // Try to create a race condition by triggering handlerTread closing and client3 closing at the same time + // This often triggers ClosedByInterruptException in handleKey + server.handlerThread.interrupt(); + client3.closeConnection(); + client3ConnectionClosed.get(); + serverConnectionClosed3.get(); + + server.handlerThread.join(); + client2ConnectionClosed.get(); + serverConnectionClosed2.get(); + + server.stop(); + } +}