001/*
002 * $HeadURL: file:///opt/dev/not-yet-commons-ssl-SVN-repo/tags/commons-ssl-0.3.17/src/java/org/apache/commons/ssl/RMISocketFactoryImpl.java $
003 * $Revision: 166 $
004 * $Date: 2014-04-28 11:40:25 -0700 (Mon, 28 Apr 2014) $
005 *
006 * ====================================================================
007 * Licensed to the Apache Software Foundation (ASF) under one
008 * or more contributor license agreements.  See the NOTICE file
009 * distributed with this work for additional information
010 * regarding copyright ownership.  The ASF licenses this file
011 * to you under the Apache License, Version 2.0 (the
012 * "License"); you may not use this file except in compliance
013 * with the License.  You may obtain a copy of the License at
014 *
015 *   http://www.apache.org/licenses/LICENSE-2.0
016 *
017 * Unless required by applicable law or agreed to in writing,
018 * software distributed under the License is distributed on an
019 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
020 * KIND, either express or implied.  See the License for the
021 * specific language governing permissions and limitations
022 * under the License.
023 * ====================================================================
024 *
025 * This software consists of voluntary contributions made by many
026 * individuals on behalf of the Apache Software Foundation.  For more
027 * information on the Apache Software Foundation, please see
028 * <http://www.apache.org/>.
029 *
030 */
031
032package org.apache.commons.ssl;
033
034import javax.net.ServerSocketFactory;
035import javax.net.SocketFactory;
036import javax.net.ssl.SSLException;
037import javax.net.ssl.SSLPeerUnverifiedException;
038import javax.net.ssl.SSLProtocolException;
039import javax.net.ssl.SSLSocket;
040import java.io.EOFException;
041import java.io.IOException;
042import java.io.InterruptedIOException;
043import java.net.DatagramSocket;
044import java.net.InetAddress;
045import java.net.NetworkInterface;
046import java.net.ServerSocket;
047import java.net.Socket;
048import java.net.SocketException;
049import java.net.UnknownHostException;
050import java.rmi.server.RMISocketFactory;
051import java.security.GeneralSecurityException;
052import java.security.cert.X509Certificate;
053import java.util.Arrays;
054import java.util.Collections;
055import java.util.Enumeration;
056import java.util.HashMap;
057import java.util.Iterator;
058import java.util.LinkedList;
059import java.util.Map;
060import java.util.Set;
061import java.util.SortedSet;
062import java.util.TreeMap;
063import java.util.TreeSet;
064
065
066/**
067 * An RMISocketFactory ideal for using RMI over SSL.  The server secures both
068 * the registry and the remote objects.  The client assumes that either both
069 * the registry and the remote objects will use SSL, or both will use
070 * plain-socket.  The client is able to auto detect plain-socket registries
071 * and downgrades itself to accomodate those.
072 * <p/>
073 * Unlike most existing RMI over SSL solutions in use (including Java 5's
074 * javax.rmi.ssl.SslRMIClientSocketFactory), this one does proper SSL hostname
075 * verification.  From the client perspective this is straighforward.  From
076 * the server perspective we introduce a clever trick:  we perform an initial
077 * "hostname verification" by trying the current value of
078 * "java.rmi.server.hostname" against our server certificate.  If the
079 * "java.rmi.server.hostname" System Property isn't set, we set it ourselves
080 * using the CN value we extract from our server certificate!  (Some
081 * complications arise should a wildcard certificate show up, but we try our
082 * best to deal with those).
083 * <p/>
084 * An SSL server cannot be started without a private key.  We have defined some
085 * default behaviour for trying to find a private key to use that we believe
086 * is convenient and sensible:
087 * <p/>
088 * If running from inside Tomcat, we try to re-use Tomcat's private key and
089 * certificate chain (assuming Tomcat-SSL on port 8443 is enabled).  If this
090 * isn't available, we look for the "javax.net.ssl.keyStore" System property.
091 * Finally, if that isn't available, we look for "~/.keystore" and assume
092 * a password of "changeit".
093 * <p/>
094 * If after all these attempts we still failed to find a private key, the
095 * RMISocketFactoryImpl() constructor will throw an SSLException.
096 *
097 * @author Credit Union Central of British Columbia
098 * @author <a href="http://www.cucbc.com/">www.cucbc.com</a>
099 * @author <a href="mailto:juliusdavies@cucbc.com">juliusdavies@cucbc.com</a>
100 * @since 22-Apr-2005
101 */
102public class RMISocketFactoryImpl extends RMISocketFactory {
103    public final static String RMI_HOSTNAME_KEY = "java.rmi.server.hostname";
104    private final static LogWrapper log = LogWrapper.getLogger(RMISocketFactoryImpl.class);
105
106    private volatile SocketFactory defaultClient;
107    private volatile ServerSocketFactory sslServer;
108    private volatile String localBindAddress = null;
109    private volatile int anonymousPort = 31099;
110    private Map clientMap = new TreeMap();
111    private Map serverSockets = new HashMap();
112    private final SocketFactory plainClient = SocketFactory.getDefault();
113
114    public RMISocketFactoryImpl() throws GeneralSecurityException, IOException {
115        this(true);
116    }
117
118    /**
119     * @param createDefaultServer If false, then we only set the default
120     *                            client, and the default server is set to null.
121     *                            If true, then a default server is also created.
122     * @throws GeneralSecurityException bad things
123     * @throws IOException              bad things
124     */
125    public RMISocketFactoryImpl(boolean createDefaultServer)
126        throws GeneralSecurityException, IOException {
127        SSLServer defaultServer = createDefaultServer ? new SSLServer() : null;
128        SSLClient defaultClient = new SSLClient();
129
130        // RMI calls to localhost will not check that host matches CN in
131        // certificate.  Hopefully this is acceptable.  (The registry server
132        // will followup the registry lookup with the proper DNS name to get
133        // the remote object, anyway).
134        HostnameVerifier verifier = HostnameVerifier.DEFAULT_AND_LOCALHOST;
135        defaultClient.setHostnameVerifier(verifier);
136        if (defaultServer != null) {
137            defaultServer.setHostnameVerifier(verifier);
138            // The RMI server will try to re-use Tomcat's "port 8443" SSL
139            // Certificate if possible.
140            defaultServer.useTomcatSSLMaterial();
141            X509Certificate[] x509 = defaultServer.getAssociatedCertificateChain();
142            if (x509 == null || x509.length < 1) {
143                throw new SSLException("Cannot initialize RMI-SSL Server: no KeyMaterial!");
144            }
145            setServer(defaultServer);
146        }
147        setDefaultClient(defaultClient);
148    }
149
150    public void setServer(ServerSocketFactory f)
151        throws GeneralSecurityException, IOException {
152        this.sslServer = f;
153        if (f instanceof SSLServer) {
154            final HostnameVerifier VERIFIER;
155            VERIFIER = HostnameVerifier.DEFAULT_AND_LOCALHOST;
156
157            final SSLServer ssl = (SSLServer) f;
158            final X509Certificate[] chain = ssl.getAssociatedCertificateChain();
159            String[] cns = Certificates.getCNs(chain[0]);
160            String[] subjectAlts = Certificates.getDNSSubjectAlts(chain[0]);
161            LinkedList names = new LinkedList();
162            if (cns != null && cns.length > 0) {
163                // Only first CN is used.  Not going to get into the IE6 nonsense
164                // where all CN values are used.
165                names.add(cns[0]);
166            }
167            if (subjectAlts != null && subjectAlts.length > 0) {
168                names.addAll(Arrays.asList(subjectAlts));
169            }
170
171            String rmiHostName = System.getProperty(RMI_HOSTNAME_KEY);
172            // If "java.rmi.server.hostname" is already set, don't mess with it.
173            // But blowup if it's not going to work with our SSL Server
174            // Certificate!
175            if (rmiHostName != null) {
176                try {
177                    VERIFIER.check(rmiHostName, cns, subjectAlts);
178                }
179                catch (SSLException ssle) {
180                    String s = ssle.toString();
181                    throw new SSLException(RMI_HOSTNAME_KEY + " of " + rmiHostName + " conflicts with SSL Server Certificate: " + s);
182                }
183            } else {
184                // If SSL Cert only contains one non-wild name, just use that and
185                // hope for the best.
186                boolean hopingForBest = false;
187                if (names.size() == 1) {
188                    String name = (String) names.get(0);
189                    if (!name.startsWith("*")) {
190                        System.setProperty(RMI_HOSTNAME_KEY, name);
191                        log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found in my SSL Server Certificate.");
192                        hopingForBest = true;
193                    }
194                }
195                if (!hopingForBest) {
196                    // Help me, Obi-Wan Kenobi; you're my only hope.  All we can
197                    // do now is grab our internet-facing addresses, reverse-lookup
198                    // on them, and hope that one of them validates against our
199                    // server cert.
200                    Set s = getMyInternetFacingIPs();
201                    Iterator it = s.iterator();
202                    while (it.hasNext()) {
203                        String name = (String) it.next();
204                        try {
205                            VERIFIER.check(name, cns, subjectAlts);
206                            System.setProperty(RMI_HOSTNAME_KEY, name);
207                            log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found by reverse-dns against my own IP.");
208                            hopingForBest = true;
209                            break;
210                        }
211                        catch (SSLException ssle) {
212                            // next!
213                        }
214                    }
215                }
216                if (!hopingForBest) {
217                    throw new SSLException("'" + RMI_HOSTNAME_KEY + "' not present.  Must work with my SSL Server Certificate's CN field: " + names);
218                }
219            }
220        }
221        trustOurself();
222    }
223
224    public void setLocalBindAddress(String localBindAddress) {
225        this.localBindAddress = localBindAddress;
226    }
227
228    public void setAnonymousPort(int port) {
229        this.anonymousPort = port;
230    }
231
232    public void setDefaultClient(SocketFactory f)
233        throws GeneralSecurityException, IOException {
234        this.defaultClient = f;
235        trustOurself();
236    }
237
238    public void setClient(String host, SocketFactory f)
239        throws GeneralSecurityException, IOException {
240        if (f != null && sslServer != null) {
241            boolean clientIsCommonsSSL = f instanceof SSLClient;
242            boolean serverIsCommonsSSL = sslServer instanceof SSLServer;
243            if (clientIsCommonsSSL && serverIsCommonsSSL) {
244                SSLClient c = (SSLClient) f;
245                SSLServer s = (SSLServer) sslServer;
246                trustEachOther(c, s);
247            }
248        }
249        Set names = hostnamePossibilities(host);
250        Iterator it = names.iterator();
251        synchronized (this) {
252            while (it.hasNext()) {
253                clientMap.put(it.next(), f);
254            }
255        }
256    }
257
258    public void removeClient(String host) {
259        Set names = hostnamePossibilities(host);
260        Iterator it = names.iterator();
261        synchronized (this) {
262            while (it.hasNext()) {
263                clientMap.remove(it.next());
264            }
265        }
266    }
267
268    public synchronized void removeClient(SocketFactory sf) {
269        Iterator it = clientMap.entrySet().iterator();
270        while (it.hasNext()) {
271            Map.Entry entry = (Map.Entry) it.next();
272            Object o = entry.getValue();
273            if (sf.equals(o)) {
274                it.remove();
275            }
276        }
277    }
278
279    private Set hostnamePossibilities(String host) {
280        host = host != null ? host.toLowerCase().trim() : "";
281        if ("".equals(host)) {
282            return Collections.EMPTY_SET;
283        }
284        TreeSet names = new TreeSet();
285        names.add(host);
286        InetAddress[] addresses;
287        try {
288            // If they gave us "hostname.com", this will give us the various
289            // IP addresses:
290            addresses = InetAddress.getAllByName(host);
291            for (int i = 0; i < addresses.length; i++) {
292                String name1 = addresses[i].getHostName();
293                String name2 = addresses[i].getHostAddress();
294                names.add(name1.trim().toLowerCase());
295                names.add(name2.trim().toLowerCase());
296            }
297        }
298        catch (UnknownHostException uhe) {
299            /* oh well, nothing found, nothing to add for this client */
300        }
301
302        try {
303            host = InetAddress.getByName(host).getHostAddress();
304
305            // If they gave us "1.2.3.4", this will hopefully give us
306            // "hostname.com" so that we can then try and find any other
307            // IP addresses associated with that name.
308            host = InetAddress.getByName(host).getHostName();
309            names.add(host.trim().toLowerCase());
310            addresses = InetAddress.getAllByName(host);
311            for (int i = 0; i < addresses.length; i++) {
312                String name1 = addresses[i].getHostName();
313                String name2 = addresses[i].getHostAddress();
314                names.add(name1.trim().toLowerCase());
315                names.add(name2.trim().toLowerCase());
316            }
317        }
318        catch (UnknownHostException uhe) {
319            /* oh well, nothing found, nothing to add for this client */
320        }
321        return names;
322    }
323
324    private void trustOurself()
325        throws GeneralSecurityException, IOException {
326        if (defaultClient == null || sslServer == null) {
327            return;
328        }
329        boolean clientIsCommonsSSL = defaultClient instanceof SSLClient;
330        boolean serverIsCommonsSSL = sslServer instanceof SSLServer;
331        if (clientIsCommonsSSL && serverIsCommonsSSL) {
332            SSLClient c = (SSLClient) defaultClient;
333            SSLServer s = (SSLServer) sslServer;
334            trustEachOther(c, s);
335        }
336    }
337
338    private void trustEachOther(SSLClient client, SSLServer server)
339        throws GeneralSecurityException, IOException {
340        if (client != null && server != null) {
341            // Our own client should trust our own server.
342            X509Certificate[] certs = server.getAssociatedCertificateChain();
343            if (certs != null && certs[0] != null) {
344                TrustMaterial tm = new TrustMaterial(certs[0]);
345                client.addTrustMaterial(tm);
346            }
347
348            // Our own server should trust our own client.
349            certs = client.getAssociatedCertificateChain();
350            if (certs != null && certs[0] != null) {
351                TrustMaterial tm = new TrustMaterial(certs[0]);
352                server.addTrustMaterial(tm);
353            }
354        }
355    }
356
357    public ServerSocketFactory getServer() { return sslServer; }
358
359    public SocketFactory getDefaultClient() { return defaultClient; }
360
361    public synchronized SocketFactory getClient(String host) {
362        host = host != null ? host.trim().toLowerCase() : "";
363        return (SocketFactory) clientMap.get(host);
364    }
365
366    public synchronized ServerSocket createServerSocket(int port)
367        throws IOException {
368        // Re-use existing ServerSocket if possible.
369        if (port == 0) {
370            port = anonymousPort;
371        }
372        Integer key = new Integer(port);
373        ServerSocket ss = (ServerSocket) serverSockets.get(key);
374        if (ss == null || ss.isClosed()) {
375            if (ss != null && ss.isClosed()) {
376                System.out.println("found closed server on port: " + port);
377            }
378            log.debug("commons-ssl RMI server-socket: listening on port " + port);
379            ss = sslServer.createServerSocket(port);
380            serverSockets.put(key, ss);
381        }
382        return ss;
383    }
384
385    public Socket createSocket(String host, int port)
386        throws IOException {
387        host = host != null ? host.trim().toLowerCase() : "";
388        InetAddress local = null;
389        String bindAddress = localBindAddress;
390        if (bindAddress == null) {
391            bindAddress = System.getProperty(RMI_HOSTNAME_KEY);
392            if (bindAddress != null) {
393                local = InetAddress.getByName(bindAddress);
394                if (!local.isLoopbackAddress()) {
395                    String ip = local.getHostAddress();
396                    Set myInternetIps = getMyInternetFacingIPs();
397                    if (!myInternetIps.contains(ip)) {
398                        log.warn("Cannot bind to " + ip + " since it doesn't exist on this machine.");
399                        // Not going to be able to bind as this.  Our RMI_HOSTNAME_KEY
400                        // must be set to some kind of proxy in front of us.  So we
401                        // still want to use it, but we can't bind to it.
402                        local = null;
403                        bindAddress = null;
404                    }
405                }
406            }
407        }
408        if (bindAddress == null) {
409            // Our last resort - let's make sure we at least use something that's
410            // internet facing!
411            bindAddress = getMyDefaultIP();
412        }
413        if (local == null && bindAddress != null) {
414            local = InetAddress.getByName(bindAddress);
415            localBindAddress = local.getHostName();
416        }
417
418        SocketFactory sf;
419        synchronized (this) {
420            sf = (SocketFactory) clientMap.get(host);
421        }
422        if (sf == null) {
423            sf = defaultClient;
424        }
425
426        Socket s = null;
427        SSLSocket ssl = null;
428        int soTimeout = Integer.MIN_VALUE;
429        IOException reasonForPlainSocket = null;
430        boolean tryPlain = false;
431        try {
432            s = sf.createSocket(host, port, local, 0);
433            soTimeout = s.getSoTimeout();
434            if (!(s instanceof SSLSocket)) {
435                // Someone called setClient() or setDefaultClient() and passed in
436                // a plain socket factory.  Okay, nothing to see, move along.
437                return s;
438            } else {
439                ssl = (SSLSocket) s;
440            }
441
442            // If we don't get the peer certs in 15 seconds, revert to plain
443            // socket.
444            ssl.setSoTimeout(15000);
445            ssl.getSession().getPeerCertificates();
446
447            // Everything worked out okay, so go back to original soTimeout.
448            ssl.setSoTimeout(soTimeout);
449            return ssl;
450        }
451        catch (IOException ioe) {
452            // SSL didn't work.  Let's analyze the IOException to see if maybe
453            // we're accidentally attempting to talk to a plain-socket RMI
454            // server.
455            Throwable t = ioe;
456            while (!tryPlain && t != null) {
457                tryPlain = tryPlain || t instanceof EOFException;
458                tryPlain = tryPlain || t instanceof InterruptedIOException;
459                tryPlain = tryPlain || t instanceof SSLProtocolException;
460                t = t.getCause();
461            }
462            if (!tryPlain && ioe instanceof SSLPeerUnverifiedException) {
463                try {
464                    if (ssl != null) {
465                        ssl.startHandshake();
466                    }
467                }
468                catch (IOException ioe2) {
469                    // Stacktrace from startHandshake() will be more descriptive
470                    // then the one we got from getPeerCertificates().
471                    ioe = ioe2;
472                    t = ioe2;
473                    while (!tryPlain && t != null) {
474                        tryPlain = tryPlain || t instanceof EOFException;
475                        tryPlain = tryPlain || t instanceof InterruptedIOException;
476                        tryPlain = tryPlain || t instanceof SSLProtocolException;
477                        t = t.getCause();
478                    }
479                }
480            }
481            if (!tryPlain) {
482                log.debug("commons-ssl RMI-SSL failed: " + ioe);
483                throw ioe;
484            } else {
485                reasonForPlainSocket = ioe;
486            }
487        }
488        finally {
489            // Some debug logging:
490            boolean isPlain = tryPlain || (s != null && ssl == null);
491            String socket = isPlain ? "RMI plain-socket " : "RMI ssl-socket ";
492            String localIP = local != null ? local.getHostAddress() : "ANY";
493            StringBuffer buf = new StringBuffer(64);
494            buf.append(socket);
495            buf.append(localIP);
496            buf.append(" --> ");
497            buf.append(host);
498            buf.append(":");
499            buf.append(port);
500            log.debug(buf.toString());
501        }
502
503        // SSL didn't work.  Remote server either timed out, or sent EOF, or
504        // there was some kind of SSLProtocolException.  (Any other problem
505        // would have caused an IOException to be thrown, so execution wouldn't
506        // have made it this far).  Maybe plain socket will work in these three
507        // cases.
508        sf = plainClient;
509        s = JavaImpl.connect(null, sf, host, port, local, 0, 15000, null);
510        if (soTimeout != Integer.MIN_VALUE) {
511            s.setSoTimeout(soTimeout);
512        }
513
514        try {
515            // Plain socket worked!  Let's remember that for next time an RMI call
516            // against this host happens.
517            setClient(host, plainClient);
518            String msg = "RMI downgrading from SSL to plain-socket for " + host + " because of " + reasonForPlainSocket;
519            log.warn(msg, reasonForPlainSocket);
520        }
521        catch (GeneralSecurityException gse) {
522            throw new RuntimeException("can't happen because we're using plain socket", gse);
523            // won't happen because we're using plain socket, not SSL.
524        }
525
526        return s;
527    }
528
529
530    public static String getMyDefaultIP() {
531        String anInternetIP = "64.111.122.211";
532        String ip = null;
533        try {
534            DatagramSocket dg = new DatagramSocket();
535            dg.setSoTimeout(250);
536            // 64.111.122.211 is juliusdavies.ca.
537            // This code doesn't actually send any packets (so no firewalls can
538            // get in the way).  It's just a neat trick for getting our
539            // internet-facing interface card.
540            InetAddress addr = Util.toInetAddress(anInternetIP);
541            dg.connect(addr, 12345);
542            InetAddress localAddr = dg.getLocalAddress();
543            ip = localAddr.getHostAddress();
544            // log.debug( "Using bogus UDP socket (" + anInternetIP + ":12345), I think my IP address is: " + ip );
545            dg.close();
546            if (localAddr.isLoopbackAddress() || "0.0.0.0".equals(ip)) {
547                ip = null;
548            }
549        }
550        catch (IOException ioe) {
551            log.debug("Bogus UDP didn't work: " + ioe);
552        }
553        return ip;
554    }
555
556    public static SortedSet getMyInternetFacingIPs() throws SocketException {
557        TreeSet set = new TreeSet();
558        Enumeration en = NetworkInterface.getNetworkInterfaces();
559        while (en.hasMoreElements()) {
560            NetworkInterface ni = (NetworkInterface) en.nextElement();
561            Enumeration en2 = ni.getInetAddresses();
562            while (en2.hasMoreElements()) {
563                InetAddress addr = (InetAddress) en2.nextElement();
564                if (!addr.isLoopbackAddress()) {
565                    String ip = addr.getHostAddress();
566                    String reverse = addr.getHostName();
567                    // IP:
568                    set.add(ip);
569                    // Reverse-Lookup:
570                    set.add(reverse);
571
572                }
573            }
574        }
575        return set;
576    }
577
578}