diff --git a/include/net.h b/include/net.h index 1de3196..bd100f1 100644 --- a/include/net.h +++ b/include/net.h @@ -171,6 +171,7 @@ public: ssize_t RecvRAW(char **p_rdt, Addr &taddr); ssize_t RecvRAW_SM(char **p_rdt, Addr &taddr); void SendRespond(string &str); + }; //TCP客户端套接字类 diff --git a/include/server.h b/include/server.h index 1374329..0944d30 100644 --- a/include/server.h +++ b/include/server.h @@ -115,7 +115,7 @@ public: uint32_t info; // 信息串 char *msg = NULL; - unsigned long msg_size = 0; + unsigned long msg_size = 0; // 来源ip地址 struct sockaddr_in address; // 用简单字符串直接出适合 @@ -214,6 +214,7 @@ struct client_register{ Addr t_addr; // 守护线程ID pthread_t tid; + sqlite3 *psql; }; struct client_listen{ @@ -225,6 +226,12 @@ struct client_listen{ client_register *pcltr; }; +struct connection_info { + bool if_listen = false; + bool if_beat = false; + bool if_send = false; +}; + struct connection_listener{ int data_sfd; Addr client_addr; @@ -234,6 +241,9 @@ struct connection_listener{ SocketTCPCServer *server_cnt = nullptr; bool if_active = true; bool *pif_atv = nullptr; + void *write_buff = nullptr; + struct connection_info *p_ci = nullptr; + pthread_t *beat_pid = nullptr, *listen_pid = nullptr, *send_pid = nullptr; }; //通用服务器类 @@ -338,7 +348,7 @@ public: static void Post2Packet(packet &pkt, encrypt_post &pst, aes_key256 &key); static void Packet2Post(packet &pkt, encrypt_post &pst, aes_key256 &key); static void GetPostInfo(packet &pkt, encrypt_post &pst); - static void SendConnectionInfo(SocketTCPClient *pcnt_sock, bool ifshort); + static void SendConnectionInfo(SocketTCPClient *pcnt_sock, string type); }; //通用客户端类 diff --git a/src/controller.cpp b/src/controller.cpp index 3e7d02b..808771a 100644 --- a/src/controller.cpp +++ b/src/controller.cpp @@ -303,6 +303,7 @@ int set(string instruct, vector &configs, vector &lconfigs, vect int server(string instruct, vector &configs, vector &lconfigs, vector &targets){ initClock(); setThreadsClock(); + signal(SIGPIPE, SIG_IGN); if(targets.size() == 0){ //Server nsvr; //setServerClock(&nsvr, 3); @@ -468,55 +469,72 @@ int client(string instruct, vector &configs, vector &lconfigs, v nclt.SetPublicKey(*ppbc); sqlite3_finalize(psqlsmt); - aes_key256 naeskey; - nclt.SetAESKey(naeskey); - - string reqstr = " {\"key\":null, \"name\":null, \"tag\":null, \"sqe_key\":null, \"listen_port\": null,\"listen_ip\":null}"; - - Document reqdata; - if(reqdata.Parse(reqstr.data()).HasParseError()) throw "fail to parse into json"; - -// 生成并传递端对端加密报文密钥 - reqdata["key"].SetArray(); - Value &tmp_key = reqdata["key"]; - const uint8_t *p_key = naeskey.GetKey(); - Document::AllocatorType& allocator = reqdata.GetAllocator(); - for (int idx = 0; idx <32; idx++) { - tmp_key.PushBack(p_key[idx],allocator); - } - - - reqdata["name"].SetString(nclt.name.data(),(uint32_t)nclt.name.size()); - reqdata["tag"].SetString(nclt.tag.data(),(uint32_t)nclt.tag.size()); - reqdata["sqe_key"].SetString(nclt.sqe_key.data(), (uint32_t)nclt.sqe_key.size()); - //设置TCP监听端口 - reqdata["listen_port"].SetInt(9052); - - - //如果强制指定客户端IP地址 - string ip; - if(if_setip) ip = set_ip; - else ip = "127.0.0.1"; - - reqdata["listen_ip"].SetString(ip.data(),(uint32_t)ip.size()); - + //检测本地的注册信息 + sql_quote = "select count(name) from sqlite_master where name = \"client_register_info\""; + sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + if (sqlite3_step(psqlsmt) != SQLITE_ROW) { + sql::printError(psql); + throw "database is abnormal"; + } + int if_find = sqlite3_column_int(psqlsmt, 0); + if (if_find) { + //如果本地已经有注册信息 - //构造请求 - StringBuffer strbuff; - Writer writer(strbuff); - reqdata.Accept(writer); - string json_str = strbuff.GetString(); + } + else { + //如果本地没有注册信息 + //向主广场服务器注册 + aes_key256 naeskey; + nclt.SetAESKey(naeskey); - printf("Connecting...\n"); -// 已获得主广场服务器的密钥,进行启动客户端守护进程前的准备工作 - nclt.NewRequest(&preq, msqe_ip, msqe_port, "private request", json_str, true); - nclt.NewRequestListener(preq, 44, psql,registerSQECallback); + string reqstr = " {\"key\":null, \"name\":null, \"tag\":null, \"sqe_key\":null, \"listen_port\": null,\"listen_ip\":null}"; - //等待主广场服务器回应 - if_wait = 1; - while (if_wait == 1) { - sleep(1); - } + Document reqdata; + if (reqdata.Parse(reqstr.data()).HasParseError()) throw "fail to parse into json"; + + // 生成并传递端对端加密报文密钥 + reqdata["key"].SetArray(); + Value &tmp_key = reqdata["key"]; + const uint8_t *p_key = naeskey.GetKey(); + Document::AllocatorType& allocator = reqdata.GetAllocator(); + for (int idx = 0; idx < 32; idx++) { + tmp_key.PushBack(p_key[idx], allocator); + } + + + reqdata["name"].SetString(nclt.name.data(), (uint32_t)nclt.name.size()); + reqdata["tag"].SetString(nclt.tag.data(), (uint32_t)nclt.tag.size()); + reqdata["sqe_key"].SetString(nclt.sqe_key.data(), (uint32_t)nclt.sqe_key.size()); + //设置TCP监听端口 + reqdata["listen_port"].SetInt(9052); + + + //如果强制指定客户端IP地址 + string ip; + if (if_setip) ip = set_ip; + else ip = "127.0.0.1"; + + reqdata["listen_ip"].SetString(ip.data(), (uint32_t)ip.size()); + + + //构造请求 + StringBuffer strbuff; + Writer writer(strbuff); + reqdata.Accept(writer); + string json_str = strbuff.GetString(); + + printf("Connecting...\n"); + // 已获得主广场服务器的密钥,进行启动客户端守护进程前的准备工作 + nclt.NewRequest(&preq, msqe_ip, msqe_port, "private request", json_str, true); + nclt.NewRequestListener(preq, 44, psql, registerSQECallback); + + //等待主广场服务器回应 + if_wait = 1; + while (if_wait == 1) { + sleep(1); + } + } + //得到服务器回应 if (!if_wait) { // 成功注册 @@ -540,23 +558,45 @@ int client(string instruct, vector &configs, vector &lconfigs, v } //创建客户端连接管理线程 + pthread_t beat_pid = 0, listen_pid = 0, send_pid = 0; connection_listener *pncl = new connection_listener(); pncl->client_addr = nclt.server_cnt->GetClientAddr(); pncl->data_sfd = nclt.server_cnt->GetDataSFD(); pncl->key = nclt.post_key; pncl->father_buff = buff; pncl->server_cnt = nclt.server_cnt; + pncl->beat_pid = &beat_pid; + pncl->listen_pid = &listen_pid; + pncl->send_pid = &send_pid; + pncl->p_ci = new connection_info(); pthread_create(&pncl->pid, NULL, clientServiceDeamon, pncl); - + memset(buff, 0, sizeof(uint32_t)); while (1) { + //获得连接状态 + if (!memcmp(buff, "CIFO", sizeof(uint32_t))) { + memcpy(buff, "RCFO", sizeof(uint32_t)); + memcpy(buff+sizeof(uint32_t), pncl->p_ci, sizeof(connection_info)); + } //检测父进程信号 - if(!memcmp(buff, "Exit", sizeof(uint32_t))){ + else if(!memcmp(buff, "Exit", sizeof(uint32_t))){ pncl->if_active = false; - - pthread_join(pncl->pid, NULL); + + //注销所有主要线程 + if(pncl->p_ci->if_beat) pthread_cancel(beat_pid); + if (pncl->p_ci->if_listen) pthread_cancel(listen_pid); + if (pncl->p_ci->if_send) pthread_cancel(send_pid); + pthread_cancel(pncl->pid); + nclt.server_cnt->Close(); + //关闭所有打开的文件描述符 + int fd = 0; + int fd_limit = sysconf(_SC_OPEN_MAX); + while (fd < fd_limit) close(fd++); + + + free(pncl->p_ci); delete pncl; memcpy(buff, "SEXT", sizeof(uint32_t)); //断开共享内存连接 @@ -578,21 +618,62 @@ int client(string instruct, vector &configs, vector &lconfigs, v } usleep(1000); } - error::printSuccess("\nShell For Client: "); + error::printSuccess("\n-------------------------------\nShell For Client: \n-------------------------------\n"); string cmdstr; char cmd[1024]; while (1) { printf(">"); gets_s(cmd,1024); cmdstr = cmd; - if(cmdstr == "Exit"){ + if(cmdstr == "stop"){ error::printInfo("Start to stop service..."); memcpy(buff, "Exit", sizeof(uint32_t)); while (memcmp(buff, "SEXT", sizeof(uint32_t))) { - sleep(10000); + usleep(1000); } error::printInfo("Service stopped."); } + else if(cmdstr == "status"){ + memcpy(buff, "CIFO", sizeof(uint32_t)); + while (memcmp(buff, "RCFO", sizeof(uint32_t))) { + usleep(1000); + } + connection_info n_ci; + memcpy(&n_ci, buff + sizeof(uint32_t), sizeof(connection_info)); + memset(buff, 0, sizeof(uint32_t)); + printf("STATUS:\n"); + if (n_ci.if_beat) error::printSuccess("*Beat"); + else error::printRed("*Beat"); + if (n_ci.if_listen) error::printSuccess("*Listen"); + else error::printRed("*Listen"); + if (n_ci.if_send) error::printSuccess("*Send"); + else error::printRed("*Send"); + + + } + else if (cmdstr == "quit") { + //关闭所有打开的文件描述符 + int fd = 0; + //nclt.server_cnt->Close(); + int fd_limit = sysconf(_SC_OPEN_MAX); + while (fd < fd_limit) close(fd++); + shmdt(buff); + exit(0); + } + else if (cmdstr == "ping") { + if (memcmp(buff, "WAIT", sizeof(uint32_t))) { + raw_data nrwd; + SQEServer::BuildSmallRawData(nrwd, "PING"); + memcpy(buff, "WAIT", sizeof(uint32_t)); + memcpy(buff+sizeof(uint32_t), &nrwd.msg_size, sizeof(uint64_t)); + memcpy(buff + 3 * sizeof(uint32_t), nrwd.msg, nrwd.msg_size); + memcpy(buff + 3 * sizeof(uint32_t) + nrwd.msg_size, "TADS", sizeof(uint32_t)); + memcpy(buff, "SDAT", sizeof(uint32_t)); + Server::freeRawdataServer(nrwd); + } + + + } } } diff --git a/src/model.cpp b/src/model.cpp index 66614a0..db170e6 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -12,13 +12,13 @@ extern int if_wait; namespace error { void printError(string error_info){ - printf("\033[31mError: %s\n\033[0m",error_info.data()); + printf("\033[31mError: %s\033[0m\n",error_info.data()); } void printWarning(string warning_info){ - printf("\033[33mWarning: %s\n\033[0m",warning_info.data()); + printf("\033[33mWarning: %s\033[0m\n",warning_info.data()); } void printSuccess(string succes_info){ - printf("\033[32m%s\n\033[0m",succes_info.data()); + printf("\033[32m%s\033[0m\n",succes_info.data()); } void printRed(string red_info) { printf("\033[31m%s\n\033[0m", red_info.data()); @@ -58,11 +58,10 @@ void registerSQECallback(respond *pres,void *args){ resdoc.Parse(resjson.data()); string status = resdoc["status"].GetString(); if(status == "ok"){ - printf("Register succeed.\n"); if_wait = 0; } else{ - printf("Register Fail.\n"); + error::printError("register failed."); if_wait = -1; } } @@ -83,7 +82,6 @@ void *connectionDeamon(void *args){ SocketTCPCServer ntcps; ntcps.SetDataSFD(pcntl->data_sfd); ntcps.SetClientAddr(pcntl->client_addr); - // 获得连接的类型是长链还是断链 size = ntcps.RecvRAW_SM(&buff, t_addr); raw_data *pnrwd = new raw_data(); @@ -94,22 +92,45 @@ void *connectionDeamon(void *args){ if(Server::CheckRawMsg(buff, size)){ Server::ProcessSignedRawMsg(buff, size, *pnrwd); if(!memcmp(&pnrwd->info, "LCNT", sizeof(uint32_t))){ + //接收长连接 if_sm = false; } else if(!memcmp(&pnrwd->info, "SCNT", sizeof(uint32_t))){ + //接收短连接 if_sm = true; ntcps.SendRespond(dget); } else if(!memcmp(&pnrwd->info, "CNTL", sizeof(uint32_t))){ - if_sm = true; - //printf("Listen Connection From Server\n"); - - ntcps.CloseConnection(); - pthread_exit(NULL); + //发送长连接 + if_sm = false; + pcntl->p_ci->if_listen = true; + *pcntl->listen_pid = pcntl->pid; + pcntl->write_buff = pcntl->father_buff; + while (1) { + if (*pcntl->pif_atv == false) { + close(pcntl->data_sfd); + pcntl->p_ci->if_listen = false; + delete pcntl; + pthread_exit(NULL); + } + if (!memcmp(pcntl->write_buff, "SDAT", sizeof(uint32_t))) { + uint32_t nsrwd_size = 0; + Byte buff[BUFSIZ]; + memcpy(&nsrwd_size, pcntl->write_buff + sizeof(uint32_t), sizeof(uint32_t)); + if (!memcmp(pcntl->write_buff + 3 * sizeof(uint32_t) + nsrwd_size, "TADS", sizeof(uint32_t))) { + memcpy(buff, pcntl->write_buff + 3 * sizeof(uint32_t), nsrwd_size); + send(pcntl->data_sfd, buff, nsrwd_size, 0); + } + else error::printError("buffer error."); + memset(pcntl->write_buff, 0, sizeof(uint32_t)); + } + usleep(1000); + + } } else{ //断开无效连接 - printf("Connection Illegal.\n"); + printf("Connection Info Illegal.\n"); delete pnrwd; close(pcntl->data_sfd); delete pcntl; @@ -123,6 +144,7 @@ void *connectionDeamon(void *args){ delete pnrwd; pthread_exit(NULL); } + Server::freeRawdataServer(*pnrwd); delete pnrwd; while (1) { @@ -134,7 +156,6 @@ void *connectionDeamon(void *args){ //区分长连接与短连接 if(if_sm) size = ntcps.RecvRAW(&buff, t_addr); else size = ntcps.RecvRAW_SM(&buff, t_addr); - if(size > 0){ raw_data *pnrwd = new raw_data(); packet *nppkt = new packet(); @@ -149,8 +170,10 @@ void *connectionDeamon(void *args){ if(!memcmp(&pncryp->type, "JRES", sizeof(uint32_t))){ //自我解析 pncryp->SelfParse(); + printf("Register Status: "); if(pncryp->edoc["status"].GetString() == string("ok")){ - error::printSuccess("Register Successful."); + error::printSuccess("Succeed"); + error::printInfo("\nStart Command Line Tools...\n"); //进入客户端管理终端 memcpy(pcntl->father_buff,"D_OK", sizeof(uint32_t)); } @@ -163,19 +186,16 @@ void *connectionDeamon(void *args){ } //心跳连接 else if(!memcmp(&pnrwd->info, "BEAT", sizeof(uint32_t))){ - //printf("Connection Beated.\n"); + + if (!pcntl->p_ci->if_beat) { + pcntl->p_ci->if_beat = true; + *pcntl->beat_pid = pcntl->pid; + + } } Server::freeRawdataServer(*pnrwd); Server::freePcaketServer(*nppkt); } - else if(size < 0){ - //printf("Lost Connection From Server.\n"); - delete pnrwd; - delete pncryp; - delete nppkt; - delete pcntl; - break; - } free(buff); delete pnrwd; delete pncryp; @@ -190,18 +210,25 @@ void *connectionDeamon(void *args){ void *clientServiceDeamon(void *arg) { connection_listener *pclst = (connection_listener *)arg; + while (1) { if (pclst->if_active == false) { break; } + //接受新连接 pclst->server_cnt->Accept(); - //构造连接守护子进程 + + //构造连接守护子线程 connection_listener *pncl = new connection_listener(); pncl->client_addr = pclst->client_addr; pncl->data_sfd = pclst->server_cnt->GetDataSFD(); pncl->key = pclst->key; pncl->father_buff = pclst->father_buff; pncl->pif_atv = &pclst->if_active; + pncl->p_ci = pclst->p_ci; + pncl->beat_pid = pclst->beat_pid; + pncl->listen_pid = pclst->listen_pid; + pncl->send_pid = pclst->send_pid; pthread_attr_t attr; pthread_attr_init(&attr); diff --git a/src/server.cpp b/src/server.cpp index f3f598f..6da6d5e 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -667,6 +667,28 @@ SQEServer::SQEServer(int port):Server(port){ error::printSuccess("Server Name: "+name); error::printSuccess("Listen Port: " + std::to_string(port)); + sql_quote = "select count(name) from sqlite_master where name = \"register_info\""; + sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + if (sqlite3_step(psqlsmt) != SQLITE_ROW) { + sql::printError(psql); + throw "database is abnormal"; + } + int if_find = sqlite3_column_int(psqlsmt,0); + if (!if_find) { + sql::table_create(psql, "register_info", { + {"name","TEXT"}, + {"tag","TEXT"}, + {"client_id","INT"}, + {"key","NONE"}, + {"ip","TEXT"}, + {"udp_port","INT"}, + {"tcp_port","INT"}, + {"passwd","INT"}, + {"status","INT"} + }); + error::printInfo("create table register_info."); + } + sqlite3_finalize(psqlsmt); } void SQEServer::Packet2Request(packet &pkt, request &req){ @@ -749,8 +771,9 @@ void SQEServer::ProcessRequset(void){ pclr->t_addr.SetIP(jdoc["listen_ip"].GetString()); pclr->t_addr.SetPort(port); pclr->passwd = rand64(); + pclr->psql = psql; // 联络线程生命周期 - pclr->click = 9999; + pclr->click = 99999; //if(pthread_mutex_lock(&mutex_cltreg) != 0) throw "lock error"; client_lst.insert({pclr->client_id,pclr}); //pthread_mutex_unlock(&mutex_cltreg); @@ -766,7 +789,33 @@ void SQEServer::ProcessRequset(void){ pncr->arg = (void *)pclr; newClock(pncr); rids.insert({preq->r_id,pclr}); - + + //写入注册信息到数据库 + sqlite3_stmt *psqlsmt; + sql::insert_info(psql, &psqlsmt, "register_info",{ + {"name","?1"}, + {"tag","?2"}, + {"client_id","?3"}, + {"key","?4"}, + {"ip","?5"}, + {"passwd","?6"}, + {"tcp_port","?7"}, + {"udp_port","?8"}, + {"status","?9"} + }); + + sqlite3_bind_text(psqlsmt, 1, pclr->name.data(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(psqlsmt, 2, pclr->tag.data(), -1, SQLITE_TRANSIENT); + sqlite3_bind_int64(psqlsmt, 3, pclr->client_id); + sqlite3_bind_blob(psqlsmt, 4, (const void *)pclr->key.key, sizeof(uint64_t) * 4, SQLITE_TRANSIENT); + sqlite3_bind_text(psqlsmt, 5, jdoc["listen_ip"].GetString(), -1, SQLITE_TRANSIENT); + sqlite3_bind_int64(psqlsmt, 6, pclr->passwd); + sqlite3_bind_int(psqlsmt, 7, port); + sqlite3_bind_int(psqlsmt, 8, preq->recv_port); + sqlite3_bind_int(psqlsmt, 9, 0); + sqlite3_step(psqlsmt); + sqlite3_finalize(psqlsmt); + } // 构建回复包 respond *pnr = new respond(); @@ -787,6 +836,84 @@ void SQEServer::ProcessRequset(void){ pthread_mutex_unlock(&mutex_sndpkt); } + else if(preq->type == "client login"){ + preq->req_doc.Parse(preq->data.data()); + Document &ldoc = preq->req_doc; + + sqlite3_stmt *psqlsmt; + const char *pzTail; + client_register *pclr = new client_register(); + pclr->client_id = ldoc["client_id"].GetInt64(); + string sql_quote = "select * from register_info where client_id = ?1;"; + sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + sqlite3_bind_text(psqlsmt, 1, std::to_string(pclr->client_id).data(), -1, SQLITE_TRANSIENT); + if (sqlite3_step(psqlsmt) == SQLITE_DONE) { + error::printInfo("Invaild login."); + delete pclr; + sqlite3_finalize(psqlsmt); + } + else { + pclr->passwd = sqlite3_column_int64(psqlsmt, 7); + if (pclr->passwd == ldoc["passwd"].GetInt64()) { + if (sqlite3_column_int(psqlsmt, 8) == 0) { + //完善客户端登录管理结构体 + pclr->name = (const char *)sqlite3_column_text(psqlsmt, 0); + pclr->tag = (const char *)sqlite3_column_text(psqlsmt, 1); + pclr->t_addr.SetIP((const char *)sqlite3_column_text(psqlsmt, 4)); + pclr->t_addr.SetPort(sqlite3_column_int(psqlsmt, 6)); + pclr->psql = psql; + memcpy((void *)pclr->key.GetKey(), sqlite3_column_blob(psqlsmt, 3), sizeof(uint64_t) * 4); + printf("Login successfully %s[%s]:%s\n", pclr->name.data(), pclr->tag.data()); + + //注册客户端联络守护进程 + clock_register *pncr = new clock_register(); + pncr->if_thread = true; + pncr->if_reset = false; + pncr->click = 64; + pncr->rawclick = 0; + pncr->func = clientWaitDeamon; + pncr->arg = (void *)pclr; + newClock(pncr); + rids.insert({ preq->r_id,pclr }); + + sqlite3_finalize(psqlsmt); + sql_quote = "update register_info set status = 1 where client_id = ?1"; + sqlite3_prepare(psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + sqlite3_bind_int64(psqlsmt, 1, pclr->client_id); + sqlite3_step(psqlsmt); + sqlite3_finalize(psqlsmt); + } + else { + sqlite3_finalize(psqlsmt); + delete pclr; + } + + //构建回复包 + respond *pnr = new respond(); + pnr->r_id = preq->r_id; + string res_data = "{\"status\":\"ok\"}"; + pnr->SetBuff((Byte *)res_data.data(), (uint32_t)res_data.size()); + pnr->type = "register respond"; + pnr->t_addr = preq->t_addr; + pnr->t_addr.SetPort(preq->recv_port); + packet *pnpkt = new packet(); + Respond2Packet(*pnpkt, *pnr); + delete pnr; + + //将标准数据包添加到发送列表 + if (pthread_mutex_lock(&mutex_sndpkt) != 0) throw "lock error"; + packets_out.push_back(pnpkt); + pthread_mutex_unlock(&mutex_sndpkt); + + } + else { + error::printInfo("Wrong password"); + delete pclr; + sqlite3_finalize(psqlsmt); + } + } + + } delete preq; preq = nullptr; } @@ -1062,7 +1189,7 @@ void *clientChecker(void *args){ break; } else{ - sleep(3); + sleep(1); } } pthread_exit(NULL); @@ -1076,54 +1203,61 @@ void encrypt_post::SelfParse(void) { void *clientListener(void *args){ client_listen *pcltl = (client_listen *)args; char *buff; - Addr taddr; - while(1){ -// 如果连接断开 - if(pcltl->if_connected == false) break; + Addr taddr; // 建立新的监听连接 - try { - pcltl->ptcps->Reconnect(); + try { + pcltl->ptcps->Reconnect(); + } + catch (const char *errinfo) { + string errstr = errinfo; + if (errstr == "fail to connect") { + pcltl->if_connected = false; + pthread_exit(NULL); } - catch (const char *errinfo) { - string errstr = errinfo; - if (errstr == "fail to connect") { - pcltl->if_connected = false; - break; - } - } -// 说明连接类型 - raw_data nsrwd; - SQEServer::BuildSmallRawData(nsrwd, "CNTL"); - pcltl->ptcps->SendRAW(nsrwd.msg, nsrwd.msg_size); - Server::freeRawdataServer(nsrwd); - ssize_t size = pcltl->ptcps->RecvRAW(&buff, taddr); - if(size > 0){ - if(Server::CheckRawMsg(buff, size)){ - raw_data nrwd; - Server::ProcessSignedRawMsg(buff, size, nrwd); + } + //发送连接属性信息 + SQEServer::SendConnectionInfo(pcltl->ptcps, "CNTL"); + + while (1) { + // 如果连接断开 + if (pcltl->if_connected == false) break; + ssize_t size = pcltl->ptcps->RecvRAW(&buff, taddr); + if (size > 0) { + if (Server::CheckRawMsg(buff, size)) { + raw_data nrwd; + Server::ProcessSignedRawMsg(buff, size, nrwd); //如果二进制串中储存端对端加密报文 - if(!memcmp(&nrwd.info,"ECYP",sizeof(uint32_t))){ + if (!memcmp(&nrwd.info, "ECYP", sizeof(uint32_t))) { encrypt_post necryp; SQEServer::SignedRawData2Post(nrwd, necryp, pcltl->pcltr->key); - if(!memcmp(&necryp.type,"JIFO",sizeof(uint32_t))){ + if (!memcmp(&necryp.type, "JIFO", sizeof(uint32_t))) { necryp.SelfParse(); - printf("Client %s[%s] Send Encrypt Post(JSON).\n",pcltl->pcltr->name.data(),pcltl->pcltr->tag.data()); - uint64_t pwd = necryp.edoc["pwdmd5"].GetInt64(); - if(pwd == pcltl->pcltr->passwd){ - printf("Password Check Passed.\n"); - } - else{ - printf("Wrong Password.\n"); - } - } - //necryp.FreeBuff(); - } - Server::freeRawdataServer(nrwd); - } - free(buff); - } - } + printf("Client %s[%s] Send Encrypt Post(JSON).\n", pcltl->pcltr->name.data(), pcltl->pcltr->tag.data()); + uint64_t pwd = necryp.edoc["pwdmd5"].GetInt64(); + if (pwd == pcltl->pcltr->passwd) { + printf("Password Check Passed.\n"); + } + else { + printf("Wrong Password.\n"); + } + } + //necryp.FreeBuff(); + } + else if (!memcmp(&nrwd.info, "PING", sizeof(uint32_t))) { + error::printInfo("client ping."); + + } + Server::freeRawdataServer(nrwd); + } + free(buff); + } + else if(size < 0){ + pcltl->if_connected == false; + break; + } + usleep(1000); + } pthread_exit(NULL); } bool resFromClient(SocketTCPClient *pcnt_sock){ @@ -1156,11 +1290,10 @@ void SQEServer::Post2SignedRawData(encrypt_post &ecyp, aes_key256 &key, raw_data Server::SignedRawdata(&rw,"ECYP"); } -void SQEServer::SendConnectionInfo(SocketTCPClient *pcnt_sock, bool ifshort) { +void SQEServer::SendConnectionInfo(SocketTCPClient *pcnt_sock, string type) { raw_data nsrwd; //说明连接类型 - if(ifshort) SQEServer::BuildSmallRawData(nsrwd, "SCNT"); - else SQEServer::BuildSmallRawData(nsrwd, "LCNT"); + SQEServer::BuildSmallRawData(nsrwd, type.data()); pcnt_sock->SendRAW(nsrwd.msg, nsrwd.msg_size); Server::freeRawdataServer(nsrwd); } @@ -1170,6 +1303,7 @@ void *clientWaitDeamon(void *pvclt){ client_register *pclr = (client_register *)pcti->args; SocketTCPClient *pcnt_sock = nullptr; bool if_success = false; + //尝试主动连接客户端 printf("Connecting client %s[%s].\n",pclr->name.data(),pclr->tag.data()); for(int i = 0; i < 8; i++){ if(tryConnection(&pcnt_sock, pclr)){ @@ -1181,6 +1315,14 @@ void *clientWaitDeamon(void *pvclt){ } if(!if_success){ printf("Fail To Get Register.\n"); + //更新登录信息 + string sql_quote = "update register_info set status = 0 where client_id = ?1"; + sqlite3_stmt *psqlsmt; + const char *pzTail; + sqlite3_prepare(pclr->psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + sqlite3_bind_int64(psqlsmt, 1, pclr->client_id); + sqlite3_step(psqlsmt); + sqlite3_finalize(psqlsmt); delete pclr; clockThreadFinish(pcti->tid); pthread_exit(NULL); @@ -1208,7 +1350,7 @@ void *clientWaitDeamon(void *pvclt){ SQEServer::Post2SignedRawData(*ncryp, pclr->key, *pnrwd); //发送连接属性信息 - SQEServer::SendConnectionInfo(pcnt_sock,true); + SQEServer::SendConnectionInfo(pcnt_sock,"SCNT"); //等待反馈 if (resFromClient(pcnt_sock)) { pcnt_sock->SendRAW(pnrwd->msg, pnrwd->msg_size); @@ -1216,13 +1358,21 @@ void *clientWaitDeamon(void *pvclt){ } else { //注册连接未被识别 - error::printError("Client connection error."); + error::printError("client connection error."); delete pclr; delete pnrwd; delete ncryp; clockThreadFinish(pcti->tid); pthread_exit(NULL); } + + string sql_quote = "update register_info set status = 1 where client_id = ?1"; + sqlite3_stmt *psqlsmt; + const char *pzTail; + sqlite3_prepare(pclr->psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + sqlite3_bind_int64(psqlsmt, 1, pclr->client_id); + sqlite3_step(psqlsmt); + sqlite3_finalize(psqlsmt); //建立客户端连接管理信息提供结构 client_listen *pcltl = new client_listen(); @@ -1243,6 +1393,15 @@ void *clientWaitDeamon(void *pvclt){ sleep(1); if(pcltl->if_connected == false){ printf("Register lost %s[%s]\n",pclr->name.data(),pclr->tag.data()); + //更新登录信息 + string sql_quote = "update register_info set status = 0 where client_id = ?1"; + sqlite3_stmt *psqlsmt; + const char *pzTail; + sqlite3_prepare(pclr->psql, sql_quote.data(), -1, &psqlsmt, &pzTail); + sqlite3_bind_int64(psqlsmt, 1, pclr->client_id); + sqlite3_step(psqlsmt); + sqlite3_finalize(psqlsmt); + //服务端销户? break; } } diff --git a/src/socket.cpp b/src/socket.cpp index 731ba16..8b7080f 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -120,12 +120,38 @@ ssize_t SocketUDPClient::SendRAW(char *buff, unsigned long size){ } ssize_t SocketTCPClient::SendRAW(char *buff, unsigned long size){ - ssize_t send_size = send(client_sfd, buff, size, 0); - /*if(send_size < 0){ - printf("Error[%u]:",errno); - perror("send"); - }*/ - return send_size; + //对于长数据进行分段发送 + if (size > 1023) { + ssize_t idx = 0, nidx = 0; + Byte vbuff[1024], gbuff[1024]; + while (idx < size-1) { + if (!idx) memcpy(vbuff, "DSAT", sizeof(uint32_t)); + else memcpy(vbuff, "DCST", sizeof(uint32_t)); + if (idx + 1000 < size - 1) { + nidx = idx + 1000; + memcpy(vbuff + sizeof(uint32_t) + nidx - idx + 1, "DCTN", sizeof(uint32_t)); + } + else { + nidx = size - 1; + memcpy(vbuff + sizeof(uint32_t) + nidx - idx + 1, "DFSH", sizeof(uint32_t)); + } + memcpy(vbuff + sizeof(uint32_t), buff + idx, nidx - idx + 1); + + + send(client_sfd, vbuff, 2 * sizeof(uint32_t) + nidx - idx + 1, 0); + int grtn = recv(client_sfd, gbuff, BUFSIZ,0); + if (grtn > 0 && !memcmp(gbuff, "DGET", sizeof(uint32_t))); + else { + return -1; + } + idx = nidx + 1; + } + return size; + } + else { + ssize_t send_size = send(client_sfd, buff, size, 0); + return send_size; + } } void SocketClient::SetSendSockAddr(struct sockaddr_in tsi){ @@ -136,15 +162,42 @@ void SocketTCPClient::Close(void){ close(client_sfd); } +//长连接数据接收 ssize_t SocketTCPCServer::RecvRAW_SM(char **p_rdt, Addr &taddr){ - ssize_t tmp_bdtas = recv(data_sfd, buff, BUFSIZ, 0); - if (tmp_bdtas > 0) { - *p_rdt = (char *)malloc(tmp_bdtas); - memcpy(*p_rdt, buff, tmp_bdtas); - } - return tmp_bdtas; + ssize_t bdtas = 0, tmp_bdtas; + *p_rdt = nullptr; + bool dsat = false, dfsh = false, if_signal = false; + while (!dfsh && (tmp_bdtas = recv(data_sfd, buff, BUFSIZ, 0)) > 0) { + if (!memcmp(buff, "NETC", sizeof(uint32_t))) { + dsat = true; + dfsh = true; + if_signal = true; + } + if (!memcmp(buff, "DSAT", sizeof(uint32_t))) dsat = true; + if (!memcmp(buff+tmp_bdtas-sizeof(uint32_t), "DFSH", sizeof(uint32_t))) dfsh = true; + if (dsat) { + send(data_sfd, "DGET", sizeof(uint32_t), 0); + if (*p_rdt == nullptr) { + if (if_signal) { + *p_rdt = (char *)malloc(tmp_bdtas); + memcpy(*p_rdt, buff, tmp_bdtas); + bdtas += tmp_bdtas; + continue; + } + *p_rdt = (char *)malloc(tmp_bdtas - 2 * sizeof(uint32_t)); + memcpy(*p_rdt, buff + sizeof(uint32_t), tmp_bdtas - 2 * sizeof(uint32_t)); + } + else { + *p_rdt = (char *)realloc(*p_rdt, bdtas + tmp_bdtas - 2 * sizeof(uint32_t)); + memcpy(*p_rdt + bdtas, buff + sizeof(uint32_t), tmp_bdtas - 2 * sizeof(uint32_t)); + } + } + bdtas += tmp_bdtas; + } + return bdtas; } +//短连接数据接收 ssize_t SocketTCPCServer::RecvRAW(char **p_rdt, Addr &taddr){ ssize_t bdtas = 0 ,tmp_bdtas; *p_rdt = nullptr; @@ -178,7 +231,7 @@ void SocketTCPCServer::CloseConnection(void){ close(data_sfd); } void SocketTCPCServer::Close(void) { - close(server_sfd); + shutdown(server_sfd, SHUT_RDWR); }