Add aes cbc module

This commit is contained in:
eric 2020-07-07 03:28:29 +08:00
parent 7b01638c9f
commit 48b9f1b65e
12 changed files with 201 additions and 15 deletions

View File

@ -15,7 +15,7 @@ using std::pair;
//提示信息打印类函数 //提示信息打印类函数
namespace Net { namespace Net {
namespace printTools { namespace PrintTools {
using FormalItem = pair<string, string>; using FormalItem = pair<string, string>;

View File

@ -0,0 +1,50 @@
//
// Created by 胡宇 on 2020/7/7.
//
#ifndef NET_AES_CBC_ENCRYPTOR_H
#define NET_AES_CBC_ENCRYPTOR_H
#include <cassert>
#include <cstdint>
#include <string>
#include <openssl/aes.h>
#include <openssl/evp.h>
#include "debug_tools/print_tools.h"
#include "random_generator.h"
namespace Net {
class AESCBCEncryptor {
public:
AESCBCEncryptor() {
generate_random_key_data();
aes_init(key_data);
}
string getKeyData() const{
return key_data;
}
void encrypt(const std::string &data, std::string &encrypted_data);
void decrypt(std::string &data, const std::string &encrypt_data);
private:
const int nrounds = 8;
uint8_t key[32], iv[32];
EVP_CIPHER_CTX *e_ctx = EVP_CIPHER_CTX_new();
std::string key_data;
void generate_random_key_data();
void aes_init(std::string &key_data);
};
}
#endif //NET_AES_CBC_ENCRYPTOR_H

View File

@ -0,0 +1,30 @@
//
// Created by 胡宇 on 2020/7/7.
//
#ifndef NET_RANDOM_GENERATOR_H
#define NET_RANDOM_GENERATOR_H
#include <boost/random.hpp>
#include <boost/random/random_device.hpp>
namespace Net{
namespace Rand{
// 范围均匀分布无符号32位整数
class UniformUInt {
public:
UniformUInt(uint32_t min, uint32_t max) : uniformInt(min, max){
}
int generate() const;
private:
boost::uniform_int<uint32_t> uniformInt;
};
}
}
#endif //NET_RANDOM_GENERATOR_H

View File

@ -33,7 +33,7 @@ namespace Net {
} }
void printInfo(){ void printInfo(){
printTools::printInfoFormal("RSAPubKey Info", { PrintTools::printInfoFormal("RSAPubKey Info", {
{"n", this->n.getDataHex()}, {"n", this->n.getDataHex()},
{"e", this->e.getDataHex()} {"e", this->e.getDataHex()}
}); });
@ -62,7 +62,7 @@ namespace Net {
} }
void printInfo() const{ void printInfo() const{
printTools::printInfoFormal("RSAPrvKey Info", { PrintTools::printInfoFormal("RSAPrvKey Info", {
{"n", this->n.getDataHex()}, {"n", this->n.getDataHex()},
{"e", this->e.getDataHex()}, {"e", this->e.getDataHex()},
{"d", this->d.getDataHex()}, {"d", this->d.getDataHex()},
@ -102,13 +102,16 @@ namespace Net {
this->if_prv_key = true; this->if_prv_key = true;
} }
// 生成一对公私钥
void generateKeyPair(); void generateKeyPair();
// 检查私钥是否合法 // 检查私钥是否合法
bool checkKey(); bool checkKey();
// 公钥加密
void publicKeyEncrypt(const string &data, string &encrypted_data); void publicKeyEncrypt(const string &data, string &encrypted_data);
// 私钥解密
void privateKeyDecrypt(string &data, const string& encrypted_data); void privateKeyDecrypt(string &data, const string& encrypted_data);
uint32_t getBufferSize() const { uint32_t getBufferSize() const {

View File

@ -32,6 +32,7 @@ public:
if(!if_generate) generate(); if(!if_generate) generate();
return this->sha256_data; return this->sha256_data;
} }
private: private:
bool if_generate = false; bool if_generate = false;
string raw_data; string raw_data;

View File

@ -7,7 +7,7 @@
using std::string; using std::string;
namespace Net { namespace Net {
namespace printTools { namespace PrintTools {
void printError(const string &error_info) { void printError(const string &error_info) {
printf("\033[31mError: %s\033[0m\n", error_info.data()); printf("\033[31mError: %s\033[0m\n", error_info.data());
} }

View File

@ -1,3 +1,3 @@
add_library(utils STATIC rsa_key_chain.cpp) add_library(utils STATIC rsa_key_chain.cpp aes_cbc_encryptor.cpp random_generator.cpp)
target_link_libraries(utils debugTools ssl crypto) target_link_libraries(utils debugTools ssl crypto)

View File

@ -0,0 +1,68 @@
//
// Created by 胡宇 on 2020/7/7.
//
#include "utils/aes_cbc_encryptor.h"
void Net::AESCBCEncryptor::encrypt(const string &data, string &encrypted_data) {
int c_len = data.length() + AES_BLOCK_SIZE, f_len = 0;
auto *encrypt_buffer = reinterpret_cast<uint8_t *>(malloc(c_len));
EVP_EncryptInit_ex(e_ctx, nullptr, nullptr, nullptr, nullptr);
EVP_EncryptUpdate(e_ctx, encrypt_buffer, &c_len,
reinterpret_cast<const unsigned char *>(data.data()), data.length());
EVP_EncryptFinal_ex(e_ctx, encrypt_buffer + c_len, &f_len);
int len = c_len + f_len;
if(!encrypted_data.empty()) encrypted_data.clear();
encrypted_data.append(reinterpret_cast<const char *>(encrypt_buffer), len);
free(encrypt_buffer);
}
void Net::AESCBCEncryptor::decrypt(string &data, const string &encrypt_data) {
int p_len = encrypt_data.length(), f_len = 0;
auto *plain_buffer = static_cast<uint8_t *>(malloc(p_len));
EVP_DecryptInit_ex(e_ctx, nullptr, nullptr, nullptr, nullptr);
EVP_DecryptUpdate(e_ctx, plain_buffer, &p_len,
reinterpret_cast<const unsigned char *>(encrypt_data.data()), encrypt_data.length());
EVP_DecryptFinal_ex(e_ctx, plain_buffer + p_len, &f_len);
int len = p_len + f_len;
if(!data.empty()) data.clear();
data.append(reinterpret_cast<const char *>(plain_buffer), len);
free(plain_buffer);
}
void Net::AESCBCEncryptor::generate_random_key_data() {
Rand::UniformUInt rand(0, UINT32_MAX);
uint32_t p_data[8];
for(unsigned int & i : p_data){
i = rand.generate();
}
key_data.append(reinterpret_cast<const char *>(p_data), 32);
}
void Net::AESCBCEncryptor::aes_init(string &key_data) {
int i = EVP_BytesToKey(EVP_aes_256_cbc(), EVP_sha256(), nullptr,
reinterpret_cast<const unsigned char *>(key_data.c_str()), key_data.length(),
nrounds, key, iv);
if (i != 32) {
throw std::runtime_error("key data must equal 256 bits.");
}
EVP_CIPHER_CTX_init(e_ctx);
EVP_EncryptInit_ex(e_ctx, EVP_aes_256_cbc(), nullptr, key, iv);
}

View File

@ -0,0 +1,11 @@
//
// Created by 胡宇 on 2020/7/7.
//
#include "utils/random_generator.h"
boost::random::mt19937 rand_seed;
int Net::Rand::UniformUInt::generate() const {
return uniformInt(rand_seed);
}

View File

@ -24,13 +24,12 @@ void Net::RSAKeyChain::privateKeyDecrypt(string &data, const string &encrypted_d
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal"); if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal");
// 使用私钥解密 // 使用私钥解密
int decrypted_size = -1;
unique_ptr<unsigned char[]>p_buffer (new unsigned char[buffer_size]); unique_ptr<unsigned char[]>p_buffer (new unsigned char[buffer_size]);
if((decrypted_size = RSA_private_decrypt(encrypted_data.size(), if(RSA_private_decrypt(encrypted_data.size(),
reinterpret_cast<const unsigned char *>(&encrypted_data[0]), reinterpret_cast<const unsigned char *>(&encrypted_data[0]),
p_buffer.get(), p_buffer.get(),
key_pair, key_pair,
RSA_PKCS1_OAEP_PADDING)) == -1) RSA_PKCS1_OAEP_PADDING) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr)); throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
else data = string(reinterpret_cast<const char *>(p_buffer.get())); else data = string(reinterpret_cast<const char *>(p_buffer.get()));
} }
@ -46,12 +45,11 @@ void Net::RSAKeyChain::publicKeyEncrypt(const string &data, string &encrypted_da
string tmp_data = data; string tmp_data = data;
tmp_data.resize(buffer_size - 42); tmp_data.resize(buffer_size - 42);
// 使用公钥加密 // 使用公钥加密
int encrypted_size = -1; if(RSA_public_encrypt(tmp_data.size(),
if((encrypted_size = RSA_public_encrypt(tmp_data.size(),
reinterpret_cast<const unsigned char *>(&tmp_data[0]), reinterpret_cast<const unsigned char *>(&tmp_data[0]),
reinterpret_cast<unsigned char *>(&encrypted_data[0]), reinterpret_cast<unsigned char *>(&encrypted_data[0]),
key_pair, key_pair,
RSA_PKCS1_OAEP_PADDING)) == -1) RSA_PKCS1_OAEP_PADDING) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr)); throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
} }
@ -59,6 +57,6 @@ bool Net::RSAKeyChain::checkKey() {
if(!this->if_prv_key) throw runtime_error("illegal call of checkKey"); if(!this->if_prv_key) throw runtime_error("illegal call of checkKey");
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
int return_code = RSA_check_key(this->key_pair); int return_code = RSA_check_key(this->key_pair);
if(return_code == -1) throw runtime_error("printTools occur when rsa check key"); if(return_code == -1) throw runtime_error("PrintTools occur when rsa check key");
else return return_code == 1; else return return_code == 1;
} }

View File

@ -0,0 +1,25 @@
//
// Created by 胡宇 on 2020/7/7.
//
#include <gtest/gtest.h>
#include "utils/aes_cbc_encryptor.h"
using namespace Net;
TEST(AES_Test, base_test_1){
AESCBCEncryptor encryptor;
PrintTools::printInfoBuffer(encryptor.getKeyData(), "Key Data");
std::string data, encrypt_data;
encryptor.encrypt("Hello World", encrypt_data);
PrintTools::printInfoBuffer(encrypt_data, "Encrypt Data");
encryptor.decrypt(data, encrypt_data);
ASSERT_EQ(data, std::string("Hello World"));
}

View File

@ -22,20 +22,20 @@ TEST(RSATest, init_test_1) {
TEST(RSATest, generate_test_1) { TEST(RSATest, generate_test_1) {
env->rsa->generateKeyPair(); env->rsa->generateKeyPair();
string data = std::to_string(env->rsa->getBufferSize()); string data = std::to_string(env->rsa->getBufferSize());
printTools::printInfo(data, "Buffer Size"); PrintTools::printInfo(data, "Buffer Size");
} }
TEST(RSATest, pub_encrypt_test_1) { TEST(RSATest, pub_encrypt_test_1) {
string encrypted_data; string encrypted_data;
env->rsa->publicKeyEncrypt(env->rsa_test_data, encrypted_data); env->rsa->publicKeyEncrypt(env->rsa_test_data, encrypted_data);
printTools::printInfoBuffer(encrypted_data, "Encrypted Data"); PrintTools::printInfoBuffer(encrypted_data, "Encrypted Data");
env->rsa_encrypt_data = encrypted_data; env->rsa_encrypt_data = encrypted_data;
} }
TEST(RSATest, prv_decrypt_test_1){ TEST(RSATest, prv_decrypt_test_1){
string data; string data;
env->rsa->privateKeyDecrypt(data, env->rsa_encrypt_data); env->rsa->privateKeyDecrypt(data, env->rsa_encrypt_data);
printTools::printInfo(data, "Decrypt Data"); PrintTools::printInfo(data, "Decrypt Data");
ASSERT_EQ(data, env->rsa_test_data); ASSERT_EQ(data, env->rsa_test_data);
} }