1 /**
2   * CruddyORM is a simple object relational mapping library for Postgresql and vibe.d.
3   *
4   * Cruddiness:
5   *  - It only supports one database.
6   *  - It assumes that you only want to persist fields, not properties.
7   *  - It assumes that the id field is a UUID named `id`.
8   *  - It uses std.experimental.logger instead of vibe's logging. Vibe doesn't offer a topic-based
9   *    logging configuration, so that wouldn't let the ORM be stricter about its logging than
10   *    application code.
11   *  - It doesn't do connection pooling.
12   */
13 module cruddyorm;
14 
15 import core.time;
16 import datefmt;
17 import dpq2.conv.time;
18 import dpq2;
19 import std.conv;
20 import std.datetime;
21 import std.experimental.logger;
22 import std..string;
23 import std.traits;
24 import std.typecons;
25 import std.uuid;
26 import vibe.db.postgresql;
27 
28 alias Connection = LockedConnection!__Conn;
29 
30 /** Attribute to mark a field as transient (don't save it). */
31 struct Transient {}
32 
33 enum Conflict
34 {
35     noop,
36     update,
37     error
38 }
39 
40 /**
41   * The logger for the ORM.
42   * This is public so you can set its level separately.
43   * By default, it logs at warning and above.
44   */
45 __gshared MultiLogger dblog;
46 
47 /**
48   * The connection string to connect to postgresql.
49   */
50 __gshared string connectionString;
51 
52 private PostgresClient pg;
53 
54 shared static this()
55 {
56     // So I can manipulate its log level separately
57     dblog = new MultiLogger(LogLevel.warning);
58     dblog.insertLogger("parent", sharedLog);
59 }
60 
61 
62 /**
63   * Update a DB row.
64   */
65 void update(T)(T val)
66 {
67     static immutable string cmd = updateText!T();
68     dblog.trace(cmd);
69     dblog.tracef("param: %s", val);
70     auto params = val.toParams(true);
71     params.sqlCommand = cmd;
72     inConnection!(conn => conn.execParams(params));
73 }
74 
75 /**
76   * Update a DB row with an existing connection.
77   */
78 void updateConn(T)(Connection conn, T val)
79 {
80     static immutable string cmd = updateText!T();
81     dblog.trace(cmd);
82     dblog.tracef("param: %s", val);
83     auto params = val.toParams(true);
84     params.sqlCommand = cmd;
85     conn.execParams(params);
86 }
87 
88 /**
89   * Insert a new DB row.
90   */
91 void insert(T, Conflict onConflict = Conflict.error)(ref T val)
92 {
93     static if (is (typeof(val.id) == UUID))
94     {
95         val.id = randomUUID();
96     }
97     static immutable string cmd = insertText!(T, onConflict);
98     dblog.trace(cmd);
99     dblog.tracef("param: %s", val);
100     auto params = val.toParams(false);
101     params.sqlCommand = cmd;
102     inConnection!(conn => conn.execParams(params));
103 }
104 
105 void insertConn(T)(ref scope Connection conn, ref T val)
106 {
107     if (conn is null)
108     {
109         errorf("querying with existing connection: connection is null");
110     }
111     else if (conn.status != 0)
112     {
113         errorf("unexpected status from connection: %s", conn.status);
114     }
115     if (conn is null)
116     {
117         insert(val);
118     }
119     static if (is (typeof(val.id) == UUID))
120     {
121         val.id = randomUUID();
122     }
123     static immutable string cmd = insertText!(T, Conflict.error);
124     auto params = val.toParams(false);
125     params.sqlCommand = cmd;
126     conn.execParams(params);
127 }
128 
129 auto queryConn(T = void, Args...)(ref scope Connection conn, string cmd, Args args...)
130 {
131     if (conn is null)
132     {
133         errorf("querying with existing connection: connection is null");
134     }
135     else if (conn.status != 0)
136     {
137         errorf("unexpected status from connection: %s", conn.status);
138     }
139     dblog.tracef("query: [%s] args: %s", cmd, args);
140     QueryParams params;
141     params.argsVariadic(args);
142     params.sqlCommand = cmd;
143     auto result = conn.execParams(params);
144     static if (!is(T == void))
145     {
146         auto vals = new T[result.length];
147         foreach (i, ref v; vals)
148         {
149             v = parse!T(result[i]);
150         }
151         return vals;
152     }
153 }
154 
155 /**
156   * Update the value if it's already in the database, otherwise insert it.
157   *
158   * This relies on newly created items not having IDs. This might not work well for stuff that's got
159   * complex creation steps; in that case, you need to manually call insert.
160   */
161 void saveOrUpdate(T)(ref T val)
162 {
163     if (val.id == UUID.init)
164     {
165         val.id = randomUUID;
166         insert(val);
167     }
168     else
169     {
170         update(val);
171     }
172 }
173 
174 /**
175   * Delete something from the database.
176   */
177 void dbdelete(T)(T val)
178 {
179     dbdelete!T(val.id);
180 }
181 
182 /**
183   * Delete something from the database.
184   */
185 void dbdelete(T)(UUID id)
186 {
187     query!void("DELETE FROM " ~ T.stringof ~ "s WHERE id = ?", id);
188 }
189 
190 Nullable!T fetch(T)(UUID id)
191 {
192     enum cmd = `SELECT * FROM ` ~ T.stringof.toLower ~ `s WHERE id = $1`;
193     QueryParams params;
194     params.argsVariadic(id);
195     params.sqlCommand = cmd;
196     auto result = inConnection!(conn => conn.execParams(params));
197     if (result.length > 0)
198     {
199         return Nullable!T(parse!T(result[0]));
200     }
201     return Nullable!T.init;
202 }
203 
204 /**
205   * Execute a query, parsing the results automatically.
206   */
207 auto query(T = void, Args...)(string cmd, Args args)
208 {
209     dblog.tracef("query: [%s] args: %s", cmd, args);
210     QueryParams params;
211     params.argsVariadic(args);
212     params.sqlCommand = cmd;
213     auto result = inConnection!(conn => conn.execParams(params));
214     dblog.tracef("finished query");
215     static if (!is(T == void))
216     {
217         auto vals = new T[result.length];
218         foreach (i, ref v; vals)
219         {
220             dblog.tracef("parsed result %s", i);
221             v = parse!T(result[i]);
222         }
223         return vals;
224     }
225 }
226 
227 // Parse a DB row out into a class or struct instance.
228 T parse(T)(immutable Row row) if (is(T == class) || is(T == struct))
229 {
230     import std.traits;
231     import std.datetime;
232     import std.uuid;
233 
234     T val;
235     static if (is(T == class))
236     {
237         val = new T();
238     }
239 
240     foreach (mm; FieldNameTuple!T)
241     {
242         const m = mm;
243         alias FT = typeof(__traits(getMember, T, m));
244 
245         bool found = false;
246         string normalName = m;
247         for (int i = 0; i < row.length; i++)
248         {
249             import std.uni : sicmp;
250 
251             auto name = row.columnName(i);
252             if (sicmp(name, m) == 0)
253             {
254                 normalName = name;
255                 found = true;
256                 break;
257             }
258         }
259         if (!found)
260         {
261             continue;
262         }
263 
264         auto cell = row[normalName];
265         if (cell.isNull)
266         {
267             // should have default value here
268             continue;
269         }
270 
271         static if (isFunction!FT)
272         {
273             continue;
274         }
275         else static if (is(FT == UUID))
276         {
277             auto s = cell.as!UUID;
278             __traits(getMember, val, m) = s;
279         }
280         else static if (is(FT == SysTime))
281         {
282             auto sansTZ = cell.as!TimeStampWithoutTZ;
283             auto st = SysTime(sansTZ.dateTime, sansTZ.fracSec.hnsecs.hnsecs, UTC());
284             __traits(getMember, val, m) = st;
285         }
286         else static if (is(FT == string))
287         {
288             __traits(getMember, val, m) = cell.as!string;
289         }
290         else static if (is(FT == Duration))
291         {
292             __traits(getMember, val, m) = dur!"seconds"(cell.as!int);
293         }
294         else static if (is(FT == int))
295         {
296             __traits(getMember, val, m) = cell.as!int;
297         }
298         else static if (is(FT == double))
299         {
300             __traits(getMember, val, m) = cell.as!double;
301         }
302         else static if (is(FT == bool))
303         {
304             __traits(getMember, val, m) = cell.as!bool;
305         }
306         else
307         {
308             static assert(false, "can't deserialize " ~ FT.stringof ~ " from DB");
309         }
310     }
311     return val;
312 }
313 
314 // Parse a DB row out into a struct.
315 T parse(T)(immutable Row row) if (!is(T == class) && !is(T == struct))
316 {
317     return row[0].as!T;
318 }
319 
320 // Convert a thingy into a query parameter set.
321 QueryParams toParams(T)(T val, bool trailingId)
322 {
323     // I suspect I have too much space here.
324     Value[__traits(derivedMembers, T).length + 1] v;
325     int i = 0;
326     foreach (m; __traits(derivedMembers, T))
327     {
328         alias FT = typeof(__traits(getMember, T, m));
329         string str;
330         static if (!isFunction!FT)
331         {
332             auto fieldVal = __traits(getMember, val, m);
333             static if (is(FT == SysTime))
334             {
335                 if (fieldVal == SysTime.init)
336                 {
337                     // TODO send default / null value?
338                 }
339                 else
340                 {
341                     auto fs = fieldVal.format(ISO8601FORMAT);
342                     v[i] = toValue(fs);
343                     dblog.infof("field %s value %s", i, fs);
344                 }
345             }
346             else static if (is(FT == Duration))
347             {
348                 auto secs =  fieldVal.total!("seconds");
349                 v[i] = toValue(cast(int)secs);
350             }
351             else static if (is(FT == string))
352             {
353                 v[i] = toValue(fieldVal);
354             }
355             else static if (isNumeric!FT)
356             {
357                 v[i] = toValue(fieldVal);
358             }
359             else
360             {
361                 v[i] = toValue(std.conv.to!string(fieldVal));
362             }
363             v[i].data();
364             i++;
365         }
366     }
367     if (trailingId)
368     {
369         static if (is (typeof(val.id)))
370         {
371             v[i] = toValue(val.id.to!string, ValueFormat.TEXT);
372             i++;
373         }
374         else
375         {
376             throw new Exception(
377                     "asked for trailing id for type " ~
378                     T.stringof ~ "with no trailing id");
379         }
380     }
381     QueryParams p;
382     p.args = v[0..i].dup;
383     return p;
384 }
385 
386 string updateText(T)()
387 {
388     string cmd = `UPDATE ` ~ T.stringof.toLower ~ `s SET `;
389     int i = 0;
390     foreach (m; __traits(derivedMembers, T))
391     {
392         alias FT = typeof(__traits(getMember, T, m));
393         static if (!isFunction!FT && !hasUDA!(__traits(getMember, T, m), Transient))
394         {
395             i++;
396             if (i > 1)
397             {
398                 cmd ~= `, `;
399             }
400             cmd ~= m;
401             cmd ~= ` = `;
402             static if (is(FT == UUID))
403             {
404                 cmd ~= `uuid($` ~ i.to!string ~ `)`;
405             }
406             else static if (is(FT == SysTime))
407             {
408                 cmd ~= '$';
409                 cmd ~= i.to!string;
410                 cmd ~= `::timestamp without time zone`;
411             }
412             else
413             {
414                 cmd ~= '$';
415                 cmd ~= i.to!string;
416             }
417         }
418     }
419     i++;
420     cmd ~= ` WHERE id = uuid($`;
421     cmd ~= i.to!string;
422     cmd ~= ")";
423     return cmd;
424 }
425 
426 string insertText(T, Conflict onConflict = Conflict.error)()
427 {
428     string cmd = `INSERT INTO ` ~ T.stringof.toLower ~ `s (`;
429     int i = 0;
430     string values = ``;
431     foreach (m; __traits(derivedMembers, T))
432     {
433         alias FT = typeof(__traits(getMember, T, m));
434         static if (!isFunction!FT && !hasUDA!(__traits(getMember, T, m), Transient))
435         {
436             i++;
437             if (i > 1)
438             {
439                 cmd ~= `, `;
440                 values ~= `, `;
441             }
442             cmd ~= m;
443             static if (is(FT == UUID))
444             {
445                 values ~= `uuid($` ~ i.to!string ~ `)`;
446             }
447             else static if (is(FT == SysTime))
448             {
449                 values ~= `$`;
450                 values ~= i.to!string;
451                 values ~= `::timestamp without time zone`;
452             }
453             else
454             {
455                 values ~= `$`;
456                 values ~= i.to!string;
457             }
458         }
459     }
460     cmd ~= `) VALUES (`;
461     cmd ~= values ~ `)`;
462     final switch (onConflict)
463     {
464         case Conflict.update:
465             cmd ~= ` ON CONFLICT DO UPDATE`;
466             break;
467         case Conflict.noop:
468             cmd ~= ` ON CONFLICT DO NOTHING`;
469             break;
470         case Conflict.error:
471             // default behavior
472             break;
473     }
474     return cmd;
475 }
476 
477 auto inConnection(alias fn)()
478 {
479     import std.exception : enforce;
480     enforce(connectionString, "No connection string provided");
481     if (pg is null)
482     {
483         pg = new PostgresClient(connectionString, 15);
484     }
485 
486     auto conn = pg.lockConnection;
487     scope (exit) delete conn;
488     return fn(conn);
489 }
490