001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032import java.util.concurrent.CountDownLatch; 033 034import javax.net.SocketFactory; 035import javax.net.ssl.SSLContext; 036import javax.net.ssl.SSLEngine; 037import javax.net.ssl.SSLEngineResult; 038import javax.net.ssl.SSLEngineResult.HandshakeStatus; 039import javax.net.ssl.SSLParameters; 040import javax.net.ssl.SSLPeerUnverifiedException; 041import javax.net.ssl.SSLSession; 042 043import org.apache.activemq.MaxFrameSizeExceededException; 044import org.apache.activemq.command.ConnectionInfo; 045import org.apache.activemq.openwire.OpenWireFormat; 046import org.apache.activemq.thread.TaskRunnerFactory; 047import org.apache.activemq.util.IOExceptionSupport; 048import org.apache.activemq.util.ServiceStopper; 049import org.apache.activemq.wireformat.WireFormat; 050import org.slf4j.Logger; 051import org.slf4j.LoggerFactory; 052 053public class NIOSSLTransport extends NIOTransport { 054 055 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 056 057 protected boolean needClientAuth; 058 protected boolean wantClientAuth; 059 protected String[] enabledCipherSuites; 060 protected String[] enabledProtocols; 061 protected boolean verifyHostName = false; 062 063 protected SSLContext sslContext; 064 protected SSLEngine sslEngine; 065 protected SSLSession sslSession; 066 067 protected volatile boolean handshakeInProgress = false; 068 protected SSLEngineResult.Status status = null; 069 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 070 protected TaskRunnerFactory taskRunnerFactory; 071 072 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 073 super(wireFormat, socketFactory, remoteLocation, localLocation); 074 } 075 076 public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer, 077 ByteBuffer inputBuffer) throws IOException { 078 super(wireFormat, socket, initBuffer); 079 this.sslEngine = engine; 080 if (engine != null) { 081 this.sslSession = engine.getSession(); 082 } 083 this.inputBuffer = inputBuffer; 084 } 085 086 public void setSslContext(SSLContext sslContext) { 087 this.sslContext = sslContext; 088 } 089 090 volatile boolean hasSslEngine = false; 091 092 @Override 093 protected void initializeStreams() throws IOException { 094 if (sslEngine != null) { 095 hasSslEngine = true; 096 } 097 NIOOutputStream outputStream = null; 098 try { 099 channel = socket.getChannel(); 100 channel.configureBlocking(false); 101 102 if (sslContext == null) { 103 sslContext = SSLContext.getDefault(); 104 } 105 106 String remoteHost = null; 107 int remotePort = -1; 108 109 try { 110 URI remoteAddress = new URI(this.getRemoteAddress()); 111 remoteHost = remoteAddress.getHost(); 112 remotePort = remoteAddress.getPort(); 113 } catch (Exception e) { 114 } 115 116 // initialize engine, the initial sslSession we get will need to be 117 // updated once the ssl handshake process is completed. 118 if (!hasSslEngine) { 119 if (remoteHost != null && remotePort != -1) { 120 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 121 } else { 122 sslEngine = sslContext.createSSLEngine(); 123 } 124 125 if (verifyHostName) { 126 SSLParameters sslParams = new SSLParameters(); 127 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 128 sslEngine.setSSLParameters(sslParams); 129 } 130 131 sslEngine.setUseClientMode(false); 132 if (enabledCipherSuites != null) { 133 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 134 } 135 136 if (enabledProtocols != null) { 137 sslEngine.setEnabledProtocols(enabledProtocols); 138 } 139 140 if (wantClientAuth) { 141 sslEngine.setWantClientAuth(wantClientAuth); 142 } 143 144 if (needClientAuth) { 145 sslEngine.setNeedClientAuth(needClientAuth); 146 } 147 148 sslSession = sslEngine.getSession(); 149 150 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 151 inputBuffer.clear(); 152 } 153 154 outputStream = new NIOOutputStream(channel); 155 outputStream.setEngine(sslEngine); 156 this.dataOut = new DataOutputStream(outputStream); 157 this.buffOut = outputStream; 158 159 //If the sslEngine was not passed in, then handshake 160 if (!hasSslEngine) { 161 sslEngine.beginHandshake(); 162 } 163 handshakeStatus = sslEngine.getHandshakeStatus(); 164 if (!hasSslEngine) { 165 doHandshake(); 166 } 167 168 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 169 @Override 170 public void onSelect(SelectorSelection selection) { 171 try { 172 initialized.await(); 173 } catch (InterruptedException error) { 174 onException(IOExceptionSupport.create(error)); 175 } 176 serviceRead(); 177 } 178 179 @Override 180 public void onError(SelectorSelection selection, Throwable error) { 181 if (error instanceof IOException) { 182 onException((IOException) error); 183 } else { 184 onException(IOExceptionSupport.create(error)); 185 } 186 } 187 }); 188 doInit(); 189 190 } catch (Exception e) { 191 try { 192 if(outputStream != null) { 193 outputStream.close(); 194 } 195 super.closeStreams(); 196 } catch (Exception ex) {} 197 throw new IOException(e); 198 } 199 } 200 201 final protected CountDownLatch initialized = new CountDownLatch(1); 202 203 protected void doInit() throws Exception { 204 taskRunnerFactory.execute(new Runnable() { 205 206 @Override 207 public void run() { 208 //Need to start in new thread to let startup finish first 209 //We can trigger a read because we know the channel is ready since the SSL handshake 210 //already happened 211 serviceRead(); 212 initialized.countDown(); 213 } 214 }); 215 } 216 217 //Only used for the auto transport to abort the openwire init method early if already initialized 218 boolean openWireInititialized = false; 219 220 protected void doOpenWireInit() throws Exception { 221 //Do this later to let wire format negotiation happen 222 if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) { 223 initBuffer.buffer.flip(); 224 if (initBuffer.buffer.hasRemaining()) { 225 nextFrameSize = -1; 226 receiveCounter += initBuffer.readSize; 227 processCommand(initBuffer.buffer); 228 processCommand(initBuffer.buffer); 229 initBuffer.buffer.clear(); 230 openWireInititialized = true; 231 } 232 } 233 } 234 235 protected void finishHandshake() throws Exception { 236 if (handshakeInProgress) { 237 handshakeInProgress = false; 238 nextFrameSize = -1; 239 240 // Once handshake completes we need to ask for the now real sslSession 241 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 242 // cipher suite. 243 sslSession = sslEngine.getSession(); 244 } 245 } 246 247 @Override 248 public void serviceRead() { 249 try { 250 if (handshakeInProgress) { 251 doHandshake(); 252 } 253 254 doOpenWireInit(); 255 256 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 257 plain.position(plain.limit()); 258 259 while (true) { 260 //If the transport was already stopped then break 261 if (this.isStopped()) { 262 return; 263 } 264 265 if (!plain.hasRemaining()) { 266 267 int readCount = secureRead(plain); 268 269 if (readCount == 0) { 270 break; 271 } 272 273 // channel is closed, cleanup 274 if (readCount == -1) { 275 onException(new EOFException()); 276 selection.close(); 277 break; 278 } 279 280 receiveCounter += readCount; 281 } 282 283 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 284 processCommand(plain); 285 } 286 } 287 } catch (IOException e) { 288 onException(e); 289 } catch (Throwable e) { 290 onException(IOExceptionSupport.create(e)); 291 } 292 } 293 294 protected void processCommand(ByteBuffer plain) throws Exception { 295 296 // Are we waiting for the next Command or are we building on the current one 297 if (nextFrameSize == -1) { 298 299 // We can get small packets that don't give us enough for the frame size 300 // so allocate enough for the initial size value and 301 if (plain.remaining() < Integer.SIZE) { 302 if (currentBuffer == null) { 303 currentBuffer = ByteBuffer.allocate(4); 304 } 305 306 // Go until we fill the integer sized current buffer. 307 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 308 currentBuffer.put(plain.get()); 309 } 310 311 // Didn't we get enough yet to figure out next frame size. 312 if (currentBuffer.hasRemaining()) { 313 return; 314 } else { 315 currentBuffer.flip(); 316 nextFrameSize = currentBuffer.getInt(); 317 } 318 319 } else { 320 321 // Either we are completing a previous read of the next frame size or its 322 // fully contained in plain already. 323 if (currentBuffer != null) { 324 325 // Finish the frame size integer read and get from the current buffer. 326 while (currentBuffer.hasRemaining()) { 327 currentBuffer.put(plain.get()); 328 } 329 330 currentBuffer.flip(); 331 nextFrameSize = currentBuffer.getInt(); 332 333 } else { 334 nextFrameSize = plain.getInt(); 335 } 336 } 337 338 if (wireFormat instanceof OpenWireFormat) { 339 OpenWireFormat openWireFormat = (OpenWireFormat) wireFormat; 340 long maxFrameSize = openWireFormat.getMaxFrameSize(); 341 342 if (openWireFormat.isMaxFrameSizeEnabled() && nextFrameSize > maxFrameSize) { 343 throw new MaxFrameSizeExceededException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 344 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 345 } 346 } 347 348 // now we got the data, lets reallocate and store the size for the marshaler. 349 // if there's more data in plain, then the next call will start processing it. 350 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 351 currentBuffer.putInt(nextFrameSize); 352 353 } else { 354 // If its all in one read then we can just take it all, otherwise take only 355 // the current frame size and the next iteration starts a new command. 356 if (currentBuffer != null) { 357 if (currentBuffer.remaining() >= plain.remaining()) { 358 currentBuffer.put(plain); 359 } else { 360 byte[] fill = new byte[currentBuffer.remaining()]; 361 plain.get(fill); 362 currentBuffer.put(fill); 363 } 364 365 // Either we have enough data for a new command or we have to wait for some more. 366 if (currentBuffer.hasRemaining()) { 367 return; 368 } else { 369 currentBuffer.flip(); 370 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 371 doConsume(command); 372 nextFrameSize = -1; 373 currentBuffer = null; 374 } 375 } 376 } 377 } 378 379 //Prevent concurrent access while reading from the channel 380 protected synchronized int secureRead(ByteBuffer plain) throws Exception { 381 382 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 383 int bytesRead = channel.read(inputBuffer); 384 385 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 386 return 0; 387 } 388 389 if (bytesRead == -1) { 390 sslEngine.closeInbound(); 391 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 392 return -1; 393 } 394 } 395 } 396 397 plain.clear(); 398 399 inputBuffer.flip(); 400 SSLEngineResult res; 401 do { 402 res = sslEngine.unwrap(inputBuffer, plain); 403 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 404 && res.bytesProduced() == 0); 405 406 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 407 finishHandshake(); 408 } 409 410 status = res.getStatus(); 411 handshakeStatus = res.getHandshakeStatus(); 412 413 // TODO deal with BUFFER_OVERFLOW 414 415 if (status == SSLEngineResult.Status.CLOSED) { 416 sslEngine.closeInbound(); 417 return -1; 418 } 419 420 inputBuffer.compact(); 421 plain.flip(); 422 423 return plain.remaining(); 424 } 425 426 protected void doHandshake() throws Exception { 427 handshakeInProgress = true; 428 Selector selector = null; 429 SelectionKey key = null; 430 boolean readable = true; 431 try { 432 while (true) { 433 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 434 switch (handshakeStatus) { 435 case NEED_UNWRAP: 436 if (readable) { 437 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 438 } 439 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 440 long now = System.currentTimeMillis(); 441 if (selector == null) { 442 selector = Selector.open(); 443 key = channel.register(selector, SelectionKey.OP_READ); 444 } else { 445 key.interestOps(SelectionKey.OP_READ); 446 } 447 int keyCount = selector.select(this.getSoTimeout()); 448 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 449 throw new SocketTimeoutException("Timeout during handshake"); 450 } 451 readable = key.isReadable(); 452 } 453 break; 454 case NEED_TASK: 455 Runnable task; 456 while ((task = sslEngine.getDelegatedTask()) != null) { 457 task.run(); 458 } 459 break; 460 case NEED_WRAP: 461 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 462 break; 463 case FINISHED: 464 case NOT_HANDSHAKING: 465 finishHandshake(); 466 return; 467 } 468 } 469 } finally { 470 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 471 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 472 } 473 } 474 475 @Override 476 protected void doStart() throws Exception { 477 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 478 // no need to init as we can delay that until demand (eg in doHandshake) 479 super.doStart(); 480 } 481 482 @Override 483 protected void doStop(ServiceStopper stopper) throws Exception { 484 initialized.countDown(); 485 486 if (taskRunnerFactory != null) { 487 taskRunnerFactory.shutdownNow(); 488 taskRunnerFactory = null; 489 } 490 if (channel != null) { 491 channel.close(); 492 channel = null; 493 } 494 super.doStop(stopper); 495 } 496 497 /** 498 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 499 * 500 * @param command 501 * The Command coming in. 502 */ 503 @Override 504 public void doConsume(Object command) { 505 if (command instanceof ConnectionInfo) { 506 ConnectionInfo connectionInfo = (ConnectionInfo) command; 507 connectionInfo.setTransportContext(getPeerCertificates()); 508 } 509 super.doConsume(command); 510 } 511 512 /** 513 * @return peer certificate chain associated with the ssl socket 514 */ 515 @Override 516 public X509Certificate[] getPeerCertificates() { 517 518 X509Certificate[] clientCertChain = null; 519 try { 520 if (sslEngine.getSession() != null) { 521 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 522 } 523 } catch (SSLPeerUnverifiedException e) { 524 if (LOG.isTraceEnabled()) { 525 LOG.trace("Failed to get peer certificates.", e); 526 } 527 } 528 529 return clientCertChain; 530 } 531 532 public boolean isNeedClientAuth() { 533 return needClientAuth; 534 } 535 536 public void setNeedClientAuth(boolean needClientAuth) { 537 this.needClientAuth = needClientAuth; 538 } 539 540 public boolean isWantClientAuth() { 541 return wantClientAuth; 542 } 543 544 public void setWantClientAuth(boolean wantClientAuth) { 545 this.wantClientAuth = wantClientAuth; 546 } 547 548 public String[] getEnabledCipherSuites() { 549 return enabledCipherSuites; 550 } 551 552 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 553 this.enabledCipherSuites = enabledCipherSuites; 554 } 555 556 public String[] getEnabledProtocols() { 557 return enabledProtocols; 558 } 559 560 public void setEnabledProtocols(String[] enabledProtocols) { 561 this.enabledProtocols = enabledProtocols; 562 } 563 564 public boolean isVerifyHostName() { 565 return verifyHostName; 566 } 567 568 public void setVerifyHostName(boolean verifyHostName) { 569 this.verifyHostName = verifyHostName; 570 } 571}