1 /++ 2 mvd stands for Multiple Virtual Dispatch. It lets you 3 write functions that take any number of arguments of 4 objects and match based on the dynamic type of each 5 of them. 6 7 --- 8 void foo(Object a, Object b) {} // 1 9 void foo(MyClass b, Object b) {} // 2 10 void foo(DerivedClass a, MyClass b) {} // 3 11 12 Object a = new MyClass(); 13 Object b = new Object(); 14 15 mvd!foo(a, b); // will call overload #2 16 --- 17 18 The return values must be compatible; [mvd] will return 19 the least specialized static type of the return values 20 (most likely the shared base class type of all return types, 21 or `void` if there isn't one). 22 23 All non-class/interface types should be compatible among overloads. 24 Otherwise you are liable to get compile errors. (Or it might work, 25 that's up to the compiler's discretion.) 26 +/ 27 module arsd.mvd; 28 29 import std.traits; 30 31 /// This exists just to make the documentation of [mvd] nicer looking. 32 template CommonReturnOfOverloads(alias fn) { 33 alias overloads = __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn)); 34 static if (overloads.length == 1) { 35 alias CommonReturnOfOverloads = ReturnType!(overloads[0]); 36 } 37 else { 38 alias CommonReturnOfOverloads = CommonType!(staticMap!(ReturnType, overloads)); 39 } 40 } 41 42 /// See details on the [arsd.mvd] page. 43 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) { 44 return mvdObj!fn(null, args); 45 } 46 47 CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) { 48 typeof(return) delegate() bestMatch; 49 int bestScore; 50 51 string argsStr() { 52 string s; 53 foreach(arg; args) { 54 if(s.length) 55 s ~= ", "; 56 static if (is(typeof(arg) == class)) { 57 if (arg is null) { 58 s ~= "null " ~ typeof(arg).stringof; 59 } else { 60 s ~= typeid(arg).name; 61 } 62 } else { 63 s ~= typeof(arg).stringof; 64 } 65 } 66 return s; 67 } 68 69 ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) { 70 Parameters!overload pargs; 71 int score = 0; 72 foreach(idx, parg; pargs) { 73 alias t = typeof(parg); 74 static if(is(t == interface) || is(t == class)) { 75 t value = cast(t) args[idx]; 76 // HACK: cast to Object* so we can set the value even if it's an immutable class 77 *cast(Object*) &pargs[idx] = cast(Object) value; 78 if(args[idx] !is null && pargs[idx] is null) 79 continue ov; // failed cast, forget it 80 else 81 score += BaseClassesTuple!t.length + 1; 82 } else 83 pargs[idx] = args[idx]; 84 } 85 if(score == bestScore) 86 throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")"); 87 if(score > bestScore) { 88 bestMatch = () { 89 static if(is(typeof(return) == void)) 90 __traits(child, this_, overload)(pargs); 91 else 92 return __traits(child, this_, overload)(pargs); 93 }; 94 bestScore = score; 95 } 96 } 97 98 if(bestMatch is null) 99 throw new Exception("no match existed with args (" ~ argsStr ~ ")"); 100 101 return bestMatch(); 102 } 103 104 /// 105 unittest { 106 107 class MyClass {} 108 class DerivedClass : MyClass {} 109 class OtherClass {} 110 111 static struct Wrapper { 112 static: // this is just a namespace cuz D doesn't allow overloading inside unittest 113 int foo(Object a, Object b) { return 1; } 114 int foo(MyClass a, Object b) { return 2; } 115 int foo(DerivedClass a, MyClass b) { return 3; } 116 117 int bar(MyClass a) { return 4; } 118 } 119 120 with(Wrapper) { 121 assert(mvd!foo(new Object, new Object) == 1); 122 assert(mvd!foo(new MyClass, new DerivedClass) == 2); 123 assert(mvd!foo(new DerivedClass, new DerivedClass) == 3); 124 assert(mvd!foo(new OtherClass, new OtherClass) == 1); 125 assert(mvd!foo(new OtherClass, new MyClass) == 1); 126 assert(mvd!foo(new DerivedClass, new DerivedClass) == 3); 127 assert(mvd!foo(new OtherClass, new MyClass) == 1); 128 129 //mvd!bar(new OtherClass); 130 } 131 } 132 133 /// 134 unittest { 135 136 class MyClass {} 137 class DerivedClass : MyClass {} 138 class OtherClass {} 139 140 class Wrapper { 141 int x; 142 143 int foo(Object a, Object b) { return x + 1; } 144 int foo(MyClass a, Object b) { return x + 2; } 145 int foo(DerivedClass a, MyClass b) { return x + 3; } 146 147 int bar(MyClass a) { return x + 4; } 148 } 149 150 Wrapper wrapper = new Wrapper; 151 wrapper.x = 20; 152 assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21); 153 assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22); 154 assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23); 155 assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21); 156 assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21); 157 assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23); 158 assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21); 159 160 //mvd!bar(new OtherClass); 161 } 162 163 /// 164 unittest { 165 class MyClass {} 166 167 static bool success = false; 168 169 static struct Wrapper { 170 static: 171 void foo(MyClass a) { success = true; } 172 } 173 174 with(Wrapper) { 175 mvd!foo(new MyClass); 176 assert(success); 177 } 178 } 179 180 /// 181 unittest { 182 immutable class Foo {} 183 184 immutable class Bar : Foo { 185 int x; 186 187 this(int x) { 188 this.x = x; 189 } 190 } 191 192 immutable class Baz : Foo { 193 int x, y; 194 195 this(int x, int y) { 196 this.x = x; 197 this.y = y; 198 } 199 } 200 201 static struct Wrapper { 202 static: 203 204 int foo(Bar b) { return b.x; } 205 int foo(Baz b) { return b.x + b.y; } 206 } 207 208 with(Wrapper) { 209 Foo x = new Bar(3); 210 Foo y = new Baz(5, 7); 211 assert(mvd!foo(x) == 3); 212 assert(mvd!foo(y) == 12); 213 } 214 }