1 module arsd.rpc;
2 
3 /*
4 	FIXME:
5 		1) integrate with arsd.eventloop
6 		2) make it easy to use with other processes; pipe to a process and talk to it that way. perhaps with shared memory too?
7 		3) extend the serialization capabilities
8 */
9 
10 ///+ //example usage
11 interface ExampleNetworkFunctions {
12 	string sayHello(string name);
13 	int add(int a, int b);
14 	S2 structTest(S1);
15 	void die();
16 }
17  
18 // the server must implement the interface
19 class ExampleServer : ExampleNetworkFunctions {
20 	override string sayHello(string name) {
21 		return "Hello, " ~ name;
22 	}
23 
24 	override int add(int a, int b) {
25 		return a+b;
26 	}
27 
28 	override S2 structTest(S1 a) {
29 		return S2(a.name, a.number);
30 	}
31 
32 	override void die() {
33 		throw new Exception("death requested");
34 	}
35 
36 	mixin NetworkServer!ExampleNetworkFunctions;
37 }
38 
39 struct S1 {
40 	int number;
41 	string name;
42 }
43 
44 struct S2 {
45 	string name;
46 	int number;
47 }
48 
49 import std.stdio;
50 void main(string[] args) {
51 	if(args.length > 1) {
52 		auto client = makeNetworkClient!ExampleNetworkFunctions("localhost", 5005);
53 		// these work like the interface above, but instead of returning the value,
54 		// they take callbacks for success (where the arg is the retval)
55 		// and failure (the arg is the exception)
56 		client.sayHello("whoa", (a) { writeln(a); }, null);
57 		client.add(1,2, (a) { writeln(a); }, null);
58 		client.add(10,20, (a) { writeln(a); }, null);
59 		client.structTest(S1(20, "cool!"), (a) { writeln(a.name, " -- ", a.number); }, null);
60 		client.die(delegate () { writeln("shouldn't happen"); }, delegate(a) { writeln(a); });
61 		client.eventLoop();
62 
63 		/*
64 		auto client = makeNetworkClient!(ExampleNetworkFunctions, false)("localhost", 5005);
65 		writeln(client.sayHello("whoa"));
66 		writeln(client.add(1, 2));
67 		client.die();
68 		writeln(client.add(1, 2));
69 		*/
70 	} else {
71 		auto server = new ExampleServer(5005);
72 		server.eventLoop();
73 	}
74 }
75 //+/
76 
77 mixin template NetworkServer(Interface) {
78 	import std.socket;
79 	private Socket socket;
80 	public this(ushort port) {
81 		socket = new TcpSocket();
82 		socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true);
83 		socket.bind(new InternetAddress(port));
84 		socket.listen(16);
85 	}
86 
87 	final public void eventLoop() {
88 		auto check = new SocketSet();
89 		Socket[] connections;
90 		connections.reserve(16);
91 		ubyte[4096] buffer;
92 
93 		while(true) {
94 			check.reset();
95 			check.add(socket);
96 			foreach(connection; connections) {
97 				check.add(connection);
98 			}
99 
100 			if(Socket.select(check, null, null)) {
101 				if(check.isSet(socket)) {
102 					connections ~= socket.accept();
103 				}
104 
105 				foreach(connection; connections) {
106 					if(check.isSet(connection)) {
107 						auto gotNum = connection.receive(buffer);
108 						if(gotNum == 0) {
109 							// connection is closed, we could remove it from the list
110 						} else {
111 							auto got = buffer[0 .. gotNum];
112 							another:
113 							int length, functionNumber, sequenceNumber;
114 							got = deserializeInto(got, length);
115 							got = deserializeInto(got, functionNumber);
116 							got = deserializeInto(got, sequenceNumber);
117 
118 							//writeln("got ", sequenceNumber, " calling ", functionNumber);
119 
120 							auto remaining = got[length .. $];
121 							got = got[0 .. length];
122 							import std.conv;
123 							assert(length == got.length, to!string(length) ~ " != " ~ to!string(got.length)); // FIXME: what if it doesn't all come at once?
124 							callByNumber(functionNumber, sequenceNumber, got, connection);
125 
126 							if(remaining.length) {
127 								got = remaining;
128 								goto another;
129 							}
130 						}
131 					}
132 				}
133 			}
134 		}
135 	}
136 
137 	final private void callByNumber(int functionNumber, int sequenceNumber, const(ubyte)[] buffer, Socket connection) {
138 		ubyte[4096] sendBuffer;
139 		int length = 12;
140 		// length, sequence, success
141 		serialize(sendBuffer[4 .. 8], sequenceNumber);
142 		string callCode() {
143 			import std.conv;
144 			import std.traits;
145 			string code;
146 			foreach(memIdx, member; __traits(allMembers, Interface)) {
147 				code ~= "\t\tcase " ~ to!string(memIdx + 1) ~ ":\n";
148 				alias mem = PassThrough!(__traits(getMember, Interface, member));
149 				// we need to deserialize the arguments, call the function, and send back the response (if there is one)
150 				string argsString;
151 				foreach(i, arg; ParameterTypeTuple!mem) {
152 					if(i)
153 						argsString ~= ", ";
154 					auto istr = to!string(i);
155 					code ~= "\t\t\t" ~ arg.stringof ~ " arg" ~ istr ~ ";\n";
156 					code ~= "\t\t\tbuffer = deserializeInto(buffer, arg" ~ istr ~ ");\n";
157 
158 					argsString ~= "arg" ~ istr;
159 				}
160 
161 				// the call
162 				static if(is(ReturnType!mem == void)) {
163 					code ~= "\n\t\t\t" ~ member ~ "(" ~ argsString ~ ");\n";
164 				} else {
165 					// call and return answer
166 					code ~= "\n\t\t\tauto ret = " ~ member ~ "(" ~ argsString ~ ");\n";
167 
168 					code ~= "\t\t\tserialize(sendBuffer[8 .. 12], cast(int) 1);\n"; // yes success
169 					code ~= "\t\t\tauto serialized = serialize(sendBuffer[12 .. $], ret);\n";
170 					code ~= "\t\t\tserialize(sendBuffer[0 .. 4], cast(int) serialized.length);\n";
171 					code ~= "\t\t\tlength += serialized.length;\n";
172 				}
173 				code ~= "\t\tbreak;\n";
174 			}
175 			return code;
176 		}
177 
178 		try {
179 			switch(functionNumber) {
180 				default: assert(0, "unknown function");
181 				//pragma(msg, callCode());
182 				mixin(callCode());
183 			}
184 		} catch(Throwable t) {
185 			//writeln("thrown: ", t);
186 			serialize(sendBuffer[8 .. 12], cast(int) 0); // no success
187 
188 			auto place = sendBuffer[12 .. $];
189 			int l;
190 			auto s = serialize(place, t.msg);
191 			place = place[s.length .. $];
192 			l += s.length;
193 			s = serialize(place, t.file);
194 			place = place[s.length .. $];
195 			l += s.length;
196 			s = serialize(place, t.line);
197 			place = place[s.length .. $];
198 			l += s.length;
199 
200 			serialize(sendBuffer[0 .. 4], l);
201 			length += l;
202 		}
203 
204 		if(length != 12) // if there is a response...
205 			connection.send(sendBuffer[0 .. length]);
206 	}
207 }
208 
209 template PassThrough(alias a) {
210 	alias PassThrough = a;
211 }
212 
213 // general FIXME: what if we run out of buffer space?
214 
215 // returns the part of the buffer that was actually used
216 final public ubyte[] serialize(T)(ubyte[] buffer, in T s) {
217 	auto original = buffer;
218 	size_t totalLength = 0;
219 	import std.traits;
220 	static if(isArray!T) {
221 		/* length */ {
222 			auto used = serialize(buffer, cast(int)  s.length);
223 			totalLength += used.length;
224 			buffer = buffer[used.length .. $];
225 		}
226 		foreach(i; s) {
227 			auto used = serialize(buffer, i);
228 			totalLength += used.length;
229 			buffer = buffer[used.length .. $];
230 		}
231 	} else static if(isPointer!T) {
232 		static assert(0, "no pointers allowed");
233 	} else static if(!hasIndirections!T) {
234 		// covers int, float, char, etc. most the builtins
235 		import std.string;
236 		assert(buffer.length >= T.sizeof, format("%s won't fit in %s buffer", T.stringof, buffer.length));
237 		buffer[0 .. T.sizeof] = (cast(ubyte*)&s)[0 .. T.sizeof];
238 		totalLength += T.sizeof;
239 		buffer = buffer[T.sizeof .. $];
240 	} else {
241 		// structs, classes, etc.
242 		foreach(i, t; s.tupleof) {
243 			auto used = serialize(buffer, t);
244 			totalLength += used.length;
245 			buffer = buffer[used.length .. $];
246 		}
247 	}
248 
249 	return original[0 .. totalLength];
250 }
251 
252 // returns the remaining part of the buffer
253 final public inout(ubyte)[] deserializeInto(T)(inout(ubyte)[] buffer, ref T s) {
254 	import std.traits;
255 
256 	static if(isArray!T) {
257 		size_t length;
258 		buffer = deserializeInto(buffer, length);
259 		s.length = length;
260 		foreach(i; 0 .. length)
261 			buffer = deserializeInto(buffer, s[i]);
262 	} else static if(isPointer!T) {
263 		static assert(0, "no pointers allowed");
264 	} else static if(!hasIndirections!T) {
265 		// covers int, float, char, etc. most the builtins
266 		(cast(ubyte*)(&s))[0 .. T.sizeof] = buffer[0 .. T.sizeof];
267 		buffer = buffer[T.sizeof .. $];
268 	} else {
269 		// structs, classes, etc.
270 		foreach(i, t; s.tupleof) {
271 			buffer = deserializeInto(buffer, s.tupleof[i]);
272 		}
273 	}
274 
275 	return buffer;
276 }
277 
278 mixin template NetworkClient(Interface, bool useAsync = true) {
279 	private static string createClass() {
280 		// this doesn't actually inherit from the interface because
281 		// the return value needs to be handled async
282 		string code;// = `final class Class /*: ` ~ Interface.stringof ~ `*/ {`;
283 		code ~= "\n\timport std.socket;";
284 		code ~= "\n\tprivate Socket socket;";
285 		if(useAsync) {
286 			code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onSuccesses;";
287 			code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onErrors;";
288 		}
289 		code ~= "\n\tprivate uint lastSequenceNumber;";
290 		code ~= q{
291 	private this(string host, ushort port) {
292 		this.socket = new TcpSocket();
293 		this.socket.connect(new InternetAddress(host, port));
294 	}
295 	};
296 
297 		if(useAsync)
298 		code ~= q{
299 	final public void eventLoop() {
300 		ubyte[4096] buffer;
301 		bool open = true;
302 
303 		do {
304 			auto gotNum = socket.receive(buffer);
305 			if(gotNum == 0) {
306 				open = false;
307 				break;
308 			}
309 			while(gotNum < 9) {
310 				auto g2 = socket.receive(buffer[gotNum .. $]);
311 				if(g2 == 0) {
312 					open = false;
313 					break;
314 				}
315 				gotNum += g2;
316 			}
317 
318 			auto got = buffer[0 .. gotNum];
319 			another:
320 			uint length, seq;
321 			uint success;
322 			got = deserializeInto(got, length);
323 			got = deserializeInto(got, seq);
324 			got = deserializeInto(got, success);
325 			auto more = got[length .. $];
326 
327 			if(got.length >= length) {
328 				if(success) {
329 					auto s = (seq in onSuccesses);
330 					if(s !is null && *s !is null)
331 						(*s)(got);
332 				} else {
333 					auto s = (seq in onErrors);
334 					if(s !is null && *s !is null)
335 						(*s)(got);
336 				}
337 			}
338 
339 			if(more.length) {
340 				got = more;
341 				goto another;
342 			}
343 		} while(open);
344 	}
345 	};
346 		code ~= "\n\tpublic:\n";
347 
348 		foreach(memIdx, member; __traits(allMembers, Interface)) {
349 			import std.traits;
350 			alias mem = PassThrough!(__traits(getMember, Interface, member));
351 			string type;
352 			if(useAsync)
353 				type = "void";
354 			else {
355 				static if(is(ReturnType!mem == void))
356 					type = "void";
357 				else
358 					type = (ReturnType!mem).stringof;
359 			}
360 			code ~= "\t\tfinal "~type~" " ~ member ~ "(";
361 			bool hadArgument = false;
362 			import std.conv;
363 			// arguments
364 			foreach(i, arg; ParameterTypeTuple!mem) {
365 				if(hadArgument)
366 					code ~= ", ";
367 				// FIXME: this is one place the arg can get unknown if we don't have all the imports
368 				code ~= arg.stringof ~ " arg" ~ to!string(i);
369 				hadArgument = true;
370 			}
371 
372 			if(useAsync) {
373 				if(hadArgument)
374 					code ~= ", ";
375 
376 				static if(is(ReturnType!mem == void))
377 					code ~= "void delegate() onSuccess";
378 				else
379 					code ~= "void delegate("~(ReturnType!mem).stringof~") onSuccess";
380 				code ~= ", ";
381 				code ~= "void delegate(Throwable) onError";
382 			}
383 			code ~= ") {\n";
384 			code ~= "auto seq = ++lastSequenceNumber;";
385 		if(useAsync)
386 		code ~= q{
387 			#line 252
388 			onSuccesses[seq] = (const(ubyte)[] buffer) {
389 				onSuccesses.remove(seq);
390 				onErrors.remove(seq);
391 
392 				import std.traits;
393 
394 				static if(is(ParameterTypeTuple!(typeof(onSuccess)) == void)) {
395 					if(onSuccess !is null)
396 						onSuccess();
397 				} else {
398 					ParameterTypeTuple!(typeof(onSuccess)) args;
399 					foreach(i, arg; args)
400 						buffer = deserializeInto(buffer, args[i]);
401 					if(onSuccess !is null)
402 						onSuccess(args);
403 				}
404 			};
405 			onErrors[seq] = (const(ubyte)[] buffer) {
406 				onSuccesses.remove(seq);
407 				onErrors.remove(seq);
408 				auto t = new Throwable("");
409 				buffer = deserializeInto(buffer, t.msg);
410 				buffer = deserializeInto(buffer, t.file);
411 				buffer = deserializeInto(buffer, t.line);
412 
413 				if(onError !is null)
414 					onError(t);
415 			};
416 		};
417 
418 		code ~= q{
419 			#line 283
420 			ubyte[4096] bufferBase;
421 			auto buffer = bufferBase[12 .. $]; // leaving room for size, func number, and seq number
422 			ubyte[] serialized;
423 			int used;
424 		};
425 			// preparing the request
426 			foreach(i, arg; ParameterTypeTuple!mem) {
427 				code ~= "\t\t\tserialized = serialize(buffer, arg" ~ to!string(i) ~ ");\n";
428 				code ~= "\t\t\tused += serialized.length;\n";
429 				code ~= "\t\t\tbuffer = buffer[serialized.length .. $];\n";
430 			}
431 
432 			code ~= "\t\t\tserialize(bufferBase[0 .. 4], used);\n";
433 			code ~= "\t\t\tserialize(bufferBase[4 .. 8], " ~ to!string(memIdx + 1) ~ ");\n";
434 			code ~= "\t\t\tserialize(bufferBase[8 .. 12], seq);\n";
435 
436 			// FIXME: what if it doesn't all send at once?
437 			code ~= "\t\t\tsocket.send(bufferBase[0 .. 12 + used]);\n";
438 			//code ~= `writeln("sending ", bufferBase[0 .. 12 + used]);`;
439 
440 		if(!useAsync)
441 		code ~= q{
442 			ubyte[4096] dbuffer;
443 			bool open = true;
444 			static if(is(typeof(return) == void)) {
445 
446 			} else
447 				typeof(return) returned;
448 
449 			auto gotNum = socket.receive(dbuffer);
450 			if(gotNum == 0) {
451 				open = false;
452 				throw new Exception("connection closed");
453 			}
454 			while(gotNum < 9) {
455 				auto g2 = socket.receive(dbuffer[gotNum .. $]);
456 				if(g2 == 0) {
457 					open = false;
458 					break;
459 				}
460 				gotNum += g2;
461 			}
462 
463 			auto got = dbuffer[0 .. gotNum];
464 			another:
465 			uint length;
466 			uint success;
467 			got = deserializeInto(got, length);
468 			got = deserializeInto(got, seq);
469 			got = deserializeInto(got, success);
470 			auto more = got[length .. $];
471 
472 			if(got.length >= length) {
473 				if(success) {
474 					/*
475 					auto s = (seq in onSuccesses);
476 					if(s !is null && *s !is null)
477 						(*s)(got);
478 					*/
479 					static if(is(typeof(return) == void)) {
480 					} else {
481 						got = deserializeInto(got, returned);
482 					}
483 				} else {
484 					/*
485 					auto s = (seq in onErrors);
486 					if(s !is null && *s !is null)
487 						(*s)(got);
488 					*/
489 					auto t = new Throwable("");
490 					got = deserializeInto(got, t.msg);
491 					got = deserializeInto(got, t.file);
492 					got = deserializeInto(got, t.line);
493 					throw t;
494 				}
495 			}
496 
497 			if(more.length) {
498 				got = more;
499 				goto another;
500 			}
501 			static if(is(typeof(return) == void)) {
502 
503 			} else
504 				return returned;
505 		};
506 
507 			code ~= "}\n";
508 			code ~= "\n";
509 		}
510 		//code ~= `}`;
511 		return code;
512 	}
513 
514 	//pragma(msg, createClass()); // for debugging help
515 	mixin(createClass());
516 }
517 
518 auto makeNetworkClient(Interface, bool useAsync = true)(string host, ushort port) {
519 	class Thing {
520 		mixin NetworkClient!(Interface, useAsync);
521 	}
522 
523 	return new Thing(host, port);
524 }
525 
526 // the protocol is:
527 /*
528 
529 client connects
530 	ulong interface hash
531 
532 handshake complete
533 
534 messages:
535 
536 	uint messageLength
537 	uint sequence number
538 	ushort function number, 0 is reserved for interface check
539 	serialized arguments....
540 
541 
542 
543 server responds with answers:
544 
545 	uint messageLength
546 	uint re: sequence number
547 	ubyte, 1 == success, 0 == error
548 	serialized return value
549 
550 */