1 
2 module mars.pgsql;
3 
4 import std.algorithm;
5 import std.conv;
6 import std.format;
7 import std.string;
8 import std.range;
9 import std.typecons;
10 import std.variant;
11 
12 import mars.defs;
13 import mars.msg : AuthoriseError, InsertError, DeleteError, RequestState;
14 version(unittest) import mars.starwars;
15 
16 import ddb.postgres;
17 import ddb.db;
18 import vibe.core.log;
19 
20 string insertIntoReturningParameter(const(Table) table)
21 {
22     int i = 1;
23     return "insert into %s values (%s) returning *".format(
24         table.name,
25         table.columns.map!( (c) => c.type == Type.serial? "default" : "$" ~ (i++).to!string).join(", "));
26 }
27 unittest {
28     auto sql = Table("bar", [Col("foo", Type.text, false), Col("baz", Type.text, false)],[],[]).insertIntoReturningParameter();
29     assert( sql == "insert into bar values ($1, $2) returning *", sql );
30     auto sql2 = Table("bar", [Col("w_id", Type.serial), Col("w", Type.text)], [0], []).insertIntoReturningParameter();
31     assert( sql2 == "insert into bar values (default, $1) returning *", sql2);
32 }
33 
34 string deleteFromParameter(const(Table) table)
35 {
36     return "delete from %s where %s".format(
37             table.name, 
38             zip(iota(0, table.pkCols.length), table.pkCols)
39                 .map!( (t) => t[1].name ~ " = $" ~ (t[0]+1).to!string)
40                 .join(" and "));
41 }
42 unittest {
43     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0, 1], []).deleteFromParameter();
44     assert( sql == "delete from bar where foo = $1 and bar = $2", sql);
45 }
46 
47 string updateFromParameters(const(Table) table)
48 {
49     immutable(Col)[] whereCols = table.pkCols.length >0? table.pkCols : table.columns;
50     int dollarIndex =1;
51     return "update %s set %s where %s".format(
52         table.name,
53         table.columns.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(", "),
54         whereCols.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(" and "));
55 }
56 unittest {
57     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0], []).updateFromParameters();
58     assert( sql == "update bar set foo = $1, bar = $2, baz = $3 where foo = $4", sql );
59 }
60 
61 struct DatabaseService {
62     string host;
63     ushort port;
64     string database;
65  
66     /**
67      * Returns: an instance of `Database` of null if can't connect or authenticate. Errors details in 'err' */
68     Database connect(string user, string password, ref AuthoriseError err) in {
69         assert(user && password);
70     } body {
71         Database db;
72         try {
73             db = new Database(host, database, user, password);
74             err = AuthoriseError.authorised;
75         }
76         catch(ServerErrorException e){
77             switch(e.code){
78                 case "28000": // role "user" does not exist
79                     logInfo("PostgreSQL role does not exist");
80                     err = AuthoriseError.wrongUsernameOrPassword;
81                     break;
82                 case "28P01": // password authentication failed for user "user"
83                     logInfo("PostgreSQL password authentication failed for user");
84                     err = AuthoriseError.wrongUsernameOrPassword;
85                     break;
86                 default:
87                     logWarn("S --- C | Unhandled PostgreSQL server error during connection!");
88                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
89                     err = AuthoriseError.unknownError;
90             }
91         }
92         catch(Exception e){
93             logWarn("S --- C | exception connecting to the PostgreSQL!");
94             logWarn("S --- C | %s", e);
95             err = AuthoriseError.unknownError;
96         }
97         assert( err != AuthoriseError.assertCheck);
98         return db;
99     }
100 }
101 
102 class Database
103 {
104     private this(string host, string database, string user, string password){
105         //if( db is null ){
106         //    db = new PostgresDB(["host": host, "database": database, "user": user, "password": password]);
107         //}
108         this.host = host; this.database = database; this.user = user; this.password = password;
109         conn_ = new PGConnection([ "host" : this.host, "database" : this.database, "user" : this.user, "password" : this.password ]);
110     }
111     private string host, database, user, password;
112 
113     private auto lockOpenConnection(){
114         try {
115             conn_.executeNonQuery("select true");
116         }
117         catch(Exception e){
118             logInfo("PostgreSQL connection coming from the pool seems closed: reopening it");
119             conn_ = new PGConnection([ "host" : this.host, "database" : this.database, "user" : this.user, "password" : this.password ]);
120         }
121         return conn_;
122     }
123 
124     void execute(const Select select)
125     {
126         auto conn = lockOpenConnection();
127         scope(exit) conn.close();
128         string s = `select %s from %s`.format(select.cols[0].name, select.tables[0].name);
129         auto q = conn.executeQuery(s); 
130     }
131 
132     void executeUnsafe(string sql){
133         auto conn = lockOpenConnection();
134         scope(exit) conn.close();
135         auto q = conn.executeQuery(sql);
136         foreach(v; q){
137             import std.stdio; writeln("-->", v);
138         }
139     }
140     T executeScalarUnsafe(T)(string sql){
141         auto conn = lockOpenConnection();
142         scope(exit) conn.close();
143         return conn.executeScalar!T(sql);
144     }
145 
146     // usato da sync per la sottoscrizione di query complesse
147     auto executeQueryUnsafe(string sql){
148         auto conn = lockOpenConnection();
149         scope(exit) conn.close();
150         return conn.executeQuery(sql);
151     }
152 
153     // usato da sync per la sottoscrizione di query complesse, con parametri
154     auto executeQueryUnsafe(string sql, Variant[string] parameters){
155         auto conn = lockOpenConnection();
156         scope(exit) conn.close();
157         // ... sort param names, transform names into a sequence of $1, $2
158         auto pgargs = xxx(sql, parameters);
159         // ... prepare the statement
160         auto cmd = new PGCommand(conn, pgargs[0]);
161         foreach(j, param; pgargs[1]){
162             // ... try to guess the PGType from the Variant typeinfo ...
163             auto pgType = toPGType(param.type);
164             switch(pgType) with (PGType){
165                 case TEXT:
166                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!string;
167                     break;
168                 case INT2:
169                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!short;
170                     break;
171                 case INT4:
172                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!int;
173                     break;
174                 default:
175                     assert(false, pgType.to!string);
176             }
177         }
178         return cmd.executeQuery();
179     }
180     version(unittest_starwars){ unittest {
181         auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
182         auto recordSet = db.executeQueryUnsafe("select * from planets where name = $name", ["name": Variant("Tatooine")]);
183         scope(exit) recordSet.close();
184         assert(recordSet.front[1].get!long == 120_000);
185     }}
186 
187     auto executeQueryUnsafe(Row)(string sql){
188         auto conn = lockOpenConnection();
189         scope(exit) conn.close();
190         return conn.executeQuery!Row(sql);
191     }
192 
193     auto executeInsert(immutable(Table) table, Row, )(Row record, ref InsertError err){
194         auto conn = lockOpenConnection();
195         scope(exit) conn.close();
196         enum sql = insertIntoReturningParameter(table);
197         auto cmd = new PGCommand(conn, sql);
198         addParameters!(table, Row, true)(cmd, record); // skip serial parameters
199         Row result;
200         try {
201             auto querySet = cmd.executeQuery!Row();
202             scope(exit) querySet.close();
203             result = querySet.front;
204             err = InsertError.inserted;
205         }
206         catch(ServerErrorException e){
207             switch(e.code){
208                 case "23505": //  duplicate key value violates unique constraint "<constraintname>" (for example in primary keys)
209                     err = InsertError.duplicateKeyViolations;
210                     break;
211                 default:
212                     logWarn("S --- C | Unhandled PostgreSQL server error during insertion!");
213                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
214                     err = InsertError.unknownError;
215             }
216         }
217         return result;
218     }
219 
220     void executeDelete(immutable(Table) table, Pk)(Pk pk, ref DeleteError err){
221         auto conn = lockOpenConnection();
222         scope(exit) conn.close();
223         enum sql = deleteFromParameter(table);
224         auto cmd = new PGCommand(conn, sql);
225 
226         addParameters!table(cmd, pk);
227         try {
228             cmd.executeNonQuery();
229             err = DeleteError.deleted;
230         }
231         catch(ServerErrorException e){
232             switch(e.code){
233                 default:
234                     logWarn("S --- C | Unhandled PostgreSQL server error during deletion!");
235                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
236                     err = DeleteError.unknownError;
237             }
238         }
239     }
240 
241     void executeUpdate(immutable(Table) table, Pk, Row)(Pk pk, Row record, ref RequestState state){
242         auto conn = lockOpenConnection();
243         scope(exit) conn.close();
244         enum sql = updateFromParameters(table);
245         auto cmd = new PGCommand(conn, sql);
246         addParameters!(table)(cmd, record);
247         short i = record.tupleof.length +1;
248         addParameters!table(cmd, pk, i);
249         try {
250             cmd.executeNonQuery();
251         }
252         catch(ServerErrorException e){
253             switch(e.code){
254                 default:
255                     logWarn("S --- C | Unhandled PostgreSQL server error during update!");
256                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
257                     state = RequestState.rejectedAsPGSqlError;
258             }
259         }
260     }
261 
262     private {
263         public PGConnection conn_;
264     }
265 }
266 
267 
268 private {
269     import mars.lexer;
270     import mars.sqldb;
271 
272     PGType toPGType(TypeInfo t){
273         if(t == typeid(bool)) return PGType.BOOLEAN;
274         if(t == typeid(int)) return PGType.INT4;
275         if(t == typeid(short)) return PGType.INT2;
276         if(t == typeid(string)) return PGType.TEXT;
277         if(t == typeid(float)) return PGType.FLOAT4;
278         if(t == typeid(double)) return PGType.FLOAT8;
279         if(t == typeid(ubyte[])) return PGType.BYTEA;
280         assert(false, t.to!string);
281     }
282 
283     PGType toPGType(Type t){
284         final switch(t) with(Type) {
285             case boolean: return PGType.BOOLEAN;
286             case integer: return PGType.INT4; // XXX check
287             case bigint: return PGType.INT8;
288             case smallint: return PGType.INT2; // XXX check 
289             case text: return PGType.TEXT;
290             case real_: return PGType.FLOAT4;
291             case doublePrecision: return PGType.FLOAT8;
292             case bytea: return PGType.BYTEA;
293             case smallserial: return PGType.INT2; // XXX check
294             case serial: return PGType.INT4; // there's not really a serial type in postgres
295 
296             case unknown:
297             case date:
298             case varchar: // varchar(n), tbd as column
299                 assert(false, t.to!string); // not implemented right now, catch at CT
300         }
301     }
302 
303     void addParameters(immutable(Table) table, Struct, bool noSerials = false, short tupleofIndex =0)(PGCommand cmd, Struct s, short paramIndex =1){
304         static if( is(Struct : asStruct!table) || Struct.tupleof.length == asStruct!(table).tupleof.length )
305         {
306             auto type =  table.columns[tupleofIndex].type;
307             static if( noSerials ) auto mustAdd = type != Type.serial && type != Type.smallserial;
308             else bool mustAdd = true;
309             if( mustAdd ) cmd.parameters.add(paramIndex, table.columns[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
310         }
311         else static if( is(Struct : asPkStruct!table) || Struct.tupleof.length == asPkStruct!(table).tupleof.length )
312         {
313             auto type =  table.columns[tupleofIndex].type;
314             static if( noSerials ) auto mustAdd = type != Type.serial && type != Type.smallserial;
315             else bool mustAdd = true;
316             if( mustAdd ) cmd.parameters.add(paramIndex, table.pkCols[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
317         }
318         else static assert(false);
319 
320         static if( s.tupleof.length > tupleofIndex+1 ) addParameters!(table, Struct, noSerials, tupleofIndex +1)(cmd, s, ++paramIndex);
321     }
322 
323 
324     string select(const(Select) stat){
325         return `select %s from %s`.format(
326             stat.cols.map!((c) => c.name).join(", "),
327             stat.tables.map!( (t) => t.name ).join(", "), /// XXX ho bisogno del nome dello schema QUA... refactory necessario
328             );
329     }
330     unittest {
331         auto s = starwarsSchema();
332         const sql = cast(Select)Parser([s], scan("select name from sw.people")).parse();
333         assert(select(sql) == "select name from people", select(sql));
334     }
335     version(unittest_starwars){
336         unittest {
337             enum pub = starwarsSchema();
338             enum tokens = scan("select name from sw.people");
339             static const stat = Parser([pub], tokens).parse();
340             auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
341             db.execute(cast(Select)stat);
342         }
343         unittest {
344             // check bigint select
345             enum pub = starwarsSchema();
346             enum tokens = scan("select population from sw.planets");
347             static const stat = Parser([pub], tokens).parse();
348             auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
349             db.execute(cast(Select)stat);
350         }
351     }
352     else {
353         version(unittest){
354             pragma(msg, "compile with version 'unittest_starwars' to activate postgresql starwars tests.");
355         }
356     }
357 
358     auto xxx(string sql, Variant[string] parameters){
359         auto names = sort(parameters.keys);
360         Variant[] pgparam;
361         foreach(name; names){
362             pgparam ~= parameters[name];
363             sql = sql.replace("$"~name, "$"~(pgparam.length).to!string); // they are starting from $1, and not from $0
364         }
365         return tuple(sql, pgparam);
366     }
367     unittest {
368         auto r = xxx("select * from planets where name=$name", ["name": Variant("Tatooine")]);
369     }
370 }
371