改写RSA模块

添加RSA私钥解密函数
This commit is contained in:
saturneric 2019-12-12 23:25:55 +08:00
parent 157333a6d9
commit 067e6f3ffc
6 changed files with 78 additions and 22 deletions

View File

@ -28,6 +28,8 @@ namespace Net {
void printInfo(const string& info, string tag = ""); void printInfo(const string& info, string tag = "");
void printInfoBuffer(const string& info, string tag = "");
void printInfoFormal(const string& title, initializer_list<FormalItem> body); void printInfoFormal(const string& title, initializer_list<FormalItem> body);
} }
} }

View File

@ -9,6 +9,7 @@
#include <openssl/rsa.h> #include <openssl/rsa.h>
#include <openssl/pem.h> #include <openssl/pem.h>
#include <openssl/err.h>
#include <memory> #include <memory>
using namespace std; using namespace std;
@ -26,31 +27,16 @@ namespace Net {
this->buffer_size = t.buffer_size; this->buffer_size = t.buffer_size;
} }
void generateKeyPair(){ void generateKeyPair();
BIGNUM *e = BN_new();
// 生成一个4bit质数
BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr);
// 生成一对秘钥
RSA_generate_key_ex(key_pair, 2048, e, nullptr);
BN_free(e);
if(this->key_pair == nullptr) throw runtime_error("key pair generation failed");
buffer_size = RSA_size(key_pair);
}
void checkKey(){ void checkKey(){
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
RSA_check_key(this->key_pair); RSA_check_key(this->key_pair);
} }
void publicKeyEncrypt(string &data, string &encrypted_data){ void publicKeyEncrypt(const string &data, string &encrypted_data);
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long"); void privateKeyDecrypt(string &data, const string& encrypted_data);
// 预分配储存空间
encrypted_data.resize(buffer_size);
// 使用公钥加密
RSA_public_encrypt(data.size(), reinterpret_cast<const unsigned char *>(data.c_str()),
reinterpret_cast<unsigned char *>(&data[0]), key_pair, RSA_NO_PADDING);
}
uint32_t getBufferSize() const { uint32_t getBufferSize() const {
return this->buffer_size; return this->buffer_size;

View File

@ -38,5 +38,21 @@ namespace Net {
} }
printf("----------------------------------<\n<<<\n\n"); printf("----------------------------------<\n<<<\n\n");
} }
void printInfoBuffer(const string &info, string tag) {
printf("\n[DEBUG INFO (BUFFER)]\n");
printf(">----------------------------------------------\n");
uint8_t *p_i = (uint8_t *) &info[0];
uint8_t *p_e = (uint8_t *) &info[info.size()-1];
for(int c = 0;p_i < p_e; ++p_i, ++c){
if(!(c % 16) && c) printf("\n");
printf("%02x ",*p_i);
}
printf("\n");
printf("----------------------------------------------<\n");
if(tag.size())
printf("{ %s }\n\n",tag.data());
}
} }
} }

View File

@ -10,3 +10,44 @@ void Net::RSAKeyChain::getDefaultRSAMethod() {
{"name", rsaMethod->name}, {"name", rsaMethod->name},
}); });
} }
void Net::RSAKeyChain::generateKeyPair() {
BIGNUM *e = BN_new();
// 生成一个4bit质数
BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr);
// 生成一对秘钥
RSA_generate_key_ex(key_pair, 2048, e, nullptr);
BN_free(e);
if(this->key_pair == nullptr) throw runtime_error("key pair generation failed");
buffer_size = RSA_size(key_pair);
}
void Net::RSAKeyChain::privateKeyDecrypt(string &data, const string &encrypted_data) {
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
assert(buffer_size > 0);
if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal");
// 使用私钥解密
if(RSA_private_decrypt(encrypted_data.size(), reinterpret_cast<const unsigned char *>(&encrypted_data[0]),
reinterpret_cast<unsigned char *>(&data[0]),
key_pair,
RSA_NO_PADDING) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
}
void Net::RSAKeyChain::publicKeyEncrypt(const string &data, string &encrypted_data) {
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
assert(buffer_size > 0);
if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long");
// 预分配储存空间
encrypted_data.resize(buffer_size);
// 加密数据转移
string tmp_data = data;
tmp_data.resize(buffer_size);
// 使用公钥加密
if(RSA_public_encrypt(tmp_data.size(), reinterpret_cast<const unsigned char *>(&tmp_data[0]),
reinterpret_cast<unsigned char *>(&encrypted_data[0]),
key_pair,
RSA_NO_PADDING) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
}

View File

@ -10,6 +10,8 @@
class GlobalTestEnv : public testing::Environment{ class GlobalTestEnv : public testing::Environment{
public: public:
unique_ptr<Net::RSAKeyChain> rsa{new Net::RSAKeyChain()}; unique_ptr<Net::RSAKeyChain> rsa{new Net::RSAKeyChain()};
string rsa_test_data = "hello world";
string rsa_encrypt_data;
}; };
#endif //NET_ENV_H #endif //NET_ENV_H

View File

@ -14,15 +14,24 @@ GlobalTestEnv *_env;
TEST(RSATest, init_test_1) { TEST(RSATest, init_test_1) {
_env->rsa->getDefaultRSAMethod(); _env->rsa->getDefaultRSAMethod();
error::printInfo(to_string(_env->rsa->getBufferSize()), string("Buffer Size"));
} }
TEST(RSATest, generate_test_1) { TEST(RSATest, generate_test_1) {
_env->rsa->generateKeyPair(); _env->rsa->generateKeyPair();
error::printInfo(to_string(_env->rsa->getBufferSize()), string("Buffer Size"));
} }
TEST(RSATest, pub_encrypt_test_1) { TEST(RSATest, pub_encrypt_test_1) {
RSAKeyChain rsa; string encrypted_data;
rsa.generateKeyPair(); _env->rsa->publicKeyEncrypt(_env->rsa_test_data, encrypted_data);
error::printInfoBuffer(encrypted_data, "Encrypted Data");
_env->rsa_encrypt_data = encrypted_data;
}
TEST(RSATest, prv_decrypt_test_1){
string data;
_env->rsa->privateKeyDecrypt(data, _env->rsa_encrypt_data);
error::printInfo(data, "Decrypt Data");
} }