001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataOutputStream;
021import java.io.EOFException;
022import java.io.IOException;
023import java.net.Socket;
024import java.net.URI;
025import java.net.UnknownHostException;
026import java.nio.ByteBuffer;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import javax.net.SocketFactory;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLEngine;
032import javax.net.ssl.SSLEngineResult;
033import javax.net.ssl.SSLParameters;
034
035import org.apache.activemq.thread.TaskRunnerFactory;
036import org.apache.activemq.util.IOExceptionSupport;
037import org.apache.activemq.util.ServiceStopper;
038import org.apache.activemq.wireformat.WireFormat;
039
040/**
041 * This transport initializes the SSLEngine and reads the first command before
042 * handing off to the detected transport.
043 *
044 */
045public class AutoInitNioSSLTransport extends NIOSSLTransport {
046
047    public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
048        super(wireFormat, socketFactory, remoteLocation, localLocation);
049    }
050
051    public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
052        super(wireFormat, socket, null, null, null);
053    }
054
055    @Override
056    public void setSslContext(SSLContext sslContext) {
057        this.sslContext = sslContext;
058    }
059
060    public ByteBuffer getInputBuffer() {
061        return this.inputBuffer;
062    }
063
064    @Override
065    protected void initializeStreams() throws IOException {
066        NIOOutputStream outputStream = null;
067        try {
068            channel = socket.getChannel();
069            channel.configureBlocking(false);
070
071            if (sslContext == null) {
072                sslContext = SSLContext.getDefault();
073            }
074
075            String remoteHost = null;
076            int remotePort = -1;
077
078            try {
079                URI remoteAddress = new URI(this.getRemoteAddress());
080                remoteHost = remoteAddress.getHost();
081                remotePort = remoteAddress.getPort();
082            } catch (Exception e) {
083            }
084
085            // initialize engine, the initial sslSession we get will need to be
086            // updated once the ssl handshake process is completed.
087            if (remoteHost != null && remotePort != -1) {
088                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
089            } else {
090                sslEngine = sslContext.createSSLEngine();
091            }
092
093            if (verifyHostName) {
094                SSLParameters sslParams = new SSLParameters();
095                sslParams.setEndpointIdentificationAlgorithm("HTTPS");
096                sslEngine.setSSLParameters(sslParams);
097            }
098
099            sslEngine.setUseClientMode(false);
100            if (enabledCipherSuites != null) {
101                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
102            }
103
104            if (enabledProtocols != null) {
105                sslEngine.setEnabledProtocols(enabledProtocols);
106            }
107
108            if (wantClientAuth) {
109                sslEngine.setWantClientAuth(wantClientAuth);
110            }
111
112            if (needClientAuth) {
113                sslEngine.setNeedClientAuth(needClientAuth);
114            }
115
116            sslSession = sslEngine.getSession();
117
118            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
119            inputBuffer.clear();
120
121            outputStream = new NIOOutputStream(channel);
122            outputStream.setEngine(sslEngine);
123            this.dataOut = new DataOutputStream(outputStream);
124            this.buffOut = outputStream;
125            sslEngine.beginHandshake();
126            handshakeStatus = sslEngine.getHandshakeStatus();
127            doHandshake();
128
129        } catch (Exception e) {
130            try {
131                if(outputStream != null) {
132                    outputStream.close();
133                }
134                super.closeStreams();
135            } catch (Exception ex) {}
136            throw new IOException(e);
137        }
138    }
139
140    @Override
141    protected void doOpenWireInit() throws Exception {
142
143    }
144
145    public SSLEngine getSslSession() {
146        return this.sslEngine;
147    }
148
149    private volatile byte[] readData;
150
151    private final AtomicInteger readSize = new AtomicInteger();
152
153    public byte[] getReadData() {
154        return readData != null ? readData : new byte[0];
155    }
156
157    public AtomicInteger getReadSize() {
158        return readSize;
159    }
160
161    //Prevent concurrent access to SSLEngine
162    @Override
163    public synchronized void serviceRead() {
164        try {
165            if (handshakeInProgress) {
166                doHandshake();
167            }
168
169            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
170            plain.position(plain.limit());
171
172            while (true) {
173                //If the transport was already stopped then break
174                if (this.isStopped()) {
175                    return;
176                }
177
178                if (!plain.hasRemaining()) {
179                    int readCount = secureRead(plain);
180
181                    if (readCount == 0) {
182                        break;
183                    }
184
185                    // channel is closed, cleanup
186                    if (readCount == -1) {
187                        onException(new EOFException());
188                        break;
189                    }
190
191                    receiveCounter += readCount;
192                    readSize.addAndGet(readCount);
193                }
194
195                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
196                    processCommand(plain);
197                    //we have received enough bytes to detect the protocol
198                    if (receiveCounter >= 8) {
199                        break;
200                    }
201                }
202            }
203        } catch (IOException e) {
204            onException(e);
205        } catch (Throwable e) {
206            onException(IOExceptionSupport.create(e));
207        }
208    }
209
210    @Override
211    protected void processCommand(ByteBuffer plain) throws Exception {
212        ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
213        if (readData != null) {
214            newBuffer.put(readData);
215        }
216        newBuffer.put(plain);
217        newBuffer.flip();
218        readData = newBuffer.array();
219    }
220
221
222    @Override
223    public void doStart() throws Exception {
224        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
225        // no need to init as we can delay that until demand (eg in doHandshake)
226        connect();
227    }
228
229
230    @Override
231    protected void doStop(ServiceStopper stopper) throws Exception {
232        if (taskRunnerFactory != null) {
233            taskRunnerFactory.shutdownNow();
234            taskRunnerFactory = null;
235        }
236    }
237
238
239}