1 /++
2 	mvd stands for Multiple Virtual Dispatch. It lets you
3 	write functions that take any number of arguments of
4 	objects and match based on the dynamic type of each
5 	of them.
6 
7 	---
8 	void foo(Object a, Object b) {} // 1
9 	void foo(MyClass b, Object b) {} // 2
10 	void foo(DerivedClass a, MyClass b) {} // 3
11 
12 	Object a = new MyClass();
13 	Object b = new Object();
14 
15 	mvd!foo(a, b); // will call overload #2
16 	---
17 
18 	The return values must be compatible; [mvd] will return
19 	the least specialized static type of the return values
20 	(most likely the shared base class type of all return types,
21 	or `void` if there isn't one).
22 
23 	All non-class/interface types should be compatible among overloads.
24 	Otherwise you are liable to get compile errors. (Or it might work,
25 	that's up to the compiler's discretion.)
26 +/
27 module arsd.mvd;
28 
29 import std.traits;
30 
31 /// This exists just to make the documentation of [mvd] nicer looking.
32 alias CommonReturnOfOverloads(alias fn) = CommonType!(staticMap!(ReturnType, __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))));
33 
34 /// See details on the [arsd.mvd] page.
35 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
36 	return mvdObj!fn(null, args);
37 }
38 
39 CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) {
40 	typeof(return) delegate() bestMatch;
41 	int bestScore;
42 
43 	string argsStr() {
44 		string s;
45 		foreach(arg; args) {
46 			if(s.length)
47 				s ~= ", ";
48 			static if (is(typeof(arg) == class)) {
49 				if (arg is null) {
50 					s ~= "null " ~ typeof(arg).stringof;
51 				} else {
52 					s ~= typeid(arg).name;
53 				}
54 			} else {
55 				s ~= typeof(arg).stringof;
56 			}
57 		}
58 		return s;
59 	}
60 
61 	ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
62 		Parameters!overload pargs;
63 		int score = 0;
64 		foreach(idx, parg; pargs) {
65 			alias t = typeof(parg);
66 			static if(is(t == interface) || is(t == class)) {
67 				pargs[idx] = cast(typeof(parg)) args[idx];
68 				if(args[idx] !is null && pargs[idx] is null)
69 					continue ov; // failed cast, forget it
70 				else
71 					score += BaseClassesTuple!t.length + 1;
72 			} else
73 				pargs[idx] = args[idx];
74 		}
75 		if(score == bestScore)
76 			throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")");
77 		if(score > bestScore) {
78 			bestMatch = () {
79 				static if(is(typeof(return) == void))
80 					__traits(child, this_, overload)(pargs);
81 				else
82 					return __traits(child, this_, overload)(pargs);
83 			};
84 			bestScore = score;
85 		}
86 	}
87 
88 	if(bestMatch is null)
89 		throw new Exception("no match existed with args (" ~ argsStr ~ ")");
90 
91 	return bestMatch();
92 }
93 
94 ///
95 unittest {
96 
97 	class MyClass {}
98 	class DerivedClass : MyClass {}
99 	class OtherClass {}
100 
101 	static struct Wrapper {
102 		static: // this is just a namespace cuz D doesn't allow overloading inside unittest
103 		int foo(Object a, Object b) { return 1; }
104 		int foo(MyClass a, Object b) { return 2; }
105 		int foo(DerivedClass a, MyClass b) { return 3; }
106 
107 		int bar(MyClass a) { return 4; }
108 	}
109 
110 	with(Wrapper) {
111 		assert(mvd!foo(new Object, new Object) == 1);
112 		assert(mvd!foo(new MyClass, new DerivedClass) == 2);
113 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
114 		assert(mvd!foo(new OtherClass, new OtherClass) == 1);
115 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
116 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
117 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
118 
119 		//mvd!bar(new OtherClass);
120 	}
121 }
122 
123 ///
124 unittest {
125 
126 	class MyClass {}
127 	class DerivedClass : MyClass {}
128 	class OtherClass {}
129 
130 	class Wrapper {
131 		int x;
132 
133 		int foo(Object a, Object b) { return x + 1; }
134 		int foo(MyClass a, Object b) { return x + 2; }
135 		int foo(DerivedClass a, MyClass b) { return x + 3; }
136 
137 		int bar(MyClass a) { return x + 4; }
138 	}
139 
140 	Wrapper wrapper = new Wrapper;
141 	wrapper.x = 20;
142 	assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21);
143 	assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22);
144 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
145 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21);
146 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
147 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
148 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
149 
150 	//mvd!bar(new OtherClass);
151 }