Administrator
发布于 2026-01-06 / 7 阅读
0
0

IO多路复用-select

select 服务端

#include <iostream>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/select.h>
#include <cstring>
#include <vector>
#include <cerrno>
#include <stdexcept> // 为了std::runtime_error
#include <string> // 为了std::string

class SelectServer {
private:
    int server_fd;
    int max_fd;
    fd_set read_fds;      // select使用的fd集合
    fd_set master_fds;    // 主fd集合
    std::vector<int> client_fds;  // 客户端fd列表
    
public:
    SelectServer(int port) {
        server_fd = socket(AF_INET, SOCK_STREAM, 0);
        if (server_fd < 0) {
            throw std::runtime_error("Socket creation failed");
        }
        
        // 设置socket选项,避免地址占用
        int opt = 1;
        if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
            close(server_fd);
            throw std::runtime_error("Setsockopt failed");
        }
        
        // 绑定地址
        struct sockaddr_in address{};
        address.sin_family = AF_INET;
        address.sin_addr.s_addr = INADDR_ANY;
        address.sin_port = htons(port);
        
        if (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) < 0) {
            close(server_fd);
            throw std::runtime_error("Bind failed");
        }
        
        // 监听连接
        if (listen(server_fd, 5) < 0) {
            close(server_fd);
            throw std::runtime_error("Listen failed");
        }
        
        // 初始化fd集合
        FD_ZERO(&master_fds);
        FD_SET(server_fd, &master_fds);
        max_fd = server_fd;
        
        std::cout << "Server started on port " << port << std::endl;
    }
    
    ~SelectServer() {
        for (int client_fd : client_fds) {
            close(client_fd);
        }
        close(server_fd);
        std::cout << "Server stopped" << std::endl;
    }
    
    void run() {
        while (true) {
            read_fds = master_fds;  // 每次select前需要重置
            
            // 使用select等待事件
            struct timeval timeout{};
            timeout.tv_sec = 5;     // 5秒超时
            timeout.tv_usec = 0;
            
            int activity = select(max_fd + 1, &read_fds, nullptr, nullptr, &timeout);
            
            if (activity < 0 && errno != EINTR) {
                std::cerr << "Select error: " << strerror(errno) << std::endl;
                continue;
            }
            
            if (activity == 0) {
                // 超时,可以做一些其他处理
                // std::cout << "Select timeout, checking other tasks..." << std::endl;
                continue;
            }
            
            // 检查服务器socket是否有新连接
            if (FD_ISSET(server_fd, &read_fds)) {
                handleNewConnection();
            }
            
            // 检查客户端socket是否有数据
            checkClientData();
        }
    }
    
private:
    void handleNewConnection() {
        struct sockaddr_in client_addr{};
        socklen_t addr_len = sizeof(client_addr);
        
        int client_fd = accept(server_fd, (struct sockaddr*)&client_addr, &addr_len);
        
        if (client_fd < 0) {
            std::cerr << "Accept failed: " << strerror(errno) << std::endl;
            return;
        }
        
        // 将新客户端socket设为非阻塞(可选)
        // fcntl(client_fd, F_SETFL, O_NONBLOCK);
        
        // 添加到fd集合
        FD_SET(client_fd, &master_fds);
        client_fds.push_back(client_fd);
        
        // 更新最大fd
        if (client_fd > max_fd) {
            max_fd = client_fd;
        }
        
        std::cout << "New connection from " 
                  << inet_ntoa(client_addr.sin_addr) << ":" 
                  << ntohs(client_addr.sin_port) 
                  << " (fd: " << client_fd << ")" << std::endl;
        
        // 发送欢迎消息
        const char* welcome_msg = "Welcome to select server!\n";
        send(client_fd, welcome_msg, strlen(welcome_msg), 0);
    }
    
    void checkClientData() {
        // 遍历所有客户端(使用迭代器,因为可能会删除元素)
        auto it = client_fds.begin();
        while (it != client_fds.end()) {
            int client_fd = *it;
            
            if (FD_ISSET(client_fd, &read_fds)) {
                char buffer[1024];
                memset(buffer, 0, sizeof(buffer));
                
                // 接收数据
                ssize_t bytes_read = recv(client_fd, buffer, sizeof(buffer) - 1, 0);
                
                if (bytes_read > 0) {
                    buffer[bytes_read] = '\0';
                    std::cout << "Received from fd " << client_fd << ": " << buffer;
                    
                    // 回显数据给客户端
                    std::string response = "Echo: " + std::string(buffer);
                    send(client_fd, response.c_str(), response.length(), 0);
                    
                    // 如果客户端发送"exit",关闭连接
                    if (strcmp(buffer, "exit\n") == 0 || strcmp(buffer, "exit\r\n") == 0) {
                        std::cout << "Client " << client_fd << " requested to disconnect" << std::endl;
                        closeClient(client_fd, it);
                        continue;  // 迭代器已经更新,继续下一个
                    }
                } 
                else if (bytes_read == 0) {
                    // 客户端关闭连接
                    std::cout << "Client " << client_fd << " disconnected" << std::endl;
                    closeClient(client_fd, it);
                    continue;  // 迭代器已经更新,继续下一个
                } 
                else {
                    // 读取错误
                    if (errno != EWOULDBLOCK && errno != EAGAIN) {
                        std::cerr << "Recv error from fd " << client_fd 
                                  << ": " << strerror(errno) << std::endl;
                        closeClient(client_fd, it);
                        continue;  // 迭代器已经更新,继续下一个
                    }
                }
            }
            ++it;
        }
    }
    
    void closeClient(int client_fd, std::vector<int>::iterator& it) {
        // 关闭socket
        close(client_fd);
        
        // 从fd集合中移除
        FD_CLR(client_fd, &master_fds);
        
        // 从客户端列表中移除
        it = client_fds.erase(it);
        
        // 更新max_fd(如果需要)
        if (client_fd == max_fd) {
            max_fd = server_fd;
            for (int fd : client_fds) {
                if (fd > max_fd) {
                    max_fd = fd;
                }
            }
        }
    }
};

int main(int argc, char* argv[]) {
    int port = 8888;
    
    if (argc > 1) {
        port = std::stoi(argv[1]);
    }
    
    try {
        SelectServer server(port);
        server.run();
    } 
    catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }
    
    return 0;
}


评论