1 
2 module mars.pgsql;
3 
4 import std.algorithm;
5 import std.conv;
6 import std.format;
7 import std.string;
8 import std.range;
9 import std.typecons;
10 import std.variant;
11 
12 import mars.defs;
13 import mars.msg : AuthoriseError, InsertError, DeleteError, RequestState;
14 version(unittest) import mars.starwars;
15 
16 import ddb.postgres;
17 import ddb.db;
18 import vibe.core.log;
19 
20 string insertIntoReturningParameter(const(Table) table)
21 {
22     int i = 1;
23     return "insert into %s values (%s) returning *".format(
24         table.name,
25         table.columns.map!( (c) => c.type == Type.serial? "default" : "$" ~ (i++).to!string).join(", "));
26 }
27 unittest {
28     auto sql = Table("bar", [Col("foo", Type.text, false), Col("baz", Type.text, false)],[],[]).insertIntoReturningParameter();
29     assert( sql == "insert into bar values ($1, $2) returning *", sql );
30     auto sql2 = Table("bar", [Col("w_id", Type.serial), Col("w", Type.text)], [0], []).insertIntoReturningParameter();
31     assert( sql2 == "insert into bar values (default, $1) returning *", sql2);
32 }
33 
34 string deleteFromParameter(const(Table) table)
35 {
36     return "delete from %s where %s".format(
37             table.name, 
38             zip(iota(0, table.pkCols.length), table.pkCols)
39                 .map!( (t) => t[1].name ~ " = $" ~ (t[0]+1).to!string)
40                 .join(" and "));
41 }
42 unittest {
43     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0, 1], []).deleteFromParameter();
44     assert( sql == "delete from bar where foo = $1 and bar = $2", sql);
45 }
46 
47 string updateFromParameters(const(Table) table)
48 {
49     immutable(Col)[] whereCols = table.pkCols.length >0? table.pkCols : table.columns;
50     int dollarIndex =1;
51     return "update %s set %s where %s".format(
52         table.name,
53         table.columns.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(", "),
54         whereCols.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(" and "));
55 }
56 unittest {
57     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0], []).updateFromParameters();
58     assert( sql == "update bar set foo = $1, bar = $2, baz = $3 where foo = $4", sql );
59 }
60 
61 struct DatabaseService {
62     string host;
63     ushort port;
64     string database;
65  
66     /**
67      * Returns: an instance of `Database` of null if can't connect or authenticate. Errors details in 'err' */
68     Database connect(string user, string password, ref AuthoriseError err) in {
69         assert(user && password);
70     } body {
71         Database db;
72         try {
73             db = new Database(host, database, user, password);
74             err = AuthoriseError.authorised;
75         }
76         catch(ServerErrorException e){
77             switch(e.code){
78                 case "28000": // role "user" does not exist
79                     logInfo("PostgreSQL role does not exist");
80                     err = AuthoriseError.wrongUsernameOrPassword;
81                     break;
82                 case "28P01": // password authentication failed for user "user"
83                     logInfo("PostgreSQL password authentication failed for user");
84                     err = AuthoriseError.wrongUsernameOrPassword;
85                     break;
86                 default:
87                     logWarn("S --- C | Unhandled PostgreSQL server error during connection!");
88                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
89                     err = AuthoriseError.unknownError;
90             }
91         }
92         catch(Exception e){
93             logWarn("S --- C | exception connecting to the PostgreSQL!");
94             logWarn("S --- C | %s", e);
95             err = AuthoriseError.unknownError;
96         }
97         assert( err != AuthoriseError.assertCheck);
98         return db;
99     }
100 }
101 
102 class Database
103 {
104     private this(string host, string database, string user, string password){
105         if( db is null ){
106             db = new PostgresDB(["host": host, "database": database, "user": user, "password": password]);
107         }
108         conn = db.lockConnection();
109     }
110 
111     void execute(const Select select)
112     {
113         string s = `select %s from %s`.format(select.cols[0].name, select.tables[0].name);
114         auto q = conn.executeQuery(s); 
115     }
116 
117     void executeUnsafe(string sql){
118         auto q = conn.executeQuery(sql);
119         foreach(v; q){
120             import std.stdio; writeln("-->", v);
121         }
122     }
123     T executeScalarUnsafe(T)(string sql){
124         return conn.executeScalar!T(sql);
125     }
126 
127     // usato da sync per la sottoscrizione di query complesse
128     auto executeQueryUnsafe(string sql){
129         return conn.executeQuery(sql);
130     }
131 
132     // usato da sync per la sottoscrizione di query complesse, con parametri
133     auto executeQueryUnsafe(string sql, Variant[string] parameters){
134         // ... sort param names, transform names into a sequence of $1, $2
135         auto pgargs = xxx(sql, parameters);
136         // ... prepare the statement
137         auto cmd = new PGCommand(conn, pgargs[0]);
138         foreach(j, param; pgargs[1]){
139             // ... try to guess the PGType from the Variant typeinfo ...
140             auto pgType = toPGType(param.type);
141             switch(pgType) with (PGType){
142                 case TEXT:
143                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!string;
144                     break;
145                 case INT2:
146                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!short;
147                     break;
148                 case INT4:
149                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!int;
150                     break;
151                 default:
152                     assert(false, pgType.to!string);
153             }
154         }
155         return cmd.executeQuery();
156     }
157     version(unittest_starwars){ unittest {
158         auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
159         auto recordSet = db.executeQueryUnsafe("select * from planets where name = $name", ["name": Variant("Tatooine")]);
160         scope(exit) recordSet.close();
161         assert(recordSet.front[1].get!long == 120_000);
162     }}
163 
164     auto executeQueryUnsafe(Row)(string sql){
165         return conn.executeQuery!Row(sql);
166     }
167 
168     auto executeInsert(immutable(Table) table, Row, )(Row record, ref InsertError err){
169         enum sql = insertIntoReturningParameter(table);
170         auto cmd = new PGCommand(conn, sql);
171         addParameters!(table, Row, true)(cmd, record); // skip serial parameters
172         Row result;
173         try {
174             auto querySet = cmd.executeQuery!Row();
175             scope(exit) querySet.close();
176             result = querySet.front;
177             err = InsertError.inserted;
178         }
179         catch(ServerErrorException e){
180             switch(e.code){
181                 case "23505": //  duplicate key value violates unique constraint "<constraintname>" (for example in primary keys)
182                     err = InsertError.duplicateKeyViolations;
183                     break;
184                 default:
185                     logWarn("S --- C | Unhandled PostgreSQL server error during insertion!");
186                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
187                     err = InsertError.unknownError;
188             }
189         }
190         return result;
191     }
192 
193     void executeDelete(immutable(Table) table, Pk)(Pk pk, ref DeleteError err){
194         enum sql = deleteFromParameter(table);
195         auto cmd = new PGCommand(conn, sql);
196 
197         addParameters!table(cmd, pk);
198         try {
199             cmd.executeNonQuery();
200             err = DeleteError.deleted;
201         }
202         catch(ServerErrorException e){
203             switch(e.code){
204                 default:
205                     logWarn("S --- C | Unhandled PostgreSQL server error during deletion!");
206                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
207                     err = DeleteError.unknownError;
208             }
209         }
210     }
211 
212     void executeUpdate(immutable(Table) table, Pk, Row)(Pk pk, Row record, ref RequestState state){
213         enum sql = updateFromParameters(table);
214         auto cmd = new PGCommand(conn, sql);
215         addParameters!(table)(cmd, record);
216         short i = record.tupleof.length +1;
217         addParameters!table(cmd, pk, i);
218         try {
219             cmd.executeNonQuery();
220         }
221         catch(ServerErrorException e){
222             switch(e.code){
223                 case "23503":
224                     logInfo("S --- C | PostgreSQL can't update the primary key as still referenced (maybe add an update cascade?).");
225                     state = RequestState.rejectedAsForeignKeyViolation;
226                     break;
227                 default:
228                     logWarn("S --- C | Unhandled PostgreSQL server error during update!");
229                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
230                     state = RequestState.rejectedAsPGSqlError;
231             }
232         }
233     }
234 
235     //private {
236         private PostgresDB db;
237         public PGConnection conn;
238     //}
239 }
240 
241 
242 private {
243     import mars.lexer;
244     import mars.sqldb;
245 
246     PGType toPGType(TypeInfo t){
247         if(t == typeid(bool)) return PGType.BOOLEAN;
248         if(t == typeid(int)) return PGType.INT4;
249         if(t == typeid(short)) return PGType.INT2;
250         if(t == typeid(string)) return PGType.TEXT;
251         if(t == typeid(float)) return PGType.FLOAT4;
252         if(t == typeid(double)) return PGType.FLOAT8;
253         if(t == typeid(ubyte[])) return PGType.BYTEA;
254         assert(false, t.to!string);
255     }
256 
257     PGType toPGType(Type t){
258         final switch(t) with(Type) {
259             case boolean: return PGType.BOOLEAN;
260             case integer: return PGType.INT4; // XXX check
261             case bigint: return PGType.INT8;
262             case smallint: return PGType.INT2; // XXX check 
263             case text: return PGType.TEXT;
264             case real_: return PGType.FLOAT4;
265             case doublePrecision: return PGType.FLOAT8;
266             case bytea: return PGType.BYTEA;
267             case smallserial: return PGType.INT2; // XXX check
268             case serial: return PGType.INT4; // there's not really a serial type in postgres
269 
270             case unknown:
271             case date:
272             case varchar: // varchar(n), tbd as column
273                 assert(false, t.to!string); // not implemented right now, catch at CT
274         }
275     }
276 
277     void addParameters(immutable(Table) table, Struct, bool noSerials = false, short tupleofIndex =0)(PGCommand cmd, Struct s, short paramIndex =1){
278         static if( is(Struct : asStruct!table) || Struct.tupleof.length == asStruct!(table).tupleof.length )
279         {
280             auto type =  table.columns[tupleofIndex].type;
281             static if( noSerials ) auto mustAdd = type != Type.serial && type != Type.smallserial;
282             else bool mustAdd = true;
283             if( mustAdd ) cmd.parameters.add(paramIndex, table.columns[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
284         }
285         else static if( is(Struct : asPkStruct!table) || Struct.tupleof.length == asPkStruct!(table).tupleof.length )
286         {
287             auto type =  table.columns[tupleofIndex].type;
288             static if( noSerials ) auto mustAdd = type != Type.serial && type != Type.smallserial;
289             else bool mustAdd = true;
290             if( mustAdd ) cmd.parameters.add(paramIndex, table.pkCols[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
291         }
292         else static assert(false);
293 
294         static if( s.tupleof.length > tupleofIndex+1 ) addParameters!(table, Struct, noSerials, tupleofIndex +1)(cmd, s, ++paramIndex);
295     }
296 
297 
298     string select(const(Select) stat){
299         return `select %s from %s`.format(
300             stat.cols.map!((c) => c.name).join(", "),
301             stat.tables.map!( (t) => t.name ).join(", "), /// XXX ho bisogno del nome dello schema QUA... refactory necessario
302             );
303     }
304     unittest {
305         auto s = starwarsSchema();
306         const sql = cast(Select)Parser([s], scan("select name from sw.people")).parse();
307         assert(select(sql) == "select name from people", select(sql));
308     }
309     version(unittest_starwars){
310         unittest {
311             enum pub = starwarsSchema();
312             enum tokens = scan("select name from sw.people");
313             static const stat = Parser([pub], tokens).parse();
314             auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
315             db.execute(cast(Select)stat);
316         }
317         unittest {
318             // check bigint select
319             enum pub = starwarsSchema();
320             enum tokens = scan("select population from sw.planets");
321             static const stat = Parser([pub], tokens).parse();
322             auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
323             db.execute(cast(Select)stat);
324         }
325     }
326     else {
327         version(unittest){
328             pragma(msg, "compile with version 'unittest_starwars' to activate postgresql starwars tests.");
329         }
330     }
331 
332     auto xxx(string sql, Variant[string] parameters){
333         auto names = sort(parameters.keys);
334         Variant[] pgparam;
335         foreach(name; names){
336             pgparam ~= parameters[name];
337             sql = sql.replace("$"~name, "$"~(pgparam.length).to!string); // they are starting from $1, and not from $0
338         }
339         return tuple(sql, pgparam);
340     }
341     unittest {
342         auto r = xxx("select * from planets where name=$name", ["name": Variant("Tatooine")]);
343     }
344 }
345