001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032import java.util.concurrent.CountDownLatch; 033 034import javax.net.SocketFactory; 035import javax.net.ssl.SSLContext; 036import javax.net.ssl.SSLEngine; 037import javax.net.ssl.SSLEngineResult; 038import javax.net.ssl.SSLEngineResult.HandshakeStatus; 039import javax.net.ssl.SSLParameters; 040import javax.net.ssl.SSLPeerUnverifiedException; 041import javax.net.ssl.SSLSession; 042 043import org.apache.activemq.command.ConnectionInfo; 044import org.apache.activemq.openwire.OpenWireFormat; 045import org.apache.activemq.thread.TaskRunnerFactory; 046import org.apache.activemq.util.IOExceptionSupport; 047import org.apache.activemq.util.ServiceStopper; 048import org.apache.activemq.wireformat.WireFormat; 049import org.slf4j.Logger; 050import org.slf4j.LoggerFactory; 051 052public class NIOSSLTransport extends NIOTransport { 053 054 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 055 056 protected boolean needClientAuth; 057 protected boolean wantClientAuth; 058 protected String[] enabledCipherSuites; 059 protected String[] enabledProtocols; 060 protected boolean verifyHostName = false; 061 062 protected SSLContext sslContext; 063 protected SSLEngine sslEngine; 064 protected SSLSession sslSession; 065 066 protected volatile boolean handshakeInProgress = false; 067 protected SSLEngineResult.Status status = null; 068 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 069 protected TaskRunnerFactory taskRunnerFactory; 070 071 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 072 super(wireFormat, socketFactory, remoteLocation, localLocation); 073 } 074 075 public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer, 076 ByteBuffer inputBuffer) throws IOException { 077 super(wireFormat, socket, initBuffer); 078 this.sslEngine = engine; 079 if (engine != null) { 080 this.sslSession = engine.getSession(); 081 } 082 this.inputBuffer = inputBuffer; 083 } 084 085 public void setSslContext(SSLContext sslContext) { 086 this.sslContext = sslContext; 087 } 088 089 volatile boolean hasSslEngine = false; 090 091 @Override 092 protected void initializeStreams() throws IOException { 093 if (sslEngine != null) { 094 hasSslEngine = true; 095 } 096 NIOOutputStream outputStream = null; 097 try { 098 channel = socket.getChannel(); 099 channel.configureBlocking(false); 100 101 if (sslContext == null) { 102 sslContext = SSLContext.getDefault(); 103 } 104 105 String remoteHost = null; 106 int remotePort = -1; 107 108 try { 109 URI remoteAddress = new URI(this.getRemoteAddress()); 110 remoteHost = remoteAddress.getHost(); 111 remotePort = remoteAddress.getPort(); 112 } catch (Exception e) { 113 } 114 115 // initialize engine, the initial sslSession we get will need to be 116 // updated once the ssl handshake process is completed. 117 if (!hasSslEngine) { 118 if (remoteHost != null && remotePort != -1) { 119 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 120 } else { 121 sslEngine = sslContext.createSSLEngine(); 122 } 123 124 if (verifyHostName) { 125 SSLParameters sslParams = new SSLParameters(); 126 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 127 sslEngine.setSSLParameters(sslParams); 128 } 129 130 sslEngine.setUseClientMode(false); 131 if (enabledCipherSuites != null) { 132 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 133 } 134 135 if (enabledProtocols != null) { 136 sslEngine.setEnabledProtocols(enabledProtocols); 137 } 138 139 if (wantClientAuth) { 140 sslEngine.setWantClientAuth(wantClientAuth); 141 } 142 143 if (needClientAuth) { 144 sslEngine.setNeedClientAuth(needClientAuth); 145 } 146 147 sslSession = sslEngine.getSession(); 148 149 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 150 inputBuffer.clear(); 151 } 152 153 outputStream = new NIOOutputStream(channel); 154 outputStream.setEngine(sslEngine); 155 this.dataOut = new DataOutputStream(outputStream); 156 this.buffOut = outputStream; 157 158 //If the sslEngine was not passed in, then handshake 159 if (!hasSslEngine) { 160 sslEngine.beginHandshake(); 161 } 162 handshakeStatus = sslEngine.getHandshakeStatus(); 163 if (!hasSslEngine) { 164 doHandshake(); 165 } 166 167 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 168 @Override 169 public void onSelect(SelectorSelection selection) { 170 try { 171 initialized.await(); 172 } catch (InterruptedException error) { 173 onException(IOExceptionSupport.create(error)); 174 } 175 serviceRead(); 176 } 177 178 @Override 179 public void onError(SelectorSelection selection, Throwable error) { 180 if (error instanceof IOException) { 181 onException((IOException) error); 182 } else { 183 onException(IOExceptionSupport.create(error)); 184 } 185 } 186 }); 187 doInit(); 188 189 } catch (Exception e) { 190 try { 191 if(outputStream != null) { 192 outputStream.close(); 193 } 194 super.closeStreams(); 195 } catch (Exception ex) {} 196 throw new IOException(e); 197 } 198 } 199 200 final protected CountDownLatch initialized = new CountDownLatch(1); 201 202 protected void doInit() throws Exception { 203 taskRunnerFactory.execute(new Runnable() { 204 205 @Override 206 public void run() { 207 //Need to start in new thread to let startup finish first 208 //We can trigger a read because we know the channel is ready since the SSL handshake 209 //already happened 210 serviceRead(); 211 initialized.countDown(); 212 } 213 }); 214 } 215 216 //Only used for the auto transport to abort the openwire init method early if already initialized 217 boolean openWireInititialized = false; 218 219 protected void doOpenWireInit() throws Exception { 220 //Do this later to let wire format negotiation happen 221 if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) { 222 initBuffer.buffer.flip(); 223 if (initBuffer.buffer.hasRemaining()) { 224 nextFrameSize = -1; 225 receiveCounter += initBuffer.readSize; 226 processCommand(initBuffer.buffer); 227 processCommand(initBuffer.buffer); 228 initBuffer.buffer.clear(); 229 openWireInititialized = true; 230 } 231 } 232 } 233 234 protected void finishHandshake() throws Exception { 235 if (handshakeInProgress) { 236 handshakeInProgress = false; 237 nextFrameSize = -1; 238 239 // Once handshake completes we need to ask for the now real sslSession 240 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 241 // cipher suite. 242 sslSession = sslEngine.getSession(); 243 } 244 } 245 246 //Prevent concurrent access to SSLEngine 247 @Override 248 public synchronized void serviceRead() { 249 try { 250 if (handshakeInProgress) { 251 doHandshake(); 252 } 253 254 doOpenWireInit(); 255 256 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 257 plain.position(plain.limit()); 258 259 while (true) { 260 //If the transport was already stopped then break 261 if (this.isStopped()) { 262 return; 263 } 264 265 if (!plain.hasRemaining()) { 266 267 int readCount = secureRead(plain); 268 269 if (readCount == 0) { 270 break; 271 } 272 273 // channel is closed, cleanup 274 if (readCount == -1) { 275 onException(new EOFException()); 276 selection.close(); 277 break; 278 } 279 280 receiveCounter += readCount; 281 } 282 283 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 284 processCommand(plain); 285 } 286 } 287 } catch (IOException e) { 288 onException(e); 289 } catch (Throwable e) { 290 onException(IOExceptionSupport.create(e)); 291 } 292 } 293 294 protected void processCommand(ByteBuffer plain) throws Exception { 295 296 // Are we waiting for the next Command or are we building on the current one 297 if (nextFrameSize == -1) { 298 299 // We can get small packets that don't give us enough for the frame size 300 // so allocate enough for the initial size value and 301 if (plain.remaining() < Integer.SIZE) { 302 if (currentBuffer == null) { 303 currentBuffer = ByteBuffer.allocate(4); 304 } 305 306 // Go until we fill the integer sized current buffer. 307 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 308 currentBuffer.put(plain.get()); 309 } 310 311 // Didn't we get enough yet to figure out next frame size. 312 if (currentBuffer.hasRemaining()) { 313 return; 314 } else { 315 currentBuffer.flip(); 316 nextFrameSize = currentBuffer.getInt(); 317 } 318 319 } else { 320 321 // Either we are completing a previous read of the next frame size or its 322 // fully contained in plain already. 323 if (currentBuffer != null) { 324 325 // Finish the frame size integer read and get from the current buffer. 326 while (currentBuffer.hasRemaining()) { 327 currentBuffer.put(plain.get()); 328 } 329 330 currentBuffer.flip(); 331 nextFrameSize = currentBuffer.getInt(); 332 333 } else { 334 nextFrameSize = plain.getInt(); 335 } 336 } 337 338 if (wireFormat instanceof OpenWireFormat) { 339 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 340 if (nextFrameSize > maxFrameSize) { 341 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 342 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 343 } 344 } 345 346 // now we got the data, lets reallocate and store the size for the marshaler. 347 // if there's more data in plain, then the next call will start processing it. 348 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 349 currentBuffer.putInt(nextFrameSize); 350 351 } else { 352 // If its all in one read then we can just take it all, otherwise take only 353 // the current frame size and the next iteration starts a new command. 354 if (currentBuffer != null) { 355 if (currentBuffer.remaining() >= plain.remaining()) { 356 currentBuffer.put(plain); 357 } else { 358 byte[] fill = new byte[currentBuffer.remaining()]; 359 plain.get(fill); 360 currentBuffer.put(fill); 361 } 362 363 // Either we have enough data for a new command or we have to wait for some more. 364 if (currentBuffer.hasRemaining()) { 365 return; 366 } else { 367 currentBuffer.flip(); 368 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 369 doConsume(command); 370 nextFrameSize = -1; 371 currentBuffer = null; 372 } 373 } 374 } 375 } 376 377 protected int secureRead(ByteBuffer plain) throws Exception { 378 379 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 380 int bytesRead = channel.read(inputBuffer); 381 382 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 383 return 0; 384 } 385 386 if (bytesRead == -1) { 387 sslEngine.closeInbound(); 388 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 389 return -1; 390 } 391 } 392 } 393 394 plain.clear(); 395 396 inputBuffer.flip(); 397 SSLEngineResult res; 398 do { 399 res = sslEngine.unwrap(inputBuffer, plain); 400 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 401 && res.bytesProduced() == 0); 402 403 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 404 finishHandshake(); 405 } 406 407 status = res.getStatus(); 408 handshakeStatus = res.getHandshakeStatus(); 409 410 // TODO deal with BUFFER_OVERFLOW 411 412 if (status == SSLEngineResult.Status.CLOSED) { 413 sslEngine.closeInbound(); 414 return -1; 415 } 416 417 inputBuffer.compact(); 418 plain.flip(); 419 420 return plain.remaining(); 421 } 422 423 protected void doHandshake() throws Exception { 424 handshakeInProgress = true; 425 Selector selector = null; 426 SelectionKey key = null; 427 boolean readable = true; 428 try { 429 while (true) { 430 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 431 switch (handshakeStatus) { 432 case NEED_UNWRAP: 433 if (readable) { 434 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 435 } 436 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 437 long now = System.currentTimeMillis(); 438 if (selector == null) { 439 selector = Selector.open(); 440 key = channel.register(selector, SelectionKey.OP_READ); 441 } else { 442 key.interestOps(SelectionKey.OP_READ); 443 } 444 int keyCount = selector.select(this.getSoTimeout()); 445 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 446 throw new SocketTimeoutException("Timeout during handshake"); 447 } 448 readable = key.isReadable(); 449 } 450 break; 451 case NEED_TASK: 452 Runnable task; 453 while ((task = sslEngine.getDelegatedTask()) != null) { 454 task.run(); 455 } 456 break; 457 case NEED_WRAP: 458 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 459 break; 460 case FINISHED: 461 case NOT_HANDSHAKING: 462 finishHandshake(); 463 return; 464 } 465 } 466 } finally { 467 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 468 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 469 } 470 } 471 472 @Override 473 protected void doStart() throws Exception { 474 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 475 // no need to init as we can delay that until demand (eg in doHandshake) 476 super.doStart(); 477 } 478 479 @Override 480 protected void doStop(ServiceStopper stopper) throws Exception { 481 initialized.countDown(); 482 483 if (taskRunnerFactory != null) { 484 taskRunnerFactory.shutdownNow(); 485 taskRunnerFactory = null; 486 } 487 if (channel != null) { 488 channel.close(); 489 channel = null; 490 } 491 super.doStop(stopper); 492 } 493 494 /** 495 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 496 * 497 * @param command 498 * The Command coming in. 499 */ 500 @Override 501 public void doConsume(Object command) { 502 if (command instanceof ConnectionInfo) { 503 ConnectionInfo connectionInfo = (ConnectionInfo) command; 504 connectionInfo.setTransportContext(getPeerCertificates()); 505 } 506 super.doConsume(command); 507 } 508 509 /** 510 * @return peer certificate chain associated with the ssl socket 511 */ 512 @Override 513 public X509Certificate[] getPeerCertificates() { 514 515 X509Certificate[] clientCertChain = null; 516 try { 517 if (sslEngine.getSession() != null) { 518 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 519 } 520 } catch (SSLPeerUnverifiedException e) { 521 if (LOG.isTraceEnabled()) { 522 LOG.trace("Failed to get peer certificates.", e); 523 } 524 } 525 526 return clientCertChain; 527 } 528 529 public boolean isNeedClientAuth() { 530 return needClientAuth; 531 } 532 533 public void setNeedClientAuth(boolean needClientAuth) { 534 this.needClientAuth = needClientAuth; 535 } 536 537 public boolean isWantClientAuth() { 538 return wantClientAuth; 539 } 540 541 public void setWantClientAuth(boolean wantClientAuth) { 542 this.wantClientAuth = wantClientAuth; 543 } 544 545 public String[] getEnabledCipherSuites() { 546 return enabledCipherSuites; 547 } 548 549 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 550 this.enabledCipherSuites = enabledCipherSuites; 551 } 552 553 public String[] getEnabledProtocols() { 554 return enabledProtocols; 555 } 556 557 public void setEnabledProtocols(String[] enabledProtocols) { 558 this.enabledProtocols = enabledProtocols; 559 } 560 561 public boolean isVerifyHostName() { 562 return verifyHostName; 563 } 564 565 public void setVerifyHostName(boolean verifyHostName) { 566 this.verifyHostName = verifyHostName; 567 } 568}