1 /++ 2 mvd stands for Multiple Virtual Dispatch. It lets you 3 write functions that take any number of arguments of 4 objects and match based on the dynamic type of each 5 of them. 6 7 --- 8 void foo(Object a, Object b) {} // 1 9 void foo(MyClass b, Object b) {} // 2 10 void foo(DerivedClass a, MyClass b) {} // 3 11 12 Object a = new MyClass(); 13 Object b = new Object(); 14 15 mvd!foo(a, b); // will call overload #2 16 --- 17 18 The return values must be compatible; [mvd] will return 19 the least specialized static type of the return values 20 (most likely the shared base class type of all return types, 21 or `void` if there isn't one). 22 23 All non-class/interface types should be compatible among overloads. 24 Otherwise you are liable to get compile errors. (Or it might work, 25 that's up to the compiler's discretion.) 26 +/ 27 module arsd.mvd; 28 29 import std.traits; 30 31 /// This exists just to make the documentation of [mvd] nicer looking. 32 alias CommonReturnOfOverloads(alias fn) = CommonType!(staticMap!(ReturnType, __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn)))); 33 34 /// See details on the [arsd.mvd] page. 35 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) { 36 return mvdObj!fn(null, args); 37 } 38 39 CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) { 40 typeof(return) delegate() bestMatch; 41 int bestScore; 42 43 string argsStr() { 44 string s; 45 foreach(arg; args) { 46 if(s.length) 47 s ~= ", "; 48 static if (is(typeof(arg) == class)) { 49 if (arg is null) { 50 s ~= "null " ~ typeof(arg).stringof; 51 } else { 52 s ~= typeid(arg).name; 53 } 54 } else { 55 s ~= typeof(arg).stringof; 56 } 57 } 58 return s; 59 } 60 61 ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) { 62 Parameters!overload pargs; 63 int score = 0; 64 foreach(idx, parg; pargs) { 65 alias t = typeof(parg); 66 static if(is(t == interface) || is(t == class)) { 67 pargs[idx] = cast(typeof(parg)) args[idx]; 68 if(args[idx] !is null && pargs[idx] is null) 69 continue ov; // failed cast, forget it 70 else 71 score += BaseClassesTuple!t.length + 1; 72 } else 73 pargs[idx] = args[idx]; 74 } 75 if(score == bestScore) 76 throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")"); 77 if(score > bestScore) { 78 bestMatch = () { 79 static if(is(typeof(return) == void)) 80 __traits(child, this_, overload)(pargs); 81 else 82 return __traits(child, this_, overload)(pargs); 83 }; 84 bestScore = score; 85 } 86 } 87 88 if(bestMatch is null) 89 throw new Exception("no match existed with args (" ~ argsStr ~ ")"); 90 91 return bestMatch(); 92 } 93 94 /// 95 unittest { 96 97 class MyClass {} 98 class DerivedClass : MyClass {} 99 class OtherClass {} 100 101 static struct Wrapper { 102 static: // this is just a namespace cuz D doesn't allow overloading inside unittest 103 int foo(Object a, Object b) { return 1; } 104 int foo(MyClass a, Object b) { return 2; } 105 int foo(DerivedClass a, MyClass b) { return 3; } 106 107 int bar(MyClass a) { return 4; } 108 } 109 110 with(Wrapper) { 111 assert(mvd!foo(new Object, new Object) == 1); 112 assert(mvd!foo(new MyClass, new DerivedClass) == 2); 113 assert(mvd!foo(new DerivedClass, new DerivedClass) == 3); 114 assert(mvd!foo(new OtherClass, new OtherClass) == 1); 115 assert(mvd!foo(new OtherClass, new MyClass) == 1); 116 assert(mvd!foo(new DerivedClass, new DerivedClass) == 3); 117 assert(mvd!foo(new OtherClass, new MyClass) == 1); 118 119 //mvd!bar(new OtherClass); 120 } 121 } 122 123 /// 124 unittest { 125 126 class MyClass {} 127 class DerivedClass : MyClass {} 128 class OtherClass {} 129 130 class Wrapper { 131 int x; 132 133 int foo(Object a, Object b) { return x + 1; } 134 int foo(MyClass a, Object b) { return x + 2; } 135 int foo(DerivedClass a, MyClass b) { return x + 3; } 136 137 int bar(MyClass a) { return x + 4; } 138 } 139 140 Wrapper wrapper = new Wrapper; 141 wrapper.x = 20; 142 assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21); 143 assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22); 144 assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23); 145 assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21); 146 assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21); 147 assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23); 148 assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21); 149 150 //mvd!bar(new OtherClass); 151 }