/**
 ****************************************************************************************
 *
 * @file me_mic.c
 *
 * @brief The MIC Calculator utility generic implementation.
 *
 * Copyright (C) RivieraWaves 2011-2019
 *
 ****************************************************************************************
 */

/** @addtogroup MIC_CALC
 * @{
 */

/*
 * INCLUDE FILES
 ****************************************************************************************
 */

#include "me_mic.h"
#include "co_utils.h"
#include "dbg_assert.h"
#if NX_HSU
#include "hsu.h"
#endif

/*
 * DEFINES
 ****************************************************************************************
 */

/**
 * At the end of the MIC Calculation, a padding is added that consists of the signature
 * byte 0x5A and number of padding bytes that are set to zeroes. to make sure the over all
 * length of the message can be divided by 4 without remainder.
 */
#define MIC_END_SIGNATURE   (0x0000005A)

/**
 *  Mask to be applied to the received TID value (three lower bits)
 */
#define MIC_TID_MASK        (0x00000007)

// If NX_HSU == 2, don't include software implementation
#if NX_HSU < 2

/*
 * MACROS
 ****************************************************************************************
 */

/**
 ****************************************************************************************
 * @brief A macro for the right rotation instruction.
 *        Rotates L by Q bits to the right
 *
 * @param[in]  L   The value to be rotated.
 * @param[in]  Q   The number of bits to rotate L with.
 *
 * @return          The value of L after being rotated Q bits to the right.
 ****************************************************************************************
 */
#define ROR32(L,Q)  ( ( (uint32_t)(L) >> (Q) ) | ( (uint32_t)(L) << (32-(Q)) ) )

/**
 ****************************************************************************************
 * @brief A macro for the right rotation instruction.
 *        Rotates L by Q bits to the right
 *
 * @param[in]  L   The value to be rotated.
 * @param[in]  Q   The number of bits to rotate L with.
 *
 * @return          The value of L after being rotated Q bits to the right.
 ****************************************************************************************
 */
#define ROL32(L,Q)  ( ( (uint32_t)(L) << (Q) ) | ( (uint32_t)(L) >> (32-(Q)) ) )

/**
 ****************************************************************************************
 * @brief A macro for swap operation.
 *        Swaps the 1st and 2nd bytes, and swaps the 3rd and 4th bytes
 *
 * @param[in]  L   The value to apply swap on
 *
 * @return          The value of L after swapping as described above
 ****************************************************************************************
 */
#define XSWAP32(L)    ( ( ( L & 0xFF00FF00 ) >> 8 ) | \
                        ( ( L & 0x00FF00FF ) << 8 ) )

/**
 * @brief A macro for multiplication by 8
 *
 * This macro is used to convert the lengths
 * from the byte unit to the bit unit
 *
 * @param[in]  X   the length in bytes
 * @return          the length in bits (i.e. X << 3)
 */
/**
 ****************************************************************************************
 * @brief This macro is used to convert the lengths from the byte unit to the bit unit
 *
 * @param[in]  X   The number of bytes to be converted in bits
 *
 * @return          The provided value converted in bits
 ****************************************************************************************
 */
#define LEN_IN_BITS(X) ( (X) << 3 )

/**
 ****************************************************************************************
 * @brief This macro is used to right shift a value by another value
 *
 * @param[in]  X   The value to be right shifted
 * @param[in]  S   The number by which X has to be right shifted
 *
 * @return         The right shifted value
 ****************************************************************************************
 */
#define SHIFTR(X, S) (((S) == 32)? 0 : ((X) >> (S)) )

/*
 * PRIVATE FUNCTION IMPLEMENTATIONS
 ****************************************************************************************
 */

/**
 ****************************************************************************************
 * @brief Michael block function implementation
 *        Implement the Feistel-type Michael block function b as defined in the IEEE
 *        standard 802.11-2012 (section 11.4.2.3 - Figure 11-10 and 11-11).
 *
 *        Input (l, r, M(i))
 *        Output (l, r)
 *        b(l, r, M(i))
 *           l <- l XOR M(i)
 *           r <- r XOR (l <<< 17)
 *           l <- (l + r) mod 2^32
 *           r <- r XOR XSWAP(l)
 *           l <- (l + r) mod 2^32
 *           r <- r XOR (l <<< 3)
 *           l <- (l + r) mod 2^32
 *           r <- r XOR (l >>> 2)
 *           l <- (l + r) mod 2^32
 *           return (l, r)
 *
 *        where <<< denotes the rotate-left operator on 32-bit
 *              >>> the rotate-right operator
 *              XSWAP a function that swaps the position of the 2 least significant octets.
 *
 * @param[in]  mic_calc_ptr   A pointer to an array of two 32-bit words that holds the MIC key.
 * @param[in]  block          32-bit word M(i) on which function has to be applied
 *
 ****************************************************************************************
 */
static void michael_block(struct mic_calc *mic_calc_ptr, uint32_t block)
{
    uint32_t l = mic_calc_ptr->mic_key_least;
    uint32_t r = mic_calc_ptr->mic_key_most;

    l   ^= block;
    r   ^=  ROL32(l, 17);
    l   +=  r;
    r   ^=  XSWAP32(l);
    l   +=  r;
    r   ^=  ROL32(l, 3);
    l   +=  r;
    r   ^=  ROR32(l, 2);
    l   +=  r;

    mic_calc_ptr->mic_key_most  = r;
    mic_calc_ptr->mic_key_least = l;
}

/**
 ******************************************************************************
 * @brief Initializes TKIP MIC computation (Software implementation)
 *
 * @param[out] mic_calc_ptr Pointer to Mic structure that will be initialize
 * @param[in]  mic_key_ptr  Key to use for MIC computation
 * @param[in]  aad          Additional Authentication Data vector
 *                          (CPU address, 16 bytes long)
 ******************************************************************************
 */
static void michael_init(struct mic_calc *mic_calc_ptr, uint32_t *mic_key_ptr,
                         uint32_t *aad)
{
    // Initialize MIC value
    mic_calc_ptr->mic_key_least = mic_key_ptr[0];
    mic_calc_ptr->mic_key_most  = mic_key_ptr[1];
    mic_calc_ptr->last_m_i      = 0;
    mic_calc_ptr->last_m_i_len  = 0;

    // Apply Michael on AAD
    michael_block(mic_calc_ptr, aad[0]);
    michael_block(mic_calc_ptr, aad[1]);
    michael_block(mic_calc_ptr, aad[2]);
    michael_block(mic_calc_ptr, aad[3]);
}

/**
 ******************************************************************************
 * @brief Continues TKIP MIC computation (Software implementation)
 *
 * Continue MIC computation with the provided data.
 *
 * @param[in,out] mic_calc_ptr Pointer to Mic structure
 * @param[in]     start_ptr    Address of the data buffer (HW address)
 * @param[in]     data_len     Length, in bytes, of the data buffer
 ******************************************************************************
 */
static void michael_calc(struct mic_calc *mic_calc_ptr, uint32_t start_ptr,
                         uint32_t data_len)
{
    // M(i) block
    uint32_t m_i;
    // Remaining length to proceed
    uint32_t rem_len = data_len;
    // Number of 32-bit blocks
    uint32_t nb_blocks;
    // Pointer to the payload
    uint32_t *u32_ptr = HW2CPU(start_ptr & ~0x03);
    uint32_t val = *u32_ptr++;
    uint8_t cut = start_ptr & 0x03;
    uint8_t valid = 4 - cut;
    uint32_t last_m_i_len = mic_calc_ptr->last_m_i_len;
    uint32_t last_m_i = mic_calc_ptr->last_m_i;

    val >>= LEN_IN_BITS(cut);
    if (data_len < valid)
    {
        val &= 0xFFFFFFFF >> LEN_IN_BITS(4 - data_len);
        valid = data_len;
        rem_len = 0;
    }
    else
    {
        rem_len -= valid;
    }

    if ((last_m_i_len + valid) < 4)
    {
        last_m_i |= val << LEN_IN_BITS(last_m_i_len);
        last_m_i_len += valid;
    }
    else
    {
        m_i = last_m_i | (val << LEN_IN_BITS(last_m_i_len));

        last_m_i = SHIFTR(val, LEN_IN_BITS(4 - last_m_i_len));
        last_m_i_len += valid - 4;

        // Apply Michael block function
        michael_block(mic_calc_ptr, m_i);
    }

    // Compute the number of remaining blocks
    nb_blocks = (rem_len >> 2);

    for (uint32_t block_cnt = 0; block_cnt < nb_blocks; block_cnt++)
    {
        // Read memory
        val = *u32_ptr++;

        // Extract the block value
        m_i = last_m_i | (val << LEN_IN_BITS(last_m_i_len));

        // Save last MI
        last_m_i = SHIFTR(val, LEN_IN_BITS(4 - last_m_i_len));

        // Apply Michael block function
        michael_block(mic_calc_ptr, m_i);
    }

    // If pending bytes store them inside the mic_calc structure
    if (rem_len > (nb_blocks << 2))
    {
        uint32_t add_bytes = rem_len - (nb_blocks << 2);

        val = (*u32_ptr) & (SHIFTR(0xFFFFFFFF, LEN_IN_BITS(4 - add_bytes)));
        if ((last_m_i_len + add_bytes) > 3)
        {
            // Extract the block value
            m_i = last_m_i | (val << LEN_IN_BITS(last_m_i_len));

            // Save last MI
            last_m_i = SHIFTR(val, LEN_IN_BITS(4 - last_m_i_len));
            last_m_i_len += add_bytes - 4;

            // Apply Michael block function
            michael_block(mic_calc_ptr, m_i);
        }
        else
        {
            last_m_i |= val << LEN_IN_BITS(last_m_i_len);
            last_m_i_len += add_bytes;
        }
    }

    mic_calc_ptr->last_m_i     = last_m_i;
    mic_calc_ptr->last_m_i_len = last_m_i_len;
}

/**
 ******************************************************************************
 * @brief Ends TKIP MIC computation (Software implementation)
 *
 * Ends TKIP computation by adding padding data
 *
 * @param[in,out] mic_calc_ptr Pointer to Mic structure that will be initialize
 ******************************************************************************
 */
static void michael_end(struct mic_calc *mic_calc_ptr)
{
    // M(n-2) block
    uint32_t m_n_2 = mic_calc_ptr->last_m_i;

    ASSERT_ERR(mic_calc_ptr->last_m_i_len < 4);

    m_n_2 |= (MIC_END_SIGNATURE << LEN_IN_BITS(mic_calc_ptr->last_m_i_len));

    // Apply Michael block function to the last 2 blocks
    michael_block(mic_calc_ptr, m_n_2);
    // M(n-1) = 0 by construction
    michael_block(mic_calc_ptr, 0);
}
#endif /* NX_HSU < 2 */

/*
 * PUBLIC FUNCTION IMPLEMENTATIONS
 ******************************************************************************
 */

#if NX_HSU
/// Additional Authentication Data vector for MIC computation
/// Must be in SHARED RAM when HSU is used.
uint32_t mic_aad[4] __SHAREDRAM;
#endif

void me_mic_init(struct mic_calc *mic_calc_ptr, uint32_t *mic_key_ptr,
                 struct mac_addr *da, struct mac_addr *sa, uint8_t tid)
{
    #if !NX_HSU
    uint32_t mic_aad[4];
    #endif

    /* prepare AAD vector = DA SA Prio Padding */
    mic_aad[0] = (uint32_t)da->array[0] | (((uint32_t)da->array[1]) << 16);
    mic_aad[1] = (uint32_t)da->array[2] | (((uint32_t)sa->array[0]) << 16);
    mic_aad[2] = (uint32_t)sa->array[1] | (((uint32_t)sa->array[2]) << 16);
    if (tid == 0xFF)
        mic_aad[3] = 0;
    else
        mic_aad[3] = (uint32_t)tid & MIC_TID_MASK;

    #if NX_HSU
    if (hsu_michael_init(mic_calc_ptr, mic_key_ptr, mic_aad))
        return;
    #endif

    #if NX_HSU < 2
    michael_init(mic_calc_ptr, mic_key_ptr, mic_aad);
    #endif
}

void me_mic_calc(struct mic_calc *mic_calc_ptr, uint32_t start_ptr,
                 uint32_t data_len)
{
    #if NX_HSU
    if (hsu_michael_calc(mic_calc_ptr, start_ptr, data_len))
        return;
    #endif

    #if NX_HSU < 2
    michael_calc(mic_calc_ptr, start_ptr, data_len);
    #endif
}

void me_mic_end(struct mic_calc *mic_calc_ptr)
{
    #if NX_HSU
    if (hsu_michael_end(mic_calc_ptr))
        return;
    #endif

    #if NX_HSU < 2
    michael_end(mic_calc_ptr);
    #endif
}

/// @}
