/*
 * @(#)GssKerberosV5.java	1.2 00/02/16
 *
 * Copyright 2000 Sun Microsystems, Inc. All Rights Reserved.
 *
 * Sun grants you ("Licensee") a non-exclusive, royalty free,
 * license to use, modify and redistribute this software in source and
 * binary code form, provided that i) this copyright notice and license
 * appear on all copies of the software; and ii) Licensee does not
 * utilize the software in a manner which is disparaging to Sun.
 *
 * This software is provided "AS IS," without a warranty of any
 * kind. ALL EXPRESS OR IMPLIED CONDITIONS, REPRESENTATIONS AND
 * WARRANTIES, INCLUDING ANY IMPLIED WARRANTY OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT, ARE
 * HEREBY EXCLUDED.  SUN AND ITS LICENSORS SHALL NOT BE LIABLE
 * FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING,
 * MODIFYING OR DISTRIBUTING THE SOFTWARE OR ITS DERIVATIVES. IN
 * NO EVENT WILL SUN OR ITS LICENSORS BE LIABLE FOR ANY LOST
 * REVENUE, PROFIT OR DATA, OR FOR DIRECT, INDIRECT, SPECIAL,
 * CONSEQUENTIAL, INCIDENTAL OR PUNITIVE DAMAGES, HOWEVER
 * CAUSED AND REGARDLESS OF THE THEORY OF LIABILITY, ARISING OUT
 * OF THE USE OF OR INABILITY TO USE SOFTWARE, EVEN IF SUN HAS
 * BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
 *
 * This software is not designed or intended for use in on-line
 * control of aircraft, air traffic, aircraft navigation or aircraft
 * communications; or in the design, construction, operation or
 * maintenance of any nuclear facility. Licensee represents and warrants
 * that it will not use or redistribute the Software for such purposes.
 */

package nk;

//import com.sun.security.sasl.preview.*;
import com.netscape.sasl.*;
import java.io.*;
import java.util.Hashtable;

// JAAS
import javax.security.auth.callback.CallbackHandler;

// JGSS from DSTC; implements draft -00 of JGSS API
import org.ietf.jgss.*;

/**
  * Implements the GSSAPI SASL mechanism for Kerberos V5.
  * (<A HREF="ftp://ftp.isi.edu/in-notes/rfc2222.txt">RFC 2222</A>,
  * <a HREF="http://www.ietf.org/internet-drafts/draft-ietf-cat-sasl-gssapi-00.txt">draft-ietf-cat-sasl-gssapi-00.txt</a>).
  *
  * @author Rosanna Lee
  * @author modified for use with Netscape LDAPJDK by Norbert Klasen
  */
public class GssKerberosV5 implements SaslClient {
    protected static final  String PKGNAME = "nk";
    protected static boolean debug = false;

    protected boolean completed = false;
    protected GSSContext secCtx = null;
    protected boolean privacy = false;
    protected boolean integrity = false;
    protected int qop = 0;	  // 0 means default quality of protection
    protected int minKeyLen = 0;  // 0=none; 1=integrity only; >1=integ + priv
    protected int maxKeyLen = 256;
    protected int sendMaxBufSize = 0;    // specified by peer but can override
    protected int recvMaxBufSize = 4096;
    protected int rawSendSize;
    protected byte protections;   // security layers we support

  /**
   * Creates OIDs without throwing GSSException.
   * Note: Needed to define public static final Oids in GSS-API classes.
   */
  private static Oid createOid(String oid)
  {
    try {
      return new Oid(oid);
    }
    catch (GSSException gsse) {
      return null;
    }
  }

    private boolean finalHandshake = false;
    private boolean mutual = false;       // default false
    static final private String MUTUAL_AUTH =
    	"javax.security.sasl.server.authentication";


    /**
     * Creates a SASL mechanism with client credentials that it needs
     * to participate in GSS-API/Kerberos v5 authentication exchange
     * with the server.
     */
    public GssKerberosV5(String authzID, String protocol, String serverName,
	    Hashtable props, CallbackHandler cbh) throws SaslException
    {
        debug = System.getProperty(PKGNAME+".debug", "false").equalsIgnoreCase("true");

        // Parse properties  to set desired context options
        if (props != null) {
            // Minimum key length
            String prop = (String)props.get(MIN_KEY);
            if (prop != null) {
                try {
                    minKeyLen = Integer.parseInt(prop);
                } catch (NumberFormatException e) {
                    throw new SaslException(
                        "Property must be string representation of integer: " +
                        MIN_KEY);
                }
            }

            // Maximum key length
            prop = (String)props.get(MAX_KEY);
            if (prop != null) {
                try {
                    maxKeyLen = Integer.parseInt(prop);
                } catch (NumberFormatException e) {
                    throw new SaslException(
                        "Property must be string representation of integer: " +
                        MAX_KEY);
                }
                // set quality of protection if user specifies max key length
                // otherwise, defaults to mechanism default
                qop = maxKeyLen;
            }

            if (maxKeyLen == 0) {
                protections = NO_PROTECTION;
            } else if (maxKeyLen == 1) {
                protections = INTEGRITY_ONLY_PROTECTION;
                if (minKeyLen == 0) {
                    // no protection also acceptable
                    protections |= NO_PROTECTION;
                }
            } else {
                protections = PRIVACY_PROTECTION;
                if (minKeyLen == 0) {
                    protections |= NO_PROTECTION|INTEGRITY_ONLY_PROTECTION;
                } else if (minKeyLen == 1) {
                    protections |= INTEGRITY_ONLY_PROTECTION;
                }
            }

            if (debug) {
                System.err.println("client protections: " + protections);
            }

            // max send buf size
            prop = (String)props.get(MAX_SEND_BUF);
            if (prop != null) {
                try {
                    sendMaxBufSize = Integer.parseInt(prop);
                } catch (NumberFormatException e) {
                    throw new SaslException(
                        "Property must be string representation of integer: " +
                        MAX_SEND_BUF);
                }
            }

            // max recv buf size
            prop = (String)props.get(MAX_RECV_BUF);
            if (prop != null) {
                try {
                    recvMaxBufSize = Integer.parseInt(prop);
                } catch (NumberFormatException e) {
                    throw new SaslException(
                        "Property must be string representation of integer: " +
                        MAX_RECV_BUF);
                }
            }
        }

        String service = protocol.toUpperCase() + "@" + serverName;

        try {
            GSSManager mgr = GSSManager.getInstance();

            // Create the name for the requested service entity for Krb5 mech
            GSSName acceptorName = mgr.createName(service,
                GSSName.NT_HOSTBASED_SERVICE, KRB5_OID);

            // Create a context using default credentials for Krb5 mech
            secCtx = mgr.createContext(acceptorName,
                KRB5_OID, /* mechanism */
                null,       /* default credentials */
                GSSContext.INDEFINITE_LIFETIME);

            // Parse properties  to set desired context options
            if (props != null) {
                // Mutual authentication
                String prop = (String)props.get(MUTUAL_AUTH);
                if (prop != null) {
                    mutual = "true".equalsIgnoreCase(prop);
                }
            }
            secCtx.requestMutualAuth(mutual);

            if ((protections&INTEGRITY_ONLY_PROTECTION) != 0) {
                // Might need integrity
                secCtx.requestInteg(true);
            }

            if ((protections&PRIVACY_PROTECTION) != 0) {
                // Might need privacy
                secCtx.requestConf(true);
            }

            // %%% The final handshake breaks in Active Directory
            // if we don't set confidentiality and integrity here
            secCtx.requestConf(true);
            secCtx.requestInteg(true);

        } catch (GSSException e) {
            throw new SaslException("Failure to initialize security context: " +
                e.getMessage(), e);
        }
    }

//    Netscape LDAPPJDK uses createInitialResponse() instead
//    public boolean hasInitialResponse() {
//    	return true;
//    }

    /**
     * Retrieves this mechanism's name.
     *
     * @return  The string "GSSAPI".
     */
    public String getMechanismName() {
    	return "GSSAPI";
    }

    /**
     * Determines whether this mechanism has completed.
     * GSSAPI completes when server returns GSS_S_COMPLETE.
     *
     * @return true if has completed; false otherwise;
     */
    public boolean isComplete() {
	    return completed;
    }

    /**
     * Netscape LDAPJDK calls this function to produce initial response
     *
     * @return byte array containing the initial response
     */
    public byte[] createInitialResponse() throws SaslException {
        return evaluateChallenge(new byte[0]);
    }

    /**
     * Processes the challenge data.
     *
     * The server sends a challenge data using which the client must
     * process using GSS_Init_sec_context.
     * As per RFC 2222, when GSS_S_COMPLETE is returned, we do
     * an extra handshake to determine the negotiated security protection
     * and buffer sizes.
     *
     * @param challengeData A non-null byte array containing the
     * challenge data from the server.
     * <i>Note: Netscape LDAPJDK calls evaluateChallenge wiht output from
     * LDAP_SUCCESS bind response, i.e.: null</i>
     * @return A non-null byte array containing the response to be
     * sent to the server.
     */
    public byte[] evaluateChallenge(byte[] challengeData) throws SaslException {
        if ( completed ) {
            if ( challengeData == null ) {
                // evaluateChallenge is called with output of LDAP_SUCCESS
                // bind response, i.e.: null
                return new byte[0];
            }
            throw new SaslException("SASL authentication already complete");
        }

        if (finalHandshake) {
            return doFinalHandshake(challengeData);
        } else {
            // Security context not established yet; continue with init
            try {
                byte[] gssOutToken = secCtx.initSecContext(challengeData,
                    0, challengeData.length);

                if (secCtx.isEstablished()) {
                    finalHandshake = true;
                    if (gssOutToken == null) {
                    // RFC 2222 7.2.1:  Client responds with no data
                    return new byte[0];
                    }
                }

                return gssOutToken;
            } catch (GSSException e) {
                if (debug) {
                    System.err.println(e.getMessage());
                }
                throw new SaslException("GSS initiate failed: "+
                    e.getMessage(), e);
            }
        }
    }

    /**
     * Security context already established. challengeData
     * should contain security layers and server's maximum buffer size
     *
     * @param challengeData
     * @return responseData
     */
    public byte[] doFinalHandshake(byte[] challengeData) throws SaslException {
        try {
            if (debug) {
                System.err.println("size: " + challengeData.length);
//                for (int i = 0; i < challengeData.length; i++) {
//                    System.err.print(
//                        Integer.toHexString( challengeData[i]&0xFF ) + " ");
//                    if ( i > 0 && i % 16 == 0) {
//                        System.err.println();
//                    }
//                }
//                System.err.println();
            }

            // %%% Active Directory expects an extra empty exchange
            if (challengeData.length == 0) {
                return new byte[0];
            }

            byte[] gssOutToken = secCtx.unwrap(challengeData, 0,
                challengeData.length, new MessageProp(0, false));

            // First octet is a bit-mask specifying the protections
            // supported by the server
            // Client selects highest available protection requested
            byte commonProtections = (byte)(protections&gssOutToken[0]);
            byte selectedProtection;

            if (debug) {
                System.err.println("Server protections: " + gssOutToken[0]);
            }

            if ((commonProtections&PRIVACY_PROTECTION) != 0) {
                selectedProtection = PRIVACY_PROTECTION;
                privacy = true;
                integrity = true;
            } else if ((commonProtections&INTEGRITY_ONLY_PROTECTION) != 0) {
                selectedProtection = INTEGRITY_ONLY_PROTECTION;
                privacy = false;
                integrity = true;
            } else if ((commonProtections&NO_PROTECTION) != 0) {
                selectedProtection = NO_PROTECTION;
                privacy = false;
                integrity = false;
            } else {
                throw new SaslException(
                    "No common protection layer between client and server");
            }

            // 2nd-4th octets specifies maximum buffer size expected by
            // server (in network byte order)
            int srvMaxBufSize = networkByteOrderToInt(gssOutToken, 1, 3);

            // Determine the max send buffer size based on what the
            // server is able to receive and our specified max
            sendMaxBufSize = (sendMaxBufSize == 0) ? srvMaxBufSize :
                Math.min(sendMaxBufSize, srvMaxBufSize);

            // Update context to limit size of returned buffer
            rawSendSize = secCtx.getWrapSizeLimit(qop, privacy, sendMaxBufSize);

            // %%% getWrapSizeLimit() always returns 0 in DSTC code
            if (rawSendSize <= 0 || rawSendSize > 0xFFFFFF) {
                if (debug) {
                    System.err.println("getWrapSizeLimit returned: " +
                        rawSendSize + ", setting rawSendSize to: 4096");
                }
                rawSendSize = 4096;
            }

            if (debug) {
                System.err.println("server max recv size: " + srvMaxBufSize);
                System.err.println("rawSendSize: " + rawSendSize);
            }

            // Construct negotiated security layers and client's max
            // receive buffer size
            byte[] gssInToken = new byte[4];
            gssInToken[0] = selectedProtection;

            if (debug) {
                System.err.println("selected protection: " + selectedProtection);
                System.err.println("privacy: " + privacy);
                System.err.println("integrity: " + integrity);
            }

            intToNetworkByteOrder(recvMaxBufSize, gssInToken, 1, 3);

            gssOutToken = secCtx.wrap(gssInToken,
                0, gssInToken.length,
                new MessageProp(0 /* gop */, false /* privacy */));

            completed = true;  // server authenticated

            return gssOutToken;
        } catch (GSSException e) {
            if (debug) {
                System.err.println(e.getMessage());
            }
            throw new SaslException("Final handshake failed: " +
                e.getMessage(), e);
        }
    }

    /**
      * Returns the input stream from which to read SASL buffers.
      * If neither privacy nor integrity is needed, this is the identity function.
      * Otherwise, return a stream that does GSSContext.wrap().
      *
      * @return <tt>src</tt>
      * @throws IOException If this method is called before the client has
      *   has completed.
      */
    public InputStream getInputStream(InputStream src) throws IOException {
        if (completed) {
            if (!privacy && !integrity) {
                return src;    // Don't need a different stream
            } else {
                // Create stream that does secCtx.unwrap()
                return new GssInputStream(src);
            }
        } else {
            throw new SaslException("Not completed");
        }
    }

    /**
      * Returns the output stream to which to write data to be encapsulated
      * inside a SASL buffer for transmission to the server.
      * If neither privacy nor integrity is needed, this is the identity function.
      * Otherwise, return a stream that does GSSContext.unwrap().
      *
      * @return <tt>dest</tt>
      * @throws IOException If this method is called before the client has
      *   has completed.
      */
    public OutputStream getOutputStream(OutputStream dest) throws IOException {
        if (completed) {
            if (!privacy && !integrity) {
                return dest;   // Don't need a different stream
            } else {
                // Create stream that does secCtx.wrap()
                return new GssOutputStream(dest, rawSendSize  );
            }
        } else {
            throw new SaslException("Not completed");
        }
    }

    /**
     * converts int vaule to octets in network byte order (little endian)
     * representation
     *
     * @param num int value to convert
     * @param buf buffer to write octets to
     * @param start start position in buf
     * @param count number of octets to read from buf
     */
    protected static void intToNetworkByteOrder(int num, byte[] buf,
        int start, int count)
    {
        if (count > 4) {
            throw new IllegalArgumentException("Cannot handle more than 4 bytes");
        }

        for (int i = count-1; i >= 0; i--) {
            buf[start+i] = (byte)(num & 0xff);
            num >>>= 8;
        }
    }

    /**
     * converts octets in network byte order (little endian) representation
     * to int.
     *
     * @param buf buffer containing octets
     * @param start start position in buf
     * @param count number of octets (max 4) to read from buf
     * @return the int value of buf
     */
    protected static int networkByteOrderToInt(byte[] buf,
        int start, int count)
    {
        if (count > 4) {
            throw new IllegalArgumentException("Cannot handle more than 4 bytes");
        }

        int answer = 0;
        for (int i = 0; i < count; i++) {
            answer <<= 8;
            answer |= ((int)buf[start+i] & 0xff);
        }
        return answer;
    }

    // ---------------- property names -----------------
    // default 0 (no protection); 1 (integrity only)
    static final protected String MIN_KEY = "javax.security.sasl.encryption.minimum";

    // default 256
    static final protected String MAX_KEY = "javax.security.sasl.encryption.maximum";

    //
    static final protected String MAX_SEND_BUF = "javax.security.sasl.buffer.send";
    static final protected String MAX_RECV_BUF = "javax.security.sasl.maxbuffer";

    static final protected String KRB5_OID_STR = "1.2.840.113554.1.2.2";
    static final protected Oid KRB5_OID = createOid(KRB5_OID_STR);

    static final protected byte NO_PROTECTION = (byte)1;
    static final protected byte INTEGRITY_ONLY_PROTECTION = (byte)2;
    static final protected byte PRIVACY_PROTECTION = (byte)4;

    // --------------- utility classes ---------------

    class GssInputStream extends InputStream {
        private MessageProp msgProp;    // QOP and privacy for unwrap
        private byte[] buf;	        // buffer for storing processed bytes
        private int bufPos;		// read position in buf
        private byte[] lenBuf = new byte[4];  // buffer for storing length
        private InputStream in;		// underlying input stream

        GssInputStream(InputStream in) {
            super();
            this.in = in;
            msgProp = new MessageProp(qop, privacy);
            buf = new byte[recvMaxBufSize];
        }

        public int read() throws IOException {
            byte[] inBuf = new byte[1];
            int count = read(inBuf, 0, 1);
            if (count > 0) {
                return inBuf[0];
            } else {
                throw new EOFException();
            }
        }

        public int read(byte[] inBuf, int start, int count) throws IOException {
            if (bufPos >= buf.length) {
                fill();   // read next SASL buffer
            }

            int avail = buf.length - bufPos;
            if (count > avail) {
                // Requesting more that we have stored
                // Return all that we have; next invocation of read() will
                // trigger fill()
                System.arraycopy(buf, bufPos, inBuf, start, avail);
                bufPos = buf.length;
                return avail;
            } else {
                // Requesting less than we have stored
                // Return all that was requested
                System.arraycopy(buf, bufPos, inBuf, start, count);
                bufPos += count;
                return count;
            }
        }

        /**
         * Fills the buf with more data by reading a SASL buffer, unwrapping it,
         * and leaving the bytes in buf for read() to return.
         */
        private void fill() throws IOException {
            // Read in length of buffer
            readFully(lenBuf, 4);
            int len = networkByteOrderToInt(lenBuf, 0, 4);

            if (debug) {
                System.err.println("reading " + len + " bytes from network");
            }

            // Read SASL buffer
            byte[] saslBuffer = new byte[len];
            readFully(saslBuffer, len);

            // Unwrap
            try {
                buf = secCtx.unwrap(saslBuffer, 0, len, msgProp);
            } catch (GSSException e) {
                throw new SaslException("Problems unwrapping SASL buffer. " +
                    e.getMessage(), e);
            }

            if (buf.length > recvMaxBufSize) {
                throw new SaslException(
                    "GSS unwrap returned a buffer size (" + buf.length +
                    ") that exceeds the negotiated limit:" + recvMaxBufSize);
            }
            bufPos = 0;
        }

        /**
         * Read requested number of bytes before returning.
         */
        private void readFully(byte[] inBuf, int total) throws IOException {
            int count, pos = 0;

            if (debug) {
            System.err.println("readFully " + total + " from " + in);
            }

            while (total > 0) {
            count = in.read(inBuf, pos, total);

            if (debug) {
                System.err.println("readFully read" + count);
            }

            if (count == -1 ) {
                throw new EOFException();
            }
            pos += count;
            total -= count;
            }
        }

        public int available() throws IOException {
            return buf.length - bufPos;
        }

        public void close() throws IOException {
            in.close();
        }
    }

    class GssOutputStream extends BufferedOutputStream {
        private MessageProp msgProp;    // QOP and privacy for wrap
        private byte[] lenBuf = new byte[4];  // buffer for storing length

        GssOutputStream(OutputStream out, int size) {
            super(out, size);

            if (debug) {
                System.err.println("GssOutputStream: " + out + ", bufsize: " + size );
            }
            msgProp = new MessageProp(qop, privacy);
        }

        public void write(int b) throws IOException {
            byte[] buffer = new byte[1];
            buffer[0] = (byte)b;
            write(buffer, 0, 1);
        }

        public void write(byte[] b, int off, int len)
            throws IOException {

            // "Packetize" buffer to be within rawSendSize
            if (debug) {
                System.err.println("Total size: " + len);
            }

            while ( len > 0 ) {
                int freeBufSpace = buf.length - count;
                int n = len < freeBufSpace ? len : freeBufSpace;
                System.arraycopy( b, off, buf, count, n );
                len -= n;
                if ( (count += n) == buf.length ) {
                    flush();
                }
            }
        }

        public void flush() throws IOException {
            byte[] gssOutToken;

            try {
                gssOutToken = secCtx.wrap(buf, 0, count, msgProp);
            } catch (GSSException e) {
                if ( debug ) {
                    System.err.println( "Problem performing GSS wrap: " +
                        e.getMessage() );
                }
                throw new SaslException("Problem performing GSS wrap: " +
                    e.getMessage(), e);
            }

            // Write out length
            intToNetworkByteOrder(gssOutToken.length, lenBuf, 0, 4);

            if (debug) {
                System.err.println("sending size: " + gssOutToken.length);
            }
            out.write(lenBuf, 0, 4);

            // Write out GSS token
            out.write(gssOutToken, 0, gssOutToken.length);
            count = 0;
        }

    }

}

