1 /++
2 	mvd stands for Multiple Virtual Dispatch. It lets you
3 	write functions that take any number of arguments of
4 	objects and match based on the dynamic type of each
5 	of them.
6 
7 	---
8 	void foo(Object a, Object b) {} // 1
9 	void foo(MyClass b, Object b) {} // 2
10 	void foo(DerivedClass a, MyClass b) {} // 3
11 
12 	Object a = new MyClass();
13 	Object b = new Object();
14 
15 	mvd!foo(a, b); // will call overload #2
16 	---
17 
18 	The return values must be compatible; [mvd] will return
19 	the least specialized static type of the return values
20 	(most likely the shared base class type of all return types,
21 	or `void` if there isn't one).
22 
23 	All non-class/interface types should be compatible among overloads.
24 	Otherwise you are liable to get compile errors. (Or it might work,
25 	that's up to the compiler's discretion.)
26 +/
27 module arsd.mvd;
28 
29 import std.traits;
30 
31 /// This exists just to make the documentation of [mvd] nicer looking.
32 template CommonReturnOfOverloads(alias fn) {
33 	alias overloads = __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn));
34 	static if (overloads.length == 1) {
35 		alias CommonReturnOfOverloads = ReturnType!(overloads[0]);
36 	}
37 	else {
38 		alias CommonReturnOfOverloads = CommonType!(staticMap!(ReturnType, overloads));
39 	}
40 }
41 
42 /// See details on the [arsd.mvd] page.
43 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
44 	return mvdObj!fn(null, args);
45 }
46 
47 CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) {
48 	typeof(return) delegate() bestMatch;
49 	int bestScore;
50 
51 	string argsStr() {
52 		string s;
53 		foreach(arg; args) {
54 			if(s.length)
55 				s ~= ", ";
56 			static if (is(typeof(arg) == class)) {
57 				if (arg is null) {
58 					s ~= "null " ~ typeof(arg).stringof;
59 				} else {
60 					s ~= typeid(arg).name;
61 				}
62 			} else {
63 				s ~= typeof(arg).stringof;
64 			}
65 		}
66 		return s;
67 	}
68 
69 	ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
70 		Parameters!overload pargs;
71 		int score = 0;
72 		foreach(idx, parg; pargs) {
73 			alias t = typeof(parg);
74 			static if(is(t == interface) || is(t == class)) {
75 				pargs[idx] = cast(typeof(parg)) args[idx];
76 				if(args[idx] !is null && pargs[idx] is null)
77 					continue ov; // failed cast, forget it
78 				else
79 					score += BaseClassesTuple!t.length + 1;
80 			} else
81 				pargs[idx] = args[idx];
82 		}
83 		if(score == bestScore)
84 			throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")");
85 		if(score > bestScore) {
86 			bestMatch = () {
87 				static if(is(typeof(return) == void))
88 					__traits(child, this_, overload)(pargs);
89 				else
90 					return __traits(child, this_, overload)(pargs);
91 			};
92 			bestScore = score;
93 		}
94 	}
95 
96 	if(bestMatch is null)
97 		throw new Exception("no match existed with args (" ~ argsStr ~ ")");
98 
99 	return bestMatch();
100 }
101 
102 ///
103 unittest {
104 
105 	class MyClass {}
106 	class DerivedClass : MyClass {}
107 	class OtherClass {}
108 
109 	static struct Wrapper {
110 		static: // this is just a namespace cuz D doesn't allow overloading inside unittest
111 		int foo(Object a, Object b) { return 1; }
112 		int foo(MyClass a, Object b) { return 2; }
113 		int foo(DerivedClass a, MyClass b) { return 3; }
114 
115 		int bar(MyClass a) { return 4; }
116 	}
117 
118 	with(Wrapper) {
119 		assert(mvd!foo(new Object, new Object) == 1);
120 		assert(mvd!foo(new MyClass, new DerivedClass) == 2);
121 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
122 		assert(mvd!foo(new OtherClass, new OtherClass) == 1);
123 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
124 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
125 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
126 
127 		//mvd!bar(new OtherClass);
128 	}
129 }
130 
131 ///
132 unittest {
133 
134 	class MyClass {}
135 	class DerivedClass : MyClass {}
136 	class OtherClass {}
137 
138 	class Wrapper {
139 		int x;
140 
141 		int foo(Object a, Object b) { return x + 1; }
142 		int foo(MyClass a, Object b) { return x + 2; }
143 		int foo(DerivedClass a, MyClass b) { return x + 3; }
144 
145 		int bar(MyClass a) { return x + 4; }
146 	}
147 
148 	Wrapper wrapper = new Wrapper;
149 	wrapper.x = 20;
150 	assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21);
151 	assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22);
152 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
153 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21);
154 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
155 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
156 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
157 
158 	//mvd!bar(new OtherClass);
159 }
160 
161 ///
162 unittest {
163 	class MyClass {}
164 
165 	static bool success = false;
166 
167 	static struct Wrapper {
168 		static:
169 		void foo(MyClass a) { success = true; }
170 	}
171 
172 	with(Wrapper) {
173 		mvd!foo(new MyClass);
174 		assert(success);
175 	}
176 }