/**
 *  Copyright Notice:
 *  Copyright 2021-2022 DMTF. All rights reserved.
 *  License: BSD 3-Clause License. For full text see link: https://github.com/DMTF/libspdm/blob/main/LICENSE.md
 **/

#include "spdm_unit_test.h"
#include "internal/libspdm_requester_lib.h"
#include "internal/libspdm_secured_message_lib.h"

#if LIBSPDM_ENABLE_CAPABILITY_SET_CERT_CAP

libspdm_return_t libspdm_requester_set_certificate_test_send_message(
    void *spdm_context, size_t request_size, const void *request,
    uint64_t timeout)
{
    libspdm_test_context_t *spdm_test_context;
    uint32_t *session_id;
    libspdm_session_info_t *session_info;
    bool is_app_message;
    uint8_t *app_message;
    size_t app_message_size;
    uint8_t message_buffer[LIBSPDM_SENDER_BUFFER_SIZE];
    spdm_set_certificate_request_t *spdm_request;

    memcpy(message_buffer, request, request_size);
    libspdm_transport_test_decode_message(spdm_context, &session_id, &is_app_message, true,
                                          request_size, message_buffer, &app_message_size,
                                          (void **)&app_message);

    spdm_test_context = libspdm_get_test_context();
    switch (spdm_test_context->case_id) {
    case 0x1:
        return LIBSPDM_STATUS_SEND_FAIL;
    case 0x2:
        spdm_request = (spdm_set_certificate_request_t *)app_message;

        assert_int_equal(spdm_request->header.spdm_version, SPDM_MESSAGE_VERSION_12);
        assert_int_equal(spdm_request->header.request_response_code, SPDM_SET_CERTIFICATE);
        assert_int_equal(spdm_request->header.param1 & SPDM_SET_CERTIFICATE_REQUEST_SLOT_ID_MASK,
                         0);
        assert_int_equal(spdm_request->header.param2, 0);

        return LIBSPDM_STATUS_SUCCESS;
    case 0x3:
        assert_true(false);

        return LIBSPDM_STATUS_SUCCESS;
    case 0x4:
        return LIBSPDM_STATUS_SUCCESS;
    case 0x5:
        session_id = NULL;
        session_info = libspdm_get_session_info_via_session_id(spdm_context, 0xFFFFFFFF);
        if (session_info == NULL) {
            return LIBSPDM_STATUS_SEND_FAIL;
        }
        LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO, "Request (0x%zx):\n", request_size));
        libspdm_dump_hex(request, request_size);

        libspdm_get_scratch_buffer (spdm_context, (void **)&app_message, &app_message_size);
        libspdm_transport_test_decode_message(
            spdm_context, &session_id, &is_app_message,
            false, request_size, message_buffer,
            &app_message_size, (void **)&app_message);

        /* WALKAROUND: If just use single context to encode
         * message and then decode message */
        ((libspdm_secured_message_context_t
          *)(session_info->secured_message_context))
        ->application_secret.request_data_sequence_number--;

        return LIBSPDM_STATUS_SUCCESS;
    case 0x6:
        spdm_request = (spdm_set_certificate_request_t *)app_message;

        assert_int_equal(spdm_request->header.spdm_version, SPDM_MESSAGE_VERSION_12);
        assert_int_equal(spdm_request->header.request_response_code, SPDM_SET_CERTIFICATE);
        assert_int_equal(spdm_request->header.param1 & SPDM_SET_CERTIFICATE_REQUEST_SLOT_ID_MASK,
                         0);
        assert_int_equal(spdm_request->header.param2, 0);

        return LIBSPDM_STATUS_SUCCESS;
    case 0x7:
        spdm_request = (spdm_set_certificate_request_t *)app_message;

        assert_int_equal(spdm_request->header.spdm_version, SPDM_MESSAGE_VERSION_13);
        assert_int_equal(spdm_request->header.request_response_code, SPDM_SET_CERTIFICATE);
        assert_int_equal(spdm_request->header.param1 & SPDM_SET_CERTIFICATE_REQUEST_SLOT_ID_MASK,
                         0);
        assert_int_equal((spdm_request->header.param1 &
                          SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_CERT_MODEL_MASK) >>
                         SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_CERT_MODEL_OFFSET,
                         SPDM_CERTIFICATE_INFO_CERT_MODEL_NONE);
        assert_int_equal(spdm_request->header.param1 &
                         SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_ERASE,
                         SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_ERASE);
        assert_int_equal(spdm_request->header.param2, 0);

        return LIBSPDM_STATUS_SUCCESS;
    case 0x8:
        assert_true(false);
        return LIBSPDM_STATUS_SEND_FAIL;
    case 0x9:
        spdm_request = (spdm_set_certificate_request_t *)app_message;

        assert_int_equal(spdm_request->header.spdm_version, SPDM_MESSAGE_VERSION_13);
        assert_int_equal(spdm_request->header.request_response_code, SPDM_SET_CERTIFICATE);
        assert_int_equal(spdm_request->header.param1 & SPDM_SET_CERTIFICATE_REQUEST_SLOT_ID_MASK,
                         3);
        assert_int_equal((spdm_request->header.param1 &
                          SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_CERT_MODEL_MASK) >>
                         SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_CERT_MODEL_OFFSET,
                         SPDM_CERTIFICATE_INFO_CERT_MODEL_DEVICE_CERT);
        assert_int_equal(spdm_request->header.param2, 1);

        return LIBSPDM_STATUS_SUCCESS;

    default:
        return LIBSPDM_STATUS_SEND_FAIL;
    }
}

libspdm_return_t libspdm_requester_set_certificate_test_receive_message(
    void *spdm_context, size_t *response_size,
    void **response, uint64_t timeout)
{
    libspdm_test_context_t *spdm_test_context;

    spdm_test_context = libspdm_get_test_context();
    switch (spdm_test_context->case_id) {
    case 0x1:
        return LIBSPDM_STATUS_RECEIVE_FAIL;

    case 0x2: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_12;
        spdm_response->header.request_response_code = SPDM_SET_CERTIFICATE_RSP;
        spdm_response->header.param1 = 0;
        spdm_response->header.param2 = 0;

        libspdm_transport_test_encode_message(spdm_context, NULL, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
    }
        return LIBSPDM_STATUS_SUCCESS;

    case 0x3: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_12;
        spdm_response->header.request_response_code = SPDM_SET_CERTIFICATE_RSP;
        spdm_response->header.param1 = 0;
        spdm_response->header.param2 = 0;

        libspdm_transport_test_encode_message(spdm_context, NULL, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
    }
        return LIBSPDM_STATUS_SUCCESS;

    case 0x4: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_12;
        spdm_response->header.request_response_code = SPDM_SET_CERTIFICATE_RSP;
        spdm_response->header.param1 = 1;
        spdm_response->header.param2 = 0;

        libspdm_transport_test_encode_message(spdm_context, NULL, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
    }
        return LIBSPDM_STATUS_SUCCESS;

    case 0x05: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;
        uint32_t session_id;
        libspdm_session_info_t *session_info;
        uint8_t *scratch_buffer;
        size_t scratch_buffer_size;

        session_id = 0xFFFFFFFF;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_12;
        spdm_response->header.request_response_code = SPDM_SET_CERTIFICATE_RSP;
        /*slot id is 1*/
        spdm_response->header.param1 = 1;
        spdm_response->header.param2 = 0;

        /* For secure message, message is in sender buffer, we need copy it to scratch buffer.
         * transport_message is always in sender buffer. */
        libspdm_get_scratch_buffer (spdm_context, (void **)&scratch_buffer, &scratch_buffer_size);
        libspdm_copy_mem (scratch_buffer + transport_header_size,
                          scratch_buffer_size - transport_header_size,
                          spdm_response, spdm_response_size);
        spdm_response = (void *)(scratch_buffer + transport_header_size);
        libspdm_transport_test_encode_message(spdm_context, &session_id, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
        session_info = libspdm_get_session_info_via_session_id(
            spdm_context, session_id);
        if (session_info == NULL) {
            return LIBSPDM_STATUS_RECEIVE_FAIL;
        }
        /* WALKAROUND: If just use single context to encode message and then decode message */
        ((libspdm_secured_message_context_t
          *)(session_info->secured_message_context))
        ->application_secret.response_data_sequence_number--;
    }
        return LIBSPDM_STATUS_SUCCESS;

    case 0x6: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_12;
        spdm_response->header.request_response_code = SPDM_ERROR;
        spdm_response->header.param1 = SPDM_ERROR_CODE_RESET_REQUIRED;
        spdm_response->header.param2 = 0;

        libspdm_transport_test_encode_message(spdm_context, NULL, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
    }
        return LIBSPDM_STATUS_SUCCESS;

    case 0x7: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_13;
        spdm_response->header.request_response_code = SPDM_SET_CERTIFICATE_RSP;
        spdm_response->header.param1 = 0;
        spdm_response->header.param2 = 0;

        libspdm_transport_test_encode_message(spdm_context, NULL, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
    }
        return LIBSPDM_STATUS_SUCCESS;

    case 0x8:
        assert_true(false);
        return LIBSPDM_STATUS_RECEIVE_FAIL;

    case 0x9: {
        spdm_set_certificate_response_t *spdm_response;
        size_t spdm_response_size;
        size_t transport_header_size;

        spdm_response_size = sizeof(spdm_set_certificate_response_t);
        transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE;
        spdm_response = (void *)((uint8_t *)*response + transport_header_size);

        spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_13;
        spdm_response->header.request_response_code = SPDM_SET_CERTIFICATE_RSP;
        spdm_response->header.param1 = 3;
        spdm_response->header.param2 = 0;

        libspdm_transport_test_encode_message(spdm_context, NULL, false,
                                              false, spdm_response_size,
                                              spdm_response, response_size,
                                              response);
    }
        return LIBSPDM_STATUS_SUCCESS;

    default:
        return LIBSPDM_STATUS_RECEIVE_FAIL;
    }
}


/**
 * Test 1: message could not be sent
 * Expected Behavior: get a RETURN_DEVICE_ERROR return code
 **/
void libspdm_test_requester_set_certificate_case1(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    void *data;
    size_t data_size;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x1;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_12 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state =
        LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;

    libspdm_read_responder_public_certificate_chain(m_libspdm_use_hash_algo,
                                                    m_libspdm_use_asym_algo,
                                                    &data, &data_size, NULL, NULL);

    status = libspdm_set_certificate(spdm_context, NULL, 0, data, data_size);

    assert_int_equal(status, LIBSPDM_STATUS_SEND_FAIL);
    free(data);
}

/**
 * Test 2: Successful response to set certificate for slot 0
 * Expected Behavior: get a RETURN_SUCCESS return code
 **/
void libspdm_test_requester_set_certificate_case2(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    void *data;
    size_t data_size;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x2;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_12 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state =
        LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;

    libspdm_read_responder_public_certificate_chain(m_libspdm_use_hash_algo,
                                                    m_libspdm_use_asym_algo,
                                                    &data, &data_size, NULL, NULL);

    status = libspdm_set_certificate(spdm_context, NULL, 0, data, data_size);

    assert_int_equal(status, LIBSPDM_STATUS_SUCCESS);
    free(data);
}

/**
 * Test 3: Unsuccessful response to set certificate for slot 0, because cert_chain is NULL.
 * Expected Behavior: get a LIBSPDM_STATUS_INVALID_PARAMETER return code
 **/
void libspdm_test_requester_set_certificate_case3(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x3;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_12 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state =
        LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;

    status = libspdm_set_certificate(spdm_context, NULL, 0, NULL, 0);

    assert_int_equal(status, LIBSPDM_STATUS_INVALID_PARAMETER);
}

/**
 * Test 5: Successful response to set certificate for slot 1 in secure session
 * Expected Behavior: get a RETURN_SUCCESS return code
 **/
void libspdm_test_requester_set_certificate_case5(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;
    uint32_t session_id;
    libspdm_session_info_t *session_info;

    void *data;
    size_t data_size;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x05;
    spdm_context->connection_info.connection_state =
        LIBSPDM_CONNECTION_STATE_AUTHENTICATED;

    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_MEAS_CAP_SIG;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_PSK_CAP;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_ENCRYPT_CAP;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_MAC_CAP;
    spdm_context->local_context.capability.flags |=
        SPDM_GET_CAPABILITIES_REQUEST_FLAGS_PSK_CAP;
    spdm_context->local_context.capability.flags |=
        SPDM_GET_CAPABILITIES_REQUEST_FLAGS_ENCRYPT_CAP;
    spdm_context->local_context.capability.flags |=
        SPDM_GET_CAPABILITIES_REQUEST_FLAGS_MAC_CAP;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;
    spdm_context->connection_info.algorithm.dhe_named_group =
        m_libspdm_use_dhe_algo;
    spdm_context->connection_info.algorithm.aead_cipher_suite =
        m_libspdm_use_aead_algo;

    libspdm_read_responder_public_certificate_chain(m_libspdm_use_hash_algo,
                                                    m_libspdm_use_asym_algo, &data,
                                                    &data_size, NULL, NULL);

    session_id = 0xFFFFFFFF;
    session_info = &spdm_context->session_info[0];
    libspdm_session_info_init(spdm_context, session_info, session_id, true);
    libspdm_secured_message_set_session_state(session_info->secured_message_context,
                                              LIBSPDM_SESSION_STATE_ESTABLISHED);

    /* slot id is 1*/
    status = libspdm_set_certificate(spdm_context, &session_id, 1, data, data_size);
    assert_int_equal(status, LIBSPDM_STATUS_SUCCESS);

    free(data);
}

/**
 * Test 6: Successful response to set certificate for slot 0 with a reset required
 * Expected Behavior: get a RETURN_SUCCESS return code
 **/
void libspdm_test_requester_set_certificate_case6(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    void *data;
    size_t data_size;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x6;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_12 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state =
        LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP |
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_CERT_INSTALL_RESET_CAP;

    libspdm_read_responder_public_certificate_chain(m_libspdm_use_hash_algo,
                                                    m_libspdm_use_asym_algo,
                                                    &data, &data_size, NULL, NULL);

    status = libspdm_set_certificate(spdm_context, NULL, 0, data, data_size);

    assert_int_equal(status, LIBSPDM_STATUS_RESET_REQUIRED_PEER);
    free(data);
}

/**
 * Test 7: Successful response to erase certificate for slot 0
 * Expected Behavior: get a RETURN_SUCCESS return code
 **/
void libspdm_test_requester_set_certificate_case7(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x7;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_13 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state =
        LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;

    status = libspdm_set_certificate_ex(spdm_context, NULL, 0, NULL, 0,
                                        SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_ERASE, 0);

    assert_int_equal(status, LIBSPDM_STATUS_SUCCESS);
}

/**
 * Test 8: Illegal combination of MULTI_KEY_CONN_RSP = true, Erase = false, and SetCertModel = 0.
 * Expected Behavior: function returns LIBSPDM_STATUS_INVALID_PARAMETER.
 **/
void libspdm_test_requester_set_certificate_case8(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    void *data;
    size_t data_size;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x8;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_13 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state = LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;
    spdm_context->connection_info.multi_key_conn_rsp = true;

    libspdm_read_responder_public_certificate_chain(m_libspdm_use_hash_algo,
                                                    m_libspdm_use_asym_algo,
                                                    &data, &data_size, NULL, NULL);

    status = libspdm_set_certificate_ex(spdm_context, NULL, 0, data, data_size,
                                        (SPDM_CERTIFICATE_INFO_CERT_MODEL_NONE <<
                                         SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_CERT_MODEL_OFFSET),
                                        1);

    assert_int_equal(status, LIBSPDM_STATUS_INVALID_PARAMETER);
}

/**
 * Test 9: Set MULTI_KEY_CONN_RSP = true, Erase = false, and SetCertModel = DeviceCert.
 * Expected Behavior: function returns LIBSPDM_STATUS_SUCCESS.
 **/
void libspdm_test_requester_set_certificate_case9(void **state)
{
    libspdm_return_t status;
    libspdm_test_context_t *spdm_test_context;
    libspdm_context_t *spdm_context;

    void *data;
    size_t data_size;

    spdm_test_context = *state;
    spdm_context = spdm_test_context->spdm_context;
    spdm_test_context->case_id = 0x9;
    spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_13 <<
                                            SPDM_VERSION_NUMBER_SHIFT_BIT;

    spdm_context->connection_info.connection_state = LIBSPDM_CONNECTION_STATE_NEGOTIATED;
    spdm_context->connection_info.capability.flags |=
        SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_SET_CERT_CAP;
    spdm_context->connection_info.multi_key_conn_rsp = true;

    libspdm_read_responder_public_certificate_chain(m_libspdm_use_hash_algo,
                                                    m_libspdm_use_asym_algo,
                                                    &data, &data_size, NULL, NULL);

    status = libspdm_set_certificate_ex(spdm_context, NULL, 3, data, data_size,
                                        (SPDM_CERTIFICATE_INFO_CERT_MODEL_DEVICE_CERT <<
                                         SPDM_SET_CERTIFICATE_REQUEST_ATTRIBUTES_CERT_MODEL_OFFSET),
                                        1);

    assert_int_equal(status, LIBSPDM_STATUS_SUCCESS);
}

int libspdm_requester_set_certificate_test_main(void)
{
    const struct CMUnitTest spdm_requester_set_certificate_tests[] = {
        /* SendRequest failed*/
        cmocka_unit_test(libspdm_test_requester_set_certificate_case1),
        /* Successful response to set certificate*/
        cmocka_unit_test(libspdm_test_requester_set_certificate_case2),
        /* Set null cert_chain for slot 0*/
        cmocka_unit_test(libspdm_test_requester_set_certificate_case3),
        /* Successful response to set certificate for slot 1 in secure session*/
        cmocka_unit_test(libspdm_test_requester_set_certificate_case5),
        /* Successful response to set certificate with a reset required */
        cmocka_unit_test(libspdm_test_requester_set_certificate_case6),
        /* Successful response to erase certificate*/
        cmocka_unit_test(libspdm_test_requester_set_certificate_case7),
        cmocka_unit_test(libspdm_test_requester_set_certificate_case8),
        cmocka_unit_test(libspdm_test_requester_set_certificate_case9),
    };

    libspdm_test_context_t test_context = {
        LIBSPDM_TEST_CONTEXT_VERSION,
        true,
        libspdm_requester_set_certificate_test_send_message,
        libspdm_requester_set_certificate_test_receive_message,
    };

    libspdm_setup_test_context(&test_context);

    return cmocka_run_group_tests(spdm_requester_set_certificate_tests,
                                  libspdm_unit_test_group_setup,
                                  libspdm_unit_test_group_teardown);
}

#endif /* LIBSPDM_ENABLE_CAPABILITY_SET_CERT_CAP*/
