改写RSA模块
添加RSA私钥解密函数
This commit is contained in:
parent
157333a6d9
commit
067e6f3ffc
@ -28,6 +28,8 @@ namespace Net {
|
||||
|
||||
void printInfo(const string& info, string tag = "");
|
||||
|
||||
void printInfoBuffer(const string& info, string tag = "");
|
||||
|
||||
void printInfoFormal(const string& title, initializer_list<FormalItem> body);
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
|
||||
#include <openssl/rsa.h>
|
||||
#include <openssl/pem.h>
|
||||
#include <openssl/err.h>
|
||||
#include <memory>
|
||||
|
||||
using namespace std;
|
||||
@ -26,31 +27,16 @@ namespace Net {
|
||||
this->buffer_size = t.buffer_size;
|
||||
}
|
||||
|
||||
void generateKeyPair(){
|
||||
BIGNUM *e = BN_new();
|
||||
// 生成一个4bit质数
|
||||
BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr);
|
||||
// 生成一对秘钥
|
||||
RSA_generate_key_ex(key_pair, 2048, e, nullptr);
|
||||
BN_free(e);
|
||||
if(this->key_pair == nullptr) throw runtime_error("key pair generation failed");
|
||||
buffer_size = RSA_size(key_pair);
|
||||
}
|
||||
void generateKeyPair();
|
||||
|
||||
void checkKey(){
|
||||
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
|
||||
RSA_check_key(this->key_pair);
|
||||
}
|
||||
|
||||
void publicKeyEncrypt(string &data, string &encrypted_data){
|
||||
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
|
||||
if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long");
|
||||
// 预分配储存空间
|
||||
encrypted_data.resize(buffer_size);
|
||||
// 使用公钥加密
|
||||
RSA_public_encrypt(data.size(), reinterpret_cast<const unsigned char *>(data.c_str()),
|
||||
reinterpret_cast<unsigned char *>(&data[0]), key_pair, RSA_NO_PADDING);
|
||||
}
|
||||
void publicKeyEncrypt(const string &data, string &encrypted_data);
|
||||
|
||||
void privateKeyDecrypt(string &data, const string& encrypted_data);
|
||||
|
||||
uint32_t getBufferSize() const {
|
||||
return this->buffer_size;
|
||||
|
@ -38,5 +38,21 @@ namespace Net {
|
||||
}
|
||||
printf("----------------------------------<\n<<<\n\n");
|
||||
}
|
||||
|
||||
void printInfoBuffer(const string &info, string tag) {
|
||||
printf("\n[DEBUG INFO (BUFFER)]\n");
|
||||
printf(">----------------------------------------------\n");
|
||||
uint8_t *p_i = (uint8_t *) &info[0];
|
||||
uint8_t *p_e = (uint8_t *) &info[info.size()-1];
|
||||
for(int c = 0;p_i < p_e; ++p_i, ++c){
|
||||
if(!(c % 16) && c) printf("\n");
|
||||
printf("%02x ",*p_i);
|
||||
|
||||
}
|
||||
printf("\n");
|
||||
printf("----------------------------------------------<\n");
|
||||
if(tag.size())
|
||||
printf("{ %s }\n\n",tag.data());
|
||||
}
|
||||
}
|
||||
}
|
@ -10,3 +10,44 @@ void Net::RSAKeyChain::getDefaultRSAMethod() {
|
||||
{"name", rsaMethod->name},
|
||||
});
|
||||
}
|
||||
|
||||
void Net::RSAKeyChain::generateKeyPair() {
|
||||
BIGNUM *e = BN_new();
|
||||
// 生成一个4bit质数
|
||||
BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr);
|
||||
// 生成一对秘钥
|
||||
RSA_generate_key_ex(key_pair, 2048, e, nullptr);
|
||||
BN_free(e);
|
||||
if(this->key_pair == nullptr) throw runtime_error("key pair generation failed");
|
||||
buffer_size = RSA_size(key_pair);
|
||||
}
|
||||
|
||||
void Net::RSAKeyChain::privateKeyDecrypt(string &data, const string &encrypted_data) {
|
||||
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
|
||||
assert(buffer_size > 0);
|
||||
if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal");
|
||||
// 使用私钥解密
|
||||
if(RSA_private_decrypt(encrypted_data.size(), reinterpret_cast<const unsigned char *>(&encrypted_data[0]),
|
||||
reinterpret_cast<unsigned char *>(&data[0]),
|
||||
key_pair,
|
||||
RSA_NO_PADDING) == -1)
|
||||
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
|
||||
|
||||
}
|
||||
|
||||
void Net::RSAKeyChain::publicKeyEncrypt(const string &data, string &encrypted_data) {
|
||||
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
|
||||
assert(buffer_size > 0);
|
||||
if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long");
|
||||
// 预分配储存空间
|
||||
encrypted_data.resize(buffer_size);
|
||||
// 加密数据转移
|
||||
string tmp_data = data;
|
||||
tmp_data.resize(buffer_size);
|
||||
// 使用公钥加密
|
||||
if(RSA_public_encrypt(tmp_data.size(), reinterpret_cast<const unsigned char *>(&tmp_data[0]),
|
||||
reinterpret_cast<unsigned char *>(&encrypted_data[0]),
|
||||
key_pair,
|
||||
RSA_NO_PADDING) == -1)
|
||||
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
|
||||
}
|
||||
|
@ -10,6 +10,8 @@
|
||||
class GlobalTestEnv : public testing::Environment{
|
||||
public:
|
||||
unique_ptr<Net::RSAKeyChain> rsa{new Net::RSAKeyChain()};
|
||||
string rsa_test_data = "hello world";
|
||||
string rsa_encrypt_data;
|
||||
};
|
||||
|
||||
#endif //NET_ENV_H
|
||||
|
@ -14,15 +14,24 @@ GlobalTestEnv *_env;
|
||||
|
||||
TEST(RSATest, init_test_1) {
|
||||
_env->rsa->getDefaultRSAMethod();
|
||||
error::printInfo(to_string(_env->rsa->getBufferSize()), string("Buffer Size"));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST(RSATest, generate_test_1) {
|
||||
_env->rsa->generateKeyPair();
|
||||
error::printInfo(to_string(_env->rsa->getBufferSize()), string("Buffer Size"));
|
||||
}
|
||||
|
||||
TEST(RSATest, pub_encrypt_test_1) {
|
||||
RSAKeyChain rsa;
|
||||
rsa.generateKeyPair();
|
||||
string encrypted_data;
|
||||
_env->rsa->publicKeyEncrypt(_env->rsa_test_data, encrypted_data);
|
||||
error::printInfoBuffer(encrypted_data, "Encrypted Data");
|
||||
_env->rsa_encrypt_data = encrypted_data;
|
||||
}
|
||||
|
||||
TEST(RSATest, prv_decrypt_test_1){
|
||||
string data;
|
||||
_env->rsa->privateKeyDecrypt(data, _env->rsa_encrypt_data);
|
||||
error::printInfo(data, "Decrypt Data");
|
||||
}
|
Loading…
Reference in New Issue
Block a user