1 // NOTE: I haven't even tried to use this for a test yet!
2 // It's probably godawful, if it works at all.
3 ///
4 module arsd.mssql;
5 
6 version(Windows):
7 
8 pragma(lib, "odbc32");
9 
10 public import arsd.database;
11 
12 import std.string;
13 import std.exception;
14 
15 import core.sys.windows.sql;
16 import core.sys.windows.sqlext;
17 
18 class MsSql : Database {
19 	// dbname = name  is probably the most common connection string
20 	this(string connectionString) {
21 		SQLAllocHandle(SQL_HANDLE_ENV, cast(void*)SQL_NULL_HANDLE, &env);
22 		enforce(env !is null);
23 		scope(failure)
24 			SQLFreeHandle(SQL_HANDLE_ENV, env);
25 		SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, cast(void *) SQL_OV_ODBC3, 0);
26 		SQLAllocHandle(SQL_HANDLE_DBC, env, &conn);
27 		scope(failure)
28 			SQLFreeHandle(SQL_HANDLE_DBC, conn);
29 		enforce(conn !is null);
30 
31 		auto ret = SQLDriverConnect(
32 			conn, null, cast(ubyte*)connectionString.ptr, SQL_NTS,
33 			null, 0, null,
34 			SQL_DRIVER_NOPROMPT );
35 
36 		if ((ret != SQL_SUCCESS_WITH_INFO) && (ret != SQL_SUCCESS))
37 			throw new DatabaseException("Unable to connect to ODBC object: " ~ getSQLError(SQL_HANDLE_DBC, conn)); // FIXME: print error
38 
39 		//query("SET NAMES 'utf8'"); // D does everything with utf8
40 	}
41 
42 	~this() {
43 		SQLDisconnect(conn);
44 		SQLFreeHandle(SQL_HANDLE_DBC, conn);
45 		SQLFreeHandle(SQL_HANDLE_ENV, env);
46 	}
47 
48 	override void startTransaction() {
49 		query("START TRANSACTION");
50 	}
51 
52 	// possible fixme, idk if this is right
53 	override string sysTimeToValue(SysTime s) {
54 		return "'" ~ escape(s.toISOExtString()) ~ "'";
55 	}
56 
57 	ResultSet queryImpl(string sql, Variant[] args...) {
58 		sql = escapedVariants(this, sql, args);
59 
60 		// this is passed to MsSqlResult to control
61 		SQLHSTMT statement;
62 		auto returned = SQLAllocHandle(SQL_HANDLE_STMT, conn, &statement);
63 
64 		enforce(returned == SQL_SUCCESS);
65 
66 		returned = SQLExecDirect(statement, cast(ubyte*)sql.ptr, cast(SQLINTEGER) sql.length);
67 		if(returned != SQL_SUCCESS)
68 			throw new DatabaseException(getSQLError(SQL_HANDLE_STMT, statement));
69 
70 		return new MsSqlResult(statement);
71 	}
72 
73 	string escape(string sqlData) { // FIXME
74 		return ""; //FIX ME
75 		//return ret.replace("'", "''");
76 	}
77 
78 
79 	string error() {
80 		return null; // FIXME
81 	}
82 
83 	private:
84 		SQLHENV env;
85 		SQLHDBC conn;
86 }
87 
88 class MsSqlResult : ResultSet {
89 	// name for associative array to result index
90 	int getFieldIndex(string field) {
91 		if(mapping is null)
92 			makeFieldMapping();
93 		if (field !in mapping)
94 			return -1;
95 		return mapping[field];
96 	}
97 
98 
99 	string[] fieldNames() {
100 		if(mapping is null)
101 			makeFieldMapping();
102 		return columnNames;
103 	}
104 
105 	// this is a range that can offer other ranges to access it
106 	bool empty() {
107 		return isEmpty;
108 	}
109 
110 	Row front() {
111 		return row;
112 	}
113 
114 	void popFront() {
115 		if(!isEmpty)
116 			fetchNext;
117 	}
118 
119 	override size_t length()
120 	{
121 		return 1; //FIX ME
122 	}
123 	
124 	this(SQLHSTMT statement) {
125 		this.statement = statement;
126 
127 		SQLSMALLINT info;
128 		SQLNumResultCols(statement, &info);
129 		numFields = info;
130 
131 		fetchNext();
132 	}
133 
134 	~this() {
135 		SQLFreeHandle(SQL_HANDLE_STMT, statement);
136 	}
137 
138 	private:
139 		SQLHSTMT statement;
140 		int[string] mapping;
141 		string[] columnNames;
142 		int numFields;
143 
144 		bool isEmpty;
145 
146 		Row row;
147 
148 		void fetchNext() {
149 			if(isEmpty)
150 				return;
151 
152 			if(SQLFetch(statement) == SQL_SUCCESS) {
153 				Row r;
154 				r.resultSet = this;
155 				string[] row;
156 
157 				SQLLEN ptr;
158 
159 				for(int i = 0; i < numFields; i++) {
160 					string a;
161 
162 					more:
163 				        SQLCHAR[1024] buf;
164 					if(SQLGetData(statement, cast(ushort)(i+1), SQL_CHAR, buf.ptr, 1024, &ptr) != SQL_SUCCESS)
165 						throw new DatabaseException("get data: " ~ getSQLError(SQL_HANDLE_STMT, statement));
166 
167 					assert(ptr != SQL_NO_TOTAL);
168 					if(ptr == SQL_NULL_DATA)
169 						a = null;
170 					else {
171 						a ~= cast(string) buf[0 .. ptr > 1024 ? 1024 : ptr].idup;
172 						ptr -= ptr > 1024 ? 1024 : ptr;
173 						if(ptr)
174 							goto more;
175 					}
176 					row ~= a;
177 				}
178 
179 				r.row = row;
180 				this.row = r;
181 			} else {
182 				isEmpty = true;
183 			}
184 		}
185 
186 		void makeFieldMapping() {
187 			for(int i = 0; i < numFields; i++) {
188 				SQLSMALLINT len;
189 				SQLCHAR[1024] buf;
190 				auto ret = SQLDescribeCol(statement,
191 					cast(ushort)(i+1),
192 					cast(ubyte*)buf.ptr,
193 					1024,
194 					&len,
195 					null, null, null, null);
196 				if (ret != SQL_SUCCESS)
197 					throw new DatabaseException("Field mapping error: " ~ getSQLError(SQL_HANDLE_STMT, statement));
198 				
199 				string a = cast(string) buf[0 .. len].idup;
200 
201 				columnNames ~= a;
202 				mapping[a] = i;
203 			}
204 
205 		}
206 }
207 
208 private string getSQLError(short handletype, SQLHANDLE handle)
209 {
210 	char[32] sqlstate;
211 	char[256] message; 
212 	SQLINTEGER nativeerror=0;
213 	SQLSMALLINT textlen=0;			
214 	auto ret = SQLGetDiagRec(handletype, handle, 1, 
215 			cast(ubyte*)sqlstate.ptr, 
216 			cast(int*)&nativeerror, 
217 			cast(ubyte*)message.ptr, 
218 			256, 
219 			&textlen);
220 
221 	return message.idup;
222 }
223 
224 /*
225 import std.stdio;
226 void main() {
227 	//auto db = new MsSql("Driver={SQL Server};Server=<host>[\\<optional-instance-name>]>;Database=dbtest;Trusted_Connection=Yes");
228 	auto db = new MsSql("Driver={SQL Server Native Client 10.0};Server=<host>[\\<optional-instance-name>];Database=dbtest;Trusted_Connection=Yes")
229 
230 	db.query("INSERT INTO users (id, name) values (30, 'hello mang')");
231 
232 	foreach(line; db.query("SELECT * FROM users")) {
233 		writeln(line[0], line["name"]);
234 	}
235 }
236 */