概述
在作者给的聊天室程序代码中有着这样一段:
else if( fds[i].revents & POLLIN )
{
if( ret < 0 )
{
...
}
else if( ret == 0 )
{
...
}
else {
for( int j = 1; j <= user_counter; ++j )
{
if( fds[j].fd == connfd )
{
continue;
}
fds[j].events |= ~POLLIN; // 取消该套接字POLLIN,但是你觉得真的是这么写吗?
fds[j].events |= POLLOUT;
users[fds[j].fd].write_buf = users[connfd].buf;
}
}
可以看到作者是想临时取消掉除信息发送者的以外的套接字上的读事件然后并注册写事件以便将信息转发给其他的套接字,但是这首先会造成一个问题:在转发过程中如果其他套接字有信息传来,因为POLLIN被取消掉了,所以该消息会被忽略,这是问题1。
接下来的问题才是我想讲的:fds[j].events |= ~POLLIN;
是正确的取消事件的方式吗?
下面来逐步讲解:
在有新的连接到来时,该连接套接字的事件会被设置成这样:
fds[user_counter].events = POLLIN | POLLRDHUP | POLLERR;
十进制:8201 二进制:00000000000000000010000000001001
我在代码中列出了它的二进制,那向上面的那段代码一样,如果我们要取消掉POLLIN
事件呢,是下面这样吗?
fds[j].events |= ~POLLIN; // 是这样吗?
我们可以知道POLLIN
是#define POLLIN 0x001
,因此我们可以使用8201 | ~1
来模拟fds[j].events |= ~POLLIN;
于是按照作者的代码试着取消掉POLLIN
事件:
printf("%dn", 8201 | ~1 ); // 竟然是-1!
进行按位或运算|:
00000000000000000010000000001001 ( a = 8201)
11111111111111111111111111111110 (~1 = -2 )
——————————————————————————————————————
11111111111111111111111111111111 (-1)
这肯定不是我们想要的结果
我们想要的是:
00000000000000000010000000001000
也就是说fds[j].events |= ~POLLIN;
后fds[j].events
为-1,其二进制是全1,可想而知接下来fds[j].events |= POLLOUT;
也必然的得到-1,因此我们发现fds[j].events |= ~POLLIN;
这一句是错误的写法。
因此在取消某个事件中,我们需要将|
换成&
:
fds[j].events &= ~POLLIN; // 取消该套接字POLLIN,这才是正确的写法
fds[j].events |= POLLOUT;
相应地在另一处代码应改成
else if( fds[i].revents & POLLOUT )
{
int connfd = fds[i].fd;
if( users[connfd].write_buf == NULL ) {
continue;
}
ret = send( connfd, users[connfd].write_buf, strlen( users[connfd].write_buf ), 0 );
users[connfd].write_buf = NULL;
fds[i].events &= ~POLLOUT;
fds[i].events |= POLLIN;
}
将读事件的else if判断前置:
#define _GNU_SOURCE 1
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <assert.h>
#include <stdio.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <fcntl.h>
#include <stdlib.h>
#include <poll.h>
#include <libgen.h> // basename
#define USER_LIMIT 5
#define BUFFER_SIZE 64
#define FD_LIMIT 65535
struct client_data
{
sockaddr_in address;
char* write_buf;
char buf[ BUFFER_SIZE ];
};
int setnonblocking( int fd )
{
int old_option = fcntl( fd, F_GETFL );
int new_option = old_option | O_NONBLOCK;
fcntl( fd, F_SETFL, new_option );
return old_option;
}
int main( int argc, char* argv[] )
{
int ret = 0;
struct sockaddr_in address;
bzero( &address, sizeof( address ) );
address.sin_family = AF_INET;
address.sin_addr.s_addr=htonl(INADDR_ANY);
address.sin_port = htons( atoi( "9190" ) );
int listenfd = socket( PF_INET, SOCK_STREAM, 0 ); assert( listenfd != -1 );
ret = bind( listenfd, ( struct sockaddr* )&address, sizeof( address ) ); assert( ret != -1 );
ret = listen( listenfd, 5 ); assert( ret != -1 );
// 初始化pollfd类型的fds
pollfd fds[USER_LIMIT+1];
for( int i = 1; i <= USER_LIMIT; ++i ) {
fds[i].fd = -1;
fds[i].events = 0;
}
int user_counter = 0;
// 添加监听套接字的事件
fds[user_counter].fd = listenfd;
fds[user_counter].events = POLLIN | POLLERR;
fds[user_counter].revents = 0;
client_data* users = new client_data[FD_LIMIT];
while(1)
{
ret = poll( fds, user_counter+1, -1 );
if ( ret == -1 ) {
printf( "poll failuren" );
break;
}
// 遍历所有套接字(包括监听套接字,所有要+1)
for( int i = 0; i < user_counter+1; ++i )
{
//! 有新的连接
if( ( fds[i].fd == listenfd ) && ( fds[i].revents & POLLIN ) )
{
struct sockaddr_in client_address;
socklen_t client_addrlength = sizeof( client_address );
int connfd = accept( listenfd, ( struct sockaddr* )&client_address, &client_addrlength );
if ( connfd == -1 )
{
printf( "errno is: %dn", errno );
continue;
}
// 已连接的用户数大于最大用户数,则不允许连接
if( user_counter >= USER_LIMIT )
{
const char* info = "too many usersn";
printf( "%s", info );
send( connfd, info, strlen( info ), 0 );
close( connfd );
continue;
}
user_counter++;
users[connfd].address = client_address;
setnonblocking( connfd ); // 套接字使用非阻塞IO
fds[user_counter].fd = connfd;
fds[user_counter].events = POLLIN | POLLRDHUP | POLLERR;
fds[user_counter].revents = 0;
printf( "comes a new user, now have %d usersn", user_counter );
}
//! 错误事件
else if( fds[i].revents & POLLERR )
{
printf( "get an error from %dn", fds[i].fd );
char errors[ 100 ];
memset( errors, '