/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.protocol.saml;

import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.keycloak.dom.saml.v2.assertion.AssertionType;
import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.protocol.ResponseType;
import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder;
import org.keycloak.protocol.saml.SAMLEncryptionAlgorithms;
import org.keycloak.saml.SAML2LoginResponseBuilder;
import org.keycloak.saml.SAMLRequestParser;
import org.keycloak.saml.common.constants.JBossSAMLConstants;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;
import org.keycloak.services.resteasy.ResteasyKeycloakSession;
import org.keycloak.services.resteasy.ResteasyKeycloakSessionFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

public class SamlEncryptionTest {
    private static final KeyPair rsaKeyPair;
    private static final XMLEncryptionUtil.DecryptionKeyLocator keyLocator;

    @BeforeClass
    public static void beforeClass() {
        Cipher cipher = null;
        SecureRandom random = null;
        try {
            random = SecureRandom.getInstance("SHA1PRNG");
            cipher = Cipher.getInstance("RSA/ECB/OAEPPadding");
        }
        catch (NoSuchAlgorithmException | NoSuchPaddingException generalSecurityException) {
            // empty catch block
        }
        Assume.assumeNotNull((Object[])new Object[]{"OAEPPadding not supported", cipher});
        Assume.assumeNotNull((Object[])new Object[]{"SHA1PRNG required for Apache santuario xmlsec", random});
    }

    private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, String keyWrapHashMethod, String keyWrapMgf) throws Exception {
        this.testEncryption(pair, alg, keySize, keyWrapAlg, keyWrapHashMethod, keyWrapMgf, Function.identity());
    }

    private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, String keyWrapHashMethod, String keyWrapMgf, Function<Document, Document> transformer) throws Exception {
        SAML2LoginResponseBuilder builder = new SAML2LoginResponseBuilder();
        builder.requestID("requestId").destination("http://localhost").issuer("issuer").assertionExpiration(300).subjectExpiration(300).sessionExpiration(300).requestIssuer("clientId").authMethod(JBossSAMLURIConstants.AC_UNSPECIFIED.get()).sessionIndex("sessionIndex").nameIdentifier(JBossSAMLURIConstants.NAMEID_FORMAT_UNSPECIFIED.get(), "nameId");
        ResponseType samlModel = builder.buildModel();
        ResteasyKeycloakSession session = new ResteasyKeycloakSession(new ResteasyKeycloakSessionFactory());
        JaxrsSAML2BindingBuilder bindingBuilder = new JaxrsSAML2BindingBuilder((KeycloakSession)session);
        if (alg != null) {
            bindingBuilder.encryptionAlgorithm(alg);
        }
        if (keySize > 0) {
            bindingBuilder.encryptionKeySize(keySize);
        }
        if (keyWrapAlg != null) {
            bindingBuilder.keyEncryptionAlgorithm(keyWrapAlg);
        }
        if (keyWrapHashMethod != null) {
            bindingBuilder.keyEncryptionDigestMethod(keyWrapHashMethod);
        }
        if (keyWrapMgf != null) {
            bindingBuilder.keyEncryptionMgfAlgorithm(keyWrapMgf);
        }
        bindingBuilder.encrypt(pair.getPublic());
        Document samlDocument = builder.buildDocument(samlModel);
        bindingBuilder.postBinding(samlDocument);
        samlDocument = transformer.apply(samlDocument);
        String samlResponse = DocumentUtil.getDocumentAsString((Document)samlDocument);
        SAMLDocumentHolder holder = SAMLRequestParser.parseResponseDocument((byte[])samlResponse.getBytes(StandardCharsets.UTF_8));
        ResponseType responseType = (ResponseType)holder.getSamlObject();
        Assert.assertTrue((String)"Assertion is not encrypted", (boolean)AssertionUtil.isAssertionEncrypted((ResponseType)responseType));
        AssertionUtil.decryptAssertion((ResponseType)responseType, (XMLEncryptionUtil.DecryptionKeyLocator)keyLocator);
        AssertionType assertion = ((ResponseType.RTChoiceType)responseType.getAssertions().get(0)).getAssertion();
        Assert.assertEquals((Object)"issuer", (Object)assertion.getIssuer().getValue());
        MatcherAssert.assertThat((Object)assertion.getSubject().getSubType().getBaseID(), (Matcher)Matchers.instanceOf(NameIDType.class));
        NameIDType nameId = (NameIDType)assertion.getSubject().getSubType().getBaseID();
        Assert.assertEquals((Object)"nameId", (Object)nameId.getValue());
    }

    private Document moveEncryptedKeyToRetrievalMethod(Document doc) {
        NodeList nodes = doc.getElementsByTagNameNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), JBossSAMLConstants.ENCRYPTED_KEY.get());
        Element encKey = (Element)nodes.item(0);
        Element keyInfo = (Element)encKey.getParentNode();
        keyInfo.removeChild(encKey);
        encKey.setAttribute("Id", "encryption-key-123");
        keyInfo.getParentNode().getParentNode().appendChild(encKey);
        Element retrievalMethod = doc.createElementNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), "xenc:RetrievalMethod");
        retrievalMethod.setAttribute("Type", "http://www.w3.org/2001/04/xmlenc#EncryptedKey");
        retrievalMethod.setAttribute("URI", "encryption-key-123");
        keyInfo.appendChild(retrievalMethod);
        return doc;
    }

    @Test
    public void testDefault() throws Exception {
        this.testEncryption(rsaKeyPair, null, -1, null, null, null);
    }

    @Test
    public void testAES256() throws Exception {
        this.testEncryption(rsaKeyPair, "AES", 256, null, null, null);
    }

    @Test
    public void testDefaultKeyWraps() throws Exception {
        for (SAMLEncryptionAlgorithms alg : SAMLEncryptionAlgorithms.values()) {
            for (String keyWrapAlg : alg.getXmlEncIdentifiers()) {
                this.testEncryption(rsaKeyPair, null, -1, keyWrapAlg, null, null);
            }
        }
    }

    @Test
    public void testKeyWrapsWithSha512() throws Exception {
        for (SAMLEncryptionAlgorithms alg : SAMLEncryptionAlgorithms.values()) {
            for (String keyWrapAlg : alg.getXmlEncIdentifiers()) {
                this.testEncryption(rsaKeyPair, null, -1, keyWrapAlg, "http://www.w3.org/2001/04/xmlenc#sha512", null);
            }
        }
    }

    @Test
    public void testRsaOaep11WithSha512AndMgfSha512() throws Exception {
        this.testEncryption(rsaKeyPair, "AES", 256, "http://www.w3.org/2009/xmlenc11#rsa-oaep", "http://www.w3.org/2001/04/xmlenc#sha512", "http://www.w3.org/2009/xmlenc11#mgf1sha512");
    }

    @Test
    public void testEncryptionWithRetrievalMethod() throws Exception {
        this.testEncryption(rsaKeyPair, null, -1, null, null, null, this::moveEncryptedKeyToRetrievalMethod);
    }

    static {
        try {
            KeyPairGenerator rsa = KeyPairGenerator.getInstance("RSA");
            rsa.initialize(2048);
            rsaKeyPair = rsa.generateKeyPair();
        }
        catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
        keyLocator = data -> {
            try {
                Assert.assertNotNull((String)"EncryptedData does not contain KeyInfo", (Object)data.getKeyInfo());
                Assert.assertNotNull((String)"EncryptedData does not contain EncryptedKey", (Object)data.getKeyInfo().itemEncryptedKey(0));
                return Collections.singletonList(rsaKeyPair.getPrivate());
            }
            catch (XMLSecurityException e) {
                throw new IllegalArgumentException("EncryptedData does not contain KeyInfo ", e);
            }
        };
    }
}

