001/* 002 * $HeadURL: file:///opt/dev/not-yet-commons-ssl-SVN-repo/tags/commons-ssl-0.3.17/src/java/org/apache/commons/ssl/RMISocketFactoryImpl.java $ 003 * $Revision: 166 $ 004 * $Date: 2014-04-28 11:40:25 -0700 (Mon, 28 Apr 2014) $ 005 * 006 * ==================================================================== 007 * Licensed to the Apache Software Foundation (ASF) under one 008 * or more contributor license agreements. See the NOTICE file 009 * distributed with this work for additional information 010 * regarding copyright ownership. The ASF licenses this file 011 * to you under the Apache License, Version 2.0 (the 012 * "License"); you may not use this file except in compliance 013 * with the License. You may obtain a copy of the License at 014 * 015 * http://www.apache.org/licenses/LICENSE-2.0 016 * 017 * Unless required by applicable law or agreed to in writing, 018 * software distributed under the License is distributed on an 019 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 020 * KIND, either express or implied. See the License for the 021 * specific language governing permissions and limitations 022 * under the License. 023 * ==================================================================== 024 * 025 * This software consists of voluntary contributions made by many 026 * individuals on behalf of the Apache Software Foundation. For more 027 * information on the Apache Software Foundation, please see 028 * <http://www.apache.org/>. 029 * 030 */ 031 032package org.apache.commons.ssl; 033 034import javax.net.ServerSocketFactory; 035import javax.net.SocketFactory; 036import javax.net.ssl.SSLException; 037import javax.net.ssl.SSLPeerUnverifiedException; 038import javax.net.ssl.SSLProtocolException; 039import javax.net.ssl.SSLSocket; 040import java.io.EOFException; 041import java.io.IOException; 042import java.io.InterruptedIOException; 043import java.net.DatagramSocket; 044import java.net.InetAddress; 045import java.net.NetworkInterface; 046import java.net.ServerSocket; 047import java.net.Socket; 048import java.net.SocketException; 049import java.net.UnknownHostException; 050import java.rmi.server.RMISocketFactory; 051import java.security.GeneralSecurityException; 052import java.security.cert.X509Certificate; 053import java.util.Arrays; 054import java.util.Collections; 055import java.util.Enumeration; 056import java.util.HashMap; 057import java.util.Iterator; 058import java.util.LinkedList; 059import java.util.Map; 060import java.util.Set; 061import java.util.SortedSet; 062import java.util.TreeMap; 063import java.util.TreeSet; 064 065 066/** 067 * An RMISocketFactory ideal for using RMI over SSL. The server secures both 068 * the registry and the remote objects. The client assumes that either both 069 * the registry and the remote objects will use SSL, or both will use 070 * plain-socket. The client is able to auto detect plain-socket registries 071 * and downgrades itself to accomodate those. 072 * <p/> 073 * Unlike most existing RMI over SSL solutions in use (including Java 5's 074 * javax.rmi.ssl.SslRMIClientSocketFactory), this one does proper SSL hostname 075 * verification. From the client perspective this is straighforward. From 076 * the server perspective we introduce a clever trick: we perform an initial 077 * "hostname verification" by trying the current value of 078 * "java.rmi.server.hostname" against our server certificate. If the 079 * "java.rmi.server.hostname" System Property isn't set, we set it ourselves 080 * using the CN value we extract from our server certificate! (Some 081 * complications arise should a wildcard certificate show up, but we try our 082 * best to deal with those). 083 * <p/> 084 * An SSL server cannot be started without a private key. We have defined some 085 * default behaviour for trying to find a private key to use that we believe 086 * is convenient and sensible: 087 * <p/> 088 * If running from inside Tomcat, we try to re-use Tomcat's private key and 089 * certificate chain (assuming Tomcat-SSL on port 8443 is enabled). If this 090 * isn't available, we look for the "javax.net.ssl.keyStore" System property. 091 * Finally, if that isn't available, we look for "~/.keystore" and assume 092 * a password of "changeit". 093 * <p/> 094 * If after all these attempts we still failed to find a private key, the 095 * RMISocketFactoryImpl() constructor will throw an SSLException. 096 * 097 * @author Credit Union Central of British Columbia 098 * @author <a href="http://www.cucbc.com/">www.cucbc.com</a> 099 * @author <a href="mailto:juliusdavies@cucbc.com">juliusdavies@cucbc.com</a> 100 * @since 22-Apr-2005 101 */ 102public class RMISocketFactoryImpl extends RMISocketFactory { 103 public final static String RMI_HOSTNAME_KEY = "java.rmi.server.hostname"; 104 private final static LogWrapper log = LogWrapper.getLogger(RMISocketFactoryImpl.class); 105 106 private volatile SocketFactory defaultClient; 107 private volatile ServerSocketFactory sslServer; 108 private volatile String localBindAddress = null; 109 private volatile int anonymousPort = 31099; 110 private Map clientMap = new TreeMap(); 111 private Map serverSockets = new HashMap(); 112 private final SocketFactory plainClient = SocketFactory.getDefault(); 113 114 public RMISocketFactoryImpl() throws GeneralSecurityException, IOException { 115 this(true); 116 } 117 118 /** 119 * @param createDefaultServer If false, then we only set the default 120 * client, and the default server is set to null. 121 * If true, then a default server is also created. 122 * @throws GeneralSecurityException bad things 123 * @throws IOException bad things 124 */ 125 public RMISocketFactoryImpl(boolean createDefaultServer) 126 throws GeneralSecurityException, IOException { 127 SSLServer defaultServer = createDefaultServer ? new SSLServer() : null; 128 SSLClient defaultClient = new SSLClient(); 129 130 // RMI calls to localhost will not check that host matches CN in 131 // certificate. Hopefully this is acceptable. (The registry server 132 // will followup the registry lookup with the proper DNS name to get 133 // the remote object, anyway). 134 HostnameVerifier verifier = HostnameVerifier.DEFAULT_AND_LOCALHOST; 135 defaultClient.setHostnameVerifier(verifier); 136 if (defaultServer != null) { 137 defaultServer.setHostnameVerifier(verifier); 138 // The RMI server will try to re-use Tomcat's "port 8443" SSL 139 // Certificate if possible. 140 defaultServer.useTomcatSSLMaterial(); 141 X509Certificate[] x509 = defaultServer.getAssociatedCertificateChain(); 142 if (x509 == null || x509.length < 1) { 143 throw new SSLException("Cannot initialize RMI-SSL Server: no KeyMaterial!"); 144 } 145 setServer(defaultServer); 146 } 147 setDefaultClient(defaultClient); 148 } 149 150 public void setServer(ServerSocketFactory f) 151 throws GeneralSecurityException, IOException { 152 this.sslServer = f; 153 if (f instanceof SSLServer) { 154 final HostnameVerifier VERIFIER; 155 VERIFIER = HostnameVerifier.DEFAULT_AND_LOCALHOST; 156 157 final SSLServer ssl = (SSLServer) f; 158 final X509Certificate[] chain = ssl.getAssociatedCertificateChain(); 159 String[] cns = Certificates.getCNs(chain[0]); 160 String[] subjectAlts = Certificates.getDNSSubjectAlts(chain[0]); 161 LinkedList names = new LinkedList(); 162 if (cns != null && cns.length > 0) { 163 // Only first CN is used. Not going to get into the IE6 nonsense 164 // where all CN values are used. 165 names.add(cns[0]); 166 } 167 if (subjectAlts != null && subjectAlts.length > 0) { 168 names.addAll(Arrays.asList(subjectAlts)); 169 } 170 171 String rmiHostName = System.getProperty(RMI_HOSTNAME_KEY); 172 // If "java.rmi.server.hostname" is already set, don't mess with it. 173 // But blowup if it's not going to work with our SSL Server 174 // Certificate! 175 if (rmiHostName != null) { 176 try { 177 VERIFIER.check(rmiHostName, cns, subjectAlts); 178 } 179 catch (SSLException ssle) { 180 String s = ssle.toString(); 181 throw new SSLException(RMI_HOSTNAME_KEY + " of " + rmiHostName + " conflicts with SSL Server Certificate: " + s); 182 } 183 } else { 184 // If SSL Cert only contains one non-wild name, just use that and 185 // hope for the best. 186 boolean hopingForBest = false; 187 if (names.size() == 1) { 188 String name = (String) names.get(0); 189 if (!name.startsWith("*")) { 190 System.setProperty(RMI_HOSTNAME_KEY, name); 191 log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found in my SSL Server Certificate."); 192 hopingForBest = true; 193 } 194 } 195 if (!hopingForBest) { 196 // Help me, Obi-Wan Kenobi; you're my only hope. All we can 197 // do now is grab our internet-facing addresses, reverse-lookup 198 // on them, and hope that one of them validates against our 199 // server cert. 200 Set s = getMyInternetFacingIPs(); 201 Iterator it = s.iterator(); 202 while (it.hasNext()) { 203 String name = (String) it.next(); 204 try { 205 VERIFIER.check(name, cns, subjectAlts); 206 System.setProperty(RMI_HOSTNAME_KEY, name); 207 log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found by reverse-dns against my own IP."); 208 hopingForBest = true; 209 break; 210 } 211 catch (SSLException ssle) { 212 // next! 213 } 214 } 215 } 216 if (!hopingForBest) { 217 throw new SSLException("'" + RMI_HOSTNAME_KEY + "' not present. Must work with my SSL Server Certificate's CN field: " + names); 218 } 219 } 220 } 221 trustOurself(); 222 } 223 224 public void setLocalBindAddress(String localBindAddress) { 225 this.localBindAddress = localBindAddress; 226 } 227 228 public void setAnonymousPort(int port) { 229 this.anonymousPort = port; 230 } 231 232 public void setDefaultClient(SocketFactory f) 233 throws GeneralSecurityException, IOException { 234 this.defaultClient = f; 235 trustOurself(); 236 } 237 238 public void setClient(String host, SocketFactory f) 239 throws GeneralSecurityException, IOException { 240 if (f != null && sslServer != null) { 241 boolean clientIsCommonsSSL = f instanceof SSLClient; 242 boolean serverIsCommonsSSL = sslServer instanceof SSLServer; 243 if (clientIsCommonsSSL && serverIsCommonsSSL) { 244 SSLClient c = (SSLClient) f; 245 SSLServer s = (SSLServer) sslServer; 246 trustEachOther(c, s); 247 } 248 } 249 Set names = hostnamePossibilities(host); 250 Iterator it = names.iterator(); 251 synchronized (this) { 252 while (it.hasNext()) { 253 clientMap.put(it.next(), f); 254 } 255 } 256 } 257 258 public void removeClient(String host) { 259 Set names = hostnamePossibilities(host); 260 Iterator it = names.iterator(); 261 synchronized (this) { 262 while (it.hasNext()) { 263 clientMap.remove(it.next()); 264 } 265 } 266 } 267 268 public synchronized void removeClient(SocketFactory sf) { 269 Iterator it = clientMap.entrySet().iterator(); 270 while (it.hasNext()) { 271 Map.Entry entry = (Map.Entry) it.next(); 272 Object o = entry.getValue(); 273 if (sf.equals(o)) { 274 it.remove(); 275 } 276 } 277 } 278 279 private Set hostnamePossibilities(String host) { 280 host = host != null ? host.toLowerCase().trim() : ""; 281 if ("".equals(host)) { 282 return Collections.EMPTY_SET; 283 } 284 TreeSet names = new TreeSet(); 285 names.add(host); 286 InetAddress[] addresses; 287 try { 288 // If they gave us "hostname.com", this will give us the various 289 // IP addresses: 290 addresses = InetAddress.getAllByName(host); 291 for (int i = 0; i < addresses.length; i++) { 292 String name1 = addresses[i].getHostName(); 293 String name2 = addresses[i].getHostAddress(); 294 names.add(name1.trim().toLowerCase()); 295 names.add(name2.trim().toLowerCase()); 296 } 297 } 298 catch (UnknownHostException uhe) { 299 /* oh well, nothing found, nothing to add for this client */ 300 } 301 302 try { 303 host = InetAddress.getByName(host).getHostAddress(); 304 305 // If they gave us "1.2.3.4", this will hopefully give us 306 // "hostname.com" so that we can then try and find any other 307 // IP addresses associated with that name. 308 host = InetAddress.getByName(host).getHostName(); 309 names.add(host.trim().toLowerCase()); 310 addresses = InetAddress.getAllByName(host); 311 for (int i = 0; i < addresses.length; i++) { 312 String name1 = addresses[i].getHostName(); 313 String name2 = addresses[i].getHostAddress(); 314 names.add(name1.trim().toLowerCase()); 315 names.add(name2.trim().toLowerCase()); 316 } 317 } 318 catch (UnknownHostException uhe) { 319 /* oh well, nothing found, nothing to add for this client */ 320 } 321 return names; 322 } 323 324 private void trustOurself() 325 throws GeneralSecurityException, IOException { 326 if (defaultClient == null || sslServer == null) { 327 return; 328 } 329 boolean clientIsCommonsSSL = defaultClient instanceof SSLClient; 330 boolean serverIsCommonsSSL = sslServer instanceof SSLServer; 331 if (clientIsCommonsSSL && serverIsCommonsSSL) { 332 SSLClient c = (SSLClient) defaultClient; 333 SSLServer s = (SSLServer) sslServer; 334 trustEachOther(c, s); 335 } 336 } 337 338 private void trustEachOther(SSLClient client, SSLServer server) 339 throws GeneralSecurityException, IOException { 340 if (client != null && server != null) { 341 // Our own client should trust our own server. 342 X509Certificate[] certs = server.getAssociatedCertificateChain(); 343 if (certs != null && certs[0] != null) { 344 TrustMaterial tm = new TrustMaterial(certs[0]); 345 client.addTrustMaterial(tm); 346 } 347 348 // Our own server should trust our own client. 349 certs = client.getAssociatedCertificateChain(); 350 if (certs != null && certs[0] != null) { 351 TrustMaterial tm = new TrustMaterial(certs[0]); 352 server.addTrustMaterial(tm); 353 } 354 } 355 } 356 357 public ServerSocketFactory getServer() { return sslServer; } 358 359 public SocketFactory getDefaultClient() { return defaultClient; } 360 361 public synchronized SocketFactory getClient(String host) { 362 host = host != null ? host.trim().toLowerCase() : ""; 363 return (SocketFactory) clientMap.get(host); 364 } 365 366 public synchronized ServerSocket createServerSocket(int port) 367 throws IOException { 368 // Re-use existing ServerSocket if possible. 369 if (port == 0) { 370 port = anonymousPort; 371 } 372 Integer key = new Integer(port); 373 ServerSocket ss = (ServerSocket) serverSockets.get(key); 374 if (ss == null || ss.isClosed()) { 375 if (ss != null && ss.isClosed()) { 376 System.out.println("found closed server on port: " + port); 377 } 378 log.debug("commons-ssl RMI server-socket: listening on port " + port); 379 ss = sslServer.createServerSocket(port); 380 serverSockets.put(key, ss); 381 } 382 return ss; 383 } 384 385 public Socket createSocket(String host, int port) 386 throws IOException { 387 host = host != null ? host.trim().toLowerCase() : ""; 388 InetAddress local = null; 389 String bindAddress = localBindAddress; 390 if (bindAddress == null) { 391 bindAddress = System.getProperty(RMI_HOSTNAME_KEY); 392 if (bindAddress != null) { 393 local = InetAddress.getByName(bindAddress); 394 if (!local.isLoopbackAddress()) { 395 String ip = local.getHostAddress(); 396 Set myInternetIps = getMyInternetFacingIPs(); 397 if (!myInternetIps.contains(ip)) { 398 log.warn("Cannot bind to " + ip + " since it doesn't exist on this machine."); 399 // Not going to be able to bind as this. Our RMI_HOSTNAME_KEY 400 // must be set to some kind of proxy in front of us. So we 401 // still want to use it, but we can't bind to it. 402 local = null; 403 bindAddress = null; 404 } 405 } 406 } 407 } 408 if (bindAddress == null) { 409 // Our last resort - let's make sure we at least use something that's 410 // internet facing! 411 bindAddress = getMyDefaultIP(); 412 } 413 if (local == null && bindAddress != null) { 414 local = InetAddress.getByName(bindAddress); 415 localBindAddress = local.getHostName(); 416 } 417 418 SocketFactory sf; 419 synchronized (this) { 420 sf = (SocketFactory) clientMap.get(host); 421 } 422 if (sf == null) { 423 sf = defaultClient; 424 } 425 426 Socket s = null; 427 SSLSocket ssl = null; 428 int soTimeout = Integer.MIN_VALUE; 429 IOException reasonForPlainSocket = null; 430 boolean tryPlain = false; 431 try { 432 s = sf.createSocket(host, port, local, 0); 433 soTimeout = s.getSoTimeout(); 434 if (!(s instanceof SSLSocket)) { 435 // Someone called setClient() or setDefaultClient() and passed in 436 // a plain socket factory. Okay, nothing to see, move along. 437 return s; 438 } else { 439 ssl = (SSLSocket) s; 440 } 441 442 // If we don't get the peer certs in 15 seconds, revert to plain 443 // socket. 444 ssl.setSoTimeout(15000); 445 ssl.getSession().getPeerCertificates(); 446 447 // Everything worked out okay, so go back to original soTimeout. 448 ssl.setSoTimeout(soTimeout); 449 return ssl; 450 } 451 catch (IOException ioe) { 452 // SSL didn't work. Let's analyze the IOException to see if maybe 453 // we're accidentally attempting to talk to a plain-socket RMI 454 // server. 455 Throwable t = ioe; 456 while (!tryPlain && t != null) { 457 tryPlain = tryPlain || t instanceof EOFException; 458 tryPlain = tryPlain || t instanceof InterruptedIOException; 459 tryPlain = tryPlain || t instanceof SSLProtocolException; 460 t = t.getCause(); 461 } 462 if (!tryPlain && ioe instanceof SSLPeerUnverifiedException) { 463 try { 464 if (ssl != null) { 465 ssl.startHandshake(); 466 } 467 } 468 catch (IOException ioe2) { 469 // Stacktrace from startHandshake() will be more descriptive 470 // then the one we got from getPeerCertificates(). 471 ioe = ioe2; 472 t = ioe2; 473 while (!tryPlain && t != null) { 474 tryPlain = tryPlain || t instanceof EOFException; 475 tryPlain = tryPlain || t instanceof InterruptedIOException; 476 tryPlain = tryPlain || t instanceof SSLProtocolException; 477 t = t.getCause(); 478 } 479 } 480 } 481 if (!tryPlain) { 482 log.debug("commons-ssl RMI-SSL failed: " + ioe); 483 throw ioe; 484 } else { 485 reasonForPlainSocket = ioe; 486 } 487 } 488 finally { 489 // Some debug logging: 490 boolean isPlain = tryPlain || (s != null && ssl == null); 491 String socket = isPlain ? "RMI plain-socket " : "RMI ssl-socket "; 492 String localIP = local != null ? local.getHostAddress() : "ANY"; 493 StringBuffer buf = new StringBuffer(64); 494 buf.append(socket); 495 buf.append(localIP); 496 buf.append(" --> "); 497 buf.append(host); 498 buf.append(":"); 499 buf.append(port); 500 log.debug(buf.toString()); 501 } 502 503 // SSL didn't work. Remote server either timed out, or sent EOF, or 504 // there was some kind of SSLProtocolException. (Any other problem 505 // would have caused an IOException to be thrown, so execution wouldn't 506 // have made it this far). Maybe plain socket will work in these three 507 // cases. 508 sf = plainClient; 509 s = JavaImpl.connect(null, sf, host, port, local, 0, 15000, null); 510 if (soTimeout != Integer.MIN_VALUE) { 511 s.setSoTimeout(soTimeout); 512 } 513 514 try { 515 // Plain socket worked! Let's remember that for next time an RMI call 516 // against this host happens. 517 setClient(host, plainClient); 518 String msg = "RMI downgrading from SSL to plain-socket for " + host + " because of " + reasonForPlainSocket; 519 log.warn(msg, reasonForPlainSocket); 520 } 521 catch (GeneralSecurityException gse) { 522 throw new RuntimeException("can't happen because we're using plain socket", gse); 523 // won't happen because we're using plain socket, not SSL. 524 } 525 526 return s; 527 } 528 529 530 public static String getMyDefaultIP() { 531 String anInternetIP = "64.111.122.211"; 532 String ip = null; 533 try { 534 DatagramSocket dg = new DatagramSocket(); 535 dg.setSoTimeout(250); 536 // 64.111.122.211 is juliusdavies.ca. 537 // This code doesn't actually send any packets (so no firewalls can 538 // get in the way). It's just a neat trick for getting our 539 // internet-facing interface card. 540 InetAddress addr = Util.toInetAddress(anInternetIP); 541 dg.connect(addr, 12345); 542 InetAddress localAddr = dg.getLocalAddress(); 543 ip = localAddr.getHostAddress(); 544 // log.debug( "Using bogus UDP socket (" + anInternetIP + ":12345), I think my IP address is: " + ip ); 545 dg.close(); 546 if (localAddr.isLoopbackAddress() || "0.0.0.0".equals(ip)) { 547 ip = null; 548 } 549 } 550 catch (IOException ioe) { 551 log.debug("Bogus UDP didn't work: " + ioe); 552 } 553 return ip; 554 } 555 556 public static SortedSet getMyInternetFacingIPs() throws SocketException { 557 TreeSet set = new TreeSet(); 558 Enumeration en = NetworkInterface.getNetworkInterfaces(); 559 while (en.hasMoreElements()) { 560 NetworkInterface ni = (NetworkInterface) en.nextElement(); 561 Enumeration en2 = ni.getInetAddresses(); 562 while (en2.hasMoreElements()) { 563 InetAddress addr = (InetAddress) en2.nextElement(); 564 if (!addr.isLoopbackAddress()) { 565 String ip = addr.getHostAddress(); 566 String reverse = addr.getHostName(); 567 // IP: 568 set.add(ip); 569 // Reverse-Lookup: 570 set.add(reverse); 571 572 } 573 } 574 } 575 return set; 576 } 577 578}