c++ websocket 协议分析与实现

2023-12-13 23:44:34

前言
网上有很多第三方库,nopoll,uwebsockets,libwebsockets,都喜欢回调或太复杂,个人只需要在后端用,所以手动写个;

1:环境
ubuntu18
g++(支持c++11即可)
第三方库:jsoncpp,openssl
2:安装
jsoncpp 读取json 配置文件 用 自动安装 网上一堆教程
openssl 如果系统没带,需要安装下 sudo apt-get install openssl 一般是1.1版本 够用了
3:websocket server
1> 主要就用到 epoll 模式(io_uring 更好点,就是内核版本要高点),3个进程 主进程作为监控进程 2个子进程 一个network进程 一个 logic 进程
2> 子进程间 主要通过共享内存 加socketpair 通知 交换数据
在这里插入图片描述
在这里插入图片描述
3>websocket 握手协议 先看例子
在这里插入图片描述
上前端代码 html

<!DOCTYPE HTML>  
<html>  
<head>  
    <meta http-equiv="content-type" content="text/html" />  
    <meta name="author" content="https://github.com/" />  
    <title>websocket test</title>  
    <script>
		var socket;  
		function Connect(){  
			try{  
				socket=new WebSocket('ws://192.168.1.131:9000'); //'ws://192.168.1.131:9000');  
			}catch(e){  
				alert('error catch'+e);  
				return;  
			}  
			socket.onopen = sOpen;  
			socket.onerror = sError;
			socket.onmessage= sMessage;
			socket.onclose= sClose;
		}  
		function sOpen(){
			alert('connect success!');
		}
		function sError(e){
			alert("[error] " + e);
			//writeObj(e);
		}
		function sMessage(msg){ 
			if(typeof(msg) == 'object'){
				//let json = JSON.stringify(msg);
				//console.log('server says:' +json);
				//writeObj(msg);
				if(msg.data){  //msg.hasOwnProperty('data')
					console.log('server says'+msg.data);
				}else{
					writeObj(msg);
					//console.log('[1]server says'+msg.data);
				}
			}else{
				alert('server says:' + msg);  
			}
			
		}
		function sClose(e){
			alert("connect closed:" + e.code);
		}  
		function Send(){
			socket.send(document.getElementById("msg").value);
		} 
		function Close(){
			socket.close();
		} 
		function writeObj(obj){ 
			var description = ""; 
			for(var i in obj){ 
			var property=obj[i]; 
			description+=i+" = "+property+"\n"; 
			} 
			alert(description); 
		}
	</script>
</head>  
   
<body>  
<input id="msg" type="text">  
<button id="connect" onclick="Connect();">Connect</button>  
<button id="send" onclick="Send();">Send</button>  
<button id="close" onclick="Close();">Close</button>

</body>  
   
</html>  

在Microsoft Edge 运行结果
在这里插入图片描述
golang 前端代码如下

package main

import (
	"fmt"
	"golang.org/x/net/websocket"
	"log"
	"strings"
)


var origin = "http://192.168.1.131:9000"  
//var url = "ws://192.168.1.131:7077/websocket"
var url = "wss://192.168.1.131:9000/websocket"
func main() {
	ws, err := websocket.Dial(url, "", origin)
	if err != nil {
		log.Fatal(err)
	}

	// send text frame
	var message2 = "hello"
	websocket.Message.Send(ws, message2)
	fmt.Printf("Send: %s\n", message2)
	// receive text frame
	var message string
	websocket.Message.Receive(ws, &message)
	fmt.Printf("Receive: %s\n", message)

	for true {
		fmt.Printf("please input string:")
		var inputstr string
		fmt.Scan(&inputstr)
		if(strings.Compare(inputstr,"quit") == 0){
			break
		}else{
			websocket.Message.Send(ws, inputstr)
			fmt.Printf("Send: %s\n", inputstr)

			var output string
			websocket.Message.Receive(ws, &output)
			fmt.Printf("Receive: %s\n", output)
		}

	}
	ws.Close()//关闭连接
	fmt.Printf("client exit\n")
}

测试结果
在这里插入图片描述

server 握手代码
在这里插入图片描述

int  c_WebSocket::recv_handshake() {
    int n, len, ret;
    uint32_t pos = 0;
    uint16_t u16msglen = 0;
    const bool bssl = isSsl();
    if (bssl) {
        n = SSL_read(m_ssl, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos);
    }
    else {
        n = recv(m_fd, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos, 0);
    }
   
    if (n > 0) {
        if (m_is_closed) {
            m_recv_pos = 0;
            return true;
        }
        m_recv_pos += n;
        m_recv_buf[m_recv_pos] = 0;
        printf("recv %d handshake %s len=%d recvlen=%d", m_id, m_recv_buf,n,m_recv_pos);
        // goto READ;
      //  int32_t pos = 0;
        for (;;) {
            if (m_recv_pos >= c_u16MinHandShakeSize)  //消息头
            {
                //  \r 0x0D \n  0xA
                const  int nRet = fetch_http_info((char*)m_recv_buf, m_recv_pos);
                if (1 == nRet) {  //ok

                    //  f(strcasecmp(header, "Sec-Websocket-Protocol") == 0)
                    //      conn->accepted_protocol = value;
                    // 
                    std::map<std::string, std::string>::iterator it1 = m_map_header.find("Sec-WebSocket-Key"); //一般固定24个字节
                    std::map<std::string, std::string>::iterator it2 = m_map_header.find("Sec-WebSocket-Protocol");//
                    int map_size = m_map_header.size();
                    if (it1 != m_map_header.end()) {
                        printf("key=%s value=%s %d \n", it1->first.c_str(), it1->second.c_str(), map_size);
                    }
                    else {
                        return -1;
                    }

                    char acceptvalue[1024] = { 0, };
                    uint32_t  value_len = it1->second.length();
                    memcpy(acceptvalue, it1->second.c_str(), value_len);
                    //  memcpy(accept_key, websocket_key, key_length);
#define MAGIC_KEY "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
                    memcpy(acceptvalue + value_len, MAGIC_KEY, 36);
                    acceptvalue[value_len + 36] = 0;

                    unsigned char md[SHA_DIGEST_LENGTH];
                    SHA1((unsigned char*)acceptvalue, strlen(acceptvalue), md);
                    std::string server_key = base64_encode(reinterpret_cast<const unsigned char*>(md), SHA_DIGEST_LENGTH);
                   
                    char rep_handshake[1024] = { 0, };
                    memset(rep_handshake, 0, sizeof(rep_handshake));
                    if (it2 != m_map_header.end()) {
                        //子协议
                        char szsub_protocol[512] = { 0, };
                        std::size_t pos_t = it2->second.find(",");
                        if (pos_t != std::string::npos && pos_t < 512) {

                            memcpy(szsub_protocol, it2->second.c_str(), pos_t);
                            sprintf(rep_handshake, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\nSec-WebSocket-Protocol: %s\r\n\r\n",
                                server_key.c_str(), szsub_protocol);
                        }
                        else {
                            return -1;
                        }
                    }
                    else {
                        sprintf(rep_handshake, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\n\r\n",
                            server_key.c_str());
                    }
                    m_recv_pos = 0;
                    set_handshake_ok(); //握手完毕
                    send_pkg((uint8_t*)rep_handshake, strlen(rep_handshake));
                    break;
                }
                else {
                    printf("fetch_http_info nRet=%d \n ", nRet);
                    return -2;
                }
            }
            else {
                break;
            }
        }//end for

        if (pos != 0 && m_recv_pos > 0) {
            memcpy(m_recv_buf, m_recv_buf + pos, m_recv_pos);
        }
    }
    else {

        if (bssl) {
            //EAGAIN或EWOULDBLOCK二者应该是一样的,对应的错误码是11
            //ret = SSL_get_error(m_ssl, n);//int ssl_error = SSL_get_verify_result(ssl);
            //if (SSL_ERROR_WANT_READ == ret || SSL_ERROR_WANT_WRITE == ret) return true;
            SSL_ERROR_NONE == ret  n>0 ok  other error
            //printf("SSL_get_error(%d %d %d)\n", n, ret, errno);//SSL_get_error(-1 1,11)
            //return false;
            int ret = ssl_check_error(m_ssl, n);
            printf("SSL_get_error(%d %d %d %d)\n", n, ret, errno, m_recv_pos);//SSL_get_error(-1 -1 11)
            if (ret == -2) {
                return true;
            }
            if (errno == EAGAIN || errno == EINTR) {
                return true;
            }
            return false;
        }
        else {
            if (n == 0)
                return false;

            if (errno == EAGAIN || errno == EINTR) {
                return true;
            }
            else {
                return false;
            }
        }
        
    }
    return  true;
}

int c_WebSocket::fetch_http_info(char* recv_buf, const uint32_t buf_len) {

    //  \r 0x0D \n  0xA
    const uint32_t max_len = 1024;
    char bufline[max_len] = { 0, };
    uint32_t  bufpos = 0;
    uint8_t   ustate = 0;
    //std::map<std::string, std::string> map_header;
    char szsubhead[max_len] = { 0, };
    for (uint32_t i = 0; i < buf_len; i++) {
        bufline[bufpos++] = recv_buf[i];
        if (bufpos >= max_len) return  -1;
        if (recv_buf[i] == 0x0A) {
            bufline[bufpos] = 0;
            if (0 == ustate) { //GET /websocket HTTP/1.1

                if (ws_strncmp(bufline, "GET ", 4)) {
                    if (bufpos < 15) {
                        return -1;
                    }
                    //the get url must have a minimum size: GET / HTTP/1.1\r\n 16 (15 if only \n)
                    //return nopoll_cmp (buffer + iterator, "HTTP/1.1\r\n") || nopoll_cmp (buffer + iterator, "HTTP/1.1\n");
                  //  char* pos1 = strstr(bufline, "HTTP/1.1\r\n");
                  ///  char* pos2 = strstr(bufline, "HTTP/1.1\n");
                    int32_t  nhttp1_1_pos = (int32_t)(bufpos - 2 - 8); //HTTP/1.1  8BYTE   HTTP/1.1\r\n  //H的位置
                    if (bufline[bufpos - 2] != '\r') {
                        nhttp1_1_pos += 1;//HTTP/1.1\n
                    }

                    const  int32_t ucopylen = nhttp1_1_pos - 1 - 4; // -1 http前的空格  -4 是GET空格 的长度
                    if (ucopylen > 0 && ucopylen < 128) { //   /websocket 长度
                        memcpy(szsubhead, bufline + 4, ucopylen);
                        szsubhead[ucopylen] = 0;
                    }
                    else {
                        return -3;
                    }

                }
                else {
                    return -1;
                }
                ustate = 1;
                bufpos = 0;
            }
            else {
                //if (buffer_size == 2 && nopoll_ncmp (buffer, "\r\n", 2))
                if (2 == bufpos && ws_strncmp(bufline, "\r\n", 2)) {//握手协议结尾
                    ustate == 2;
                    //检查最基本的握手协议
                    // Connection: Upgrade
               //     Host: 192.168.1.2 : 8080
              //      Sec - WebSocket - Key : 821VqJT7EjnceB8m7mbwWA ==
              //      Sec - WebSocket - Version : 13
              //      Upgrade : websocket
                    // ensure we have all minumum data 
                    std::map<std::string, std::string>::iterator it1 = m_map_header.find("Upgrade");//固定  websocket
                    std::map<std::string, std::string>::iterator it2 = m_map_header.find("Connection");  //固定 Upgrade
                    std::map<std::string, std::string>::iterator it3 = m_map_header.find("Sec-WebSocket-Version");
                    std::map<std::string, std::string>::iterator it4 = m_map_header.find("Sec-WebSocket-Key"); //一般固定24个字节
                    const bool  bcheckOrigin = false; //浏览器必须有,其他可能没有
                    std::map<std::string, std::string>::iterator it5 = m_map_header.find("Origin"); //

                    if (it1 != m_map_header.end() && ws_strncmp(it1->second.c_str(), "websocket", 9) &&
                        it2 != m_map_header.end() && ws_strncmp(it2->second.c_str(), "Upgrade", 7) &&
                        it3 != m_map_header.end() && ws_strncmp(it3->second.c_str(), "13", 2) &&
                        it4 != m_map_header.end() && it4->second.length() > 12 &&
                        (bcheckOrigin == (bcheckOrigin && it5 != m_map_header.end()))) { //其他字段忽略了
                        return 1;
                    }
                    return -6;

                }
                else {
                    char* pos1 = strstr(bufline, ":");
                    if (pos1 != nullptr) {
                        //  std::string key = header.substr(0, end);
                        //  std::string value = header.substr(end + 2);
                        int32_t  key_len = pos1 - bufline;
                        int32_t  value_len = bufpos - key_len - 1 - 1;
                        if (key_len > 1 && value_len > 1) {
                            bufline[key_len] = 0;
                            std::string key = bufline;
                            if (bufline[bufpos - 1] == '\n') {
                                bufline[bufpos - 1] = 0;
                                //   --value_len;
                            }
                            if (bufline[bufpos - 2] == '\r') {
                                bufline[bufpos - 2] = 0;
                                //   --value_len;
                            }
                            std::string value = &bufline[key_len + 2];
                            m_map_header[key] = value;
                        }
                        else {
                            return -4;
                        }



                    }
                    else {
                        return -4;
                    }
                    bufpos = 0;
                }
            }

        }

    }

    return 0;
}

握手请求与回复
Origin: http://192.168.1.131:9000 : 原始的协议和URL
Connection: Upgrade:表示要升级协议了
Upgrade: websocket:表示要升级到 WebSocket 协议;
Sec-WebSocket-Version: 13:表示 WebSocket 的版本。如果服务端不支持该版本,需要返回一个 Sec-WebSocket-Versionheader ,里面包含服务端支持的版本号
Sec-WebSocket-Key:与后面服务端响应首部的 Sec-WebSocket-Accept 是配套的,提供基本的防护,比如恶意的连接,或者无意的连接
服务端响应协议升级
HTTP/1.1 101 Switching Protocols
Connection:Upgrade
Upgrade: websocket
Sec-WebSocket-Accept: Oy4NRAQ13jhfONC7bP8dTKb4PTU=
HTTP/1.1 101 Switching Protocols: 状态码 101 表示协议切换
Sec-WebSocket-Accept:根据客户端请求首部的 Sec-WebSocket-Key 计算出来
将 Sec-WebSocket-Key 跟 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 拼接。
通过 SHA1 计算出摘要,并转成 base64 字符串。计算公式如下:
Base64(sha1(Sec-WebSocket-Key + 258EAFA5-E914-47DA-95CA-C5AB0DC85B11))
Connection:Upgrade:表示协议升级
Upgrade: websocket:升级到 websocket 协议

4:接受数据帧
在这里插入图片描述
代码如下

int  c_WebSocket::recv_dataframe() {
    int n, len,ret;
    //   uint32_t pos = 0;
    uint16_t u16msglen = 0;
    const bool bssl = isSsl();
    if (isSsl()) {
        n = SSL_read(m_ssl, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos);
    }
    else {
        n = recv(m_fd, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos, 0);
    }
  //  n = recv(m_fd, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos, 0);
    if (n > 0) {
        if (m_is_closed) {
            m_recv_pos = 0;
            return true;
        }
        m_recv_pos += n;

        // goto READ;
        int32_t pos = 0;
        for (;;) {
            if (m_recv_pos >= c_u16MsgHeadSize)  //消息头 2个字节
            {
                int t = parse_dataframe(m_recv_buf + pos, m_recv_pos);
                if (t < 0) return false;
                else if (0 == t) break;

                pos += t;
                m_recv_pos -= t; u16msglen + c_u16MsgHeadSize; //sub one packet len
                //  pos += u16msglen + c_u16MsgHeadSize;
            }
            else {
                break;
            }
        }//end for

        if (pos != 0 && m_recv_pos > 0) {
            memcpy(m_recv_buf, m_recv_buf + pos, m_recv_pos);
        }
        if (pos > 0) { //收到消息的时间
            m_lastrecvmsg = get_reactor().getCurSecond();
        }
    }
    else {

        if (bssl) {
            //ret = SSL_get_error(m_ssl, n);
            //if (SSL_ERROR_WANT_READ == ret || SSL_ERROR_WANT_WRITE == ret) return true;
            SSL_ERROR_NONE == ret  n>0 ok  other error
            //return false;
            int ret = ssl_check_error(m_ssl, n);
            if (ret == -2) {
                return true;
            }
            if (errno == EAGAIN || errno == EINTR) {
                return true;
            }
            return false;
        }
        else {
            if (n == 0)
                return false;

            if (errno == EAGAIN || errno == EINTR) {
                return true;
            }
            else {
                return false;
            }
        }
    }
    return  true;
}

处理数据帧,把payload 转发到 logic进程,由logic去处理
一帧数据长度超过 65k 直接抛弃,这里可以根据实际需求设定长度

int  c_WebSocket::parse_dataframe(uint8_t* recv_buf, const uint32_t buf_len) {
    /* get fin bytes */
#define  FAIL_AND_CLOSE  -1         //接受失败OR 关闭
#define  NEED_CLOSE  -1             //需要关闭

#define  CONTINUE_RECV_MSG  0       //消息不完整,需要继续接受
#define  ONE_MSG_LENGHT(X) X        //接受完一条消息,消息总长度为X

#define  MASK_LEN   4      //掩码长度 
#define  PAYLOAD_LENGTH_126  2   //126 额外2个字节


    uint8_t  t_fin = msg_get_bit(recv_buf[0], 7);
    if (t_fin == 0) return FAIL_AND_CLOSE;
    uint8_t t_code = recv_buf[0] & 0x0F;
    uint8_t t_masked = msg_get_bit(recv_buf[1], 7);
    uint16_t t_payload_size = recv_buf[1] & 0x7F;
    if (t_masked == 0) return  FAIL_AND_CLOSE;


    if (t_code == CLOSE_FRAME) {  //关闭帧
        return NEED_CLOSE;
    }


    uint16_t t_playload_pos = c_u16MsgHeadSize;
    if (t_payload_size == 126) {
        if (buf_len < c_u16MsgHeadSize + PAYLOAD_LENGTH_126)  return CONTINUE_RECV_MSG;
        uint16_t length = 0;
        memcpy(&length, recv_buf + c_u16MsgHeadSize, PAYLOAD_LENGTH_126);
        if (length > MAX_PAYLOAD_REQ) return FAIL_AND_CLOSE;  //消息过长
        if (buf_len < c_u16MsgHeadSize + PAYLOAD_LENGTH_126 + MASK_LEN + length) return CONTINUE_RECV_MSG; //等下此接受  //4 为mask长度,前端发过来必须有
        t_payload_size = length;
        t_playload_pos += PAYLOAD_LENGTH_126; //
    }
    else if (t_payload_size == 127) {
        return FAIL_AND_CLOSE;
    }
    else {
        if (buf_len < c_u16MsgHeadSize + MASK_LEN + t_payload_size) return CONTINUE_RECV_MSG; //等下此接受
    }


    memcpy(masking_key_, &recv_buf[t_playload_pos], MASK_LEN);
    t_playload_pos += MASK_LEN; //

    if (t_code == PONG_FRAME) {
        if (m_lastsendping > 0) {
            printf("time=[%u]recv PONG_FRAME \n",g_reactor.getCurSecond());
            m_lastsendping = 0; //ping消息回复
            m_sendpingcount = 0;
        }
        return ONE_MSG_LENGHT(t_playload_pos + t_payload_size;)
    }

    if (t_payload_size == 0) {
        if (t_code == PING_FRAME) {
            //  nopoll_conn_send_pong(conn, nopoll_msg_get_payload_size(msg), (noPollPtr)nopoll_msg_get_payload(msg));
            //  nopoll_msg_unref(msg);
            send_data((char*)&recv_buf[t_playload_pos], t_payload_size, PONG_FRAME);
            return  t_playload_pos + t_payload_size;
        }
        return FAIL_AND_CLOSE;

    }
    // char* play_load = (char*)&recv_buf[t_playload_pos];
    m_payload_length_ = t_payload_size;

    int j = 0;
    for (uint i = 0; i < m_payload_length_; i++) {
        j = i % 4;
        m_payload_[i] = recv_buf[t_playload_pos + i] ^ masking_key_[j];
    }

    //put to public proc
    shm_block_t sb;
    sb.fd = m_fd;
    sb.id = m_id;
    sb.len = t_payload_size;
    sb.type = PROTO_BLOCK;
    sb.frametype = t_code;
    //把数据发送出去
   // recv_push(m_u32channel, m_u32pipeindex, &sb, m_recv_buf + pos + c_u16MsgHeadSize, false);
    recv_push(m_u32channel, m_u32pipeindex, &sb, (uint8_t*)m_payload_, false);
    //int32_t nRet =	printf("client recv one complete pack len=%d m_u32pipeindex=%d nRet=%d\n", u16msglen, m_u32pipeindex, nRet);
 //   m_recv_pos -= u16msglen + c_u16MsgHeadSize; //sub one packet len
 //   pos += u16msglen + c_u16MsgHeadSize;
    return  t_playload_pos + t_payload_size;

}

再来个logic 进程处理

void c_Logic::dologic(struct shm_block_t* pblock, uint8_t *buf, bool brecv)
{
	//处理收到的逻辑
	switch (pblock->type)
	{
	case CLOSE_BLOCK:
		{
			
		}
		break;
	case PROTO_BLOCK:
		{
			if (strncmp((char*)buf, "hello", 5) == 0) {
				buf[0] = 'H';
				buf[1] = 'E';
				buf[2] = 'L';
				buf[3] = 'L';
				buf[4] = 'O';
				send_data(pblock,buf,pblock->len, (WebSocketFrameType)pblock->frametype);
			}
			else {
				buf[0] = '_';
				send_data(pblock, buf, pblock->len, (WebSocketFrameType)pblock->frametype);
			}
		}
		break;
	case CDUMP_BLOCK:
		{

		}
		break;
	default:
		break;
	}
}

发送数据

int send_data(struct shm_block_t* pblock, uint8_t* msg, const uint32_t msglen, WebSocketFrameType ftype) {
	const uint32_t   MAX_PAYLOAD_SEND = 4 * 1024; //最大发送长度
	if (msglen > MAX_PAYLOAD_SEND) return -1;

	uint32_t length = msglen;
	char               header[14];
	int                header_size;
	memset(header, 0, sizeof(header));
	const bool bhas_fin = true;
	if (bhas_fin) {
		msg_set_bit(header, 7);
	}
	if (ftype >= 0) {
		header[0] |= ftype & 0x0f;
	}
	const bool bhas_mask = false; //服务器发送不需要mask,前端给过来才需要
	if (bhas_mask) {
		msg_set_bit(header + 1, 7);
	}

	header_size = 2;

	if (length < 126) {
		header[1] |= length;
	}
	else if (length <= 0xFFFF) {
		/* set the next header length is at least 65535 */
		header[1] |= 126;
		header_size += 2;
		/* set length into the next bytes */
		msg_set_16bit(length, header + 2);
	}
	else {
		//再大的不让发送 //先写上,用不上也没关系
		header[1] = 127;
#if defined(WS_64BIT_PLATFORM)
		if (length < 0x8000000000000000) {
			header[2] = (length & 0xFF00000000000000) >> 56;
			header[3] = (length & 0x00FF000000000000) >> 48;
			header[4] = (length & 0x0000FF0000000000) >> 40;
			header[5] = (length & 0x000000FF00000000) >> 32;
		}
#else
		// (length < 0x80000000)
		header[2] = header[3] = header[4] = header[5] = 0;
#endif
		header[6] = (length & 0x00000000FF000000) >> 24;
		header[7] = (length & 0x0000000000FF0000) >> 16;
		header[8] = (length & 0x000000000000FF00) >> 8;
		header[9] = (length & 0x00000000000000FF);

		header_size += 8;
	}

	if (bhas_mask) {
		//不写了 //
	  //  msg_set_32bit(mask_value, header + header_size);
	  //  header_size += 4;
	}
	

	uint8_t buf[MAX_PAYLOAD_SEND + 14];
	memcpy(buf, header, header_size);
	memcpy(buf + header_size, msg, msglen);
	//send_pkg(buf, msglen + header_size);
	//return msglen + header_size;

	shm_block_t sb;
	sb.fd = pblock->fd;
	sb.id = pblock->id;
	sb.type = PROTO_BLOCK;
	sb.len = msglen + header_size;
	sb.frametype = (uint8_t)ftype;
	send_push(0, 1, &sb, buf, true);
	return 0;
}

5:支持 SSL
先加载证书

bool c_Accept::loadssl(const char* private_key_file, const char* server_crt_file, const char* ca_crt_file) {

	m_ctx = SSL_CTX_new(SSLv23_server_method());
	if (!m_ctx) { return false; }
	//assert(ctx);
	// 不校验客户端证书
	SSL_CTX_set_verify(m_ctx, SSL_VERIFY_NONE, nullptr);
	// 加载CA的证书  
	if (!SSL_CTX_load_verify_locations(m_ctx, ca_crt_file, nullptr)) {
		printf("SSL_CTX_load_verify_locations error!\n");
		return false;
	}
	// 加载自己的证书  
	if (SSL_CTX_use_certificate_file(m_ctx, server_crt_file, SSL_FILETYPE_PEM) <= 0) {
		printf("SSL_CTX_use_certificate_file error!\n");
		return false;
	}

	// 加载私钥
	if (SSL_CTX_use_PrivateKey_file(m_ctx, private_key_file, SSL_FILETYPE_PEM) <= 0) {
		printf("SSL_CTX_use_PrivateKey_file error!\n");
		return false;
	}

	// 判定私钥是否正确  
	if (!SSL_CTX_check_private_key(m_ctx)) {
		printf("SSL_CTX_check_private_key error!\n");
		return false;
	}
	return true;}

accept 后, ssl = SSL_new(get_ssl_ctx()); 再调用 SSL_accept

bool c_Accept::handle_input()
{
	sockaddr_in ip;
	socklen_t len;

	int cli_fd;
	while (1) {
		len = sizeof(ip);
		cli_fd = accept(m_fd, (sockaddr *)&ip, &len);

		if (cli_fd >= 0) {
			if ((uint32_t)cli_fd >= get_reactor().max_handler()) {
				close(cli_fd);
				continue;
			}
			if (!get_reactor().add_cur_connect(get_max_connect())) {
				printf("client max connect is over \n");
				close(cli_fd);
				return true;
			}
			SSL* ssl = nullptr;
			if (isSsl()) {
				ssl = SSL_new(get_ssl_ctx());
				if (ssl == nullptr) {
					get_reactor().sub_cur_connect();
					close(cli_fd);
					continue;
				}
				printf("accept SSL_new \n");
			}
			c_WebSocket*ts = new (std::nothrow) c_WebSocket();
			if (!ts) {
				get_reactor().sub_cur_connect();
				close(cli_fd);
				continue;
			}
			printf("accept client connect \n");
			ts->start(cli_fd, ip,m_u32channel,m_u32pipeindex,ssl);
		}
		else {
			if (errno == EAGAIN || errno == EINTR || errno == EMFILE || errno == ENFILE) {
				return true;
			}
			else {
				return false;
			}
		}
	}
}
void c_WebSocket::start(int fd, sockaddr_in& ip, uint32_t channel, uint32_t u32pipeindex,SSL* ssl)
{
    m_u32channel = channel;
    m_u32pipeindex = u32pipeindex;
    m_fd = fd;
    m_ip = ip;
    //---------------------------------------------
    m_lastrecvmsg = g_reactor.getCurSecond();
    c_heartbeat::GetInstance().handle_input_modify(fd, m_id, m_lastrecvmsg, m_lastrecvmsg);
    set_noblock(m_fd);

    m_ssl = ssl;
    if (isSsl()) {
        printf("ssl client handshake ready\n");
        SSL_set_fd(ssl, m_fd);
        int code, ret;
        int retryTimes = 0;
     //   uint64_t begin = 0;//Time::SystemTime();
        // 防止客户端连接了但不进行ssl握手, 单纯的增大循环次数无法解决问题,
        while ((code = SSL_accept(ssl)) <= 0 && retryTimes++ < 100) {
            ret = SSL_get_error(ssl, code);
            if (ret != SSL_ERROR_WANT_READ) {
                printf("ssl accept error. sslerror=%d  errno=%d \n", ret,errno); // SSL_get_error(ssl, code));
                break;
            }
            usleep(20 * 1000);//20ms //msleep(1); //这里最多会有2s的等待时间,以后一定要异步
        }

        if (code != 1) {
            handle_fini();
            return;
        }
        printf("ssl client handshake ok (%d)\n", retryTimes);
    }

    m_recv_buf_size = default_recv_buff_len;
    m_recv_buf = (uint8_t*)malloc(m_recv_buf_size);
    if (!m_recv_buf) {
        handle_fini();
        return;
    }
    //----------------------------------
    return;
}

6:心跳检查 10秒(可以自行设定)未收到消息,发送ping,发送2次,没回应 断线

bool c_WebSocket::checcklastmsg(uint32_t t) {
    if (m_lastrecvmsg + 10 <= t) {
        if (!is_handshake_ok()) return true;
        if (m_lastsendping > 0 && m_lastsendping + 10 <= t && m_sendpingcount > 1) {
            //disconnect
            printf("[%u]ready disconnect \n",t);
            return true;
        }
        else if(m_lastsendping == 0  || (m_sendpingcount > 0 && m_lastsendping+10 <=t)){
            //发送ping
            m_lastsendping = t;
            ++m_sendpingcount;
            send_ping_frame();
            printf("time=%u,%d send ping frame\n", t, m_sendpingcount);
        }
    }
    return false;
}

int c_WebSocket::send_ping_frame() {

    uint32_t length = 0;
    char               header[14];
    int                header_size;
    memset(header, 0, sizeof(header));
    const bool bhas_fin = true;
    if (bhas_fin) {
        msg_set_bit(header, 7);
    }
    header[0] |= PING_FRAME & 0x0f;
    const bool bhas_mask = false; //服务器发送不需要mask,前端给过来才需要
    if (bhas_mask) {
        msg_set_bit(header + 1, 7);
    }

    header_size = 2;

    if (length < 126) {
        header[1] |= length;
    }

    uint8_t buf[MAX_PAYLOAD_SEND + 14];
    memcpy(buf, header, header_size);
    // memcpy(buf + header_size, msg, msglen);

    send_pkg(buf, header_size);
    return header_size;
}
void c_WebSocket::send_pkg(uint8_t* buf, uint32_t len){
//--------------------------------------------------------------
//有上次预留的
 uint32_t p = 0;
 int n;
 if (isSsl()) {
     n = SSL_write(m_ssl, buf, len);   // 发送响应主体
 }
 else {
    n = send(m_fd, buf, len, 0);
 }
 if (n > 0) {
     if ((uint32_t)n == len) {
         //printf("send data len = %d need send len=%d \n",n,len);
         return;
     }
     else {
         p = n;
     }
 }
 else {
     if (errno != EAGAIN && errno != EINTR) {
         handle_error();
         return;
     }
 }
 //没发送完,存起来下次再发送,这里自行处理
 //----------------------------------------------------------------
}

7:json配置文件读取 jsoncpp API

bool  c_JsonReader::read_json_file(const char* jsonfile)
{
#define  LISTENIP				"listenip"
#define  LISTENPORT				"listenport"
#define  USESSL					"usessl"
#define  PRIVATEKEYFILE			"privatekeyfile"
#define  SERVERCRTFILE			"servercrtfile"
#define  CACRTFILE				"cacrtfile"
#define  AES128KEYHANDSHAKE		"aes128keyhandshake"
#define  AES128IV				"aes128iv"
#define  MAXCONN				"maxconn"
#define  CHECKHEARTBEAT			"checkheartbeat"
#define  OPENBLACKWHITEIP		"openblackwhiteip"
#define  SINGLEIPMAXCONN		"singleipmaxconn"

#define  MIN(A,B)  A<B?A:B
	FILE* f = fopen(jsonfile, "rb");
	if (f) {
		const int buf_size = 4 * 1024; 
		char buf[buf_size] = { 0, };
		memset(buf, 0, sizeof(buf));
		size_t n = fread(buf, sizeof(char), buf_size, f);
		fclose(f);
		if (n < 10) {
			printf("read_json_file file  length too short \n");
			return false;
		}

		Json::Reader reader;
		Json::Value root;
		if (reader.parse(buf,root)) {

			if (root[LISTENIP].empty() || root[LISTENPORT].empty() || root[MAXCONN].empty()  \
				|| root[USESSL].empty()  || root[AES128KEYHANDSHAKE].empty() || root[AES128IV].empty()) {
				printf("read_json_file base fail\n");
				return  false;

			}
			const bool busessl = root[USESSL].asBool();
			m_chatSerCfg.buseSsl = busessl;
			if (busessl) { 
				const bool bp = root[PRIVATEKEYFILE].empty();
				const bool bs = root[SERVERCRTFILE].empty();
				const bool bc = root[CACRTFILE].empty();
				if (bp || bs || bc) {
					printf("read_json_file ssl fail\n");
					return  false;
				}
				strncpy(m_chatSerCfg.szprivatekeyfile, root[PRIVATEKEYFILE].asString().c_str(), MIN(root[PRIVATEKEYFILE].asString().length(), ssl_file_len));
				strncpy(m_chatSerCfg.szservercrtfile, root[SERVERCRTFILE].asString().c_str(), MIN(root[SERVERCRTFILE].asString().length(), ssl_file_len));
				strncpy(m_chatSerCfg.szcacrtfile, root[CACRTFILE].asString().c_str(), MIN(root[CACRTFILE].asString().length(), ssl_file_len));
			}
			else{

			}
			m_chatSerCfg.u16maxconn = (uint16_t)root[MAXCONN].asUInt();
			strncpy(m_chatSerCfg.szlistenip, root[LISTENIP].asString().c_str(), sizeof(m_chatSerCfg.szlistenip) - 1);
			m_chatSerCfg.nlistenport = (int32_t)root[LISTENPORT].asInt();

			memcpy(m_chatSerCfg.u8AES128keyhandshake, root[AES128KEYHANDSHAKE].asString().c_str(), root[AES128KEYHANDSHAKE].asString().length());
			memcpy(m_chatSerCfg.u8AES128iv, root[AES128IV].asString().c_str(), 16);

			{//safe config
				const bool bcheck = root[CHECKHEARTBEAT].empty();
				const bool bbwip = root[OPENBLACKWHITEIP].empty();
				const bool bmaxconn = root[SINGLEIPMAXCONN].empty();
				if (!bcheck) {
					m_chatSerCfg.bcheckheartbeat = root[CHECKHEARTBEAT].asBool();
				}

				if (bbwip && (bbwip == bmaxconn)) {
					m_chatSerCfg.u8openblackwhiteip =(uint8_t) root[CHECKHEARTBEAT].asUInt();
					m_chatSerCfg.u8singleipmaxconn = (uint8_t)root[SINGLEIPMAXCONN].asUInt();
				}
			}
			


			return true;
		}

	}
	printf("json file no exist or  parse json file fail \n");
	return false;
}

8:只是帮助分析websocket 协议
红框这边 ssl_accept 是需要优化的,可以考虑用coroutine 或 thread callback
在这里插入图片描述
9: 后续继续优化,差不多,再上demo
如果觉得有用,麻烦点个赞,加个收藏

文章来源:https://blog.csdn.net/yunteng521/article/details/134978453
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。