#include "co.h"
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
enum state{
CREATED = 0,
RUNNING,
HALT,
WAIT,
FINISHED
};
#define STACK_SIZE 4 * 1024 * 1024 * sizeof(char)
struct co {
char stack[STACK_SIZE];
const char* name;
void* arg;
enum state s;
jmp_buf context;
uint64_t RIP;
uint64_t RBP;
uint64_t RSP;
struct co* wait;
//当前协程被谁等
struct co* waited;
};
struct{
// coroutine[0] 始终存储main函数的stack frame
struct co* coroutine[100];
int count;
int cur_idx;
int cnt;
uint64_t yield_func;
}coroutines;
struct co *co_start(const char *name, void (*func)(void *), void *arg) {
// 如果是NULL,说明是首次调度,给main函数创建一个coroutine
if(coroutines.coroutine[0] == NULL){
coroutines.yield_func = (uint64_t)(void (*)())finish_co;
struct co* coroutine = (struct co*) malloc(sizeof(struct co));
coroutine->s = HALT;
coroutines.coroutine[0] = coroutine;
coroutines.cnt = 1;
}
++coroutines.cnt;
struct co* coroutine = (struct co*) malloc(sizeof(struct co));
coroutine->name = name;
coroutine->arg = arg;
coroutine->s = CREATED;
coroutine->RIP = (uint64_t)func;
coroutine->RSP = (uint64_t)coroutine->stack + STACK_SIZE;
coroutines.coroutine[++coroutines.count] = coroutine;
return coroutine;
}
void co_wait(struct co *co) {
if(co->s == FINISHED) return;
struct co* cur = coroutines.coroutine[coroutines.cur_idx];
co->waited = cur;
cur->wait = co;
cur->s = WAIT;
co_yield();
}
void finish_co(){
struct co *co = coroutines.coroutine[coroutines.cur_idx];
co->s = FINISHED;
--coroutines.cnt;
if(co->waited != NULL){
co->waited->wait = NULL;
co->waited->s = HALT;
}
co_yield();
}
void change_thread(int next_idx){
struct co* current = coroutines.coroutine[coroutines.cur_idx];
struct co* next = coroutines.coroutine[next_idx];
coroutines.cur_idx = next_idx;
if(current->s == RUNNING) current->s = HALT;
int flag = next->s;
next->s = RUNNING;
if(flag == CREATED){
asm volatile(
"movq %0, %%rsp\n\t"
"pushq %3\n\t"
"movq %1, %%rdi\n\t"
"jmp *%2\n\t"
:
: "m"(next->RSP),
"m"(next->arg),
"m"(next->RIP),
"m"(coroutines.yield_func)
);
} else {
longjmp(next->context, 1);
}
}
void co_yield() {
if(!coroutines.cnt) return;
int nxt = coroutines.cur_idx;
struct co* cur_co = coroutines.coroutine[nxt];
if(!setjmp(cur_co->context))
{
// jump to other program
for(nxt = (nxt+1) % (coroutines.count + 1);
nxt <= coroutines.count;
nxt=(nxt+1)%(coroutines.count+1))
{
if(coroutines.coroutine[nxt]->s == HALT
|| coroutines.coroutine[nxt]->s == CREATED)
{
change_thread(nxt);
break;
}
}
}
}
标签:co,cur,coroutine,next,coroutines,协程,struct
From: https://www.cnblogs.com/INnoVationv2/p/17531836.html