1 /++
2 	mvd stands for Multiple Virtual Dispatch. It lets you
3 	write functions that take any number of arguments of
4 	objects and match based on the dynamic type of each
5 	of them.
6 
7 	---
8 	void foo(Object a, Object b) {} // 1
9 	void foo(MyClass b, Object b) {} // 2
10 	void foo(DerivedClass a, MyClass b) {} // 3
11 
12 	Object a = new MyClass();
13 	Object b = new Object();
14 
15 	mvd!foo(a, b); // will call overload #2
16 	---
17 
18 	The return values must be compatible; [mvd] will return
19 	the least specialized static type of the return values
20 	(most likely the shared base class type of all return types,
21 	or `void` if there isn't one).
22 
23 	All non-class/interface types should be compatible among overloads.
24 	Otherwise you are liable to get compile errors. (Or it might work,
25 	that's up to the compiler's discretion.)
26 +/
27 module arsd.mvd;
28 
29 import std.traits;
30 
31 /// This exists just to make the documentation of [mvd] nicer looking.
32 template CommonReturnOfOverloads(alias fn) {
33 	alias overloads = __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn));
34 	static if (overloads.length == 1) {
35 		alias CommonReturnOfOverloads = ReturnType!(overloads[0]);
36 	}
37 	else {
38 		alias CommonReturnOfOverloads = CommonType!(staticMap!(ReturnType, overloads));
39 	}
40 }
41 
42 /// See details on the [arsd.mvd] page.
43 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
44 	return mvdObj!fn(null, args);
45 }
46 
47 CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) {
48 	typeof(return) delegate() bestMatch;
49 	int bestScore;
50 
51 	string argsStr() {
52 		string s;
53 		foreach(arg; args) {
54 			if(s.length)
55 				s ~= ", ";
56 			static if (is(typeof(arg) == class)) {
57 				if (arg is null) {
58 					s ~= "null " ~ typeof(arg).stringof;
59 				} else {
60 					s ~= typeid(arg).name;
61 				}
62 			} else {
63 				s ~= typeof(arg).stringof;
64 			}
65 		}
66 		return s;
67 	}
68 
69 	ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
70 		Parameters!overload pargs;
71 		int score = 0;
72 		foreach(idx, parg; pargs) {
73 			alias t = typeof(parg);
74 			static if(is(t == interface) || is(t == class)) {
75 				t value = cast(t) args[idx];
76 				// HACK: cast to Object* so we can set the value even if it's an immutable class
77 				*cast(Object*) &pargs[idx] = cast(Object) value;
78 				if(args[idx] !is null && pargs[idx] is null)
79 					continue ov; // failed cast, forget it
80 				else
81 					score += BaseClassesTuple!t.length + 1;
82 			} else
83 				pargs[idx] = args[idx];
84 		}
85 		if(score == bestScore)
86 			throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")");
87 		if(score > bestScore) {
88 			bestMatch = () {
89 				static if(is(typeof(return) == void))
90 					__traits(child, this_, overload)(pargs);
91 				else
92 					return __traits(child, this_, overload)(pargs);
93 			};
94 			bestScore = score;
95 		}
96 	}
97 
98 	if(bestMatch is null)
99 		throw new Exception("no match existed with args (" ~ argsStr ~ ")");
100 
101 	return bestMatch();
102 }
103 
104 ///
105 unittest {
106 
107 	class MyClass {}
108 	class DerivedClass : MyClass {}
109 	class OtherClass {}
110 
111 	static struct Wrapper {
112 		static: // this is just a namespace cuz D doesn't allow overloading inside unittest
113 		int foo(Object a, Object b) { return 1; }
114 		int foo(MyClass a, Object b) { return 2; }
115 		int foo(DerivedClass a, MyClass b) { return 3; }
116 
117 		int bar(MyClass a) { return 4; }
118 	}
119 
120 	with(Wrapper) {
121 		assert(mvd!foo(new Object, new Object) == 1);
122 		assert(mvd!foo(new MyClass, new DerivedClass) == 2);
123 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
124 		assert(mvd!foo(new OtherClass, new OtherClass) == 1);
125 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
126 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
127 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
128 
129 		//mvd!bar(new OtherClass);
130 	}
131 }
132 
133 ///
134 unittest {
135 
136 	class MyClass {}
137 	class DerivedClass : MyClass {}
138 	class OtherClass {}
139 
140 	class Wrapper {
141 		int x;
142 
143 		int foo(Object a, Object b) { return x + 1; }
144 		int foo(MyClass a, Object b) { return x + 2; }
145 		int foo(DerivedClass a, MyClass b) { return x + 3; }
146 
147 		int bar(MyClass a) { return x + 4; }
148 	}
149 
150 	Wrapper wrapper = new Wrapper;
151 	wrapper.x = 20;
152 	assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21);
153 	assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22);
154 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
155 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21);
156 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
157 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
158 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
159 
160 	//mvd!bar(new OtherClass);
161 }
162 
163 ///
164 unittest {
165 	class MyClass {}
166 
167 	static bool success = false;
168 
169 	static struct Wrapper {
170 		static:
171 		void foo(MyClass a) { success = true; }
172 	}
173 
174 	with(Wrapper) {
175 		mvd!foo(new MyClass);
176 		assert(success);
177 	}
178 }
179 
180 ///
181 unittest {
182 	immutable class Foo {}
183 
184 	immutable class Bar : Foo {
185 		int x;
186 
187 		this(int x) {
188 			this.x = x;
189 		}
190 	}
191 
192 	immutable class Baz : Foo {
193 		int x, y;
194 
195 		this(int x, int y) {
196 			this.x = x;
197 			this.y = y;
198 		}
199 	}
200 
201 	static struct Wrapper {
202 		static:
203 
204 		int foo(Bar b) { return b.x; }
205 		int foo(Baz b) { return b.x + b.y; }
206 	}
207 
208 	with(Wrapper) {
209 		Foo x = new Bar(3);
210 		Foo y = new Baz(5, 7);
211 		assert(mvd!foo(x) == 3);
212 		assert(mvd!foo(y) == 12);
213 	}
214 }