1 /// I never finished this. The idea is to use CT reflection to make calling another process feel as simple as calling in-process objects. Will come eventually but no promises.
2 module arsd.rpc;
3 
4 /*
5 	FIXME:
6 		1) integrate with arsd.eventloop
7 		2) make it easy to use with other processes; pipe to a process and talk to it that way. perhaps with shared memory too?
8 		3) extend the serialization capabilities
9 
10 
11 	@Throws!(List, Of, Exceptions)
12 	classes are also RPC proxied
13 	stdin/out/err also redirected
14 */
15 
16 ///+ //example usage
17 interface ExampleNetworkFunctions {
18 	string sayHello(string name);
19 	int add(int a, int b);
20 	S2 structTest(S1);
21 	void die();
22 }
23  
24 // the server must implement the interface
25 class ExampleServer : ExampleNetworkFunctions {
26 	override string sayHello(string name) {
27 		return "Hello, " ~ name;
28 	}
29 
30 	override int add(int a, int b) {
31 		return a+b;
32 	}
33 
34 	override S2 structTest(S1 a) {
35 		return S2(a.name, a.number);
36 	}
37 
38 	override void die() {
39 		throw new Exception("death requested");
40 	}
41 
42 	mixin NetworkServer!ExampleNetworkFunctions;
43 }
44 
45 struct S1 {
46 	int number;
47 	string name;
48 }
49 
50 struct S2 {
51 	string name;
52 	int number;
53 }
54 
55 import std.stdio;
56 void main(string[] args) {
57 	if(args.length > 1) {
58 		auto client = makeNetworkClient!ExampleNetworkFunctions("localhost", 5005);
59 		// these work like the interface above, but instead of returning the value,
60 		// they take callbacks for success (where the arg is the retval)
61 		// and failure (the arg is the exception)
62 		client.sayHello("whoa", (a) { writeln(a); }, null);
63 		client.add(1,2, (a) { writeln(a); }, null);
64 		client.add(10,20, (a) { writeln(a); }, null);
65 		client.structTest(S1(20, "cool!"), (a) { writeln(a.name, " -- ", a.number); }, null);
66 		client.die(delegate () { writeln("shouldn't happen"); }, delegate(a) { writeln(a); });
67 		client.eventLoop();
68 
69 		/*
70 		auto client = makeNetworkClient!(ExampleNetworkFunctions, false)("localhost", 5005);
71 		writeln(client.sayHello("whoa"));
72 		writeln(client.add(1, 2));
73 		client.die();
74 		writeln(client.add(1, 2));
75 		*/
76 	} else {
77 		auto server = new ExampleServer(5005);
78 		server.eventLoop();
79 	}
80 }
81 //+/
82 
83 mixin template NetworkServer(Interface) {
84 	import std.socket;
85 	private Socket socket;
86 	public this(ushort port) {
87 		socket = new TcpSocket();
88 		socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true);
89 		socket.bind(new InternetAddress(port));
90 		socket.listen(16);
91 	}
92 
93 	final public void eventLoop() {
94 		auto check = new SocketSet();
95 		Socket[] connections;
96 		connections.reserve(16);
97 		ubyte[4096] buffer;
98 
99 		while(true) {
100 			check.reset();
101 			check.add(socket);
102 			foreach(connection; connections) {
103 				check.add(connection);
104 			}
105 
106 			if(Socket.select(check, null, null)) {
107 				if(check.isSet(socket)) {
108 					connections ~= socket.accept();
109 				}
110 
111 				foreach(connection; connections) {
112 					if(check.isSet(connection)) {
113 						auto gotNum = connection.receive(buffer);
114 						if(gotNum == 0) {
115 							// connection is closed, we could remove it from the list
116 						} else {
117 							auto got = buffer[0 .. gotNum];
118 							another:
119 							int length, functionNumber, sequenceNumber;
120 							got = deserializeInto(got, length);
121 							got = deserializeInto(got, functionNumber);
122 							got = deserializeInto(got, sequenceNumber);
123 
124 							//writeln("got ", sequenceNumber, " calling ", functionNumber);
125 
126 							auto remaining = got[length .. $];
127 							got = got[0 .. length];
128 							import std.conv;
129 							assert(length == got.length, to!string(length) ~ " != " ~ to!string(got.length)); // FIXME: what if it doesn't all come at once?
130 							callByNumber(functionNumber, sequenceNumber, got, connection);
131 
132 							if(remaining.length) {
133 								got = remaining;
134 								goto another;
135 							}
136 						}
137 					}
138 				}
139 			}
140 		}
141 	}
142 
143 	final private void callByNumber(int functionNumber, int sequenceNumber, const(ubyte)[] buffer, Socket connection) {
144 		ubyte[4096] sendBuffer;
145 		int length = 12;
146 		// length, sequence, success
147 		serialize(sendBuffer[4 .. 8], sequenceNumber);
148 		string callCode() {
149 			import std.conv;
150 			import std.traits;
151 			string code;
152 			foreach(memIdx, member; __traits(allMembers, Interface)) {
153 				code ~= "\t\tcase " ~ to!string(memIdx + 1) ~ ":\n";
154 				alias mem = PassThrough!(__traits(getMember, Interface, member));
155 				// we need to deserialize the arguments, call the function, and send back the response (if there is one)
156 				string argsString;
157 				foreach(i, arg; ParameterTypeTuple!mem) {
158 					if(i)
159 						argsString ~= ", ";
160 					auto istr = to!string(i);
161 					code ~= "\t\t\t" ~ arg.stringof ~ " arg" ~ istr ~ ";\n";
162 					code ~= "\t\t\tbuffer = deserializeInto(buffer, arg" ~ istr ~ ");\n";
163 
164 					argsString ~= "arg" ~ istr;
165 				}
166 
167 				// the call
168 				static if(is(ReturnType!mem == void)) {
169 					code ~= "\n\t\t\t" ~ member ~ "(" ~ argsString ~ ");\n";
170 				} else {
171 					// call and return answer
172 					code ~= "\n\t\t\tauto ret = " ~ member ~ "(" ~ argsString ~ ");\n";
173 
174 					code ~= "\t\t\tserialize(sendBuffer[8 .. 12], cast(int) 1);\n"; // yes success
175 					code ~= "\t\t\tauto serialized = serialize(sendBuffer[12 .. $], ret);\n";
176 					code ~= "\t\t\tserialize(sendBuffer[0 .. 4], cast(int) serialized.length);\n";
177 					code ~= "\t\t\tlength += serialized.length;\n";
178 				}
179 				code ~= "\t\tbreak;\n";
180 			}
181 			return code;
182 		}
183 
184 		try {
185 			switch(functionNumber) {
186 				default: assert(0, "unknown function");
187 				//pragma(msg, callCode());
188 				mixin(callCode());
189 			}
190 		} catch(Throwable t) {
191 			//writeln("thrown: ", t);
192 			serialize(sendBuffer[8 .. 12], cast(int) 0); // no success
193 
194 			auto place = sendBuffer[12 .. $];
195 			int l;
196 			auto s = serialize(place, t.msg);
197 			place = place[s.length .. $];
198 			l += s.length;
199 			s = serialize(place, t.file);
200 			place = place[s.length .. $];
201 			l += s.length;
202 			s = serialize(place, t.line);
203 			place = place[s.length .. $];
204 			l += s.length;
205 
206 			serialize(sendBuffer[0 .. 4], l);
207 			length += l;
208 		}
209 
210 		if(length != 12) // if there is a response...
211 			connection.send(sendBuffer[0 .. length]);
212 	}
213 }
214 
215 template PassThrough(alias a) {
216 	alias PassThrough = a;
217 }
218 
219 // general FIXME: what if we run out of buffer space?
220 
221 // returns the part of the buffer that was actually used
222 final public ubyte[] serialize(T)(ubyte[] buffer, in T s) {
223 	auto original = buffer;
224 	size_t totalLength = 0;
225 	import std.traits;
226 	static if(isArray!T) {
227 		/* length */ {
228 			auto used = serialize(buffer, cast(int)  s.length);
229 			totalLength += used.length;
230 			buffer = buffer[used.length .. $];
231 		}
232 		foreach(i; s) {
233 			auto used = serialize(buffer, i);
234 			totalLength += used.length;
235 			buffer = buffer[used.length .. $];
236 		}
237 	} else static if(isPointer!T) {
238 		static assert(0, "no pointers allowed");
239 	} else static if(!hasIndirections!T) {
240 		// covers int, float, char, etc. most the builtins
241 		import std.string;
242 		assert(buffer.length >= T.sizeof, format("%s won't fit in %s buffer", T.stringof, buffer.length));
243 		buffer[0 .. T.sizeof] = (cast(ubyte*)&s)[0 .. T.sizeof];
244 		totalLength += T.sizeof;
245 		buffer = buffer[T.sizeof .. $];
246 	} else {
247 		// structs, classes, etc.
248 		foreach(i, t; s.tupleof) {
249 			auto used = serialize(buffer, t);
250 			totalLength += used.length;
251 			buffer = buffer[used.length .. $];
252 		}
253 	}
254 
255 	return original[0 .. totalLength];
256 }
257 
258 // returns the remaining part of the buffer
259 final public inout(ubyte)[] deserializeInto(T)(inout(ubyte)[] buffer, ref T s) {
260 	import std.traits;
261 
262 	static if(isArray!T) {
263 		size_t length;
264 		buffer = deserializeInto(buffer, length);
265 		s.length = length;
266 		foreach(i; 0 .. length)
267 			buffer = deserializeInto(buffer, s[i]);
268 	} else static if(isPointer!T) {
269 		static assert(0, "no pointers allowed");
270 	} else static if(!hasIndirections!T) {
271 		// covers int, float, char, etc. most the builtins
272 		(cast(ubyte*)(&s))[0 .. T.sizeof] = buffer[0 .. T.sizeof];
273 		buffer = buffer[T.sizeof .. $];
274 	} else {
275 		// structs, classes, etc.
276 		foreach(i, t; s.tupleof) {
277 			buffer = deserializeInto(buffer, s.tupleof[i]);
278 		}
279 	}
280 
281 	return buffer;
282 }
283 
284 mixin template NetworkClient(Interface, bool useAsync = true) {
285 	private static string createClass() {
286 		// this doesn't actually inherit from the interface because
287 		// the return value needs to be handled async
288 		string code;// = `final class Class /*: ` ~ Interface.stringof ~ `*/ {`;
289 		code ~= "\n\timport std.socket;";
290 		code ~= "\n\tprivate Socket socket;";
291 		if(useAsync) {
292 			code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onSuccesses;";
293 			code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onErrors;";
294 		}
295 		code ~= "\n\tprivate uint lastSequenceNumber;";
296 		code ~= q{
297 	private this(string host, ushort port) {
298 		this.socket = new TcpSocket();
299 		this.socket.connect(new InternetAddress(host, port));
300 	}
301 	};
302 
303 		if(useAsync)
304 		code ~= q{
305 	final public void eventLoop() {
306 		ubyte[4096] buffer;
307 		bool open = true;
308 
309 		do {
310 			auto gotNum = socket.receive(buffer);
311 			if(gotNum == 0) {
312 				open = false;
313 				break;
314 			}
315 			while(gotNum < 9) {
316 				auto g2 = socket.receive(buffer[gotNum .. $]);
317 				if(g2 == 0) {
318 					open = false;
319 					break;
320 				}
321 				gotNum += g2;
322 			}
323 
324 			auto got = buffer[0 .. gotNum];
325 			another:
326 			uint length, seq;
327 			uint success;
328 			got = deserializeInto(got, length);
329 			got = deserializeInto(got, seq);
330 			got = deserializeInto(got, success);
331 			auto more = got[length .. $];
332 
333 			if(got.length >= length) {
334 				if(success) {
335 					auto s = (seq in onSuccesses);
336 					if(s !is null && *s !is null)
337 						(*s)(got);
338 				} else {
339 					auto s = (seq in onErrors);
340 					if(s !is null && *s !is null)
341 						(*s)(got);
342 				}
343 			}
344 
345 			if(more.length) {
346 				got = more;
347 				goto another;
348 			}
349 		} while(open);
350 	}
351 	};
352 		code ~= "\n\tpublic:\n";
353 
354 		foreach(memIdx, member; __traits(allMembers, Interface)) {
355 			import std.traits;
356 			alias mem = PassThrough!(__traits(getMember, Interface, member));
357 			string type;
358 			if(useAsync)
359 				type = "void";
360 			else {
361 				static if(is(ReturnType!mem == void))
362 					type = "void";
363 				else
364 					type = (ReturnType!mem).stringof;
365 			}
366 			code ~= "\t\tfinal "~type~" " ~ member ~ "(";
367 			bool hadArgument = false;
368 			import std.conv;
369 			// arguments
370 			foreach(i, arg; ParameterTypeTuple!mem) {
371 				if(hadArgument)
372 					code ~= ", ";
373 				// FIXME: this is one place the arg can get unknown if we don't have all the imports
374 				code ~= arg.stringof ~ " arg" ~ to!string(i);
375 				hadArgument = true;
376 			}
377 
378 			if(useAsync) {
379 				if(hadArgument)
380 					code ~= ", ";
381 
382 				static if(is(ReturnType!mem == void))
383 					code ~= "void delegate() onSuccess";
384 				else
385 					code ~= "void delegate("~(ReturnType!mem).stringof~") onSuccess";
386 				code ~= ", ";
387 				code ~= "void delegate(Throwable) onError";
388 			}
389 			code ~= ") {\n";
390 			code ~= "auto seq = ++lastSequenceNumber;";
391 		if(useAsync)
392 		code ~= q{
393 			#line 252
394 			onSuccesses[seq] = (const(ubyte)[] buffer) {
395 				onSuccesses.remove(seq);
396 				onErrors.remove(seq);
397 
398 				import std.traits;
399 
400 				static if(is(ParameterTypeTuple!(typeof(onSuccess)) == void)) {
401 					if(onSuccess !is null)
402 						onSuccess();
403 				} else {
404 					ParameterTypeTuple!(typeof(onSuccess)) args;
405 					foreach(i, arg; args)
406 						buffer = deserializeInto(buffer, args[i]);
407 					if(onSuccess !is null)
408 						onSuccess(args);
409 				}
410 			};
411 			onErrors[seq] = (const(ubyte)[] buffer) {
412 				onSuccesses.remove(seq);
413 				onErrors.remove(seq);
414 				auto t = new Throwable("");
415 				buffer = deserializeInto(buffer, t.msg);
416 				buffer = deserializeInto(buffer, t.file);
417 				buffer = deserializeInto(buffer, t.line);
418 
419 				if(onError !is null)
420 					onError(t);
421 			};
422 		};
423 
424 		code ~= q{
425 			#line 283
426 			ubyte[4096] bufferBase;
427 			auto buffer = bufferBase[12 .. $]; // leaving room for size, func number, and seq number
428 			ubyte[] serialized;
429 			int used;
430 		};
431 			// preparing the request
432 			foreach(i, arg; ParameterTypeTuple!mem) {
433 				code ~= "\t\t\tserialized = serialize(buffer, arg" ~ to!string(i) ~ ");\n";
434 				code ~= "\t\t\tused += serialized.length;\n";
435 				code ~= "\t\t\tbuffer = buffer[serialized.length .. $];\n";
436 			}
437 
438 			code ~= "\t\t\tserialize(bufferBase[0 .. 4], used);\n";
439 			code ~= "\t\t\tserialize(bufferBase[4 .. 8], " ~ to!string(memIdx + 1) ~ ");\n";
440 			code ~= "\t\t\tserialize(bufferBase[8 .. 12], seq);\n";
441 
442 			// FIXME: what if it doesn't all send at once?
443 			code ~= "\t\t\tsocket.send(bufferBase[0 .. 12 + used]);\n";
444 			//code ~= `writeln("sending ", bufferBase[0 .. 12 + used]);`;
445 
446 		if(!useAsync)
447 		code ~= q{
448 			ubyte[4096] dbuffer;
449 			bool open = true;
450 			static if(is(typeof(return) == void)) {
451 
452 			} else
453 				typeof(return) returned;
454 
455 			auto gotNum = socket.receive(dbuffer);
456 			if(gotNum == 0) {
457 				open = false;
458 				throw new Exception("connection closed");
459 			}
460 			while(gotNum < 9) {
461 				auto g2 = socket.receive(dbuffer[gotNum .. $]);
462 				if(g2 == 0) {
463 					open = false;
464 					break;
465 				}
466 				gotNum += g2;
467 			}
468 
469 			auto got = dbuffer[0 .. gotNum];
470 			another:
471 			uint length;
472 			uint success;
473 			got = deserializeInto(got, length);
474 			got = deserializeInto(got, seq);
475 			got = deserializeInto(got, success);
476 			auto more = got[length .. $];
477 
478 			if(got.length >= length) {
479 				if(success) {
480 					/*
481 					auto s = (seq in onSuccesses);
482 					if(s !is null && *s !is null)
483 						(*s)(got);
484 					*/
485 					static if(is(typeof(return) == void)) {
486 					} else {
487 						got = deserializeInto(got, returned);
488 					}
489 				} else {
490 					/*
491 					auto s = (seq in onErrors);
492 					if(s !is null && *s !is null)
493 						(*s)(got);
494 					*/
495 					auto t = new Throwable("");
496 					got = deserializeInto(got, t.msg);
497 					got = deserializeInto(got, t.file);
498 					got = deserializeInto(got, t.line);
499 					throw t;
500 				}
501 			}
502 
503 			if(more.length) {
504 				got = more;
505 				goto another;
506 			}
507 			static if(is(typeof(return) == void)) {
508 
509 			} else
510 				return returned;
511 		};
512 
513 			code ~= "}\n";
514 			code ~= "\n";
515 		}
516 		//code ~= `}`;
517 		return code;
518 	}
519 
520 	//pragma(msg, createClass()); // for debugging help
521 	mixin(createClass());
522 }
523 
524 auto makeNetworkClient(Interface, bool useAsync = true)(string host, ushort port) {
525 	class Thing {
526 		mixin NetworkClient!(Interface, useAsync);
527 	}
528 
529 	return new Thing(host, port);
530 }
531 
532 // the protocol is:
533 /*
534 
535 client connects
536 	ulong interface hash
537 
538 handshake complete
539 
540 messages:
541 
542 	uint messageLength
543 	uint sequence number
544 	ushort function number, 0 is reserved for interface check
545 	serialized arguments....
546 
547 
548 
549 server responds with answers:
550 
551 	uint messageLength
552 	uint re: sequence number
553 	ubyte, 1 == success, 0 == error
554 	serialized return value
555 
556 */