Skip to Main Content

Java Security

Announcement

For appeals, questions and feedback about Oracle Forums, please email oracle-forums-moderators_us@oracle.com. Technical questions should be asked in the appropriate category. Thank you!

Non-blocking SSLEngine example

843811Jun 23 2005 — edited Aug 13 2013
Since the example of using SSLEngine with non-blocking IO that comes with Java is quite limited, I have decided to release my own for anyone who wants to see how I solved the various problems that you must face. The example is designed to be a generic non-blocking server that supports SSL.

This is only meant to be an example, as I wrote this mostly in order to learn how to use the SSLEngine, and therefore has certain limitations, and is not thouroughly tested.

You can download the file at: http://members.aol.com/ben77/nio_server2.tar.gz

Here is also the code for SecureIO, which is roughly analagous to the Java example's ChannelIOSecure:
/**
 * 
 */
package nio_server2.internalio;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.*;

import java.util.concurrent.*;

import javax.net.ssl.*;

import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;

/**
 * Does IO based on a <code>SocketChannel</code> with all data encrypted using
 * SSL.
 * 
 * @author ben
 * 
 */
public class SecureIO extends InsecureIO {
	/**
	 * SSLTasker is responsible for dealing with all long running tasks required
	 * by the SSLEngine
	 * 
	 * @author ben
	 * 
	 */
	private class SSLTasker implements Runnable {
		/**
		 * @inheritDoc
		 */
		public void run() {
			Runnable r;
			while ((r = engine.getDelegatedTask()) != null) {
				r.run();
			}
			if (inNet.position() > 0) {
				regnow(); // we may already have read what is needed
			}
			try {
				System.out.println(":" + engine.getHandshakeStatus());
				switch (engine.getHandshakeStatus()) {
					case NOT_HANDSHAKING:
						break;
					case FINISHED:
						System.err.println("Detected FINISHED in tasker");
						Thread.dumpStack();
						break;
					case NEED_TASK:
						System.err.println("Detected NEED_TASK in tasker");
						assert false;
						break;
					case NEED_WRAP:
						rereg(SelectionKey.OP_WRITE);
						break;
					case NEED_UNWRAP:
						rereg(SelectionKey.OP_READ);
						break;
				}
			} catch (IOException e) {
				e.printStackTrace();
				try {
					shutdown();
				} catch (IOException ex) {
					ex.printStackTrace();
				}
			}
			hsStatus = engine.getHandshakeStatus();
			isTasking = false;
		}
	}

	private SSLEngine engine;

	private ByteBuffer inNet; // always cleared btwn calls

	private ByteBuffer outNet; // when hasRemaining, has data to write.

	private static final ByteBuffer BLANK = ByteBuffer.allocate(0);

	private boolean initialHandshakeDone = false;

	private volatile boolean isTasking = false;

	private boolean handshaking = true;

	private SSLEngineResult.HandshakeStatus hsStatus = NEED_UNWRAP;

	private boolean shutdownStarted;

	private static Executor executor = getDefaultExecutor();

	private SSLTasker tasker = new SSLTasker();

	private ByteBuffer temp;

	/**
	 * @return the default <code>Executor</code>
	 */
	public static Executor getDefaultExecutor() {
		return new ThreadPoolExecutor(3, Integer.MAX_VALUE, 60L,
				TimeUnit.SECONDS, new SynchronousQueue<Runnable>(),
				new DaemonThreadFactory());
	}

	private static class DaemonThreadFactory implements ThreadFactory {
		private static ThreadFactory defaultFactory = Executors
				.defaultThreadFactory();

		/**
		 * Creates a thread using the default factory, but sets it to be daemon
		 * before returning it
		 * 
		 * @param r
		 *            the runnable to run
		 * @return the new thread
		 */
		public Thread newThread(Runnable r) {
			Thread t = defaultFactory.newThread(r);
			t.setDaemon(true);
			return t;
		}

	}

	/**
	 * @return the executor currently being used for all long-running tasks
	 */
	public static Executor getExecutor() {
		return executor;
	}

	/**
	 * Changes the executor being used for all long-running tasks. Currently
	 * running tasks will still use the old executor
	 * 
	 * @param executor
	 *            the new Executor to use
	 */
	public static void setExecutor(Executor executor) {
		SecureIO.executor = executor;
	}

	/**
	 * Creates a new <code>SecureIO</code>
	 * 
	 * @param channel
	 *            the channel to do IO on.
	 * @param sslCtx
	 *            the <code>SSLContext</code> to use
	 */
	public SecureIO(SocketChannel channel, SSLContext sslCtx) {
		super(channel);
		engine = sslCtx.createSSLEngine();
		engine.setUseClientMode(false);
		int size = engine.getSession().getPacketBufferSize();
		inNet = ByteBuffer.allocate(size);
		outNet = ByteBuffer.allocate(size);
		outNet.limit(0);
		temp = ByteBuffer.allocate(engine.getSession()
				.getApplicationBufferSize());
	}

	private void doTasks() throws IOException {
		rereg(0); // don't do anything until the task is done.
		isTasking = true;
		SecureIO.executor.execute(tasker);
	}

	/**
	 * Does all handshaking required by SSL.
	 * 
	 * @param dst
	 *            the destination from an application data read
	 * @return true if all needed SSL handshaking is currently complete.
	 * @throws IOException
	 *             if there are errors in handshaking.
	 * 
	 */
	@Override
	public boolean doHandshake(ByteBuffer dst) throws IOException {

		if (!handshaking) {
			return true;
		}
		if (dst.remaining() < minBufferSize()) {
			throw new IllegalArgumentException("Buffer has only "
					+ dst.remaining() + " left + minBufferSize is "
					+ minBufferSize());
		}
		if (outNet.hasRemaining()) {
			if (!flush()) {
				return false;
			}
			switch (hsStatus) {
				case FINISHED:
					handshaking = false;
					initialHandshakeDone = true;
					rereg(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
					return true;
				case NEED_UNWRAP:
					rereg(SelectionKey.OP_READ);
					break;
				case NEED_TASK:
					doTasks();
					return false;
				case NOT_HANDSHAKING:
					throw new RuntimeException(
							"NOT_HANDSHAKING encountered when handshaking");
			}
		}
		SSLEngineResult res;
		System.out.println(hsStatus + "1" + handshaking);
		switch (hsStatus) {
			case NEED_UNWRAP:
				int i;

				do {
					rereg(SelectionKey.OP_READ);
					i = super.read(inNet);
					if (i < 0) {
						engine.closeInbound();
						handshaking = false;
						shutdown();
						return true;
					}
					if (i == 0 && inNet.position() == 0) {
						return false;
					}

					inloop: do {
						inNet.flip();
						temp.clear();
						res = engine.unwrap(inNet, temp);
						inNet.compact();
						temp.flip();
						if (temp.hasRemaining()) {
							dst.put(temp);
						}
						switch (res.getStatus()) {
							case OK:
								hsStatus = res.getHandshakeStatus();
								if (hsStatus == NEED_TASK) {
									doTasks();
								}
								// if (hsStatus == FINISHED) {
								// // if (!initialHandshakeDone) {
								// // throw new RuntimeException(hsStatus
								// // + " encountered when handshaking");
								// // }
								// initialHandshakeDone = true;
								// handshaking=false;
								// key.interestOps(SelectionKey.OP_READ
								// | SelectionKey.OP_WRITE);
								// }
								// TODO check others?
								break;
							case BUFFER_UNDERFLOW:
								break inloop;
							case BUFFER_OVERFLOW:
							case CLOSED:
								throw new RuntimeException(res.getStatus()
										+ " encountered when handshaking");
						}
					} while (hsStatus == NEED_UNWRAP
							&& dst.remaining() >= minBufferSize());
				} while (hsStatus == NEED_UNWRAP
						&& dst.remaining() >= minBufferSize());
				if (inNet.position() > 0) {
					System.err.println(inNet);
				}
				if (hsStatus != NEED_WRAP) {
					break;
				} // else fall through
				rereg(SelectionKey.OP_WRITE);
			case NEED_WRAP:
				do {
					outNet.clear();
					res = engine.wrap(BLANK, outNet);
					switch (res.getStatus()) {
						case OK:
							outNet.flip();
							hsStatus = res.getHandshakeStatus();
							if (hsStatus == NEED_TASK) {
								doTasks();
								return false;
							}

							// TODO check others?
							break;
						case BUFFER_OVERFLOW:
							outNet.limit(0);
							int size = engine.getSession()
									.getPacketBufferSize();
							if (outNet.capacity() < size) {
								outNet = ByteBuffer.allocate(size);
							} else { // shouldn't happen
								throw new RuntimeException(res.getStatus()
										+ " encountered when handshaking");
							}
						case BUFFER_UNDERFLOW: // engine shouldn't care
						case CLOSED:
							throw new RuntimeException(res.getStatus()
									+ " encountered when handshaking");
					}
				} while (flush() && hsStatus == NEED_WRAP);
				break;
			case NEED_TASK:
				doTasks();
				return false;
			case FINISHED:
				break; // checked below
			case NOT_HANDSHAKING:
				System.err.println(hsStatus + " encountered when handshaking");
				handshaking = false;
				initialHandshakeDone = true;
				rereg(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
		}
		if (hsStatus == FINISHED) {
			// if (!initialHandshakeDone) {
			// throw new RuntimeException(hsStatus
			// + " encountered when handshaking");
			// }
			initialHandshakeDone = true;
			handshaking = false;
			rereg(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
		}
		System.out.println(hsStatus + "2" + handshaking);
		return !handshaking;
	}

	/**
	 * Attempts to flush all buffered data to the channel.
	 * 
	 * @return true if all buffered data has been written.
	 * @throws IOException
	 *             if there are errors writing the data
	 */
	@Override
	public boolean flush() throws IOException {
		if (!outNet.hasRemaining()) {
			return true;
		}
		super.write(outNet);
		return !outNet.hasRemaining();
	}

	/**
	 * @return the largest amount of application data that could be read from
	 *         the channel at once.
	 */
	@Override
	public int minBufferSize() {
		return engine.getSession().getApplicationBufferSize();
	}

	/**
	 * Begins or proceeds with sending an SSL shutdown message to the client.
	 * 
	 * @return true if all needed IO is complete
	 * @throws IOException
	 *             if there are errors sending the message.
	 */
	@Override
	public boolean shutdown() throws IOException {
		if (!shutdownStarted) {
			shutdownStarted = true;
			engine.closeOutbound();
		}
		if (outNet.hasRemaining() && !flush()) {
			return false;
		}
		SSLEngineResult result;
		do {
			outNet.clear();
			result = engine.wrap(BLANK, outNet);
			if (result.getStatus() != SSLEngineResult.Status.CLOSED) {
				throw new IOException("Unexpected result in shutdown:"
						+ result.getStatus());
			}
			outNet.flip();
			if (outNet.hasRemaining() && !flush()) {
				return false;
			}
		} while (result.getHandshakeStatus() == NEED_WRAP);
		return !outNet.hasRemaining();
	}

	/**
	 * Reads all possible data into the <code>ByteBuffer</code>.
	 * 
	 * @param dst
	 *            the buffer to read into.
	 * @return the number of bytes read, or -1 if the channel or
	 *         <code>SSLEngine</code> is closed
	 * @throws IllegalStateException
	 *             if the initial handshake isn't complete *
	 * @throws IOException
	 *             if there are errors.
	 * @throws IllegalStateException
	 *             if the initial handshake isn't complete
	 * @throws IllegalArgumentException
	 *             if the remaining space in dst is less than
	 *             {@link SecureIO#minBufferSize()}
	 */

	@Override
	public int read(ByteBuffer dst) throws IOException {
		if (!initialHandshakeDone) {
			throw new IllegalStateException("Initial handshake incomplete");
		}
		if (dst.remaining() < minBufferSize()) {
			throw new IllegalArgumentException("Buffer has only "
					+ dst.remaining() + " left + minBufferSize is "
					+ minBufferSize());
		}
		int sPos = dst.position();
		int i;
		while ((i = super.read(inNet)) != 0
				&& dst.remaining() >= minBufferSize()) {
			if (i < 0) {
				engine.closeInbound();
				shutdown();
				return -1;
			}
			do {
				inNet.flip();
				temp.clear();
				SSLEngineResult result = engine.unwrap(inNet, temp);
				inNet.compact();
				temp.flip();
				if (temp.hasRemaining()) {
					dst.put(temp);
				}
				switch (result.getStatus()) {
					case BUFFER_UNDERFLOW:
						continue;
					case BUFFER_OVERFLOW:
						throw new Error();
					case CLOSED:
						return -1;
					// throw new IOException("SSLEngine closed");
					case OK:
						checkHandshake();
						break;
				}
			} while (inNet.position() > 0);
		}
		return dst.position() - sPos;
	}

	/**
	 * Encrypts data and writes it to the channel.
	 * 
	 * @param src
	 *            the data to write
	 * @return the number of bytes written
	 * @throws IOException
	 *             if there are errors.
	 * @throws IllegalStateException
	 *             if the initial handshake isn't complete
	 */
	@Override
	public int write(ByteBuffer src) throws IOException {
		if (!initialHandshakeDone) {
			throw new IllegalStateException("Initial handshake incomplete");
		}
		if (!flush()) {
			return 0;
		}
		int written = 0;
		outer: while (src.hasRemaining()) {
			outNet.clear(); // we flushed it
			SSLEngineResult result = engine.wrap(src, outNet);
			outNet.flip();
			switch (result.getStatus()) {
				case BUFFER_UNDERFLOW:
					break outer; // not enough left to send (prob won't
				// happen - padding)
				case BUFFER_OVERFLOW:

					if (!flush()) {
						break outer; // can't remake while still have
						// stuff to write
					}

					int size = engine.getSession().getPacketBufferSize();
					if (outNet.capacity() < size) {
						outNet = ByteBuffer.allocate(size);
					} else { // shouldn't happen
						throw new RuntimeException(hsStatus
								+ " encountered when handshaking");
					}
					continue; // try again
				case CLOSED:
					throw new IOException("SSLEngine closed");
				case OK:
					checkHandshake();
					break;
			}
			if (!flush()) {
				break;
			}
		}
		return written;
	}

	private boolean hasRemaining(ByteBuffer[] src) {
		for (ByteBuffer b : src) {
			if (b.hasRemaining()) {
				return true;
			}
		}
		return false;
	}

	/**
	 * Encrypts data and writes it to the channel.
	 * 
	 * @param src
	 *            the data to write
	 * @return the number of bytes written
	 * @throws IOException
	 *             if there are errors.
	 * @throws IllegalStateException
	 *             if the initial handshake isn't complete
	 */
	@Override
	public long write(ByteBuffer[] src) throws IOException {
		if (!initialHandshakeDone) {
			throw new IllegalStateException("Initial handshake incomplete");
		}
		if (!flush()) {
			return 0;
		}
		int written = 0;
		outer: while (hasRemaining(src)) {
			outNet.clear(); // we flushed it
			SSLEngineResult result = engine.wrap(src, outNet);
			outNet.flip();
			switch (result.getStatus()) {
				case BUFFER_UNDERFLOW:
					break outer; // not enough left to send (prob won't
				// happen - padding)
				case BUFFER_OVERFLOW:

					if (!flush()) {
						break outer; // can't remake while still have
						// stuff to write
					}

					int size = engine.getSession().getPacketBufferSize();
					if (outNet.capacity() < size) {
						outNet = ByteBuffer.allocate(size);
					} else { // shouldn't happen
						throw new RuntimeException(hsStatus
								+ " encountered when handshaking");
					}
					continue; // try again
				case CLOSED:
					throw new IOException("SSLEngine closed");
				case OK:
					checkHandshake();
					break;
			}
			if (!flush()) {
				break;
			}
		}
		return written;
	}

	private void checkHandshake() throws IOException {
		// Thread.dumpStack();
		// System.out.println(engine.getHandshakeStatus());
		outer: while (true) {
			switch (engine.getHandshakeStatus()) {
				case NOT_HANDSHAKING:
					initialHandshakeDone = true;
					handshaking = false;
					rereg(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
					return;
				case FINISHED:
					// this shouldn't happen, I don't think. If it does, say
					// where.
					System.err.println("Detected FINISHED in checkHandshake");
					Thread.dumpStack();
					break outer;
				case NEED_TASK:
					if (isTasking) {
						while (isTasking) { // TODO: deal with by reg?
							Thread.yield();
							try {
								Thread.sleep(1);
							} catch (InterruptedException ex) {
								// TODO Auto-generated catch block
								ex.printStackTrace();
							}
						}
						break;
					}
					doTasks();
					break;
				case NEED_WRAP:
					rereg(SelectionKey.OP_WRITE);
					break outer;
				case NEED_UNWRAP:
					rereg(SelectionKey.OP_READ);
					break outer;

			}
		}
		handshaking = true;
		hsStatus = engine.getHandshakeStatus();
	}

	/**
	 * @return true if the channel is open and no shutdown message has been
	 *         recieved.
	 */
	@Override
	public boolean isOpen() {
		return super.isOpen() && !engine.isInboundDone();
	}
}
Comments
Locked Post
New comments cannot be posted to this locked post.
Post Details
Locked on Sep 10 2013
Added on Jun 23 2005
4 comments
6,356 views