1 /++
2 	mvd stands for Multiple Virtual Dispatch. It lets you
3 	write functions that take any number of arguments of
4 	objects and match based on the dynamic type of each
5 	of them.
6 
7 	---
8 	void foo(Object a, Object b) {} // 1
9 	void foo(MyClass b, Object b) {} // 2
10 	void foo(DerivedClass a, MyClass b) {} // 3
11 
12 	Object a = new MyClass();
13 	Object b = new Object();
14 
15 	mvd!foo(a, b); // will call overload #2
16 	---
17 
18 	The return values must be compatible; [mvd] will return
19 	the least specialized static type of the return values
20 	(most likely the shared base class type of all return types,
21 	or `void` if there isn't one).
22 
23 	All non-class/interface types should be compatible among overloads.
24 	Otherwise you are liable to get compile errors. (Or it might work,
25 	that's up to the compiler's discretion.)
26 
27 	History:
28 		As of January 1, 2026, it will statically reject overloads that
29 		are impossible to match with the arguments, instead of only throwing
30 		at runtime. It may still throw at runtime if objects, in theory, could
31 		match but none actually do.
32 +/
33 module arsd.mvd;
34 
35 import std.traits;
36 
37 /// This exists just to make the documentation of [mvd] nicer looking.
38 template CommonReturnOfOverloads(alias fn) {
39 	alias overloads = __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn));
40 	static if (overloads.length == 1) {
41 		alias CommonReturnOfOverloads = ReturnType!(overloads[0]);
42 	}
43 	else {
44 		alias CommonReturnOfOverloads = CommonType!(staticMap!(ReturnType, overloads));
45 	}
46 }
47 
48 /// See details on the [arsd.mvd] page.
49 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
50 	return mvdObj!fn(null, args);
51 }
52 
53 private bool isSameOrBaseOf(LookingFor, LookingAt)() {
54 	static if(is(LookingFor == LookingAt))
55 		return true;
56 	else static if(is(LookingAt Bases == super)) {
57 		foreach(base; Bases)
58 			if(isSameOrBaseOf!(LookingFor, base))
59 				return true;
60 	}
61 	return false;
62 }
63 
64 private bool canMaybeEverCastTo(PassedArg, CastingTo)() {
65 	// same type is fine, implicit cast to parent is fine
66 	if(isSameOrBaseOf!(PassedArg, CastingTo))
67 		return true;
68 	// or if a dynamic cast might succeed, that's also ok
69 	if(isSameOrBaseOf!(CastingTo, PassedArg))
70 		return false;
71 
72 	// otherwise the cast operator won't work
73 	return false;
74 }
75 
76 CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) {
77 	typeof(return) delegate() bestMatch;
78 	int bestScore;
79 
80 	string argsStr() {
81 		string s;
82 		foreach(arg; args) {
83 			if(s.length)
84 				s ~= ", ";
85 			static if (is(typeof(arg) == class)) {
86 				if (arg is null) {
87 					s ~= "null " ~ typeof(arg).stringof;
88 				} else {
89 					s ~= typeid(arg).name;
90 				}
91 			} else {
92 				s ~= typeof(arg).stringof;
93 			}
94 		}
95 		return s;
96 	}
97 
98 	ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
99 		Parameters!overload pargs;
100 		int score = 0;
101 		foreach(idx, parg; pargs) {
102 			alias t = typeof(parg);
103 
104 			static assert(canMaybeEverCastTo!(T[idx], t), "No common base class between types needed " ~ t.stringof ~ " and passed " ~ T[idx].stringof ~ "; no overload could possibly match.");
105 
106 			static if(is(t == interface) || is(t == class)) {
107 				t value = cast(t) args[idx];
108 				// HACK: cast to Object* so we can set the value even if it's an immutable class
109 				*cast(Object*) &pargs[idx] = cast(Object) value;
110 				if(args[idx] !is null && pargs[idx] is null)
111 					continue ov; // failed cast, forget it
112 				else
113 					score += BaseClassesTuple!t.length + 1;
114 			} else
115 				pargs[idx] = args[idx];
116 		}
117 		if(score == bestScore)
118 			throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")");
119 		if(score > bestScore) {
120 			bestMatch = () {
121 				static if(is(typeof(return) == void))
122 					__traits(child, this_, overload)(pargs);
123 				else
124 					return __traits(child, this_, overload)(pargs);
125 			};
126 			bestScore = score;
127 		}
128 	}
129 
130 	if(bestMatch is null)
131 		throw new Exception("no match existed with args (" ~ argsStr ~ ")");
132 
133 	return bestMatch();
134 }
135 
136 ///
137 unittest {
138 
139 	class MyClass {}
140 	class DerivedClass : MyClass {}
141 	class OtherClass {}
142 
143 	static struct Wrapper {
144 		static: // this is just a namespace cuz D doesn't allow overloading inside unittest
145 		int foo(Object a, Object b) { return 1; }
146 		int foo(MyClass a, Object b) { return 2; }
147 		int foo(DerivedClass a, MyClass b) { return 3; }
148 
149 		int bar(MyClass a) { return 4; }
150 	}
151 
152 	with(Wrapper) {
153 		assert(mvd!foo(new Object, new Object) == 1);
154 		assert(mvd!foo(new MyClass, new DerivedClass) == 2);
155 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
156 		assert(mvd!foo(new OtherClass, new OtherClass) == 1);
157 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
158 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
159 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
160 
161 		//mvd!bar(new OtherClass);
162 	}
163 }
164 
165 ///
166 unittest {
167 
168 	class MyClass {}
169 	class DerivedClass : MyClass {}
170 	class OtherClass {}
171 
172 	class Wrapper {
173 		int x;
174 
175 		int foo(Object a, Object b) { return x + 1; }
176 		int foo(MyClass a, Object b) { return x + 2; }
177 		int foo(DerivedClass a, MyClass b) { return x + 3; }
178 
179 		int bar(MyClass a) { return x + 4; }
180 	}
181 
182 	Wrapper wrapper = new Wrapper;
183 	wrapper.x = 20;
184 	assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21);
185 	assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22);
186 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
187 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21);
188 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
189 	assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
190 	assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
191 
192 	//mvd!bar(new OtherClass);
193 }
194 
195 ///
196 unittest {
197 	class MyClass {}
198 
199 	static bool success = false;
200 
201 	static struct Wrapper {
202 		static:
203 		void foo(MyClass a) { success = true; }
204 	}
205 
206 	with(Wrapper) {
207 		mvd!foo(new MyClass);
208 		assert(success);
209 	}
210 }
211 
212 ///
213 unittest {
214 	immutable class Foo {}
215 
216 	immutable class Bar : Foo {
217 		int x;
218 
219 		this(int x) {
220 			this.x = x;
221 		}
222 	}
223 
224 	immutable class Baz : Foo {
225 		int x, y;
226 
227 		this(int x, int y) {
228 			this.x = x;
229 			this.y = y;
230 		}
231 	}
232 
233 	static struct Wrapper {
234 		static:
235 
236 		int foo(Bar b) { return b.x; }
237 		int foo(Baz b) { return b.x + b.y; }
238 	}
239 
240 	with(Wrapper) {
241 		Foo x = new Bar(3);
242 		Foo y = new Baz(5, 7);
243 		assert(mvd!foo(x) == 3);
244 		assert(mvd!foo(y) == 12);
245 	}
246 }