From 067e6f3ffccacf78f6ddfe97b026f9bded9d990b Mon Sep 17 00:00:00 2001 From: saturneric Date: Thu, 12 Dec 2019 23:25:55 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E5=86=99RSA=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加RSA私钥解密函数 --- include/error.h | 2 ++ include/rsacpp.h | 24 +++++------------------- src/error.cpp | 16 ++++++++++++++++ src/rsacpp.cpp | 41 +++++++++++++++++++++++++++++++++++++++++ test/env.h | 2 ++ test/rsa_test.cpp | 15 ++++++++++++--- 6 files changed, 78 insertions(+), 22 deletions(-) diff --git a/include/error.h b/include/error.h index 8e930c4..d20088d 100644 --- a/include/error.h +++ b/include/error.h @@ -28,6 +28,8 @@ namespace Net { void printInfo(const string& info, string tag = ""); + void printInfoBuffer(const string& info, string tag = ""); + void printInfoFormal(const string& title, initializer_list body); } } diff --git a/include/rsacpp.h b/include/rsacpp.h index f10affb..2829a0e 100644 --- a/include/rsacpp.h +++ b/include/rsacpp.h @@ -9,6 +9,7 @@ #include #include +#include #include using namespace std; @@ -26,31 +27,16 @@ namespace Net { this->buffer_size = t.buffer_size; } - void generateKeyPair(){ - BIGNUM *e = BN_new(); -// 生成一个4bit质数 - BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr); -// 生成一对秘钥 - RSA_generate_key_ex(key_pair, 2048, e, nullptr); - BN_free(e); - if(this->key_pair == nullptr) throw runtime_error("key pair generation failed"); - buffer_size = RSA_size(key_pair); - } + void generateKeyPair(); void checkKey(){ if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); RSA_check_key(this->key_pair); } - void publicKeyEncrypt(string &data, string &encrypted_data){ - if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); - if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long"); -// 预分配储存空间 - encrypted_data.resize(buffer_size); -// 使用公钥加密 - RSA_public_encrypt(data.size(), reinterpret_cast(data.c_str()), - reinterpret_cast(&data[0]), key_pair, RSA_NO_PADDING); - } + void publicKeyEncrypt(const string &data, string &encrypted_data); + + void privateKeyDecrypt(string &data, const string& encrypted_data); uint32_t getBufferSize() const { return this->buffer_size; diff --git a/src/error.cpp b/src/error.cpp index 7b741aa..d6ba7f1 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -38,5 +38,21 @@ namespace Net { } printf("----------------------------------<\n<<<\n\n"); } + + void printInfoBuffer(const string &info, string tag) { + printf("\n[DEBUG INFO (BUFFER)]\n"); + printf(">----------------------------------------------\n"); + uint8_t *p_i = (uint8_t *) &info[0]; + uint8_t *p_e = (uint8_t *) &info[info.size()-1]; + for(int c = 0;p_i < p_e; ++p_i, ++c){ + if(!(c % 16) && c) printf("\n"); + printf("%02x ",*p_i); + + } + printf("\n"); + printf("----------------------------------------------<\n"); + if(tag.size()) + printf("{ %s }\n\n",tag.data()); + } } } \ No newline at end of file diff --git a/src/rsacpp.cpp b/src/rsacpp.cpp index 6fee3db..d37dc8d 100644 --- a/src/rsacpp.cpp +++ b/src/rsacpp.cpp @@ -10,3 +10,44 @@ void Net::RSAKeyChain::getDefaultRSAMethod() { {"name", rsaMethod->name}, }); } + +void Net::RSAKeyChain::generateKeyPair() { + BIGNUM *e = BN_new(); +// 生成一个4bit质数 + BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr); +// 生成一对秘钥 + RSA_generate_key_ex(key_pair, 2048, e, nullptr); + BN_free(e); + if(this->key_pair == nullptr) throw runtime_error("key pair generation failed"); + buffer_size = RSA_size(key_pair); +} + +void Net::RSAKeyChain::privateKeyDecrypt(string &data, const string &encrypted_data) { + if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); + assert(buffer_size > 0); + if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal"); +// 使用私钥解密 + if(RSA_private_decrypt(encrypted_data.size(), reinterpret_cast(&encrypted_data[0]), + reinterpret_cast(&data[0]), + key_pair, + RSA_NO_PADDING) == -1) + throw runtime_error(ERR_error_string(ERR_get_error(), nullptr)); + +} + +void Net::RSAKeyChain::publicKeyEncrypt(const string &data, string &encrypted_data) { + if(this->key_pair == nullptr) throw runtime_error("key pair is invalid"); + assert(buffer_size > 0); + if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long"); +// 预分配储存空间 + encrypted_data.resize(buffer_size); +// 加密数据转移 + string tmp_data = data; + tmp_data.resize(buffer_size); +// 使用公钥加密 + if(RSA_public_encrypt(tmp_data.size(), reinterpret_cast(&tmp_data[0]), + reinterpret_cast(&encrypted_data[0]), + key_pair, + RSA_NO_PADDING) == -1) + throw runtime_error(ERR_error_string(ERR_get_error(), nullptr)); +} diff --git a/test/env.h b/test/env.h index 94e6ee2..45ae376 100644 --- a/test/env.h +++ b/test/env.h @@ -10,6 +10,8 @@ class GlobalTestEnv : public testing::Environment{ public: unique_ptr rsa{new Net::RSAKeyChain()}; + string rsa_test_data = "hello world"; + string rsa_encrypt_data; }; #endif //NET_ENV_H diff --git a/test/rsa_test.cpp b/test/rsa_test.cpp index 6aea923..0139ec3 100644 --- a/test/rsa_test.cpp +++ b/test/rsa_test.cpp @@ -14,15 +14,24 @@ GlobalTestEnv *_env; TEST(RSATest, init_test_1) { _env->rsa->getDefaultRSAMethod(); - error::printInfo(to_string(_env->rsa->getBufferSize()), string("Buffer Size")); + } TEST(RSATest, generate_test_1) { _env->rsa->generateKeyPair(); + error::printInfo(to_string(_env->rsa->getBufferSize()), string("Buffer Size")); } TEST(RSATest, pub_encrypt_test_1) { - RSAKeyChain rsa; - rsa.generateKeyPair(); + string encrypted_data; + _env->rsa->publicKeyEncrypt(_env->rsa_test_data, encrypted_data); + error::printInfoBuffer(encrypted_data, "Encrypted Data"); + _env->rsa_encrypt_data = encrypted_data; +} + +TEST(RSATest, prv_decrypt_test_1){ + string data; + _env->rsa->privateKeyDecrypt(data, _env->rsa_encrypt_data); + error::printInfo(data, "Decrypt Data"); } \ No newline at end of file