Linux epoll 与 C++ 协程

简介

本文使用 C++20 引入的协程来编写一个 Linux epoll 程序。在此实现中,用户使用异步操作时再也无需提供自己的回调函数。以此处实现的 asyncRead() 为例:

  • 使用 asyncRead() 所需的参数和 read() 大致相同,无需传入回调;
  • asyncRead() 的内部会向 epoll 注册要监听的文件描述符、感兴趣的事件和要执行的回调(由实现提供,而无需使用者传入);
  • 当事件未就绪时,co_await asyncRead() 会挂起当前协程;
  • 当事件就绪时,epoll 循环中会执行具体的 I/O 操作(此处将其提交到 I/O 线程池中执行),当 I/O 操作完成时,恢复协程的运行。

1. ThreadPool

此处使用了两个线程池:

  • I/O 线程池:用于执行 I/O 操作;
  • 任务线程池:用于处理客户端连接(此处以 tcp 回显程序为例)。

此处使用的是自己实现的线程池,具体实现见 https://segmentfault.com/a/11...

2. IOContext

IOContext 类对 Linux epoll 做了简单的封装。

io_context.h:

#ifndef IOCONTEXT_H
#define IOCONTEXT_H

#include 
#include 
#include 
#include 
#include "thread_pool.h"

using callback_t = std::function;

struct Args
{
    callback_t m_cb;
};

class IOContext
{
public:
    IOContext(int nIOThreads=2, int nJobThreads=2);

    // 监听文件描述符 fd,感兴趣事件为 events,args 中包含要执行的回调
    bool post(int fd, int events, const Args& args);

    // 提交任务至任务线程池
    bool post(const Task& task);

    // 不再关注文件描述符 fd,并移除相应的回调
    void remove(int fd);

    // 持续监听、等待事件就绪
    void run();

private:
    int m_fd;
    std::unordered_map m_args;
    std::mutex m_lock;
    ThreadPool m_ioPool;     // I/O 线程池
    ThreadPool m_jobPool;    // 任务线程池
};

#endif

io_context.cpp:

#include "io_context.h"
#include 
#include 
#include 

IOContext::IOContext(int nIOThreads, int nJobThreads)
    : m_ioPool(nIOThreads), m_jobPool(nJobThreads)
{
    m_fd = epoll_create(1024);
}

bool IOContext::post(int fd, int events, const Args& args)
{
    struct epoll_event event;
    event.events = events;
    event.data.fd = fd;

    std::lock_guard lock(m_lock);

    int err = epoll_ctl(m_fd, EPOLL_CTL_ADD, fd, &event);
    if (err == 0)
    {
        m_args[fd] = args;
    }

    return err == 0;
}

bool IOContext::post(const Task& task)
{
    return m_jobPool.submitTask(task);
}

void IOContext::remove(int fd)
{
    std::lock_guard lock(m_lock);
    int err = epoll_ctl(m_fd, EPOLL_CTL_DEL, fd, nullptr);
    if (err == 0)
    {
        m_args.erase(fd);
    }
    else
    {
        std::cout << "remove: " << strerror(errno) << "\n";
    }
}

void IOContext::run()
{
    int timeout = -1;
    size_t nEvents = 32;
    struct epoll_event* eventList = new struct epoll_event[nEvents];

    while (true)
    {
        int nReady = epoll_wait(m_fd, eventList, nEvents, timeout);
        if (nReady < 0)
        {
            delete []eventList;
            return;
        }

        for (int i = 0; i < nReady; i++)
        {
            int fd = eventList[i].data.fd;

            m_lock.lock();
            auto cb = m_args[fd].m_cb;
            m_lock.unlock();

            remove(fd);

            m_ioPool.submitTask([=]()
            {
                cb();
            });
        }
    }
}

3. Awaitable

实现 C++ 协程所需的类型,详细解释见 https://segmentfault.com/a/11...

awaitable.h:

#ifndef AWAITABLE_H
#define AWAITABLE_H

#include 
#include 
#include 
#include 
#include "io_context.h"

// 回调需要执行的操作类型:读、写、接受客户端连接
enum class HandlerType
{
    Read, Write, Accept,
};

class Awaitable
{
public:
    Awaitable(IOContext* ctx, int fd, int events, void* buf, size_t n, HandlerType ht);

    bool await_ready();
    void await_suspend(std::coroutine_handle<> handle);
    int await_resume();

private:
    IOContext* m_ctx;
    int m_fd;
    int m_events;
    void* m_buf;
    size_t m_n;
    int m_result;
    HandlerType m_ht;
};


struct CoroRetType
{
public:
    struct promise_type
    {
        CoroRetType get_return_object();
        std::suspend_never initial_suspend();
        std::suspend_never final_suspend() noexcept;
        void return_void();
        void unhandled_exception();
    };
};
#endif

awaitable.cpp:

#include 
#include "awaitable.h"

Awaitable::Awaitable(IOContext *ctx, int fd, int events, void* buf, size_t n, HandlerType ht)
        : m_ctx(ctx), m_fd(fd), m_events(events), m_buf(buf), m_n(n), m_ht(ht)
{}

bool Awaitable::await_ready()
{
    return false;
}

int Awaitable::await_resume()
{
    return m_result;
}

// 注册要监听的文件描述符、感兴趣的事件及要执行的回调
void Awaitable::await_suspend(std::coroutine_handle<> handle)
{
    auto cb = [handle, this]() mutable
    {
        switch (m_ht)
        {
            case HandlerType::Read:
                m_result = read(m_fd, m_buf, m_n);
                break;
            case HandlerType::Write:
                m_result = write(m_fd, m_buf, m_n);
                break;
            case HandlerType::Accept:
                m_result = accept(m_fd, nullptr, nullptr);
                break;
        }
        
        handle.resume();
    };

    Args args{cb};
    m_ctx->post(m_fd, m_events, args);
}


CoroRetType CoroRetType::promise_type::get_return_object()
{
    return CoroRetType();
}

std::suspend_never CoroRetType::promise_type::initial_suspend()
{
    return std::suspend_never{};
}

std::suspend_never CoroRetType::promise_type::final_suspend() noexcept
{
    return std::suspend_never{};
}

void CoroRetType::promise_type::return_void()
{}

void CoroRetType::promise_type::unhandled_exception()
{
    std::terminate();
}

4. 异步操作

使用协程来封装异步操作。

io_util.h:

#ifndef IO_UTIL_H
#define IO_UTIL_H

#include "io_context.h"
#include "awaitable.h"

Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncAccept(IOContext* ctx, int fd);

#endif

io_util.cpp:

#include "io_util.h"

Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n)
{
    return Awaitable(ctx, fd, EPOLLIN, buf, n, HandlerType::Read);
}

Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n)
{
    return Awaitable(ctx, fd, EPOLLOUT, buf, n, HandlerType::Write);
}

Awaitable asyncAccept(IOContext* ctx, int fd)
{
    return Awaitable(ctx, fd, EPOLLIN, nullptr, 0, HandlerType::Accept);
}

5. 例子

main.cpp:

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include "io_util.h"

static std::mutex ioLock;
static uint16_t port = 6666;
static int backlog = 32;
static const char* Msg = "hello, cpp!";
static const size_t MsgLen = 11;
static IOContext ioContext;

CoroRetType handleConnection(int fd)
{
    char buf[MsgLen+1] = {0};
    int n;

    n = co_await asyncRead(&ioContext, fd, buf, MsgLen);
    buf[n+1] = '\0';

    co_await asyncWrite(&ioContext, fd, buf, n);
    close(fd);
}

CoroRetType serverThread()
{
    int listenSock = socket(AF_INET, SOCK_STREAM, 0);

    int value = 1;
    setsockopt(listenSock, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));

    struct sockaddr_in addr;
    memset(&addr, 0, sizeof(addr));
    addr.sin_port = htons(port);
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = htonl(INADDR_ANY);

    int err = bind(listenSock, (const struct sockaddr*)&addr, sizeof(addr));
    listen(listenSock, backlog);

    while (true)
    {
        int clientSock = co_await asyncAccept(&ioContext, listenSock);

        auto h = [=]()
        {
            handleConnection(clientSock);
        };

        ioContext.post(h);
    }
}

void clientThread()
{
    using namespace std::literals;

    std::this_thread::sleep_for(1s);

    int sock = socket(AF_INET, SOCK_STREAM, 0);

    struct sockaddr_in addr;
    memset(&addr, 0, sizeof(addr));
    addr.sin_port = htons(port);
    addr.sin_family = AF_INET;
    inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);

    connect(sock, (const struct sockaddr*)&addr, sizeof(addr));

    char buf[MsgLen+1] = {0};

    ssize_t n = write(sock, Msg, MsgLen);

    read(sock, buf, n);
    buf[n+1] = '\0';

    std::lock_guard lock(ioLock);
    std::cout << "clientThread: " << buf << '\n';

    close(sock);
}

int main()
{
    serverThread();

    constexpr int N = 10;
    for (int i = 0; i < N; i++)
    {
        std::thread t(clientThread);
        t.detach();
    }

    ioContext.run();
}
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!

你可能感兴趣的