最近总结了一些 OS I/O 多路复用的知识。之前对 I/O Multiplexer 的认知还停留在 select 系统调用,现在是时候扩展一下视野了。

1. 从 Socket 模型开始

Socket 作为一个应用层和传输层间的的抽象,支持网络层 IPv4 / IPv6,以及传输层 TCP / UDP。

双方要进行网络通信前,各自需要创建一个 Socket。

如果是基于 UDP 的套接字:

如果是基于 TCP 的套接字:

以基于 TCP 的套接字为例,首先使用 socket() 创建一个网络协议为 IPv4,以及传输协议为 TCP 的 Socket 结构体,然后使用 bind() 绑定 Server IP 和进程服务端口 port,并监听 listen() 在该端口上(listen 仅改变状态);

之所以需要指定 Server IP,是因为一台机器是可以有多个网卡的,每个网卡都有对应的 IP 地址。Socket 允许指定监听的网卡。0.0.0.0 表示监听所有的 network interfaces;

port 即为传输层信息,对应指定线程的服务。

Server 端 socket 进入监听状态后,调用阻塞函数 accept(),来从内核获取客户端的连接,如果没有客户端连接,则会阻塞等待客户端连接的到来。

如果客户端使用 connect() 发起连接后,双方会进行 TCP 3 次握手。在连接过程中,server 端 OS kernel 会为每个 socket 都维护两个队列:

  • 一个是 “还没完全建立” 连接的队列,称为 TCP 半连接队列(服务端 socket 处于 syn_rcvd 状态);
  • 另一个是 “已经建立” 连接的队列,称为 TCP 全连接队列,这个队列都是完成了三次握手的连接(此时服务端处于 established 状态);

当全连接队列不为空时,内核会拿出一个已连接的 socket(称为 已连接 socket)并响应任意一个阻塞在 accept() 上的服务端线程,此时该服务线程会使用这个已连接的 socket 来响应客户端(一般会新开一个进程/线程/使用其他方案来处理)。

注意,accept() 第一参数是监听 socket 的文件描述符,返回的是已连接 socket 的文件描述符,它们不一样。因为考虑多线程情况,accept 放在一个循环里,这个监听 socket 专门用于接收连接请求。

2. 提升 Socket 服务能力:为什么选择 I/O 多路复用

那么一般情况服务端应该怎么做来处理大量的 socket 连接请求?

前面说过,服务端可以新开一个进程来处理客户端连接,但每次 fork() 创建新进程(包括完整的虚拟内存空间、CPU 寄存器、内核数据结构如文件描述符等等)、进程间上下文切换的时间开销非常大。并且父进程需要通过 wait/waitpid 来回收子进程资源。更重要的是,内存空间资源也不一定足够,这在大量快速并发的场景并不切实际。

如果服务端新开一个线程来处理客户端连接,性能和其他资源紧张情况应该好于多进程实现。同一进程的线程间会共享文件描述符表、页表、所在进程的所有信息、全部的用户态空间等等,因此同进程间线程上下文开销大大减小。

为了应对线程频繁创建和销毁的情况,我们还可以通过维护线程池来缓解这个情况。

但本质上,过多的进程 / 线程最终会把压力交给操作系统。OS 想要同时管理、调度上万个进程/线程,势必会导致 OS 不堪重负(考虑调度)。

在这种场景下我们就需要使用 I/O 多路复用技术,让一个进程能够维护多个 socket。

select & poll

我们最先了解,也是最简单的 I/O 多路复用是 select 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <sys/select.h>
/**
* 该函数会阻塞监听指定的 readfds 列表中所有文件,直至列表中任一个文件能够被 read(不是 EOF)
*
* @param[in] maxfd 最大监听的数量(`readfds` 有几个 bit 位)
* @param[in/out] readfds 所有要监听的 file descriptor 列表(每个 bit 代表对对应的 fd 进行监听),最大默认 `FD_SET_SIZE` bits;
* 也是结束监听时能读 fd 的返回值。
* 只需要了解前两个参数,后面 3 个配置参数一般填 NULL
*/
int select(int nfds, fd_set *_Nullable restrict readfds,
fd_set *_Nullable restrict writefds,
fd_set *_Nullable restrict exceptfds,
struct timeval *_Nullable restrict timeout);
/* clear all bits in fdset. */
void FD_ZERO(fd_set *fdset);
/* clear bit fd in fdset */
void FD_CLR(int fd, fd_set *fdset);
/* turn on bit fd in fdset */
void FD_SET(int fd, fd_set *fdset);
/* Is bit fd in fdset on? */
int FD_ISSET(int fd, *fdset);

处理已连接 socket 的线程,将已连接的并且感兴趣的 socket 放到文件描述符集合(FD set,也就是上面的 bitmap)中,然后调用 select 函数将文件描述符集合复制到到内核里,让内核来检查是否有网络事件产生。

内核检查的方式很 naive,就是遍历这个文件描述符集合,当内核发现有网络事件发生后(例如客户端回复),在将对应的 socket 改为可读/可写,把更新状态的文件描述符表再次复制回用户态,用户态再通过遍历方式找到可读/可写的 socket 再进行对应操作。

我们发现 select 有几个问题:

  • 整个过程比较低效(两次遍历、两次复制),涉及多次 kernel 和 user 间的 memory copy 以及上下文切换;
  • 并且访问文件描述符表的时间复杂度是线性的($O(n)$);
  • 由于使用固定大小的 bitmap,受到内核中的 FD_SETSIZE 限制, 默认最大值为 1024,只能监听 0~1023 的文件描述符。

那么 poll 函数呢?

1
2
3
4
5
6
7
#include <poll.h>
struct pollfd {
int fd; /* file descriptor */
short events; /* requested events */
short revents; /* returned events */
};
int poll(struct pollfd *fds, nfds_t nfds, int timeout);

同样是在 fds 中任意一个文件描述符准备完成 / 超时 / 信号打断。只不过 poll 支持精确到事件类别的控制(events/revents)。

它和 select 一样访问模式类似,但是不一样的是,poll 不再用 bitmap 来存储所关注的文件描述符,取而代之用动态数组(以链表形式)来组织,突破了 select 的文件描述符个数限制,当然还会受到系统文件描述符限制。不过仍然是线性访问时间、低效的检查过程。

因此在高并发的情况 selectpoll 的性能还是不足够的。

epoll

这个系统调用是 Unix 专属的,一般情况下它的使用涉及接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <sys/epoll.h>
struct epoll_event {
uint32_t events; /* Epoll events */
epoll_data_t data; /* User data variable */
};
union epoll_data {
void *ptr;
int fd;
uint32_t u32;
uint64_t u64;
};
typedef union epoll_data epoll_data_t;

/**
* 创建一个新的 epoll 实例
* @param[in] size 原本用作给内核一个分配数据结构大小的提示。现在已不需要,主要是保持兼容性
* @return 属于新的 epoll 实例的文件描述符。是接下来对 epoll 接口操作指代该 epoll 的符号(epfd)
* @warning 所有 epoll_create 返回的 epfd 都需要手动回收(close())!
*/
int epoll_create(int size);
/**
* epoll 实例的感兴趣的 socket fd 列表维护在 Kernel 中。用户态需要这个控制函数来增添/修改/删除对指定 socket 文件的监听。
* @param[in] epfd 当前 epoll 实例对应的文件描述符
* @param[in] op 可选操作:EPOLL_CTL_ADD / EPOLL_CTL_MOD / EPOLL_CTL_DEL
* @param[in] fd 需要被操作的 socket 文件描述符
* @param[in] event.events 可选事件:EPOLLIN (readable) / EPOLLOUT (writable) / EPOLLRDHUP (peer close conn)
* / EPOLLET (边缘触发。不指定则默认水平触发)
* @return 0 if sucess (otherwise -1 + set errno)
*/
int epoll_ctl(int epfd, int op, int fd,
struct epoll_event *_Nullable event);
/**
* 等待 epoll 实例中指定发生事件类型的可用文件描述符
* @param[out] events 返回当前事件信息和对应的文件描述符列表
* @param[in] maxevents 传入 events buffer 能盛放的最大 epoll_event 结构体的个数
* @return 返回 ready 的文件描述符数量
*/
int epoll_wait(int epfd, struct epoll_event events[.maxevents],
int maxevents, int timeout);

一般使用方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
int s = socket(AF_INET, SOCK_STREAM, 0);
bind(s, ...);
listen(s, ...);
int epfd = epoll_create(...);
// 将所有需要监听的 socket 添加到 epfd 中
epoll_ctl(epfd, ...);

while(1) {
int n = epoll_wait(...);
for (接收到数据的 socket) {
// do something
}
}

epoll 相较于 select 和 poll 有重要的优势:

  • epoll 在内核中使用红黑树来跟踪进程所有已注册(通过 epoll_ctl)的文件描述字。

    • 少两次文件描述符 copy,减少内存分配:不需要整体对文件描述符表进行复制(放在内核管理);
    • 管理性能增强:增删改复杂度 $O(\log n)$,一般不需要查找、取出平均复杂度可以到达常数时间(因为事件驱动)!
  • 使用事件驱动机制:

    • 在内核中维护链表记录就绪事件。有网络事件就把 ready sockets 放到 kernel space 的链表中(因此是常数时间的);

    • 调用 epoll_wait 时,如果链表非空,直接复制给 user space 提供的 buffer(抱歉,没法使用共享内存,还是需要 copy);

      1
      2
      3
      4
      5
      6
      7
      // epoll wait 部分源码
      if (revents) {
      if (__put_user(revents, &uevent->events)
      || __put_user(epi->event.data, &uevent->data)) {
      // ...
      }
      }

这大大增强了 epoll API 的并发能力。

epoll 的边缘触发和水平触发(ET & LT)

在学习数字电子电路的时候老师一定和你说过,某些电子元件的触发方式,其中就讨论过边缘触发和水平触发。

  • 水平触发的意思是只要满足事件的条件,比如内核中有数据需要读,就一直不断地把这个事件通知用户(例如保持某个全局 flag 一直有效);
  • 边缘触发的意思是只有第一次满足条件的时候才触发,之后就不会再传递同样的事件了。

epoll_ctl 可以默认使用水平触发,向 event.events 追加 EPOLLET 则表示使用边缘触发。在 epoll 中,考虑一个场景:一个文件描述符上有数据可读(EPOLLIN 触发),线程开始处理数据,而在处理过程中又有新数据加入。那么:

  • 如果是边缘触发:当旧数据开始处理时,文件描述符仍然保持在就绪状态。但当有新的数据写入时,文件描述符会从就绪状态变为未就绪状态,然后再次变为就绪状态,触发一次新的 EPOLLIN 事件

    这种模式下我们应该:使用循环 read 这个 fd 中的内容直至这个 read 返回错误 (errno == EAGAIN) || (errno == EWOULDBLOCK)

    这样可以确保即使在旧数据处理过程中有新的数据写入,应用程序也能及时地得到通知,并读取新的数据。

    考虑一个问题,多线程场景下,使用边缘触发可能有问题:因为存在唤醒多个线程的问题。如果不希望多个线程同时操作 socket,就应该使用 EPOLLONESHOT,表示 one-shot,即特定的 socket fd 事件只会触发一次,然后立即移除。如果获得消息的线程以后还想接收这个 socket 的事件,需要使用 epoll_ctlEPOLL_CTL_MOD 重新注册。

    如果使用边缘触发,则不能使用阻塞 I/O,并且一个信号必须读到不能再读为止(EAGAIN/EWOULDBLOCK),因为:

    如果没有读完所有内容,则会导致下次调用 epoll_wait 时不会再收到之前消息的通知,通知信息会丢失!

    如果使用了阻塞 I/O,那么在没有通知的情况下会永远等待下去!

  • 如果是水平触发:当某个文件描述符上有数据可读,应用程序可以不立即处理完毕该事件。这样,因为当程序下一次调用 epoll_wait 时,epoll_wait 还会向应用程序通知此事件,直到事件被处理完毕。即:如果文件描述符上有数据可读,它的状态码会一直保持就绪状态,直到所有的数据都被读取完毕才会变为未就绪

    这种模式性能会略差于边缘触发。

Example

尝试一下,用 C 写一个简单的 epoll 驱动的 server:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <sys/epoll.h>
#include <sys/socket.h>
#include <fcntl.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <netinet/in.h>

#define MAX_EVENTS 64
#define PORT 8888
#define BUFFER_SIZE 1024

// 设置文件描述符为非阻塞模式
static void set_nonblocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}

static void die(const char* msg) {
perror(msg);
exit(EXIT_FAILURE);
}

static size_t mread(int client_fd, char *buf, size_t n) {
size_t total_read = 0;
while (1) {
ssize_t count = recv(client_fd, buf, n, 0);
if (count == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// 数据已读完
if (total_read > 0) {
printf("Received %zd bytes from fd %d: %.*s\n",
total_read, client_fd, (int)total_read, buf);
send(client_fd, "fsck\n", 5, 0);
}
break;
} else {
perror("recv");
close(client_fd);
break;
}
} else if (count == 0) {
printf("Connection closed by client: fd %d\n", client_fd);
close(client_fd);
break;
}
total_read += count;
memcpy(bufferi, buf, count);
}
return total_read;
}

int main() {
int listen_sock = socket(AF_INET, SOCK_STREAM, 0);
if (listen_sock == -1) die("socket");

// 设置地址复用
int optval = 1;
setsockopt(listen_sock, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));

// 绑定地址
struct sockaddr_in addr = {
.sin_family = AF_INET,
.sin_port = htons(PORT),
.sin_addr.s_addr = INADDR_ANY
};

if (bind(listen_sock, (struct sockaddr*)&addr, sizeof(addr))) die("bind");

// 设置为非阻塞模式
set_nonblocking(listen_sock);

if (listen(listen_sock, SOMAXCONN)) die("listen");

int epoll_fd = epoll_create1(0);
if (epoll_fd == -1) die("epoll_create1");

struct epoll_event event = {
.events = EPOLLIN | EPOLLET, // 边缘触发模式
.data.fd = listen_sock
};
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_sock, &event)) die("epoll_ctl");

// 事件循环
struct epoll_event events[MAX_EVENTS];
printf("Server started on port %d\n", PORT);


while (1) {
int n = epoll_wait(epoll_fd, events, MAX_EVENTS, -1);
if (n == -1) die("epoll_wait");

for (int i = 0; i < n; i++) {
// 处理新连接
if (events[i].data.fd == listen_sock) {
while (1) { // 必须循环accept直到EAGAIN
struct sockaddr_in client_addr;
socklen_t addrlen = sizeof(client_addr);
int client_fd = accept(listen_sock,
(struct sockaddr*)&client_addr,
&addrlen);
if (client_fd == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break; // 已接受所有连接
} else {
die("accept");
}
}

printf("New connection: fd %d\n", client_fd);
set_nonblocking(client_fd); // 必须设置为非阻塞

// 注册客户端socket到epoll(边缘触发)
event.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
event.data.fd = client_fd;
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &event)) {
close(client_fd);
die("epoll_ctl client");
}
}
}
else {
// 处理客户端事件
int client_fd = events[i].data.fd;

// 处理连接关闭或错误
if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) {
printf("Connection closed: fd %d\n", client_fd);
close(client_fd);
continue;
}

// 处理可读事件
if (events[i].events & EPOLLIN) {
char buf[BUFFER_SIZE];
ssize_t total_read = 0;

// 必须循环读取直到EAGAIN
mread(client_fd, buf, BUFFER_SIZE);
}
}
}
}

close(listen_sock);
return 0;
}

以及一个向指定 Server 发送指定二进制数据的 Go 程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package main

import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"net"
"os"
"os/signal"
"strings"
"syscall"
)

func main() {
if len(os.Args) < 3 {
fmt.Println("Usage: go run main.go <host> <port>")
os.Exit(1)
}

host := os.Args[1]
port := os.Args[2]
address := net.JoinHostPort(host, port)

// Connect to the server
conn, err := net.Dial("tcp", address)
if err != nil {
fmt.Printf("Error connecting to server: %v\n", err)
os.Exit(1)
}
defer conn.Close()

fmt.Printf("Connected to %s\n", address)

// Channel to handle graceful shutdown
exitChan := make(chan os.Signal, 1)
signal.Notify(exitChan, os.Interrupt, syscall.SIGTERM)

// Goroutine to listen for server messages
go func() {
reader := bufio.NewReader(conn)
for {
message, err := reader.ReadString('\n')
if err != nil {
fmt.Println("Connection closed by server.")
exitChan <- syscall.SIGTERM
return
}
fmt.Printf("Server: %s", message)
}
}()

// Goroutine to handle user input and send data to server
go func() {
scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print("Enter data to send (\\x?? for hex bytes): ")
if scanner.Scan() {
input := scanner.Text()
data, err := parseInput(input)
if err != nil {
fmt.Printf("Invalid input: %v\n", err)
continue
}
_, err = conn.Write(data)
if err != nil {
fmt.Printf("Error sending data: %v\n", err)
exitChan <- syscall.SIGTERM
return
}
} else {
fmt.Println("Input closed.")
exitChan <- syscall.SIGTERM
return
}
}
}()

// Wait for interrupt signal
<-exitChan
fmt.Println("Exiting program.")
}

// parseInput converts user input with \x?? hex sequences into a byte slice
func parseInput(input string) ([]byte, error) {
var buffer bytes.Buffer

parts := strings.Split(input, "\\x")
buffer.WriteString(parts[0]) // Add any text before the first \x

for i := 1; i < len(parts); i++ {
if len(parts[i]) < 2 {
return nil, fmt.Errorf("incomplete hex byte: \\x%s", parts[i])
}
hexByte := parts[i][:2]
rest := parts[i][2:]

b, err := hex.DecodeString(hexByte)
if err != nil {
return nil, fmt.Errorf("invalid hex byte: \\x%s", hexByte)
}

buffer.Write(b)
buffer.WriteString(rest) // Add any remaining text after the hex byte
}

return buffer.Bytes(), nil
}