对RSA模块进行修改重构工作

引入OpenSSL库,对于其中的RSA及其相关的 C API进行简单的C++封装。
This commit is contained in:
saturneric 2019-12-14 23:23:09 +08:00
parent 067e6f3ffc
commit dcff6f7726
21 changed files with 379 additions and 210 deletions

View File

@ -7,20 +7,21 @@ include_directories(include/)
include_directories(utils/)
find_package(sqlite3 REQUIRED)
find_package(boost COMPONENTS
find_package(Boost COMPONENTS
program_options REQUIRED)
find_package(SQLiteCpp REQUIRED)
find_package(gtest REQUIRED)
find_package(OpenSSL REQUIRED)
message(STATUS SSL ${OPENSSL_LIBRARIES})
include(GoogleTest)
set(OPENSSL_INCLUDE /usr/local/opt/openssl/include)
set(OPENSSL_LIB /usr/local/opt/openssl/lib)
set(OPENSSL_LIBS /usr/local/lib/libssl.dylib /usr/local/lib/libcrypto.dylib)
set(GTEST_LIB /usr/local/lib/)
set(GTEST_LIBS gtest pthread dl)
include_directories(${Boost_INCLUDE_DIRS} ${OPENSSL_INCLUDE} ${SQLiteCpp_INCLUDE_DIRS} )
link_directories(${OPENSSL_LIB})
include_directories(${Boost_INCLUDE_DIRS} ${OPENSSL_INCLUDE_DIR} ${SQLiteCpp_INCLUDE_DIRS})
link_directories(${GTEST_LIB})
aux_source_directory(src SOURCE_ALL)
@ -33,10 +34,19 @@ gtest_add_tests(TARGET NetRSATest
TEST_SUFFIX .noArgs
TEST_LIST noArgsTests)
add_library(m_error STATIC src/error.cpp)
add_library(m_rsa STATIC src/rsacpp.cpp)
add_library(test_main test/test_main.cpp)
add_executable(NetOptsPceTest test/opts_processor_test.cpp)
target_link_libraries(NetRSATest ${GTEST_LIBS} test_main crypto ssl m_error m_rsa )
gtest_add_tests(TARGET NetOptsPceTest
TEST_SUFFIX .noArgs
TEST_LIST noArgsTests)
add_library(m_error STATIC src/error.cpp)
add_library(m_rsa STATIC src/rsa_cpp_binding.cpp)
add_library(test_main STATIC test/test_main.cpp)
target_link_libraries(NetRSATest test_main m_rsa m_error ${GTEST_LIBS} ${OPENSSL_LIBS})
target_link_libraries(NetOptsPceTest test_main m_error Boost::program_options)
set_tests_properties(${noArgsTests} PROPERTIES TIMEOUT 10)
# target_link_libraries(Net SQLiteCpp sqlite3 gtest ${Boost_LIBRARIES} pthread dl ssl )

5
include/bignumber.cpp Normal file
View File

@ -0,0 +1,5 @@
//
// Created by Eric Saturn on 2019/12/13.
//
#include "bignumber.h"

77
include/bignumber.h Normal file
View File

@ -0,0 +1,77 @@
//
// Created by Eric Saturn on 2019/12/13.
//
#ifndef NET_BIGNUMBER_H
#define NET_BIGNUMBER_H
#include <memory>
#include <openssl/bn.h>
using namespace std;
namespace Net {
// 对BIGNUM进行简单封装
class BigNumber {
public:
BigNumber() : bn(BN_new(), ::BN_free) {}
BigNumber(BIGNUM *t_bn) : bn(t_bn, ::BN_free) {}
// 临时取用
BIGNUM * get() const { return bn.get(); }
BIGNUM *getCopy() const {
BIGNUM *n_bn = BN_new();
BN_copy(n_bn, bn.get());
return n_bn;
}
// 获得智能指针
shared_ptr<BIGNUM> getSharedPtr(){
return bn;
}
// 获得一份拷贝
BigNumber copy(){
BigNumber n;
BN_copy(n.get(), bn.get());
return n;
}
void copyFrom(const BIGNUM * t){
BN_copy(bn.get(), t);
}
// 向该类移交BIGNUM结构的的控制权
void set(BIGNUM *t_bn) { bn = shared_ptr<BIGNUM>(t_bn, ::BN_free); }
// 该类移交所管辖的BIGNUM结构的控制权
BIGNUM *getControl() {
if(bn != nullptr){
BIGNUM *n_bn = bn.get();
bn = nullptr;
return n_bn;
}
else return nullptr;
}
// 得到BIGNUM的Hex字符串
string getDataHex() const {
void *hex_data_str = BN_bn2hex(bn.get());
string hex_string((const char *)hex_data_str);
OPENSSL_free(hex_data_str);
return hex_string;
}
string getDataHexHash() const{
return string();
}
private:
shared_ptr<BIGNUM> bn;
};
}
#endif //NET_BIGNUMBER_H

View File

@ -11,7 +11,7 @@
#include "type.h"
#include "cpart.h"
#include "sha256generator.h"
#include "sha256_cpp_binding.h"
#include "sql.h"
//计算模块管理对象间的依赖关系管理结构

View File

@ -11,7 +11,7 @@
#include "type.h"
#include "cpart.h"
#include "sha256generator.h"
#include "sha256_cpp_binding.h"
#include "sql.h"
class Proj;

View File

@ -14,6 +14,7 @@ public:
SSL_load_error_strings();
ERR_load_BIO_strings();
OpenSSL_add_all_algorithms();
RSA_meth_get_init();
}
};

47
include/opts_processor.h Normal file
View File

@ -0,0 +1,47 @@
//
// Created by Eric Saturn on 2019/12/14.
//
#ifndef NET_OPTS_PROCESSOR_H
#define NET_OPTS_PROCESSOR_H
#include "type.h"
#include <boost/program_options.hpp>
namespace po = boost::program_options;
class OptsProcessor {
public:
OptsProcessor(){
desc.add_options()
("help", "help list")
("init", "set up the environment")
("construct", "construct new project")
("update", "update changes done to a project")
("server", "start a server daemon")
("client", "start a client daemon")
("set", "change an option");
po::variables_map vm;
po::store(po::parse_command_line(ac, av, desc), vm);
po::notify(vm);
if (vm.count("help")) {
cout << desc << "\n";
return 1;
}
if (vm.count("compression")) {
cout << "Compression level was set to "
<< vm["compression"].as<int>() << ".\n";
} else {
cout << "Compression level was not set.\n";
}
}
private:
po::options_description desc("General Net Tools (0.0.1) By Saturn&Eric");
};
#endif //NET_OPTS_PROCESSOR_H

View File

@ -1,36 +0,0 @@
#ifndef __RSA_H__
#define __RSA_H__
#include <stdint.h>
// This is the header file for the library librsaencrypt.a
struct public_key_class{
long long modulus;
long long exponent;
};
struct private_key_class{
long long modulus;
long long exponent;
};
// This function generates public and private keys, then stores them in the structures you
// provide pointers to. The 3rd argument should be the text PRIME_SOURCE_FILE to have it use
// the location specified above in this header.
void rsa_gen_keys(struct public_key_class *pub, struct private_key_class *priv, string PRIME_SOURCE_FILE);
// This function will encrypt the data pointed to by message. It returns a pointer to a heap
// array containing the encrypted data, or NULL upon failure. This pointer should be freed when
// you are finished. The encrypted data will be 8 times as large as the original data.
uint64_t *rsa_encrypt(const unsigned char *message, const unsigned long message_size, const struct public_key_class *pub);
// This function will decrypt the data pointed to by message. It returns a pointer to a heap
// array containing the decrypted data, or NULL upon failure. This pointer should be freed when
// you are finished. The variable message_size is the size in bytes of the encrypted message.
// The decrypted data will be 1/8th the size of the encrypted data.
unsigned char *rsa_decrypt(const uint64_t *message, const unsigned long message_size, const struct private_key_class *pub);
#endif

135
include/rsa_cpp_binding.h Normal file
View File

@ -0,0 +1,135 @@
//
// Created by Eric Saturn on 2019/12/10.
//
#ifndef NET_RSA_CPP_BINDING_H
#define NET_RSA_CPP_BINDING_H
#include "error.h"
#include "bignumber.cpp"
#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <openssl/err.h>
#include <memory>
using namespace std;
namespace Net {
class RSAPubKey{
public:
explicit RSAPubKey(const RSA *rsa){
const BIGNUM *n = nullptr, *e = nullptr, *d = nullptr;
RSA_get0_key(rsa, &n, &e, &d);
this->n.copyFrom(n);
this->e.copyFrom(e);
}
void printInfo(){
error::printInfoFormal("RSAPubKey Info", {
{"n", this->n.getDataHex()},
{"e", this->e.getDataHex()}
});
}
BigNumber n;
BigNumber e;
};
class RSAPrvKey{
public:
explicit RSAPrvKey(const RSA *rsa) {
const BIGNUM *n = nullptr, *e = nullptr, *d = nullptr;
const BIGNUM *p = nullptr, *q = nullptr;
RSA_get0_key(rsa, &n, &e, &d);
this->n.copyFrom(n);
this->e.copyFrom(e);
this->d.copyFrom(d);
RSA_get0_factors(rsa, &p, &q);
this->p.copyFrom(p);
this->q.copyFrom(q);
}
void printInfo(){
error::printInfoFormal("RSAPrvKey Info", {
{"n", this->n.getDataHex()},
{"e", this->e.getDataHex()},
{"d", this->d.getDataHex()},
{"p", this->p.getDataHex()},
{"q", this->q.getDataHex()}
});
}
BigNumber n, e, d;
BigNumber p, q;
};
class RSAKeyChain {
public:
RSAKeyChain() {
key_pair = RSA_new();
}
RSAKeyChain(RSAKeyChain &&t) noexcept {
this->key_pair = t.key_pair;
t.key_pair = nullptr;
this->buffer_size = t.buffer_size;
this->if_prv_key = t.if_prv_key;
this->if_pub_key = t.if_pub_key;
}
explicit RSAKeyChain(const RSAPubKey& pubKey){
key_pair = RSA_new();
RSA_set0_key(key_pair, pubKey.n.getCopy(), pubKey.e.getCopy(), nullptr);
this->if_pub_key = true;
}
explicit RSAKeyChain(const RSAPrvKey& prvKey){
key_pair = RSA_new();
RSA_set0_key(this->key_pair, prvKey.n.getCopy(), prvKey.e.getCopy(), prvKey.d.getCopy());
RSA_set0_factors(key_pair, prvKey.p.getCopy(), prvKey.q.getCopy());
this->if_prv_key = true;
}
void generateKeyPair();
// 检查私钥是否合法
bool checkKey();
void publicKeyEncrypt(const string &data, string &encrypted_data);
void privateKeyDecrypt(string &data, const string& encrypted_data);
uint32_t getBufferSize() const {
return this->buffer_size;
}
const RSA *getRSA(){
return key_pair;
}
RSAPubKey getPubKey(){
RSAPubKey pubKey(key_pair);
return pubKey;
}
RSAPrvKey getPrvKey(){
RSAPrvKey prvKey(key_pair);
return prvKey;
}
~RSAKeyChain() {
if (key_pair != nullptr) RSA_free(key_pair);
}
private:
RSA *key_pair{};
uint32_t buffer_size = 0;
bool if_prv_key = false, if_pub_key = false;
};
}
#endif //NET_RSA_CPP_BINDING_H

View File

@ -1,58 +0,0 @@
//
// Created by Eric Saturn on 2019/12/10.
//
#ifndef NET_RSACPP_H
#define NET_RSACPP_H
#include "error.h"
#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <openssl/err.h>
#include <memory>
using namespace std;
namespace Net {
class RSAKeyChain {
public:
RSAKeyChain() {
key_pair = RSA_new();
}
RSAKeyChain(RSAKeyChain &&t) noexcept {
this->key_pair = t.key_pair;
t.key_pair = nullptr;
this->buffer_size = t.buffer_size;
}
void generateKeyPair();
void checkKey(){
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
RSA_check_key(this->key_pair);
}
void publicKeyEncrypt(const string &data, string &encrypted_data);
void privateKeyDecrypt(string &data, const string& encrypted_data);
uint32_t getBufferSize() const {
return this->buffer_size;
}
static void getDefaultRSAMethod();
~RSAKeyChain() {
if (key_pair != nullptr) RSA_free(key_pair);
}
private:
RSA *key_pair;
uint32_t buffer_size = 0;
};
}
#endif //NET_RSACPP_H

View File

@ -1,47 +0,0 @@
#ifndef SHA1_H
#define SHA1_H
/*
SHA-1 in C
By Steve Reid <steve@edmweb.com>
100% Public Domain
*/
#include "type.h"
#include "stdint.h"
typedef struct
{
uint32_t state[5];
uint32_t count[2];
unsigned char buffer[64];
} SHA1_CTX;
void SHA1Transform(
uint32_t state[5],
const unsigned char buffer[64]
);
void SHA1Init(
SHA1_CTX * context
);
void SHA1Update(
SHA1_CTX * context,
const unsigned char *data,
uint32_t len
);
void SHA1Final(
unsigned char digest[20],
SHA1_CTX * context
);
void SHA1(
char *hash_out,
const char *str,
int len);
void SHA1_Easy(string &hexresult, string &str);
#endif /* SHA1_H */

5
src/opts_processor.cpp Normal file
View File

@ -0,0 +1,5 @@
//
// Created by Eric Saturn on 2019/12/14.
//
#include "opts_processor.h"

62
src/rsa_cpp_binding.cpp Normal file
View File

@ -0,0 +1,62 @@
//
// Created by Eric Saturn on 2019/12/10.
//
#include "rsa_cpp_binding.h"
void Net::RSAKeyChain::generateKeyPair() {
BigNumber e;
// 生成一个4bit质数
BN_generate_prime_ex(e.get(), 3, 1, nullptr, nullptr, nullptr);
// 生成一对秘钥
RSA_generate_key_ex(key_pair, 2048, e.get(), nullptr);
if(this->key_pair == nullptr) throw runtime_error("key pair generation failed");
buffer_size = RSA_size(key_pair);
this->if_prv_key = true;
this->if_pub_key = true;
}
void Net::RSAKeyChain::privateKeyDecrypt(string &data, const string &encrypted_data) {
if(!this->if_prv_key) throw runtime_error("illegal call of privateKeyDecrypt");
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
assert(buffer_size > 0);
if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal");
// 使用私钥解密
int decrypted_size = -1;
unique_ptr<unsigned char[]>p_buffer (new unsigned char[buffer_size]);
if((decrypted_size = RSA_private_decrypt(encrypted_data.size(),
reinterpret_cast<const unsigned char *>(&encrypted_data[0]),
p_buffer.get(),
key_pair,
RSA_PKCS1_OAEP_PADDING)) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
else data = string(reinterpret_cast<const char *>(p_buffer.get()));
}
void Net::RSAKeyChain::publicKeyEncrypt(const string &data, string &encrypted_data) {
if(!this->if_pub_key) throw runtime_error("illegal call of publicKeyEncrypt");
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
assert(buffer_size > 0);
if(data.size() >= buffer_size - 42) throw runtime_error("string data is too long");
// 预分配储存空间
encrypted_data.resize(buffer_size);
// 加密数据转移
string tmp_data = data;
tmp_data.resize(buffer_size - 42);
// 使用公钥加密
int encrypted_size = -1;
if((encrypted_size = RSA_public_encrypt(tmp_data.size(),
reinterpret_cast<const unsigned char *>(&tmp_data[0]),
reinterpret_cast<unsigned char *>(&encrypted_data[0]),
key_pair,
RSA_PKCS1_OAEP_PADDING)) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
}
bool Net::RSAKeyChain::checkKey() {
if(!this->if_prv_key) throw runtime_error("illegal call of checkKey");
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
int return_code = RSA_check_key(this->key_pair);
if(return_code == -1) throw runtime_error("error occur when rsa check key");
else return return_code == 1;
}

View File

@ -1,53 +0,0 @@
//
// Created by Eric Saturn on 2019/12/10.
//
#include "rsacpp.h"
void Net::RSAKeyChain::getDefaultRSAMethod() {
const RSA_METHOD *rsaMethod = RSA_get_default_method();
error::printInfoFormal("Default RSA Method", {
{"name", rsaMethod->name},
});
}
void Net::RSAKeyChain::generateKeyPair() {
BIGNUM *e = BN_new();
// 生成一个4bit质数
BN_generate_prime_ex(e, 3, 1, nullptr, nullptr, nullptr);
// 生成一对秘钥
RSA_generate_key_ex(key_pair, 2048, e, nullptr);
BN_free(e);
if(this->key_pair == nullptr) throw runtime_error("key pair generation failed");
buffer_size = RSA_size(key_pair);
}
void Net::RSAKeyChain::privateKeyDecrypt(string &data, const string &encrypted_data) {
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
assert(buffer_size > 0);
if(encrypted_data.size() != buffer_size) throw runtime_error("encrypt data's size is abnormal");
// 使用私钥解密
if(RSA_private_decrypt(encrypted_data.size(), reinterpret_cast<const unsigned char *>(&encrypted_data[0]),
reinterpret_cast<unsigned char *>(&data[0]),
key_pair,
RSA_NO_PADDING) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
}
void Net::RSAKeyChain::publicKeyEncrypt(const string &data, string &encrypted_data) {
if(this->key_pair == nullptr) throw runtime_error("key pair is invalid");
assert(buffer_size > 0);
if(data.size() >= this->getBufferSize()) throw runtime_error("string data is too long");
// 预分配储存空间
encrypted_data.resize(buffer_size);
// 加密数据转移
string tmp_data = data;
tmp_data.resize(buffer_size);
// 使用公钥加密
if(RSA_public_encrypt(tmp_data.size(), reinterpret_cast<const unsigned char *>(&tmp_data[0]),
reinterpret_cast<unsigned char *>(&encrypted_data[0]),
key_pair,
RSA_NO_PADDING) == -1)
throw runtime_error(ERR_error_string(ERR_get_error(), nullptr));
}

View File

@ -1,5 +1,5 @@
#include "type.h"
#include "sha256generator.h"
#include "sha256_cpp_binding.h"
void SHA256Generator::generate() {
unsigned char hash[SHA256_DIGEST_LENGTH];

View File

@ -5,13 +5,15 @@
#ifndef NET_ENV_H
#define NET_ENV_H
#include "rsacpp.h"
#include "rsa_cpp_binding.h"
class GlobalTestEnv : public testing::Environment{
public:
unique_ptr<Net::RSAKeyChain> rsa{new Net::RSAKeyChain()};
string rsa_test_data = "hello world";
string rsa_encrypt_data;
shared_ptr<Net::RSAPrvKey> prvKey = nullptr;
shared_ptr<Net::RSAPubKey> pubKey = nullptr;
};
#endif //NET_ENV_H

View File

@ -0,0 +1,4 @@
//
// Created by Eric Saturn on 2019/12/14.
//

View File

@ -3,7 +3,7 @@
//
#include <gtest/gtest.h>
#include <rsacpp.h>
#include <rsa_cpp_binding.h>
#include "env.h"
@ -13,8 +13,6 @@ using namespace std;
GlobalTestEnv *_env;
TEST(RSATest, init_test_1) {
_env->rsa->getDefaultRSAMethod();
}
@ -34,4 +32,21 @@ TEST(RSATest, prv_decrypt_test_1){
string data;
_env->rsa->privateKeyDecrypt(data, _env->rsa_encrypt_data);
error::printInfo(data, "Decrypt Data");
ASSERT_EQ(data, _env->rsa_test_data);
}
TEST(RSATest, pub_key_get_test_1){
_env->pubKey = shared_ptr<RSAPubKey>(new RSAPubKey(_env->rsa->getRSA()));
}
TEST(RSATest, prv_key_get_test_1){
_env->prvKey = shared_ptr<RSAPrvKey>(new RSAPrvKey(_env->rsa->getRSA()));
ASSERT_EQ(_env->rsa->checkKey(), true);
_env->prvKey->printInfo();
}
TEST(RSATest, prv_key_build_key_chain){
RSAKeyChain keyChain(*_env->prvKey);
ASSERT_EQ(keyChain.checkKey(), true);
}