001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataOutputStream; 021import java.io.EOFException; 022import java.io.IOException; 023import java.net.Socket; 024import java.net.URI; 025import java.net.UnknownHostException; 026import java.nio.ByteBuffer; 027import java.util.concurrent.atomic.AtomicInteger; 028 029import javax.net.SocketFactory; 030import javax.net.ssl.SSLContext; 031import javax.net.ssl.SSLEngine; 032import javax.net.ssl.SSLEngineResult; 033import javax.net.ssl.SSLParameters; 034 035import org.apache.activemq.thread.TaskRunnerFactory; 036import org.apache.activemq.util.IOExceptionSupport; 037import org.apache.activemq.util.ServiceStopper; 038import org.apache.activemq.wireformat.WireFormat; 039 040/** 041 * This transport initializes the SSLEngine and reads the first command before 042 * handing off to the detected transport. 043 * 044 */ 045public class AutoInitNioSSLTransport extends NIOSSLTransport { 046 047 public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 048 super(wireFormat, socketFactory, remoteLocation, localLocation); 049 } 050 051 public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 052 super(wireFormat, socket, null, null, null); 053 } 054 055 @Override 056 public void setSslContext(SSLContext sslContext) { 057 this.sslContext = sslContext; 058 } 059 060 public ByteBuffer getInputBuffer() { 061 return this.inputBuffer; 062 } 063 064 @Override 065 protected void initializeStreams() throws IOException { 066 NIOOutputStream outputStream = null; 067 try { 068 channel = socket.getChannel(); 069 channel.configureBlocking(false); 070 071 if (sslContext == null) { 072 sslContext = SSLContext.getDefault(); 073 } 074 075 String remoteHost = null; 076 int remotePort = -1; 077 078 try { 079 URI remoteAddress = new URI(this.getRemoteAddress()); 080 remoteHost = remoteAddress.getHost(); 081 remotePort = remoteAddress.getPort(); 082 } catch (Exception e) { 083 } 084 085 // initialize engine, the initial sslSession we get will need to be 086 // updated once the ssl handshake process is completed. 087 if (remoteHost != null && remotePort != -1) { 088 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 089 } else { 090 sslEngine = sslContext.createSSLEngine(); 091 } 092 093 if (verifyHostName) { 094 SSLParameters sslParams = new SSLParameters(); 095 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 096 sslEngine.setSSLParameters(sslParams); 097 } 098 099 sslEngine.setUseClientMode(false); 100 if (enabledCipherSuites != null) { 101 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 102 } 103 104 if (enabledProtocols != null) { 105 sslEngine.setEnabledProtocols(enabledProtocols); 106 } 107 108 if (wantClientAuth) { 109 sslEngine.setWantClientAuth(wantClientAuth); 110 } 111 112 if (needClientAuth) { 113 sslEngine.setNeedClientAuth(needClientAuth); 114 } 115 116 sslSession = sslEngine.getSession(); 117 118 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 119 inputBuffer.clear(); 120 121 outputStream = new NIOOutputStream(channel); 122 outputStream.setEngine(sslEngine); 123 this.dataOut = new DataOutputStream(outputStream); 124 this.buffOut = outputStream; 125 sslEngine.beginHandshake(); 126 handshakeStatus = sslEngine.getHandshakeStatus(); 127 doHandshake(); 128 129 } catch (Exception e) { 130 try { 131 if(outputStream != null) { 132 outputStream.close(); 133 } 134 super.closeStreams(); 135 } catch (Exception ex) {} 136 throw new IOException(e); 137 } 138 } 139 140 @Override 141 protected void doOpenWireInit() throws Exception { 142 143 } 144 145 public SSLEngine getSslSession() { 146 return this.sslEngine; 147 } 148 149 private volatile byte[] readData; 150 151 private final AtomicInteger readSize = new AtomicInteger(); 152 153 public byte[] getReadData() { 154 return readData != null ? readData : new byte[0]; 155 } 156 157 public AtomicInteger getReadSize() { 158 return readSize; 159 } 160 161 //Prevent concurrent access to SSLEngine 162 @Override 163 public synchronized void serviceRead() { 164 try { 165 if (handshakeInProgress) { 166 doHandshake(); 167 } 168 169 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 170 plain.position(plain.limit()); 171 172 while (true) { 173 //If the transport was already stopped then break 174 if (this.isStopped()) { 175 return; 176 } 177 178 if (!plain.hasRemaining()) { 179 int readCount = secureRead(plain); 180 181 if (readCount == 0) { 182 break; 183 } 184 185 // channel is closed, cleanup 186 if (readCount == -1) { 187 onException(new EOFException()); 188 break; 189 } 190 191 receiveCounter += readCount; 192 readSize.addAndGet(readCount); 193 } 194 195 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 196 processCommand(plain); 197 //we have received enough bytes to detect the protocol 198 if (receiveCounter >= 8) { 199 break; 200 } 201 } 202 } 203 } catch (IOException e) { 204 onException(e); 205 } catch (Throwable e) { 206 onException(IOExceptionSupport.create(e)); 207 } 208 } 209 210 @Override 211 protected void processCommand(ByteBuffer plain) throws Exception { 212 ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter); 213 if (readData != null) { 214 newBuffer.put(readData); 215 } 216 newBuffer.put(plain); 217 newBuffer.flip(); 218 readData = newBuffer.array(); 219 } 220 221 222 @Override 223 public void doStart() throws Exception { 224 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 225 // no need to init as we can delay that until demand (eg in doHandshake) 226 connect(); 227 } 228 229 230 @Override 231 protected void doStop(ServiceStopper stopper) throws Exception { 232 if (taskRunnerFactory != null) { 233 taskRunnerFactory.shutdownNow(); 234 taskRunnerFactory = null; 235 } 236 } 237 238 239}