include "bin32.s7i";
include "bin64.s7i";
include "bigint.s7i";
include "aes.s7i";
const type: halfHTableType is array [0 .. 15] bin64;
const type: factorHType is new struct
    var halfHTableType: lowHalfHTable  is halfHTableType.value;
    var halfHTableType: highHalfHTable is halfHTableType.value;
  end struct;
const func factorHType: computeFactorH (in string: h) is func
  result
    var factorHType: factorH is factorHType.value;
  local
    var bin64: vl is bin64(0);
    var bin64: vh is bin64(0);
    var integer: index is 0;
    var integer: subIndex is 0;
    var bin64: temp is bin64(0);
  begin
    vh := bin64(h[1 fixLen 8], BE);
    vl := bin64(h[9 fixLen 8], BE);
    factorH.lowHalfHTable[8]  := vl;        
    factorH.highHalfHTable[8] := vh;
    factorH.highHalfHTable[0] := bin64(0);  
    factorH.lowHalfHTable[0]  := bin64(0);
    for index range [](4, 2, 1) do
      temp := bin64((vl & bin64(1)) <> bin64(0) ? 16#e1000000 : 0);
      vl := (vh << 63) | (vl >> 1);
      vh := (vh >> 1) >< (temp << 32);
      factorH.lowHalfHTable[index]  := vl;
      factorH.highHalfHTable[index] := vh;
    end for;
    for index range [](2, 4, 8) do
      vh := factorH.highHalfHTable[index];
      vl := factorH.lowHalfHTable[index];
      for subIndex range 1 to pred(index) do
        factorH.highHalfHTable[index + subIndex] := vh >< factorH.highHalfHTable[subIndex];
        factorH.lowHalfHTable[index + subIndex]  := vl >< factorH.lowHalfHTable[subIndex];
      end for;
    end for;
  end func;
const array bin64: last4 is [0] (
    bin64(16#0000000000000000),  bin64(16#1c20000000000000),
    bin64(16#3840000000000000),  bin64(16#2460000000000000),
    bin64(16#7080000000000000),  bin64(16#6ca0000000000000),
    bin64(16#48c0000000000000),  bin64(16#54e0000000000000),
    bin64(16#e100000000000000_), bin64(16#fd20000000000000_),
    bin64(16#d940000000000000_), bin64(16#c560000000000000_),
    bin64(16#9180000000000000_), bin64(16#8da0000000000000_),
    bin64(16#a9c0000000000000_), bin64(16#b5e0000000000000_));
const func string: gcmMult (in string: factor1,               
                            in factorHType: factorH) is func  
  result
    var string: product is "";  
  local
    var integer: index is 0;
    var integer: lo is 0;
    var integer: hi is 0;
    var integer: rem is 0;
    var bin64: zh is bin64(0);
    var bin64: zl is bin64(0);
  begin
    lo := ord(factor1[16]) mod 16;
    hi := ord(factor1[16]) >> 4;
    zh := factorH.highHalfHTable[lo];
    zl := factorH.lowHalfHTable[lo];
    for index range 16 downto 1 do
      lo := ord(factor1[index]) mod 16;
      hi := ord(factor1[index]) >> 4;
      if index <> 16 then
        rem := ord(zl & bin64(16#f));
        zl := zh << 60 | zl >> 4;
        zh := zh >> 4;
        zh ><:= last4[rem];
        zh ><:= factorH.highHalfHTable[lo];
        zl ><:= factorH.lowHalfHTable[lo];
      end if;
      rem := ord(zl & bin64(16#f));
      zl := zh << 60 | zl >> 4;
      zh := zh >> 4;
      zh ><:= last4[rem];
      zh ><:= factorH.highHalfHTable[hi];
      zl ><:= factorH.lowHalfHTable[hi];
    end for;
    product := bytes(zh, BE, 8) & bytes(zl, BE, 8);
  end func;
const type: aesGcmState is new struct
    var string: salt is "";                         
    var string: nonce is " " mult 16;               
    var aesState: aesCipherState is aesState.value; 
    var factorHType: factorH is factorHType.value;  
    var integer: sequenceNumber is 0;               
    var string: recordTypeAndVersion is "";         
    var string: encodedNonce1 is "";                
    var string: computedMac is "\0;" mult 16;       
    var string: mac is "";                          
  end struct;
type_implements_interface(aesGcmState, cipherState);
const integer: blockSize (AES_GCM) is 0;
const func aesGcmState: setAesGcmKey (in string: aesKey,
    in string: initializationVector) is func
  result
    var aesGcmState: state is aesGcmState.value;
  local
    var string: h is "\0;" mult 16;
  begin
    state.aesCipherState := setAesKey(aesKey, "");
    state.salt := initializationVector;
    h := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                        state.aesCipherState.rounds, h);
    state.factorH := computeFactorH(h);
  end func;
const func cipherState: setCipherKey (AES_GCM, in string: cipherKey,
    in string: initializationVector) is
  return toInterface(setAesGcmKey(cipherKey, initializationVector));
const proc: initAead (inout aesGcmState: state, in string: recordTypeAndVersion,
    in integer: sequenceNumber) is func
  begin
    state.recordTypeAndVersion := recordTypeAndVersion;
    state.sequenceNumber := sequenceNumber;
  end func;
const func string: getComputedMac (in aesGcmState: state) is
  return state.computedMac;
const func string: getMac (in aesGcmState: state) is
  return state.mac;
const proc: incrementGcmNonceCounter (inout string: nonce) is func
  local
    var integer: index is 16;
  begin
    while index >= 13 and nonce[index] = '\255;' do
      nonce @:= [index] '\0;';
      decr(index);
    end while;
    if index >= 13 then
      nonce @:= [index] succ(nonce[index]);
    end if;
  end func;
const func string: aesGcmEncode (inout aesGcmState: state,
    in string: plainText) is func
  result
    var string: encoded is "";
  local
    var integer: length is 0;        
    var string: encodedNonce is "";  
    var integer: startIndex is 1;
    var integer: blockLength is 0;   
    var integer: index is 0;
    var bin32: encodedChar is bin32(0);
  begin
    
    length := length(plainText);
    while length > 0 do
      
      blockLength := length < 16 ? length : 16;
      incrementGcmNonceCounter(state.nonce);
      encodedNonce := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                                     state.aesCipherState.rounds, state.nonce);
      
      for index range 1 to blockLength do
        encodedChar := bin32(encodedNonce[index]) >< bin32(plainText[startIndex + index - 1]);
        encoded &:= char(ord(encodedChar));
        
        state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) >< encodedChar));
      end for;
      state.computedMac := gcmMult(state.computedMac, state.factorH);
      length     -:= blockLength;  
      startIndex +:= blockLength;  
    end while;
    
  end func;
const func string: aesGcmDecode (inout aesGcmState: state,
    in string: encoded) is func
  result
    var string: plainText is "";
  local
    var integer: length is 0;        
    var string: encodedNonce is "";  
    var integer: startIndex is 1;
    var integer: blockLength is 0;   
    var integer: index is 0;
    var bin32: encodedChar is bin32(0);
  begin
    
    length := length(encoded);
    while length > 0 do
      
      blockLength := length < 16 ? length : 16;
      incrementGcmNonceCounter(state.nonce);
      encodedNonce := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                                     state.aesCipherState.rounds, state.nonce);
      
      for index range 1 to blockLength do
        encodedChar := bin32(encoded[startIndex + index - 1]);
        
        state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) >< encodedChar));
        plainText &:= char(ord(bin32(encodedNonce[index]) >< encodedChar));
      end for;
      state.computedMac := gcmMult(state.computedMac, state.factorH);
      length     -:= blockLength;  
      startIndex +:= blockLength;  
    end while;
    
  end func;
const proc: initializeComputedMac (inout aesGcmState: state, in integer: plainTextLength) is func
  begin
    state.encodedNonce1 := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                                          state.aesCipherState.rounds, state.nonce);
    state.computedMac := bytes(state.sequenceNumber, UNSIGNED, BE, 8) &
                               state.recordTypeAndVersion &  
                               bytes(plainTextLength, UNSIGNED, BE, 2) &
                               "\0;" mult 3;
    state.computedMac := gcmMult(state.computedMac, state.factorH);
  end func;
const proc: finalizeComputedMac (inout aesGcmState: state, in integer: plainTextLength) is func
  local
    const integer: lengthOfMacInitializationValue is 13; 
    var string: lengthBuffer is "";
    var integer: index is 0;
  begin
    lengthBuffer := bytes(lengthOfMacInitializationValue * 8, UNSIGNED, BE, 8) &
                    bytes(plainTextLength * 8, UNSIGNED, BE, 8);
    for index range 1 to 16 do
      state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) ><
                                             bin32(lengthBuffer[index])));
    end for;
    state.computedMac := gcmMult(state.computedMac, state.factorH);
    for index range 1 to 16 do
      state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) ><
                                             bin32(state.encodedNonce1[index])));
    end for;
    
  end func;
const func string: encode (inout aesGcmState: state, in string: plainText) is func
  result
    var string: encoded is "";
  begin
    encoded := bytes(rand(bin64), BE, 8);
    state.nonce := state.salt &
                   encoded &        
                   "\0;\0;\0;\1;";  
    
    initializeComputedMac(state, length(plainText));
    encoded &:= aesGcmEncode(state, plainText);
    finalizeComputedMac(state, length(plainText));
    encoded &:= state.computedMac;
  end func;
const func string: decode (inout aesGcmState: state, in string: encoded) is func
  result
    var string: plainText is "";
  begin
    if length(encoded) > 24 then
      state.nonce := state.salt &
                     encoded[.. 8] &  
                     "\0;\0;\0;\1;";  
      
      initializeComputedMac(state, length(encoded) - 24);
      plainText := aesGcmDecode(state, encoded[9 .. length(encoded) - 16]);
      finalizeComputedMac(state, length(plainText));
      state.mac := encoded[length(encoded) - 15 ..];
      
    end if;
  end func;