//////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
//  Copyright (C) by RivieraWaves.
//  This module is a confidential and proprietary property of RivieraWaves
//  and a possession or use of this module requires written permission
//  from RivieraWaves.
//----------------------------------------------------------------------------
// $RCSmodulefile   :
// $Author          :
// Company          : RivieraWaves
//----------------------------------------------------------------------------
// $Revision:
// $Date:
// $State:
// $Locker:
// ---------------------------------------------------------------------------
// Dependencies     :
// Description      : WAPI Class For TB
//
// Simulation Notes :
// Synthesis Notes  :
// Application Note :
// Simulator        :
// Parameters       :
// Terms & concepts :
// Bugs             :
// Open issues and future enhancements :
// References       :
// Revision History :
// ---------------------------------------------------------------------------
//
//
//
//////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
`ifndef WAPI_SV
`define WAPI_SV

`define SMS4MASK3 32'hFF000000
`define SMS4MASK2 32'h00FF0000
`define SMS4MASK1 32'h0000FF00
`define SMS4MASK0 32'h000000FF

class WPIHeader;
  bit  [15:0]   frameControl;
  bit  [47:0]   addr1;
  bit  [47:0]   addr2;
  bit  [47:0]   addr3;
  bit  [47:0]   addr4;
  bit           address4Pres;
  bit           qosFrame;
  bit  [15:0]   qosCF;
  bit  [15:0]   seqControl;
  bit  [7:0]    keyIdx;
  bit  [127:0]  pn;

  bit  [127:0]  wpiEncryptionKey;
  bit  [127:0]  wpiIntegrityKey;

  bit  [15:0]   pduLength;

endclass

class WPIPDUVector;
  rand reg vector[];
  function new (int size);
    vector = new[size];
  endfunction
endclass

class WPI;
  WPIHeader            wpiHeader;
  WPIPDUVector         pduPData;
  bit                  pduEData[];
  bit          [127:0] pMIC;
  bit          [127:0] eMIC;
endclass

// SMS4 Algorithm model
class sms4_beh;

 const int unsigned FK[4] = {32'ha3b1bac6,32'h56aa3350,32'h677d9197,32'hb27022dc};

 const int unsigned CK[32] =
 {
   32'h00070e15,32'h1c232a31,32'h383f464d,32'h545b6269,
   32'h70777e85,32'h8c939aa1,32'ha8afb6bd,32'hc4cbd2d9,
   32'he0e7eef5,32'hfc030a11,32'h181f262d,32'h343b4249,
   32'h50575e65,32'h6c737a81,32'h888f969d,32'ha4abb2b9,
   32'hc0c7ced5,32'hdce3eaf1,32'hf8ff060d,32'h141b2229,
   32'h30373e45,32'h4c535a61,32'h686f767d,32'h848b9299,
   32'ha0a7aeb5,32'hbcc3cad1,32'hd8dfe6ed,32'hf4fb0209,
   32'h10171e25,32'h2c333a41,32'h484f565d,32'h646b7279
 };

  // Round Encryption Key
  int unsigned ENRK[32];

  int unsigned k[36];

  // Plain Data 128 bit segment
  int unsigned pData[4] =
  {
    32'h01234567,
    32'h89abcdef,
    32'hfedcba98,
    32'h76543210
  };

  // Encrypted Data 128 Segment
  int unsigned cData[4] =
  {
    32'h00000000,
    32'h00000000,
    32'h00000000,
    32'h00000000
  };

  // Constructor
  function new();
  endfunction

  function byte unsigned SMS4Sbox(byte unsigned index);

    const byte unsigned sbox[0:255] = {
      8'hd6,8'h90,8'he9,8'hfe,8'hcc,8'he1,8'h3d,8'hb7,8'h16,8'hb6,8'h14,8'hc2,8'h28,8'hfb,8'h2c,8'h05,
      8'h2b,8'h67,8'h9a,8'h76,8'h2a,8'hbe,8'h04,8'hc3,8'haa,8'h44,8'h13,8'h26,8'h49,8'h86,8'h06,8'h99,
      8'h9c,8'h42,8'h50,8'hf4,8'h91,8'hef,8'h98,8'h7a,8'h33,8'h54,8'h0b,8'h43,8'hed,8'hcf,8'hac,8'h62,
      8'he4,8'hb3,8'h1c,8'ha9,8'hc9,8'h08,8'he8,8'h95,8'h80,8'hdf,8'h94,8'hfa,8'h75,8'h8f,8'h3f,8'ha6,
      8'h47,8'h07,8'ha7,8'hfc,8'hf3,8'h73,8'h17,8'hba,8'h83,8'h59,8'h3c,8'h19,8'he6,8'h85,8'h4f,8'ha8,
      8'h68,8'h6b,8'h81,8'hb2,8'h71,8'h64,8'hda,8'h8b,8'hf8,8'heb,8'h0f,8'h4b,8'h70,8'h56,8'h9d,8'h35,
      8'h1e,8'h24,8'h0e,8'h5e,8'h63,8'h58,8'hd1,8'ha2,8'h25,8'h22,8'h7c,8'h3b,8'h01,8'h21,8'h78,8'h87,
      8'hd4,8'h00,8'h46,8'h57,8'h9f,8'hd3,8'h27,8'h52,8'h4c,8'h36,8'h02,8'he7,8'ha0,8'hc4,8'hc8,8'h9e,
      8'hea,8'hbf,8'h8a,8'hd2,8'h40,8'hc7,8'h38,8'hb5,8'ha3,8'hf7,8'hf2,8'hce,8'hf9,8'h61,8'h15,8'ha1,
      8'he0,8'hae,8'h5d,8'ha4,8'h9b,8'h34,8'h1a,8'h55,8'had,8'h93,8'h32,8'h30,8'hf5,8'h8c,8'hb1,8'he3,
      8'h1d,8'hf6,8'he2,8'h2e,8'h82,8'h66,8'hca,8'h60,8'hc0,8'h29,8'h23,8'hab,8'h0d,8'h53,8'h4e,8'h6f,
      8'hd5,8'hdb,8'h37,8'h45,8'hde,8'hfd,8'h8e,8'h2f,8'h03,8'hff,8'h6a,8'h72,8'h6d,8'h6c,8'h5b,8'h51,
      8'h8d,8'h1b,8'haf,8'h92,8'hbb,8'hdd,8'hbc,8'h7f,8'h11,8'hd9,8'h5c,8'h41,8'h1f,8'h10,8'h5a,8'hd8,
      8'h0a,8'hc1,8'h31,8'h88,8'ha5,8'hcd,8'h7b,8'hbd,8'h2d,8'h74,8'hd0,8'h12,8'hb8,8'he5,8'hb4,8'hb0,
      8'h89,8'h69,8'h97,8'h4a,8'h0c,8'h96,8'h77,8'h7e,8'h65,8'hb9,8'hf1,8'h09,8'hc5,8'h6e,8'hc6,8'h84,
      8'h18,8'hf0,8'h7d,8'hec,8'h3a,8'hdc,8'h4d,8'h20,8'h79,8'hee,8'h5f,8'h3e,8'hd7,8'hcb,8'h39,8'h48
      };
    SMS4Sbox = sbox[index];
  endfunction


  function int unsigned SMS4CROL(int unsigned uval, int unsigned bits);
    SMS4CROL = ((uval << bits) | (uval >> (8'h20 - bits)));
  endfunction

  function int unsigned SMS4Lt(int unsigned a);
    int unsigned b = 0;
    int unsigned c = 0;
    byte unsigned a0 = (a & `SMS4MASK0);
    byte unsigned a1 = ((a & `SMS4MASK1) >> 8);
    byte unsigned a2 = ((a & `SMS4MASK2) >> 16);
    byte unsigned a3 = ((a & `SMS4MASK3) >> 24);
    byte unsigned b0 = SMS4Sbox(a0);
    byte unsigned b1 = SMS4Sbox(a1);
    byte unsigned b2 = SMS4Sbox(a2);
    byte unsigned b3 = SMS4Sbox(a3);

    b = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
    c = b^(SMS4CROL(b, 2))^(SMS4CROL(b, 10))^(SMS4CROL(b, 18))^(SMS4CROL(b, 24));
    SMS4Lt = c;
  endfunction

  function int unsigned SMS4CalciRK(int unsigned a);
    int unsigned b = 0;
    int unsigned rk = 0;
    byte unsigned a0 = (a & `SMS4MASK0);
    byte unsigned a1 = ((a & `SMS4MASK1) >> 8);
    byte unsigned a2 = ((a & `SMS4MASK2) >> 16);
    byte unsigned a3 = ((a & `SMS4MASK3) >> 24);
    byte unsigned b0 = SMS4Sbox(a0);
    byte unsigned b1 = SMS4Sbox(a1);
    byte unsigned b2 = SMS4Sbox(a2);
    byte unsigned b3 = SMS4Sbox(a3);
    b = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
    rk = b^(SMS4CROL(b, 13))^(SMS4CROL(b, 23));
    SMS4CalciRK = rk;
  endfunction

  function  int unsigned SMS4T(int unsigned a);
    SMS4T = SMS4Lt(a);
  endfunction

  function int unsigned SMS4F(int unsigned x0, int unsigned x1, int unsigned x2, int unsigned x3, int unsigned rk);
    SMS4F = (x0^SMS4Lt(x1^x2^x3^rk));
  endfunction

  function void SMS4Encrypt();
    int unsigned ulbuf[36];
    int unsigned i;

    for (i = 0 ; i < 36 ; i++)
      ulbuf[i] = 0;

    ulbuf[0] = pData[0];
    ulbuf[1] = pData[1];
    ulbuf[2] = pData[2];
    ulbuf[3] = pData[3];

    for (i = 0 ; i < 32 ; i++)
      ulbuf[i+4] = SMS4F(ulbuf[i], ulbuf[i+1], ulbuf[i+2], ulbuf[i+3], ENRK[i]);

    cData[0] = ulbuf[35];
    cData[1] = ulbuf[34];
    cData[2] = ulbuf[33];
    cData[3] = ulbuf[32];

  endfunction

  function  void KeyExpand(bit [127:0] key);
    int unsigned  i;
    k[0] = key[127:096]^FK[0];
    k[1] = key[095:064]^FK[1];
    k[2] = key[063:032]^FK[2];
    k[3] = key[031:000]^FK[3];

    for(i = 0; i<32; i++)
    begin
        k[i+4] = k[i] ^ (SMS4CalciRK(k[i+1]^k[i+2]^k[i+3]^CK[i]));
        ENRK[i] = k[i+4];
    end
  endfunction

  function void LoadPt(bit [127:0] pt);
    pData[0] = pt[127:096];
    pData[1] = pt[095:064];
    pData[2] = pt[063:032];
    pData[3] = pt[031:000];
  endfunction

  function void GetCt(ref bit [127:0] ct);
    ct[127:096] = cData[0];
    ct[095:064] = cData[1];
    ct[063:032] = cData[2];
    ct[031:000] = cData[3];
  endfunction

  function void SMS4();
    SMS4Encrypt();
  endfunction

endclass


// SMS4-CBC mode model
class sms4_cbc_beh;

  sms4_beh      sms4_cbc_c;
  bit[127:0]    ctIn;
  bit[127:0]    ctOut;

  // Constructor
  function new();
    sms4_cbc_c = new;
  endfunction

  function void CBCInit(ref WPI wpi_c);

    sms4_cbc_c.KeyExpand(wpi_c.wpiHeader.wpiIntegrityKey);
    sms4_cbc_c.LoadPt(wpi_c.wpiHeader.pn);
    sms4_cbc_c.SMS4();
    sms4_cbc_c.GetCt(ctOut);

    ctIn = ctOut ^ {wpi_c.wpiHeader.frameControl[7], 3'h0, wpi_c.wpiHeader.frameControl[3:0],
                    wpi_c.wpiHeader. frameControl[15], 4'b1000, wpi_c.wpiHeader.frameControl[10:8],
                    wpi_c.wpiHeader.addr1[7:0],
                    wpi_c.wpiHeader.addr1[15:8],
                    wpi_c.wpiHeader.addr1[23:16],
                    wpi_c.wpiHeader.addr1[31:24],
                    wpi_c.wpiHeader.addr1[39:32],
                    wpi_c.wpiHeader.addr1[47:40],
                    wpi_c.wpiHeader.addr2[7:0],
                    wpi_c.wpiHeader.addr2[15:8],
                    wpi_c.wpiHeader.addr2[23:16],
                    wpi_c.wpiHeader.addr2[31:24],
                    wpi_c.wpiHeader.addr2[39:32],
                    wpi_c.wpiHeader.addr2[47:40],
                    4'b0000, wpi_c.wpiHeader.seqControl[3:0],
                    8'h0};
    sms4_cbc_c.LoadPt(ctIn);
    sms4_cbc_c.SMS4();
    sms4_cbc_c.GetCt(ctOut);

    if (wpi_c.wpiHeader.qosFrame == 0)
    begin
      if(wpi_c.wpiHeader.address4Pres == 0)
      begin
        ctIn= ctOut ^ {wpi_c.wpiHeader.addr3[7:0],
                       wpi_c.wpiHeader.addr3[15:8],
                       wpi_c.wpiHeader.addr3[23:16],
                       wpi_c.wpiHeader.addr3[31:24],
                       wpi_c.wpiHeader.addr3[39:32],
                       wpi_c.wpiHeader.addr3[47:40],
                       48'h0,
                       wpi_c.wpiHeader.keyIdx,
                       8'h0,
                      // wpi_c.wpiHeader.pduLength[7:0],
                       wpi_c.wpiHeader.pduLength[15:0]};
        sms4_cbc_c.LoadPt(ctIn);
        sms4_cbc_c.SMS4();
        sms4_cbc_c.GetCt(ctOut);
      end
      else
      begin
        ctIn= ctOut ^ {wpi_c.wpiHeader.addr3[7:0],
                       wpi_c.wpiHeader.addr3[15:8],
                       wpi_c.wpiHeader.addr3[23:16],
                       wpi_c.wpiHeader.addr3[31:24],
                       wpi_c.wpiHeader.addr3[39:32],
                       wpi_c.wpiHeader.addr3[47:40],
                       wpi_c.wpiHeader.addr4[7:0],
                       wpi_c.wpiHeader.addr4[15:8],
                       wpi_c.wpiHeader.addr4[23:16],
                       wpi_c.wpiHeader.addr4[31:24],
                       wpi_c.wpiHeader.addr4[39:32],
                       wpi_c.wpiHeader.addr4[47:40],
                       wpi_c.wpiHeader.keyIdx,
                       8'h0,
                       wpi_c.wpiHeader.pduLength[15:0]};
        sms4_cbc_c.LoadPt(ctIn);
        sms4_cbc_c.SMS4();
        sms4_cbc_c.GetCt(ctOut);
      end
    end
    else
    begin
      if(wpi_c.wpiHeader.address4Pres == 0)
      begin
        ctIn= ctOut ^ {wpi_c.wpiHeader.addr3[7:0],
                       wpi_c.wpiHeader.addr3[15:8],
                       wpi_c.wpiHeader.addr3[23:16],
                       wpi_c.wpiHeader.addr3[31:24],
                       wpi_c.wpiHeader.addr3[39:32],
                       wpi_c.wpiHeader.addr3[47:40],
                       48'h0,
                       wpi_c.wpiHeader.qosCF[7:0],
                       wpi_c.wpiHeader.qosCF[15:8],
                       wpi_c.wpiHeader.keyIdx,
                       8'h0};
        sms4_cbc_c.LoadPt(ctIn);
        sms4_cbc_c.SMS4();
        sms4_cbc_c.GetCt(ctOut);
      end
      else
      begin
        ctIn= ctOut ^ {wpi_c.wpiHeader.addr3[7:0],
                       wpi_c.wpiHeader.addr3[15:8],
                       wpi_c.wpiHeader.addr3[23:16],
                       wpi_c.wpiHeader.addr3[31:24],
                       wpi_c.wpiHeader.addr3[39:32],
                       wpi_c.wpiHeader.addr3[47:40],
                       wpi_c.wpiHeader.addr4[7:0],
                       wpi_c.wpiHeader.addr4[15:8],
                       wpi_c.wpiHeader.addr4[23:16],
                       wpi_c.wpiHeader.addr4[31:24],
                       wpi_c.wpiHeader.addr4[39:32],
                       wpi_c.wpiHeader.addr4[47:40],
                       wpi_c.wpiHeader.qosCF[7:0],
                       wpi_c.wpiHeader.qosCF[15:8],
                       wpi_c.wpiHeader.keyIdx,
                       8'h0};
        sms4_cbc_c.LoadPt(ctIn);
        sms4_cbc_c.SMS4();
        sms4_cbc_c.GetCt(ctOut);
      end
      ctIn= ctOut ^ {wpi_c.wpiHeader.pduLength[15:0],
                     112'h0};
      sms4_cbc_c.LoadPt(ctIn);
      sms4_cbc_c.SMS4();
      sms4_cbc_c.GetCt(ctOut);

    end
  endfunction

  function void CBCPDU(ref WPI wpi_c);

    int unsigned pduSize = wpi_c.pduPData.vector.size;
    int unsigned pduAdjustSize;
    bit[127:0] pduInSegment;
    bit pudInAdjust[];

    if((pduSize%128) == 0)
      pduAdjustSize = pduSize / 128;
    else
      pduAdjustSize = pduSize / 128 + 1;

    pudInAdjust = new [pduAdjustSize*128];

    for (int unsigned i = 0 ; i < pduSize ; i++)
      pudInAdjust[((pduAdjustSize*128)-1)-i] = wpi_c.pduPData.vector[(pduSize-1)-i];

    for (int unsigned i = pduAdjustSize ; i > 0 ; i--)
    begin
      for (int unsigned j = 0 ; j < 128 ; j++)
        pduInSegment[j] = pudInAdjust[((i - 1) * 128)+j];
      ctIn = ctOut ^ pduInSegment;

      sms4_cbc_c.LoadPt(ctIn);
      sms4_cbc_c.SMS4();
      sms4_cbc_c.GetCt(ctOut);
    end

    for (int unsigned i = 0 ; i < 128 ; i++)
      wpi_c.pMIC[i] = ctOut[i];


    pudInAdjust.delete();
  endfunction

endclass


// SMS4-OFB mode model
class sms4_ofb_beh;

  sms4_beh      sms4_ofb_c;
  bit[127:0]    ctIn;
  bit[127:0]    ctOut;

  function new();    // Constructor
    sms4_ofb_c = new;
  endfunction



  function void OFBInit(ref WPI wpi_c);

    sms4_ofb_c.KeyExpand(wpi_c.wpiHeader.wpiEncryptionKey);
    sms4_ofb_c.LoadPt(wpi_c.wpiHeader.pn);
    sms4_ofb_c.SMS4();
    sms4_ofb_c.GetCt(ctOut);

  endfunction


  function void OFBPDU(ref WPI wpi_c);

    int unsigned pduSize = wpi_c.pduPData.vector.size;
    int unsigned pduAdjustSize;
    bit[127:0] pduInSegment;
    bit[127:0] pduOutSegment;
    bit pudInAdjust[];
    bit pudOutAdjust[];
    int unsigned pduBitLoop;

    if((pduSize%128) == 0)
      pduAdjustSize = pduSize / 128;
    else
      pduAdjustSize = pduSize / 128 + 1;

    pudInAdjust = new [(pduAdjustSize + 1) * 128];
    pudOutAdjust = new [(pduAdjustSize + 1) * 128] ;

    for (pduBitLoop = 0 ; pduBitLoop < pduSize ; pduBitLoop++)
      pudInAdjust[(((pduAdjustSize + 1) * 128)-1)-pduBitLoop] = wpi_c.pduPData.vector[(pduSize-1)-pduBitLoop];

    for (int unsigned pduBitLoopMIC = pduBitLoop ; pduBitLoopMIC < ( pduBitLoop + 128) ; pduBitLoopMIC++)
      pudInAdjust[(((pduAdjustSize + 1) * 128)-1)-pduBitLoopMIC] = wpi_c.pMIC[(128-1)-(pduBitLoopMIC- pduBitLoop)];


    for (pduBitLoop = 0 ; pduBitLoop < 128 ; pduBitLoop++)
    begin
      pudOutAdjust[(((pduAdjustSize + 1) * 128)-1)-pduBitLoop] = pudInAdjust[(((pduAdjustSize + 1) * 128)-1)-pduBitLoop] ^ ctOut[128-1-pduBitLoop];
      pduInSegment[128-1-pduBitLoop]  = pudInAdjust[(((pduAdjustSize + 1)*128)-1)-pduBitLoop];
      pduOutSegment[128-1-pduBitLoop] = pudInAdjust[(((pduAdjustSize + 1)*128)-1)-pduBitLoop] ^ ctOut[128-1-pduBitLoop];
    end

    for (int unsigned pduLoop = 1 ; pduLoop < (pduAdjustSize+1); pduLoop++)
      begin
      ctIn = ctOut;
      sms4_ofb_c.LoadPt(ctIn);
      sms4_ofb_c.SMS4();
      sms4_ofb_c.GetCt(ctOut);

      for (pduBitLoop = 0 ; pduBitLoop < 128 ; pduBitLoop++)
      begin
        pudOutAdjust[((((pduAdjustSize + 1)-pduLoop)*128)-1)-pduBitLoop] = pudInAdjust[((((pduAdjustSize + 1)-pduLoop)*128)-1)-pduBitLoop] ^ ctOut[128-1-pduBitLoop];
        pduInSegment[128-1-pduBitLoop] = pudInAdjust[((((pduAdjustSize + 1)-pduLoop)*128)-1)-pduBitLoop];
        pduOutSegment[128-1-pduBitLoop] = pudInAdjust[((((pduAdjustSize + 1)-pduLoop)*128)-1)-pduBitLoop] ^ ctOut[128-1-pduBitLoop];
      end
      end

    for (pduBitLoop = 0 ; pduBitLoop < pduSize ; pduBitLoop++)
      wpi_c.pduEData[(pduSize-1)-pduBitLoop] = pudOutAdjust[(((pduAdjustSize + 1)*128)-1)-pduBitLoop];

    for (int unsigned pduBitLoopMIC = pduBitLoop ; pduBitLoopMIC < ( pduBitLoop + 128) ; pduBitLoopMIC++)
      wpi_c.eMIC[(128-1)-(pduBitLoopMIC- pduBitLoop)] = pudOutAdjust[(((pduAdjustSize + 1) * 128)-1)-pduBitLoopMIC];

    pudInAdjust.delete();
    pudOutAdjust.delete();

  endfunction

endclass

class wapi;

  sms4_ofb_beh  sms4_ofb_beh_c;
  sms4_cbc_beh  sms4_cbc_beh_c;
  WPI           wpi_c;

  function new();    // Constructor

    sms4_ofb_beh_c = new;
    sms4_cbc_beh_c = new;
    wpi_c = new;
    wpi_c.wpiHeader = new;
  endfunction


  function void InitWPIStruct();
    wpi_c.pduEData = new[wpi_c.wpiHeader.pduLength * 8];
    wpi_c.pduPData = new(wpi_c.wpiHeader.pduLength * 8);
  endfunction

  function void WAPIEncrypt();
    sms4_cbc_beh_c.CBCInit(wpi_c);
    sms4_cbc_beh_c.CBCPDU(wpi_c);

    sms4_ofb_beh_c.OFBInit(wpi_c);
    sms4_ofb_beh_c.OFBPDU(wpi_c);
  endfunction

endclass : wapi

`endif// WAPI_SV
