1 /++
2 	mvd stands for Multiple Virtual Dispatch. It lets you
3 	write functions that take any number of arguments of
4 	objects and match based on the dynamic type of each
5 	of them.
6 
7 	---
8 	void foo(Object a, Object b) {} // 1
9 	void foo(MyClass b, Object b) {} // 2
10 	void foo(DerivedClass a, MyClass b) {} // 3
11 
12 	Object a = new MyClass();
13 	Object b = new Object();
14 
15 	mvd!foo(a, b); // will call overload #2
16 	---
17 
18 	The return values must be compatible; [mvd] will return
19 	the least specialized static type of the return values
20 	(most likely the shared base class type of all return types,
21 	or `void` if there isn't one).
22 
23 	All non-class/interface types should be compatible among overloads.
24 	Otherwise you are liable to get compile errors. (Or it might work,
25 	that's up to the compiler's discretion.)
26 +/
27 module arsd.mvd;
28 
29 import std.traits;
30 
31 /// This exists just to make the documentation of [mvd] nicer looking.
32 alias CommonReturnOfOverloads(alias fn) = CommonType!(staticMap!(ReturnType, __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))));
33 
34 /// See details on the [arsd.mvd] page.
35 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
36 	typeof(return) delegate() bestMatch;
37 	int bestScore;
38 
39 	ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
40 		Parameters!overload pargs;
41 		int score = 0;
42 		foreach(idx, parg; pargs) {
43 			alias t = typeof(parg);
44 			static if(is(t == interface) || is(t == class)) {
45 				pargs[idx] = cast(typeof(parg)) args[idx];
46 				if(args[idx] !is null && pargs[idx] is null)
47 					continue ov; // failed cast, forget it
48 				else
49 					score += BaseClassesTuple!t.length + 1;
50 			} else
51 				pargs[idx] = args[idx];
52 		}
53 		if(score == bestScore)
54 			throw new Exception("ambiguous overload selection with args"); // FIXME: show the things
55 		if(score > bestScore) {
56 			bestMatch = () {
57 				static if(is(typeof(return) == void))
58 					overload(pargs);
59 				else
60 					return overload(pargs);
61 			};
62 			bestScore = score;
63 		}
64 	}
65 
66 	if(bestMatch is null)
67 		throw new Exception("no match existed");
68 
69 	return bestMatch();
70 }
71 
72 ///
73 unittest {
74 
75 	class MyClass {}
76 	class DerivedClass : MyClass {}
77 	class OtherClass {}
78 
79 	static struct Wrapper {
80 		static: // this is just a namespace cuz D doesn't allow overloading inside unittest
81 		int foo(Object a, Object b) { return 1; }
82 		int foo(MyClass a, Object b) { return 2; }
83 		int foo(DerivedClass a, MyClass b) { return 3; }
84 	}
85 
86 	with(Wrapper) {
87 		assert(mvd!foo(new Object, new Object) == 1);
88 		assert(mvd!foo(new MyClass, new DerivedClass) == 2);
89 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
90 		assert(mvd!foo(new OtherClass, new OtherClass) == 1);
91 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
92 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
93 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
94 	}
95 }