#include "tra.h"

/*
 * Simple threads for C.  Each thread effectively runs on its
 * own private copy of the main system stack, so stack pointers
 * are only meaningful to the currently executing thread.
 *
 * There is only one thread run queue.  While there may be 
 * helper I/O procs, threads only run in the main proc.
 */

void
ABORT(void)
{
	*(uchar*)0 = 0;
}

/* ----------------- */
Label mainlabel;

static void swtch0(Ctxt*, Ctxt*);
static void swtch1(Ctxt*);
static void swtch2(Ctxt*);

/* the ABORT() calls keep smart compilers from optimizing the tail recursion */

static void
switchctxt(Ctxt *cur, Ctxt *nxt)
{
	/*
	 * save the current stack position and registers
	 * the first time through.  swtch0 won't return.
	 */
	if(setlabel(&cur->label)==0){
		swtch0(cur, nxt);
		ABORT();
	}
}

static void
swtch0(Ctxt *cur, Ctxt *nxt)
{
	uchar *lo, *hi;
	int n;

	/* save the stack for the current thread */
	lo = (uchar*)cur->label.sp;
	hi = (uchar*)mainlabel.sp;
	n = hi-lo;
	cur->stk = emalloc(n);
	memmove(cur->stk, lo, n);

	/* switch to the next thread */
	swtch1(nxt);
	ABORT();
}

static void
swtch1(Ctxt *nxt)
{
	uchar buf[64];	/* take up stack space */

	/* recurse until we're out of the area the new stack will occupy. */
	if(nxt->label.sp < (intptr)buf)
		swtch1(nxt);

	/* do the copy */
	swtch2(nxt);
	ABORT();
}

static void
swtch2(Ctxt *nxt)
{
	uchar *lo, *hi;
	int n;

	/* put the next stack in place */
	lo = (uchar*)nxt->label.sp;
	hi = (uchar*)mainlabel.sp;
	n = hi-lo;
	memmove(lo, nxt->stk, n);

	free(nxt->stk);
	nxt->stk = nil;

	/* hop to it */
	gotolabel(&nxt->label);
	ABORT();
}

/* ----------------- */
struct Thread
{
	Ctxt ctxt;
	void (*fn)(void*);
	void *arg;
	int moribund;
	void *chandata;

	Thread *nextq;
	Thread *nextall;
	Thread *nextchan;

	void *data;
};

static int runqwait;
static int threadfd[2];
Thread *allthread, **eallthread;
Thread *curthread;
static Thread *tsched;
static Thread *runq;
//static Thread **endrunq;
static Lock runqlock;
void threadmain(int argc, char **argv);
static Thread *runthread(void);

static void
schedthread(void *arg)
{
	Thread **l;
	USED(arg);

	for(;;){
		curthread = runthread();
		switchctxt(&tsched->ctxt, &curthread->ctxt);
		if(curthread->moribund){
			free(curthread->ctxt.stk);
			for(l=&allthread; *l; l=&(*l)->nextall){
				if(*l == curthread){
					*l = curthread->nextall;
					if(*l == nil)
						eallthread = l;
					break;
				}
			}
			free(curthread);
		}
		curthread = nil;
	}
}

void
threadsleep(void)
{
	switchctxt(&curthread->ctxt, &tsched->ctxt);
}

static Thread*
runthread(void)
{
	char c;
	Thread *t;

//print("runthread\n");
	for(;;){
		lock(&runqlock);
		t = runq;
		if(t)
			runq = t->nextq;
		else
			runqwait = 1;
//print("runthread t %p\n", t);
		unlock(&runqlock);
		if(t)
			break;
		if(read(threadfd[0], &c, 1) != 1)
			ABORT();
	}
	return t;
}

void
threadready(Thread *t)
{
	int dowake;

	dowake = 0;
	lock(&runqlock);
/*
	if(runq==nil)
		endrunq = &runq;
	t->nextq = nil;
	*endrunq = t;
*/
//print("threadready %p\n", t);
t->nextq = runq;
runq = t;
	if(runqwait){
		dowake = 1;
		runqwait = 0;
	}
//	endrunq = &t->nextq;
	unlock(&runqlock);
	if(dowake)
		write(threadfd[1], "c", 1);
}

void
yield(void)
{
	threadready(curthread);
	threadsleep();
}

void
threadexit(void)
{
	curthread->moribund = 1;
	threadsleep();
}

void**
threaddata(void)
{
	return &curthread->data;
}

static void
threadlaunch(void *arg)
{
	Thread *t;

	t = arg;
	t->fn(t->arg);
	threadexit();
}

static Thread*
allocthread(void (*fn)(void*), void *arg)
{
	Thread *t;

	t = emalloc(sizeof(*t));
	t->fn = fn;
	t->arg = arg;
	if(allthread == nil)
		eallthread = &allthread;
	*eallthread = t;
	t->nextall = nil;
	initctxt(&t->ctxt, threadlaunch, t);
	return t;
}

void
threadcreate(void (*fn)(void*), void *arg)
{
	threadready(allocthread(fn, arg));
}

int
main(int argc, char **argv)
{
	pipe(threadfd);
	setlabel(&mainlabel);	/* only used for stack pointer */

	curthread = emalloc(sizeof(Thread));
	tsched = allocthread(schedthread, nil);
	threadmain(argc, argv);	
	threadexit();
	exits(0);
	return 0;
}

/* ------ */

/*
 * Many-reader, many-writer channels.  This depends
 * on the round-robin nature of the thread run queue.
 */
struct Chan
{
	Thread *rd;
	Thread **erd;
	Thread *wr;
	Thread **ewr;
	void *a;
	int n;
	int working;
};

static void
chanmatch(Chan *c)
{
	if(c->wr && c->rd && !c->working){
		c->working = 1;
		threadready(c->wr);
	}
}

void
send(Chan *c, void *v)
{
	if(c->wr == nil)
		c->ewr = &c->wr;
	*c->ewr = curthread;
	c->ewr = &curthread->nextchan;
	curthread->nextchan = nil;

	chanmatch(c);
	threadsleep();
	/* woke: there is a reader */
	memmove(c->a, v, c->n);
	threadready(c->rd);
	threadsleep();
	/* woke: reader is done with our data */
}

void
recv(Chan *c, void *v)
{
	if(c->rd == nil)
		c->erd = &c->rd;
	*c->erd = curthread;
	c->erd = &curthread->nextchan;
	curthread->nextchan = nil;

	chanmatch(c);
	threadsleep();
	/* woke: a writer has written to c->a */
	memmove(v, c->a, c->n);
	threadready(c->wr);
	c->wr = c->wr->nextchan;
	c->rd = c->rd->nextchan;
	c->working = 0;
	chanmatch(c);
}

void
sendp(Chan *c, void *a)
{
	if(c->n != sizeof(void*))
		panic("bad arg to sendp: expected chan(%d) got chan(%d)", sizeof(void*), c->n);

	send(c, &a);
}

void
sendul(Chan *c, ulong u)
{
	if(c->n != sizeof(ulong))
		panic("bad arg to sendul: expected chan(%d) got chan(%d)", sizeof(ulong), c->n);

	send(c, &u);
}

void*
recvp(Chan *c)
{
	void *a;

	if(c->n != sizeof(void*))
		panic("bad arg to recvp: expected chan(%d) got chan(%d)", sizeof(void*), c->n);

	recv(c, &a);
	return a;
}

ulong
recvul(Chan *c)
{
	ulong u;

	if(c->n != sizeof(ulong))
		panic("bad arg to recvul: expected chan(%d) got chan(%d)", sizeof(ulong), c->n);

	recv(c, &u);
	return u;
}

Chan*
_chan(int n)
{
	Chan *c;

	c = emalloc(sizeof(Chan)+n);
	c->a = &c[1];
	c->n = n;
	setmalloctag(c, getcallerpc(&n));
	return c;
}

/* ------ */
#ifdef ASDF
typedef struct Rendez Rendez;
struct Rendez
{
	Thread *first;
	ulong val;
};

static ulong
rendez(Rendez *r, ulong v)
{
	ulong ret;

	if(r->first==nil){
		/* got here first */
		r->val = v;
		r->first = curthread;
		switchctxt(&curthread->ctxt, &tsched->ctxt);
		return r->val;
	}else{
		/* got here second */
		ret = r->val;
		r->val = v;
		threadready(r->first);
		yield();
		return ret;
	}
}
#endif /* ASDF */

