1 
2 module mars.pgsql;
3 
4 import std.algorithm;
5 import std.conv;
6 import std.format;
7 import std.string;
8 import std.range;
9 import std.typecons;
10 import std.variant;
11 
12 import mars.defs;
13 import mars.msg : AuthoriseError, InsertError, DeleteError, RequestState;
14 version(unittest) import mars.starwars;
15 
16 import ddb.postgres;
17 import ddb.db;
18 import vibe.core.log;
19 
20 string insertIntoReturningParameter(const(Table) table)
21 {
22     return "insert into %s values (%s) returning *"
23         .format(table.name, iota(0, table.columns.length).map!( (c) => "$" ~ (c+1).to!string).join(", "));
24 }
25 unittest {
26     auto sql = Table("bar", [Col("foo", Type.text, false), Col("baz", Type.text, false)],[],[]).insertIntoReturningParameter();
27     assert( sql == "insert into bar values ($1, $2) returning *", sql );
28 }
29 
30 string deleteFromParameter(const(Table) table)
31 {
32     return "delete from %s where %s".format(
33             table.name, 
34             zip(iota(0, table.pkCols.length), table.pkCols)
35                 .map!( (t) => t[1].name ~ " = $" ~ (t[0]+1).to!string)
36                 .join(" and "));
37 }
38 unittest {
39     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0, 1], []).deleteFromParameter();
40     assert( sql == "delete from bar where foo = $1 and bar = $2", sql);
41 }
42 
43 string updateFromParameters(const(Table) table)
44 {
45     immutable(Col)[] whereCols = table.pkCols.length >0? table.pkCols : table.columns;
46     int dollarIndex =1;
47     return "update %s set %s where %s".format(
48         table.name,
49         table.columns.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(", "),
50         whereCols.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(" and "));
51 }
52 unittest {
53     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0], []).updateFromParameters();
54     assert( sql == "update bar set foo = $1, bar = $2, baz = $3 where foo = $4", sql );
55 }
56 
57 struct DatabaseService {
58     string host;
59     ushort port;
60     string database;
61  
62     /**
63      * Returns: an instance of `Database` of null if can't connect or authenticate. Errors details in 'err' */
64     Database connect(string user, string password, ref AuthoriseError err) in {
65         assert(user && password);
66     } body {
67         Database db;
68         try {
69             db = new Database(host, database, user, password);
70             err = AuthoriseError.authorised;
71         }
72         catch(ServerErrorException e){
73             switch(e.code){
74                 case "28000": // role "user" does not exist
75                     logInfo("PostgreSQL role does not exist");
76                     err = AuthoriseError.wrongUsernameOrPassword;
77                     break;
78                 case "28P01": // password authentication failed for user "user"
79                     logInfo("PostgreSQL password authentication failed for user");
80                     err = AuthoriseError.wrongUsernameOrPassword;
81                     break;
82                 default:
83                     logWarn("S --- C | Unhandled PostgreSQL server error during connection!");
84                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
85                     err = AuthoriseError.unknownError;
86             }
87         }
88         catch(Exception e){
89             logWarn("S --- C | exception connecting to the PostgreSQL!");
90             logWarn("S --- C | %s", e);
91             err = AuthoriseError.unknownError;
92         }
93         assert( err != AuthoriseError.assertCheck);
94         return db;
95     }
96 }
97 
98 class Database
99 {
100     private this(string host, string database, string user, string password){
101         if( db is null ){
102             db = new PostgresDB(["host": host, "database": database, "user": user, "password": password]);
103         }
104         conn = db.lockConnection();
105     }
106 
107     void execute(const Select select)
108     {
109         string s = `select %s from %s`.format(select.cols[0].name, select.tables[0].name);
110         auto q = conn.executeQuery(s); 
111     }
112 
113     void executeUnsafe(string sql){
114         auto q = conn.executeQuery(sql);
115         foreach(v; q){
116             import std.stdio; writeln("-->", v);
117         }
118     }
119     T executeScalarUnsafe(T)(string sql){
120         return conn.executeScalar!T(sql);
121     }
122 
123     // usato da sync per la sottoscrizione di query complesse
124     auto executeQueryUnsafe(string sql){
125         return conn.executeQuery(sql);
126     }
127 
128     // usato da sync per la sottoscrizione di query complesse, con parametri
129     auto executeQueryUnsafe(string sql, Variant[string] parameters){
130         // ... sort param names, transform names into a sequence of $1, $2
131         auto pgargs = xxx(sql, parameters);
132         // ... prepare the statement
133         auto cmd = new PGCommand(conn, pgargs[0]);
134         foreach(j, param; pgargs[1]){
135             // ... try to guess the PGType from the Variant typeinfo ...
136             auto pgType = toPGType(param.type);
137             switch(pgType) with (PGType){
138                 case TEXT:
139                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!string;
140                     break;
141                 case INT4:
142                     cmd.parameters.add((j+1).to!short, pgType).value = param.get!int;
143                     break;
144                 default:
145                     assert(false, pgType.to!string);
146             }
147         }
148         return cmd.executeQuery();
149     }
150     version(unittest_starwars){ unittest {
151         auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
152         auto recordSet = db.executeQueryUnsafe("select * from planets where name = $name", ["name": Variant("Tatooine")]);
153         scope(exit) recordSet.close();
154         assert(recordSet.front[1].get!long == 120_000);
155     }}
156 
157     auto executeQueryUnsafe(Row)(string sql){
158         return conn.executeQuery!Row(sql);
159     }
160     
161     auto executeInsert(immutable(Table) table, Row, )(Row record, ref InsertError err){
162         enum sql = insertIntoReturningParameter(table);
163         auto cmd = new PGCommand(conn, sql);
164         
165         addParameters!table(cmd, record);
166         Row result;
167         try {
168             auto querySet = cmd.executeQuery!Row();
169             scope(exit) querySet.close();
170             result = querySet.front;
171             err = InsertError.inserted;
172         }
173         catch(ServerErrorException e){
174             switch(e.code){
175                 case "23505": //  duplicate key value violates unique constraint "<constraintname>" (for example in primary keys)
176                     err = InsertError.duplicateKeyViolations;
177                     break;
178                 default:
179                     logWarn("S --- C | Unhandled PostgreSQL server error during insertion!");
180                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
181                     err = InsertError.unknownError;
182             }
183         }
184         return result;
185     }
186 
187     void executeDelete(immutable(Table) table, Pk)(Pk pk, ref DeleteError err){
188         enum sql = deleteFromParameter(table);
189         auto cmd = new PGCommand(conn, sql);
190 
191         addParameters!table(cmd, pk);
192         try {
193             cmd.executeNonQuery();
194             err = DeleteError.deleted;
195         }
196         catch(ServerErrorException e){
197             switch(e.code){
198                 default:
199                     logWarn("S --- C | Unhandled PostgreSQL server error during deletion!");
200                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
201                     err = DeleteError.unknownError;
202             }
203         }
204     }
205 
206     void executeUpdate(immutable(Table) table, Pk, Row)(Pk pk, Row record, ref RequestState state){
207         enum sql = updateFromParameters(table);
208         auto cmd = new PGCommand(conn, sql);
209         addParameters!(table)(cmd, record);
210         short i = record.tupleof.length +1;
211         addParameters!table(cmd, pk, i);
212         try {
213             //logWarn("OKKKKKKK? %s", sql);
214             //logWarn("OKKKKKKK? %s %s", pk, record);
215             cmd.executeNonQuery();
216         }
217         catch(ServerErrorException e){
218             switch(e.code){
219                 default:
220                     logWarn("S --- C | Unhandled PostgreSQL server error during update!");
221                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
222                     state = RequestState.rejectedAsPGSqlError;
223             }
224         }
225     }
226 
227     //private {
228         private PostgresDB db;
229         public PGConnection conn;
230     //}
231 }
232 
233 
234 private {
235     import mars.lexer;
236     import mars.sqldb;
237 
238     PGType toPGType(TypeInfo t){
239         if(t == typeid(bool)) return PGType.BOOLEAN;
240         if(t == typeid(int)) return PGType.INT4;
241         if(t == typeid(short)) return PGType.INT2;
242         if(t == typeid(string)) return PGType.TEXT;
243         if(t == typeid(float)) return PGType.FLOAT4;
244         if(t == typeid(double)) return PGType.FLOAT8;
245         if(t == typeid(ubyte[])) return PGType.BYTEA;
246         assert(false, t.to!string);
247     }
248 
249     PGType toPGType(Type t){
250         final switch(t) with(Type) {
251             case boolean: return PGType.BOOLEAN;
252             case integer: return PGType.INT4; // XXX check
253             case bigint: return PGType.INT8;
254             case smallint: return PGType.INT2; // XXX check 
255             case text: return PGType.TEXT;
256             case real_: return PGType.FLOAT4;
257             case doublePrecision: return PGType.FLOAT8;
258             case bytea: return PGType.BYTEA;
259             case smallserial: return PGType.INT2; // XXX check
260 
261             case unknown:
262             case date:
263             case serial:
264             case varchar: // varchar(n), tbd as column
265                 assert(false, t.to!string); // not implemented right now, catch at CT
266         }
267     }
268 
269     void addParameters(immutable(Table) table, Struct, short tupleofIndex =0)(PGCommand cmd, Struct s, short paramIndex =1){
270         static if( is(Struct : asStruct!table) || Struct.tupleof.length == asStruct!(table).tupleof.length ){
271             cmd.parameters.add(paramIndex, table.columns[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
272         }
273         else static if( is(Struct : asPkStruct!table) || Struct.tupleof.length == asPkStruct!(table).tupleof.length ){
274             cmd.parameters.add(paramIndex, table.pkCols[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
275         }
276         else static assert(false);
277 
278         static if( s.tupleof.length > tupleofIndex+1 ) addParameters!(table, Struct, tupleofIndex +1)(cmd, s, ++paramIndex);
279     }
280 
281 
282     string select(const(Select) stat){
283         return `select %s from %s`.format(
284             stat.cols.map!((c) => c.name).join(", "),
285             stat.tables.map!( (t) => t.name ).join(", "), /// XXX ho bisogno del nome dello schema QUA... refactory necessario
286             );
287     }
288     unittest {
289         auto s = starwarsSchema();
290         const sql = cast(Select)Parser([s], scan("select name from sw.people")).parse();
291         assert(select(sql) == "select name from people", select(sql));
292     }
293     version(unittest_starwars){
294         unittest {
295             enum pub = starwarsSchema();
296             enum tokens = scan("select name from sw.people");
297             static const stat = Parser([pub], tokens).parse();
298             auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
299             db.execute(cast(Select)stat);
300         }
301         unittest {
302             // check bigint select
303             enum pub = starwarsSchema();
304             enum tokens = scan("select population from sw.planets");
305             static const stat = Parser([pub], tokens).parse();
306             auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
307             db.execute(cast(Select)stat);
308         }
309     }
310     else {
311         version(unittest){
312             pragma(msg, "compile with version 'unittest_starwars' to activate postgresql starwars tests.");
313         }
314     }
315 
316     auto xxx(string sql, Variant[string] parameters){
317         auto names = sort(parameters.keys);
318         Variant[] pgparam;
319         foreach(name; names){
320             pgparam ~= parameters[name];
321             sql = sql.replace("$"~name, "$"~(pgparam.length).to!string); // they are starting from $1, and not from $0
322         }
323         return tuple(sql, pgparam);
324     }
325     unittest {
326         auto r = xxx("select * from planets where name=$name", ["name": Variant("Tatooine")]);
327         import std.stdio; writeln(r);
328     }
329 }
330