1 // NOTE: I haven't even tried to use this for a test yet! 2 // It's probably godawful, if it works at all. 3 /// 4 module arsd.mssql; 5 6 version(Windows): 7 8 pragma(lib, "odbc32"); 9 10 public import arsd.database; 11 12 import std.string; 13 import std.exception; 14 15 import core.sys.windows.sql; 16 import core.sys.windows.sqlext; 17 18 class MsSql : Database { 19 // dbname = name is probably the most common connection string 20 this(string connectionString) { 21 SQLAllocHandle(SQL_HANDLE_ENV, cast(void*)SQL_NULL_HANDLE, &env); 22 enforce(env !is null); 23 scope(failure) 24 SQLFreeHandle(SQL_HANDLE_ENV, env); 25 SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, cast(void *) SQL_OV_ODBC3, 0); 26 SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); 27 scope(failure) 28 SQLFreeHandle(SQL_HANDLE_DBC, conn); 29 enforce(conn !is null); 30 31 auto ret = SQLDriverConnect( 32 conn, null, cast(ubyte*)connectionString.ptr, SQL_NTS, 33 null, 0, null, 34 SQL_DRIVER_NOPROMPT ); 35 36 if ((ret != SQL_SUCCESS_WITH_INFO) && (ret != SQL_SUCCESS)) 37 throw new DatabaseException("Unable to connect to ODBC object: " ~ getSQLError(SQL_HANDLE_DBC, conn)); // FIXME: print error 38 39 //query("SET NAMES 'utf8'"); // D does everything with utf8 40 } 41 42 ~this() { 43 SQLDisconnect(conn); 44 SQLFreeHandle(SQL_HANDLE_DBC, conn); 45 SQLFreeHandle(SQL_HANDLE_ENV, env); 46 } 47 48 override void startTransaction() { 49 query("START TRANSACTION"); 50 } 51 52 // possible fixme, idk if this is right 53 override string sysTimeToValue(SysTime s) { 54 return "'" ~ escape(s.toISOExtString()) ~ "'"; 55 } 56 57 ResultSet queryImpl(string sql, Variant[] args...) { 58 sql = escapedVariants(this, sql, args); 59 60 // this is passed to MsSqlResult to control 61 SQLHSTMT statement; 62 auto returned = SQLAllocHandle(SQL_HANDLE_STMT, conn, &statement); 63 64 enforce(returned == SQL_SUCCESS); 65 66 returned = SQLExecDirect(statement, cast(ubyte*)sql.ptr, cast(SQLINTEGER) sql.length); 67 if(returned != SQL_SUCCESS) 68 throw new DatabaseException(getSQLError(SQL_HANDLE_STMT, statement)); 69 70 return new MsSqlResult(statement); 71 } 72 73 string escape(string sqlData) { // FIXME 74 return ""; //FIX ME 75 //return ret.replace("'", "''"); 76 } 77 78 79 string error() { 80 return null; // FIXME 81 } 82 83 private: 84 SQLHENV env; 85 SQLHDBC conn; 86 } 87 88 class MsSqlResult : ResultSet { 89 // name for associative array to result index 90 int getFieldIndex(string field) { 91 if(mapping is null) 92 makeFieldMapping(); 93 if (field !in mapping) 94 return -1; 95 return mapping[field]; 96 } 97 98 99 string[] fieldNames() { 100 if(mapping is null) 101 makeFieldMapping(); 102 return columnNames; 103 } 104 105 // this is a range that can offer other ranges to access it 106 bool empty() { 107 return isEmpty; 108 } 109 110 Row front() { 111 return row; 112 } 113 114 void popFront() { 115 if(!isEmpty) 116 fetchNext; 117 } 118 119 override size_t length() 120 { 121 return 1; //FIX ME 122 } 123 124 this(SQLHSTMT statement) { 125 this.statement = statement; 126 127 SQLSMALLINT info; 128 SQLNumResultCols(statement, &info); 129 numFields = info; 130 131 fetchNext(); 132 } 133 134 ~this() { 135 SQLFreeHandle(SQL_HANDLE_STMT, statement); 136 } 137 138 private: 139 SQLHSTMT statement; 140 int[string] mapping; 141 string[] columnNames; 142 int numFields; 143 144 bool isEmpty; 145 146 Row row; 147 148 void fetchNext() { 149 if(isEmpty) 150 return; 151 152 if(SQLFetch(statement) == SQL_SUCCESS) { 153 Row r; 154 r.resultSet = this; 155 string[] row; 156 157 SQLLEN ptr; 158 159 for(int i = 0; i < numFields; i++) { 160 string a; 161 162 more: 163 SQLCHAR[1024] buf; 164 if(SQLGetData(statement, cast(ushort)(i+1), SQL_CHAR, buf.ptr, 1024, &ptr) != SQL_SUCCESS) 165 throw new DatabaseException("get data: " ~ getSQLError(SQL_HANDLE_STMT, statement)); 166 167 assert(ptr != SQL_NO_TOTAL); 168 if(ptr == SQL_NULL_DATA) 169 a = null; 170 else { 171 a ~= cast(string) buf[0 .. ptr > 1024 ? 1024 : ptr].idup; 172 ptr -= ptr > 1024 ? 1024 : ptr; 173 if(ptr) 174 goto more; 175 } 176 row ~= a; 177 } 178 179 r.row = row; 180 this.row = r; 181 } else { 182 isEmpty = true; 183 } 184 } 185 186 void makeFieldMapping() { 187 for(int i = 0; i < numFields; i++) { 188 SQLSMALLINT len; 189 SQLCHAR[1024] buf; 190 auto ret = SQLDescribeCol(statement, 191 cast(ushort)(i+1), 192 cast(ubyte*)buf.ptr, 193 1024, 194 &len, 195 null, null, null, null); 196 if (ret != SQL_SUCCESS) 197 throw new DatabaseException("Field mapping error: " ~ getSQLError(SQL_HANDLE_STMT, statement)); 198 199 string a = cast(string) buf[0 .. len].idup; 200 201 columnNames ~= a; 202 mapping[a] = i; 203 } 204 205 } 206 } 207 208 private string getSQLError(short handletype, SQLHANDLE handle) 209 { 210 char[32] sqlstate; 211 char[256] message; 212 SQLINTEGER nativeerror=0; 213 SQLSMALLINT textlen=0; 214 auto ret = SQLGetDiagRec(handletype, handle, 1, 215 cast(ubyte*)sqlstate.ptr, 216 cast(int*)&nativeerror, 217 cast(ubyte*)message.ptr, 218 256, 219 &textlen); 220 221 return message.idup; 222 } 223 224 /* 225 import std.stdio; 226 void main() { 227 //auto db = new MsSql("Driver={SQL Server};Server=<host>[\\<optional-instance-name>]>;Database=dbtest;Trusted_Connection=Yes"); 228 auto db = new MsSql("Driver={SQL Server Native Client 10.0};Server=<host>[\\<optional-instance-name>];Database=dbtest;Trusted_Connection=Yes") 229 230 db.query("INSERT INTO users (id, name) values (30, 'hello mang')"); 231 232 foreach(line; db.query("SELECT * FROM users")) { 233 writeln(line[0], line["name"]); 234 } 235 } 236 */