diff --git a/include/instruct.h b/include/instruct.h index 28a92b2..1936d7d 100644 --- a/include/instruct.h +++ b/include/instruct.h @@ -22,11 +22,12 @@ #include "rng.hpp" - +//提示信息打印类函数 namespace error{ void printError(string error_info); void printWarning(string warning_info); void printSuccess(string succes_info); + void printRed(string red_info); } struct instructions{ diff --git a/src/aes.cpp b/src/aes.cpp index 18e3f38..0e25f6c 100755 --- a/src/aes.cpp +++ b/src/aes.cpp @@ -356,10 +356,10 @@ static void InvMixColumns(state_t* state) c = (*state)[i][2]; d = (*state)[i][3]; - (*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09); - (*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d); - (*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b); - (*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e); + (*state)[i][0] = (uint8_t)(Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09)); + (*state)[i][1] = (uint8_t)(Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d)); + (*state)[i][2] = (uint8_t)(Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b)); + (*state)[i][3] = (uint8_t)(Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e)); } } @@ -558,7 +558,7 @@ void AES_CTR_xcrypt_buffer(struct AES_ctx* ctx, uint8_t* buf, uint32_t length) ctx->Iv[bi] = 0; continue; } - ctx->Iv[bi] += 1; + ctx->Iv[bi] += (uint8_t) 1; break; } bi = 0; diff --git a/src/controller.cpp b/src/controller.cpp index 97977c6..b137fd5 100644 --- a/src/controller.cpp +++ b/src/controller.cpp @@ -15,43 +15,121 @@ extern string PRIME_SOURCE_FILE; //线程阻塞开关 int if_wait = 1; +//工具组初始化 int init(string instruct, vector &configs, vector &lconfigs, vector &targets){ sqlite3 *psql; sqlite3_stmt *psqlsmt; + + //连接数据库 sqlite3_open("info.db", &psql); const char *pzTail; + //对于服务器的初始化 if(targets[0] == "server"){ - sql::table_create(psql, "server_info", { - {"sqes_public","NONE"}, - {"sqes_private","NONE"}, - {"key_sha1","TEXT"} - }); - sql::insert_info(psql, &psqlsmt, "server_info", { - {"sqes_public","?1"}, - {"sqes_private","?2"}, - {"key_sha1","?3"}, - }); + if (targets.size() < 3) { + error::printError("Illegal Args.\nFromat: init server [server_name] [key]"); + return -1; + } + + //检查名字是否合乎规范 + if (!setting_file::if_name_illegal(targets[0].data())) { + error::printError("Illegal Arg server_name."); + return -1; + } + + try { + //创建数据库服务器信息描述数据表 + sql::table_create(psql, "server_info", { + {"name","TEXT"}, + {"sqes_public","NONE"}, + {"sqes_private","NONE"}, + {"key_sha1","TEXT"} + }); + } + catch (const char * errinfo) { + string errstr = errinfo; + if (errstr == "fail to create table") { + if (!config_search(configs, "-f")) { + error::printWarning("Have already init server information.\nUse arg \"-f\" to continue."); + return 0; + } + else{ + string sql_quote = "DELETE FROM server_info;"; + sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + int rtn = sqlite3_step(psqlsmt); + if (rtn == SQLITE_DONE) { + + } + else { + const char *error = sqlite3_errmsg(psql); + int errorcode = sqlite3_extended_errcode(psql); + printf("\033[31mSQL Error: [%d]%s\n\033[0m", errorcode, error); + throw error; + } + sqlite3_finalize(psqlsmt); + } + } + } + + //构建数据库插入命令 + sql::insert_info(psql, &psqlsmt, "server_info", { + {"sqes_public","?1"}, + {"sqes_private","?2"}, + {"key_sha1","?3"}, + {"name","?4"}, + }); + struct public_key_class npbkc; struct private_key_class nprkc; + + //生成RSA钥匙串 rsa_gen_keys(&npbkc, &nprkc, PRIME_SOURCE_FILE); + + //填写数据库数据表 sqlite3_bind_blob(psqlsmt, 1, &npbkc, sizeof(public_key_class), SQLITE_TRANSIENT); sqlite3_bind_blob(psqlsmt, 2, &nprkc, sizeof(private_key_class), SQLITE_TRANSIENT); - if(targets[1].size() < 6) error::printWarning("Key is too weak."); + sqlite3_bind_blob(psqlsmt, 4, targets[1].data(), targets[1].size(), SQLITE_TRANSIENT); + + //生成服务器访问口令哈希码(SHA1) + if(targets[2].size() < 6) error::printWarning("Key is too weak."); string sha1_hex; SHA1_Easy(sha1_hex, targets[1]); sqlite3_bind_text(psqlsmt, 3, sha1_hex.data(), -1, SQLITE_TRANSIENT); + //执行数据库写入命令 if(sqlite3_step(psqlsmt) != SQLITE_DONE){ sql::printError(psql); } sqlite3_finalize(psqlsmt); + + + //输出成功信息 error::printSuccess("Succeed."); + sqlite3_close(psql); return 0; } else{ + //对于客户端的初始化 + if(targets.size() < 2) { + error::printError("Illegal Args.\nFromat: init [client_name] [client_tag]"); + return -1; + } + + //检测名字与标签是否符合规范 + if (setting_file::if_name_illegal(targets[0])); + else { + error::printError("Illegal Arg client_name."); + return -1; + } + if (setting_file::if_name_illegal(targets[1])); + else { + error::printError("Illegal Arg client_tag."); + return -1; + } + try { + //创建客户端描述信息数据表 sql::table_create(psql, "client_info", { {"name","TEXT"}, {"tag","TEXT"}, @@ -69,11 +147,13 @@ int init(string instruct, vector &configs, vector &lconfigs, vec }); } catch (const char *error_info) { if(!strcmp(error_info, "fail to create table")){ + //检测强制参数 if(!config_search(configs, "-f")){ printf("\033[33mWarning: Have Already run init process.Try configure -f to continue.\n\033[0m"); return 0; } else{ + // 清空已存在的数据表 string sql_quote = "DELETE FROM client_info;"; sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); int rtn = sqlite3_step(psqlsmt); @@ -93,19 +173,11 @@ int init(string instruct, vector &configs, vector &lconfigs, vec } - + //构建数据库插入命令 sql::insert_info(psql, &psqlsmt, "client_info", { {"name","?1"}, {"tag","?2"} }); - if(setting_file::if_name_illegal(targets[0])); - else{ - error::printError("Args(name) abnormal."); - } - if(setting_file::if_name_illegal(targets[1])); - else{ - error::printError("Args(tag) abnormal."); - } sqlite3_bind_text(psqlsmt, 1, targets[0].data(), -1, SQLITE_TRANSIENT); sqlite3_bind_text(psqlsmt, 2, targets[1].data(), -1, SQLITE_TRANSIENT); int rtn = sqlite3_step(psqlsmt); @@ -115,17 +187,24 @@ int init(string instruct, vector &configs, vector &lconfigs, vec else throw "sql writes error"; sqlite3_finalize(psqlsmt); sqlite3_close(psql); + + //成功执行 + error::printSuccess("Succeed."); return 0; } +//修改工具组配置信息 int set(string instruct, vector &configs, vector &lconfigs, vector &targets){ if(targets.size() < 2){ - error::printError("Args error."); + error::printError("Illegal Args.\nUse help to get more information."); return -1; } + sqlite3 *psql; sqlite3_stmt *psqlsmt; const char *pzTail; + + //连接数据库 if(sqlite3_open("info.db", &psql) == SQLITE_ERROR){ sql::printError(psql); return -1; @@ -136,16 +215,16 @@ int set(string instruct, vector &configs, vector &lconfigs, vect int if_find = sqlite3_column_int(psqlsmt, 0); if(if_find); else{ - error::printError("Couldn't do set before init process."); + error::printError("Couldn't SET before INIT process."); return -1; } sqlite3_finalize(psqlsmt); - if(targets[0] == "square"){ + if(targets[0] == "server"){ sql_quote = "UPDATE client_info SET msqes_ip = ?1, msqes_port = ?2 WHERE rowid = 1;"; sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); if(!Addr::checkValidIP(targets[1])){ - error::printError("Args(ipaddr) is abnomal."); + error::printError("Arg(ipaddr) is abnomal."); sqlite3_finalize(psqlsmt); sqlite3_close(psql); return -1; @@ -158,7 +237,7 @@ int set(string instruct, vector &configs, vector &lconfigs, vect ss>>port; if(port > 0 && port <= 65535); else{ - error::printError("Args(port) is abnomal."); + error::printError("Arg(port) is abnomal."); sqlite3_finalize(psqlsmt); sqlite3_close(psql); return -1; @@ -185,7 +264,7 @@ int set(string instruct, vector &configs, vector &lconfigs, vect } sqlite3_finalize(psqlsmt); } - else if(targets[1] == "square"){ + else if(targets[1] == "server"){ sql_quote = "UPDATE client_info SET msqes_key = ?1 WHERE rowid = 1;"; sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); sqlite3_bind_text(psqlsmt, 1, targets[2].data(), -1, SQLITE_TRANSIENT); diff --git a/src/main.cpp b/src/main.cpp index 1ff427d..cd942f5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -15,13 +15,14 @@ rng::rng128 rand128({rng::tsc_seed{}(),rng::tsc_seed{}()}); int main(int argc, const char *argv[]){ // 命令 string instruct; -// 设置 +// 参数设置 vector config; -// 长设置 +// 参数长设置 vector long_config; -// 目标 +// 参数目标功能 vector target; -// 注册函数 + +// 注册参数函数 struct instructions istns; istns.construct = construct; istns.update = update; @@ -50,34 +51,45 @@ int main(int argc, const char *argv[]){ } } } -// 处理命令 - if(instruct == "construct"){ - if(istns.construct != nullptr) istns.construct(instruct,config,long_config,target); - else error::printError("Function not found."); - } - else if (instruct == "update"){ - if(istns.update != nullptr) istns.update(instruct,config,long_config,target); - else error::printError("Function not found."); - } - else if (instruct == "server"){ - if(istns.update != nullptr) istns.server(instruct,config,long_config,target); - else error::printError("Function not found."); - } - else if (instruct == "init"){ - if(istns.update != nullptr) istns.init(instruct,config,long_config,target); - else error::printError("Function not found."); - } - else if (instruct == "set"){ - if(istns.update != nullptr) istns.set(instruct,config,long_config,target); - else error::printError("Function not found."); - } - else if (instruct == "client"){ - if(istns.update != nullptr) istns.client(instruct,config,long_config,target); - else error::printError("Function not found."); - } - else{ - printf("\033[33mInstruction \"%s\" doesn't make sense.\n\033[0m",instruct.data()); - } + + int rtn = 0; +// 处理解析命令 + try { + + if (instruct == "construct") { + if (istns.construct != nullptr) rtn = istns.construct(instruct, config, long_config, target); + else error::printError("Function not found."); + } + else if (instruct == "update") { + if (istns.update != nullptr) rtn = istns.update(instruct, config, long_config, target); + else error::printError("Function not found."); + } + else if (instruct == "server") { + if (istns.update != nullptr) rtn = istns.server(instruct, config, long_config, target); + else error::printError("Function not found."); + } + else if (instruct == "init") { + if (istns.update != nullptr) rtn = istns.init(instruct, config, long_config, target); + else error::printError("Function not found."); + } + else if (instruct == "set") { + if (istns.update != nullptr) rtn = istns.set(instruct, config, long_config, target); + else error::printError("Function not found."); + } + else if (instruct == "client") { + if (istns.update != nullptr) rtn = istns.client(instruct, config, long_config, target); + else error::printError("Function not found."); + } + else { + printf("\033[33mInstruction \"%s\" doesn't make sense.\n\033[0m", instruct.data()); + } + } + catch (const char *errorinfo) { + string errstr = errorinfo; + error::printError(errstr); + if (rtn < 0) error::printRed("Abort."); + } + return 0; } diff --git a/src/model.cpp b/src/model.cpp index 9507309..0b777aa 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -20,6 +20,9 @@ namespace error { void printSuccess(string succes_info){ printf("\033[32m%s\n\033[0m",succes_info.data()); } + void printRed(string red_info) { + printf("\033[31m%s\n\033[0m", red_info.data()); + } } bool config_search(vector &configs,string tfg){ diff --git a/src/rsa.cpp b/src/rsa.cpp index 0b6ba83..b00ce4f 100755 --- a/src/rsa.cpp +++ b/src/rsa.cpp @@ -129,7 +129,7 @@ void rsa_gen_keys(struct public_key_class *pub, struct private_key_class *priv, d = d+phi_max; } - printf("primes are %lld and %lld\n",(long long)p, (long long )q); + //printf("primes are %lld and %lld\n",(long long)p, (long long )q); // We now store the public / private keys in the appropriate structs pub->modulus = max; pub->exponent = e; diff --git a/src/server.cpp b/src/server.cpp index d2890c5..27873c6 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -792,7 +792,7 @@ void SQEServer::Respond2Packet(packet &pkt, respond &res){ pkt.type = RESPOND_TYPE; pkt.address = *res.t_addr.Obj(); pkt.AddBuff((void *) &res.r_id, sizeof(rng::rng64)); - pkt.AddBuff((void *) res.t_addr.Obj(), sizeof(sockaddr_in)); + pkt.AddBuff((void *) res.t_addr.Obj(), sizeof(struct sockaddr_in)); pkt.AddBuff((void *) res.type.data(), (uint32_t)res.type.size()); pkt.AddBuff((void *) res.buff, res.buff_size); } @@ -899,6 +899,7 @@ void SQEServer::Packet2Post(packet &pkt, encrypt_post &pst, aes_key256 &key){ TMD5 = string((const char *)pkt.buffs[4].second,32); uint8_t *t_data = (uint8_t *)malloc(pst.buff_size); memcpy(t_data, pkt.buffs[3].second, pst.buff_size); + // 解密数据 struct AES_ctx naes; key.MakeIV(); @@ -1037,6 +1038,7 @@ void *clientChecker(void *args){ while (1) { if(pcltl->if_connected == false) break; if(nstcpc.SendRAW(pnrwd->msg, pnrwd->msg_size) < 0){ + //如果心跳包未被成功发送 printf("Lose Connection %s[%s]\n",pcltl->pcltr->name.data(),pcltl->pcltr->tag.data()); pcltl->if_connected = false; break;