1 /++ 2 Helper functions for generating database stuff. 3 4 Note: this is heavily biased toward Postgres 5 +/ 6 module arsd.database_generation; 7 8 /* 9 10 FIXME: support partial indexes and maybe "using" 11 FIXME: support views 12 13 Let's put indexes in there too and make index functions be the preferred way of doing a query 14 by making them convenient af. 15 */ 16 17 private enum UDA; 18 19 @UDA struct PrimaryKey { 20 string sql; 21 } 22 23 @UDA struct Default { 24 string sql; 25 } 26 27 @UDA struct Unique { } 28 29 @UDA struct ForeignKey(alias toWhat, string behavior) { 30 alias ReferencedTable = __traits(parent, toWhat); 31 } 32 33 enum CASCADE = "ON UPDATE CASCADE ON DELETE CASCADE"; 34 enum NULLIFY = "ON UPDATE CASCADE ON DELETE SET NULL"; 35 enum RESTRICT = "ON UPDATE CASCADE ON DELETE RESTRICT"; 36 37 @UDA struct DBName { string name; } 38 39 struct Nullable(T) { 40 bool isNull = true; 41 T value; 42 43 void opAssign(typeof(null)) { 44 isNull = true; 45 } 46 47 void opAssign(T v) { 48 isNull = false; 49 value = v; 50 } 51 52 T toArsdJsvar() { return value; } 53 } 54 55 struct Timestamp { 56 string value; 57 string toArsdJsvar() { return value; } 58 59 // FIXME: timezone 60 static Timestamp fromStrings(string date, string time) { 61 if(time.length < 6) 62 time ~= ":00"; 63 import std.datetime; 64 return Timestamp(SysTime.fromISOExtString(date ~ "T" ~ time).toISOExtString()); 65 } 66 } 67 68 SysTime parseDbTimestamp(Timestamp s) { 69 return parseDbTimestamp(s.value); 70 } 71 72 SysTime parseDbTimestamp(string s) { 73 if(s.length == 0) return SysTime.init; 74 auto date = s[0 .. 10]; 75 auto time = s[11 .. 20]; 76 auto tz = s[20 .. $]; 77 return SysTime.fromISOExtString(date ~ "T" ~ time ~ tz); 78 } 79 80 struct Constraint(string sql) {} 81 82 struct Index(Fields...) {} 83 struct UniqueIndex(Fields...) {} 84 85 struct Serial { 86 int value; 87 int toArsdJsvar() { return value; } 88 int getValue() { return value; } 89 alias getValue this; 90 } 91 92 93 string generateCreateTableFor(alias O)() { 94 enum tableName = toTableName(O.stringof); 95 string sql = "CREATE TABLE " ~ tableName ~ " ("; 96 string postSql; 97 bool outputtedPostSql = false; 98 99 string afterTableSql; 100 101 void addAfterTableSql(string s) { 102 afterTableSql ~= s; 103 afterTableSql ~= "\n"; 104 } 105 106 void addPostSql(string s) { 107 if(outputtedPostSql) { 108 postSql ~= ","; 109 } 110 postSql ~= "\n"; 111 postSql ~= "\t" ~ s; 112 outputtedPostSql = true; 113 } 114 115 bool outputted = false; 116 static foreach(memberName; __traits(allMembers, O)) {{ 117 alias member = __traits(getMember, O, memberName); 118 static if(is(typeof(member) == Constraint!constraintSql, string constraintSql)) { 119 version(dbgenerate_sqlite) {} else { // FIXME: make it work here too, it is the specifics of the constraint strings 120 if(outputted) { 121 sql ~= ","; 122 } 123 sql ~= "\n"; 124 sql ~= "\tCONSTRAINT " ~ memberName; 125 sql ~= " "; 126 sql ~= constraintSql; 127 outputted = true; 128 } 129 } else static if(is(typeof(member) == Index!Fields, Fields...)) { 130 string fields = ""; 131 static foreach(field; Fields) { 132 if(fields.length) 133 fields ~= ", "; 134 fields ~= __traits(identifier, field); 135 } 136 addAfterTableSql("CREATE INDEX " ~ tableName ~ "_" ~ memberName ~ " ON " ~ tableName ~ "("~fields~")"); 137 } else static if(is(typeof(member) == UniqueIndex!Fields, Fields...)) { 138 string fields = ""; 139 static foreach(field; Fields) { 140 if(fields.length) 141 fields ~= ", "; 142 fields ~= __traits(identifier, field); 143 } 144 addAfterTableSql("CREATE UNIQUE INDEX " ~ tableName ~ "_" ~ memberName ~ " ON " ~ tableName ~ "("~fields~")"); 145 } else static if(is(typeof(member) T)) { 146 if(outputted) { 147 sql ~= ","; 148 } 149 sql ~= "\n"; 150 sql ~= "\t" ~ memberName; 151 152 static if(is(T == Nullable!P, P)) { 153 static if(is(P == int)) 154 sql ~= " INTEGER NULL"; 155 else static if(is(P == string)) 156 sql ~= " TEXT NULL"; 157 else static if(is(P == double)) 158 sql ~= " FLOAT NULL"; 159 else static if(is(P == Timestamp)) 160 sql ~= " TIMESTAMPTZ NULL"; 161 else static assert(0, P.stringof); 162 } else static if(is(T == int)) 163 sql ~= " INTEGER NOT NULL"; 164 else static if(is(T == Serial)) { 165 version(dbgenerate_sqlite) 166 sql ~= " INTEGER PRIMARY KEY AUTOINCREMENT"; 167 else 168 sql ~= " SERIAL"; // FIXME postgresism 169 } else static if(is(T == string)) 170 sql ~= " TEXT NOT NULL"; 171 else static if(is(T == double)) 172 sql ~= " FLOAT NOT NULL"; 173 else static if(is(T == bool)) 174 sql ~= " BOOLEAN NOT NULL"; 175 else static if(is(T == Timestamp)) { 176 version(dbgenerate_sqlite) 177 sql ~= " TEXT NOT NULL"; 178 else 179 sql ~= " TIMESTAMPTZ NOT NULL"; // FIXME: postgresism 180 } else static if(is(T == enum)) 181 sql ~= " INTEGER NOT NULL"; // potentially crap but meh 182 183 static foreach(attr; __traits(getAttributes, member)) { 184 static if(is(typeof(attr) == Default)) { 185 // FIXME: postgresism there, try current_timestamp in sqlite 186 version(dbgenerate_sqlite) { 187 import std.string; 188 sql ~= " DEFAULT " ~ std..string.replace(attr.sql, "now()", "current_timestamp"); 189 } else 190 sql ~= " DEFAULT " ~ attr.sql; 191 } else static if(is(attr == Unique)) { 192 sql ~= " UNIQUE"; 193 } else static if(is(attr == PrimaryKey)) { 194 version(dbgenerate_sqlite) { 195 static if(is(T == Serial)) {} // skip, it is done above 196 else 197 addPostSql("PRIMARY KEY(" ~ memberName ~ ")"); 198 } else 199 addPostSql("PRIMARY KEY(" ~ memberName ~ ")"); 200 } else static if(is(attr == ForeignKey!(to, sqlPolicy), alias to, string sqlPolicy)) { 201 string refTable = toTableName(__traits(parent, to).stringof); 202 string refField = to.stringof; 203 addPostSql("FOREIGN KEY(" ~ memberName ~ ") REFERENCES "~refTable~"("~refField~(sqlPolicy.length ? ") " : ")") ~ sqlPolicy); 204 } 205 } 206 207 outputted = true; 208 } 209 }} 210 211 if(postSql.length && outputted) 212 sql ~= ",\n"; 213 214 sql ~= postSql; 215 sql ~= "\n);\n"; 216 sql ~= afterTableSql; 217 218 return sql; 219 } 220 221 string toTableName(string t) { 222 return plural(50, beautify(t, '_', true)); 223 } 224 225 // copy/pasted from english.d 226 private string plural(int count, string word, string pluralWord = null) { 227 if(count == 1 || word.length == 0) 228 return word; // it isn't actually plural 229 230 if(pluralWord !is null) 231 return pluralWord; 232 233 switch(word[$ - 1]) { 234 case 's': 235 return word ~ "es"; 236 case 'f': 237 return word[0 .. $-1] ~ "ves"; 238 case 'y': 239 return word[0 .. $-1] ~ "ies"; 240 case 'a', 'e', 'i', 'o', 'u': 241 default: 242 return word ~ "s"; 243 } 244 } 245 246 // copy/pasted from cgi 247 private string beautify(string name, char space = ' ', bool allLowerCase = false) { 248 if(name == "id") 249 return allLowerCase ? name : "ID"; 250 251 char[160] buffer; 252 int bufferIndex = 0; 253 bool shouldCap = true; 254 bool shouldSpace; 255 bool lastWasCap; 256 foreach(idx, char ch; name) { 257 if(bufferIndex == buffer.length) return name; // out of space, just give up, not that important 258 259 if((ch >= 'A' && ch <= 'Z') || ch == '_') { 260 if(lastWasCap) { 261 // two caps in a row, don't change. Prolly acronym. 262 } else { 263 if(idx) 264 shouldSpace = true; // new word, add space 265 } 266 267 lastWasCap = true; 268 } else { 269 lastWasCap = false; 270 } 271 272 if(shouldSpace) { 273 buffer[bufferIndex++] = space; 274 if(bufferIndex == buffer.length) return name; // out of space, just give up, not that important 275 shouldSpace = false; 276 } 277 if(shouldCap) { 278 if(ch >= 'a' && ch <= 'z') 279 ch -= 32; 280 shouldCap = false; 281 } 282 if(allLowerCase && ch >= 'A' && ch <= 'Z') 283 ch += 32; 284 buffer[bufferIndex++] = ch; 285 } 286 return buffer[0 .. bufferIndex].idup; 287 } 288 289 import arsd.database; 290 /++ 291 292 +/ 293 void save(O)(ref O t, Database db) { 294 t.insert(db); 295 } 296 297 /++ 298 299 +/ 300 void insert(O)(ref O t, Database db) { 301 auto builder = new InsertBuilder; 302 builder.setTable(toTableName(O.stringof)); 303 304 static foreach(memberName; __traits(allMembers, O)) {{ 305 alias member = __traits(getMember, O, memberName); 306 static if(is(typeof(member) T)) { 307 308 static if(is(T == Nullable!P, P)) { 309 auto v = __traits(getMember, t, memberName); 310 if(v.isNull) 311 builder.addFieldWithSql(memberName, "NULL"); 312 else 313 builder.addVariable(memberName, v.value); 314 } else static if(is(T == int)) 315 builder.addVariable(memberName, __traits(getMember, t, memberName)); 316 else static if(is(T == Serial)) { 317 auto v = __traits(getMember, t, memberName).value; 318 if(v) { 319 builder.addVariable(memberName, v); 320 } else { 321 // skip and let it auto-fill 322 } 323 } else static if(is(T == string)) { 324 builder.addVariable(memberName, __traits(getMember, t, memberName)); 325 } else static if(is(T == double)) 326 builder.addVariable(memberName, __traits(getMember, t, memberName)); 327 else static if(is(T == bool)) 328 builder.addVariable(memberName, __traits(getMember, t, memberName)); 329 else static if(is(T == Timestamp)) { 330 auto v = __traits(getMember, t, memberName).value; 331 if(v.length) 332 builder.addVariable(memberName, v); 333 } else static if(is(T == enum)) 334 builder.addVariable(memberName, cast(int) __traits(getMember, t, memberName)); 335 } 336 }} 337 338 import std.conv; 339 version(dbgenerate_sqlite) { 340 builder.execute(db); 341 foreach(row; db.query("SELECT max(id) FROM " ~ toTableName(O.stringof))) 342 t.id.value = to!int(row[0]); 343 } else { 344 static if (__traits(hasMember, O, "id")) 345 { 346 foreach(row; builder.execute(db, "RETURNING id")) // FIXME: postgres-ism 347 t.id.value = to!int(row[0]); 348 } 349 else 350 { 351 builder.execute(db); 352 } 353 } 354 } 355 356 // Check that insert doesn't require an `id` 357 unittest 358 { 359 static struct NoPK 360 { 361 int a; 362 } 363 364 alias test = insert!NoPK; 365 } 366 /// 367 class RecordNotFoundException : Exception { 368 this() { super("RecordNotFoundException"); } 369 } 370 371 /++ 372 Returns a given struct populated from the database. Assumes types known to this module. 373 374 MyItem item = db.find!(MyItem.id)(3); 375 376 If you just give a type, it assumes the relevant index is "id". 377 378 +/ 379 auto find(alias T)(Database db, int id) { 380 381 // FIXME: if T is an index, search by it. 382 // if it is unique, return an individual item. 383 // if not, return the array 384 385 foreach(record; db.query("SELECT * FROM " ~ toTableName(T.stringof) ~ " WHERE id = ?", id)) { 386 T t; 387 populateFromDbRow(t, record); 388 389 return t; 390 // if there is ever a second record, that's a wtf, but meh. 391 } 392 throw new RecordNotFoundException(); 393 } 394 395 private void populateFromDbRow(T)(ref T t, Row record) { 396 foreach(field, value; record) { 397 sw: switch(field) { 398 static foreach(const idx, alias mem; T.tupleof) { 399 case __traits(identifier, mem): 400 populateFromDbVal(t.tupleof[idx], value); 401 break sw; 402 } 403 default: 404 // intentionally blank 405 } 406 } 407 } 408 409 private void populateFromDbVal(V)(ref V val, string value) { 410 import std.conv; 411 static if(is(V == Constraint!constraintSql, string constraintSql)) { 412 413 } else static if(is(V == Nullable!P, P)) { 414 // FIXME 415 if(value.length && value != "null") { 416 val.isNull = false; 417 val.value = to!P(value); 418 } 419 } else static if(is(V == bool)) { 420 val = value == "t" || value == "1" || value == "true"; 421 } else static if(is(V == int) || is(V == string) || is(V == double)) { 422 val = to!V(value); 423 } else static if(is(V == enum)) { 424 val = cast(V) to!int(value); 425 } else static if(is(V == Timestamp)) { 426 val.value = value; 427 } else static if(is(V == Serial)) { 428 val.value = to!int(value); 429 } 430 } 431 432 unittest 433 { 434 static struct SomeStruct 435 { 436 int a; 437 void foo() {} 438 int b; 439 } 440 441 auto rs = new PredefinedResultSet( 442 [ "a", "b" ], 443 [ Row([ "1", "2" ]) ] 444 ); 445 446 SomeStruct s; 447 populateFromDbRow(s, rs.front); 448 449 assert(s.a == 1); 450 assert(s.b == 2); 451 } 452 /++ 453 Gets all the children of that type. Specifically, it looks in T for a ForeignKey referencing B and queries on that. 454 455 To do a join through a many-to-many relationship, you could get the children of the join table, then get the children of that... 456 Or better yet, use real sql. This is more intended to get info where there is one parent row and then many child 457 rows, not for a combined thing. 458 +/ 459 QueryBuilderHelper!(T[]) children(T, B)(B base) { 460 int countOfAssociations() { 461 int count = 0; 462 static foreach(memberName; __traits(allMembers, T)) 463 static foreach(attr; __traits(getAttributes, __traits(getMember, T, memberName))) {{ 464 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 465 static if(is(attr.ReferencedTable == B)) 466 count++; 467 } 468 }} 469 return count; 470 } 471 static assert(countOfAssociations() == 1, T.stringof ~ " does not have exactly one foreign key of type " ~ B.stringof); 472 string keyName() { 473 static foreach(memberName; __traits(allMembers, T)) 474 static foreach(attr; __traits(getAttributes, __traits(getMember, T, memberName))) {{ 475 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 476 static if(is(attr.ReferencedTable == B)) 477 return memberName; 478 } 479 }} 480 } 481 482 // return QueryBuilderHelper!(T[])(toTableName(T.stringof)).where!(mixin(keyName ~ " => base.id")); 483 484 // changing mixin cuz of regression in dmd 2.088 485 mixin("return QueryBuilderHelper!(T[])(toTableName(T.stringof)).where!("~keyName ~ " => base.id);"); 486 } 487 488 /++ 489 Finds the single row associated with a foreign key in `base`. 490 491 `T` is used to find the key, unless ambiguous, in which case you must pass `key`. 492 493 To do a join through a many-to-many relationship, go to [children] or use real sql. 494 +/ 495 T associated(B, T, string key = null)(B base, Database db) { 496 int countOfAssociations() { 497 int count = 0; 498 static foreach(memberName; __traits(allMembers, B)) 499 static foreach(attr; __traits(getAttributes, __traits(getMember, B, memberName))) { 500 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 501 static if(is(attr.ReferencedTable == T)) 502 static if(key is null || key == memberName) 503 count++; 504 } 505 } 506 return count; 507 } 508 509 static if(key is null) { 510 enum coa = countOfAssociations(); 511 static assert(coa != 0, B.stringof ~ " has no association of type " ~ T); 512 static assert(coa == 1, B.stringof ~ " has multiple associations of type " ~ T ~ "; please specify the key you want"); 513 static foreach(memberName; __traits(allMembers, B)) 514 static foreach(attr; __traits(getAttributes, __traits(getMember, B, memberName))) { 515 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 516 static if(is(attr.ReferencedTable == T)) 517 return db.find!T(__traits(getMember, base, memberName)); 518 } 519 } 520 } else { 521 static assert(countOfAssociations() == 1, B.stringof ~ " does not have a key named " ~ key ~ " of type " ~ T); 522 static foreach(attr; __traits(getAttributes, __traits(getMember, B, memberName))) { 523 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 524 static if(is(attr.ReferencedTable == T)) { 525 return db.find!T(__traits(getMember, base, key)); 526 } 527 } 528 } 529 assert(0); 530 } 531 } 532 533 534 /++ 535 It will return an aggregate row with a member of type of each table in the join. 536 537 Could do an anonymous object for other things in the sql... 538 +/ 539 auto join(TableA, TableB, ThroughTable = void)() {} 540 541 /++ 542 543 +/ 544 struct QueryBuilderHelper(T) { 545 static if(is(T == R[], R)) 546 alias TType = R; 547 else 548 alias TType = T; 549 550 SelectBuilder selectBuilder; 551 552 this(string tableName) { 553 selectBuilder = new SelectBuilder(); 554 selectBuilder.table = tableName; 555 selectBuilder.fields = ["*"]; 556 } 557 558 T execute(Database db) { 559 selectBuilder.db = db; 560 static if(is(T == R[], R)) { 561 562 } else { 563 selectBuilder.limit = 1; 564 } 565 566 T ret; 567 bool first = true; 568 foreach(row; db.query(selectBuilder.toString())) { 569 TType t; 570 populateFromDbRow(t, row); 571 572 static if(is(T == R[], R)) 573 ret ~= t; 574 else { 575 if(first) { 576 ret = t; 577 first = false; 578 } else { 579 assert(0); 580 } 581 } 582 } 583 return ret; 584 } 585 586 /// 587 typeof(this) orderBy(string criterion)() { 588 string name() { 589 int idx = 0; 590 while(idx < criterion.length && criterion[idx] != ' ') 591 idx++; 592 return criterion[0 .. idx]; 593 } 594 595 string direction() { 596 int idx = 0; 597 while(idx < criterion.length && criterion[idx] != ' ') 598 idx++; 599 import std.string; 600 return criterion[idx .. $].strip; 601 } 602 603 static assert(is(typeof(__traits(getMember, TType, name()))), TType.stringof ~ " has no field " ~ name()); 604 static assert(direction().length == 0 || direction() == "ASC" || direction() == "DESC", "sort direction must be empty, ASC, or DESC"); 605 606 selectBuilder.orderBys ~= criterion; 607 return this; 608 } 609 } 610 611 QueryBuilderHelper!(T[]) from(T)() { 612 return QueryBuilderHelper!(T[])(toTableName(T.stringof)); 613 } 614 615 /// ditto 616 template where(conditions...) { 617 Qbh where(Qbh)(Qbh this_, string[] sqlCondition...) { 618 assert(this_.selectBuilder !is null); 619 620 static string extractName(string s) { 621 if(s.length == 0) assert(0); 622 auto i = s.length - 1; 623 while(i) { 624 if(s[i] == ')') { 625 // got to close paren, now backward to non-identifier char to get name 626 auto end = i; 627 while(i) { 628 if(s[i] == ' ') 629 return s[i + 1 .. end]; 630 i--; 631 } 632 assert(0); 633 } 634 i--; 635 } 636 assert(0); 637 } 638 639 static foreach(idx, cond; conditions) {{ 640 // I hate this but __parameters doesn't work here for some reason 641 // see my old thread: https://forum.dlang.org/post/awjuoemsnmxbfgzhgkgx@forum.dlang.org 642 enum name = extractName(typeof(cond!int).stringof); 643 auto value = cond(null); 644 645 // FIXME: convert the value as necessary 646 static if(is(typeof(value) == Serial)) 647 auto dbvalue = value.value; 648 else static if(is(typeof(value) == enum)) 649 auto dbvalue = cast(int) value; 650 else 651 auto dbvalue = value; 652 653 import std.conv; 654 655 static assert(is(typeof(__traits(getMember, Qbh.TType, name))), Qbh.TType.stringof ~ " has no member " ~ name); 656 static if(is(typeof(__traits(getMember, Qbh.TType, name)) == int)) { 657 static if(is(typeof(value) : const(int)[])) { 658 string s; 659 foreach(v; value) { 660 if(s.length) s ~= ", "; 661 s ~= to!string(v); 662 } 663 this_.selectBuilder.wheres ~= name ~ " IN (" ~ s ~ ")"; 664 } else { 665 static assert(is(typeof(value) : const(int)) || is(typeof(value) == Serial), Qbh.TType.stringof ~ " is a integer key, but you passed an incompatible " ~ typeof(value).stringof); 666 667 auto placeholder = "?_internal" ~ to!string(idx); 668 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 669 this_.selectBuilder.setVariable(placeholder, dbvalue); 670 } 671 } else static if(is(typeof(__traits(getMember, Qbh.TType, name)) == Nullable!int)) { 672 static if(is(typeof(value) : const(int)[])) { 673 string s; 674 foreach(v; value) { 675 if(s.length) s ~= ", "; 676 s ~= to!string(v); 677 } 678 this_.selectBuilder.wheres ~= name ~ " IN (" ~ s ~ ")"; 679 } else { 680 static assert(is(typeof(value) : const(int)) || is(typeof(value) == Serial), Qbh.TType.stringof ~ " is a integer key, but you passed an incompatible " ~ typeof(value).stringof); 681 682 auto placeholder = "?_internal" ~ to!string(idx); 683 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 684 this_.selectBuilder.setVariable(placeholder, dbvalue); 685 } 686 } else static if(is(typeof(__traits(getMember, Qbh.TType, name)) == Serial)) { 687 static if(is(typeof(value) : const(int)[])) { 688 string s; 689 foreach(v; value) { 690 if(s.length) s ~= ", "; 691 s ~= to!string(v); 692 } 693 this_.selectBuilder.wheres ~= name ~ " IN (" ~ s ~ ")"; 694 } else { 695 static assert(is(typeof(value) : const(int)) || is(typeof(value) == Serial), Qbh.TType.stringof ~ " is a integer key, but you passed an incompatible " ~ typeof(value).stringof); 696 697 auto placeholder = "?_internal" ~ to!string(idx); 698 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 699 this_.selectBuilder.setVariable(placeholder, dbvalue); 700 } 701 702 703 } else { 704 static assert(is(typeof(__traits(getMember, Qbh.TType, name)) == typeof(value)), Qbh.TType.stringof ~ "." ~ name ~ " is not of type " ~ typeof(value).stringof); 705 706 auto placeholder = "?_internal" ~ to!string(idx); 707 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 708 this_.selectBuilder.setVariable(placeholder, dbvalue); 709 } 710 }} 711 712 this_.selectBuilder.wheres ~= sqlCondition; 713 return this_; 714 } 715 }