1 
2 module mars.pgsql;
3 
4 import std.algorithm;
5 import std.conv;
6 import std.format;
7 import std.string;
8 import std.range;
9 
10 import mars.defs;
11 import mars.msg : AuthoriseError, InsertError, DeleteError;
12 
13 import ddb.postgres;
14 import ddb.db;
15 import vibe.core.log;
16 
17 string insertIntoReturningParameter(const(Table) table)
18 {
19     return "insert into %s values (%s) returning *"
20         .format(table.name, iota(0, table.columns.length).map!( (c) => "$" ~ (c+1).to!string).join(", "));
21 }
22 unittest {
23     auto sql = Table("bar", [Col("foo", Type.text, false), Col("baz", Type.text, false)],[],[]).insertIntoReturningParameter();
24     assert( sql == "insert into bar values ($1, $2) returning *", sql );
25 }
26 
27 string deleteFromParameter(const(Table) table)
28 {
29     return "delete from %s where %s".format(
30             table.name, 
31             zip(iota(0, table.pkCols.length), table.pkCols)
32                 .map!( (t) => t[1].name ~ " = $" ~ (t[0]+1).to!string)
33                 .join(" and "));
34 }
35 unittest {
36     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0, 1], []).deleteFromParameter();
37     assert( sql == "delete from bar where foo = $1 and bar = $2", sql);
38 }
39 
40 string updateFromParameters(const(Table) table)
41 {
42     immutable(Col)[] whereCols = table.pkCols.length >0? table.pkCols : table.columns;
43     int dollarIndex =1;
44     return "update %s set %s where %s".format(
45         table.name,
46         table.columns.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(", "),
47         whereCols.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(" and "));
48 }
49 unittest {
50     auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0], []).updateFromParameters();
51     assert( sql == "update bar set foo = $1, bar = $2, baz = $3 where foo = $4", sql );
52 }
53 
54 struct DatabaseService {
55     string host;
56     ushort port;
57     string database;
58  
59     /**
60      * Returns: an instance of `Database` of null if can't connect or authenticate. Errors details in 'err' */
61     Database connect(string user, string password, ref AuthoriseError err) in {
62         assert(user && password);
63     } body {
64         Database db;
65         try {
66             db = new Database(host, database, user, password);
67             err = AuthoriseError.authorised;
68         }
69         catch(ServerErrorException e){
70             switch(e.code){
71                 case "28P01": // password authentication failed for user "user"
72                     logInfo("PostgreSQL password authentication failed for user");
73                     err = AuthoriseError.wrongUsernameOrPassword;
74                     break;
75                 default:
76                     logWarn("S -- C | Unhandled PostgreSQL server error during connection!");
77                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
78                     err = AuthoriseError.unknownError;
79             }
80         }
81         catch(Exception e){
82             logWarn("S --- C | exception connecting to the PostgreSQL!");
83             logWarn("S --- C | %s", e);
84             err = AuthoriseError.unknownError;
85         }
86         assert( err != AuthoriseError.assertCheck);
87         return db;
88     }
89 }
90 
91 class Database
92 {
93     private this(string host, string database, string user, string password){
94         if( db is null ){
95             db = new PostgresDB(["host": host, "database": database, "user": user, "password": password]);
96         }
97         conn = db.lockConnection();
98     }
99 
100     void execute(const Select select)
101     {
102         string s = `select %s from %s`.format(select.cols[0].name, select.tables[0].name);
103         auto q = conn.executeQuery(s); 
104     }
105 
106     void executeUnsafe(string sql){
107         auto q = conn.executeQuery(sql);
108         foreach(v; q){
109             import std.stdio; writeln("-->", v);
110         }
111     }
112     T executeScalarUnsafe(T)(string sql){
113         return conn.executeScalar!T(sql);
114     }
115     auto executeQueryUnsafe(string sql){
116         return conn.executeQuery(sql);
117     }
118     auto executeQueryUnsafe(Row)(string sql){
119         return conn.executeQuery!Row(sql);
120     }
121     
122     auto executeInsert(immutable(Table) table, Row, )(Row record, ref InsertError err){
123         enum sql = insertIntoReturningParameter(table);
124         auto cmd = new PGCommand(conn, sql);
125         
126         addParameters!table(cmd, record);
127         Row result;
128         try {
129             auto querySet = cmd.executeQuery!Row();
130             scope(exit) querySet.close();
131             result = querySet.front;
132             err = InsertError.inserted;
133         }
134         catch(ServerErrorException e){
135             switch(e.code){
136                 case "23505": //  duplicate key value violates unique constraint "<constraintname>" (for example in primary keys)
137                     err = InsertError.duplicateKeyViolations;
138                     break;
139                 default:
140                     logWarn("S -- C | Unhandled PostgreSQL server error during insertion!");
141                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
142                     err = InsertError.unknownError;
143             }
144         }
145         return result;
146     }
147 
148     void executeDelete(immutable(Table) table, Pk)(Pk pk, ref DeleteError err){
149         enum sql = deleteFromParameter(table);
150         auto cmd = new PGCommand(conn, sql);
151 
152         addParameters!table(cmd, pk);
153         try {
154             cmd.executeNonQuery();
155             err = DeleteError.deleted;
156         }
157         catch(ServerErrorException e){
158             switch(e.code){
159                 default:
160                     logWarn("S -- C | Unhandled PostgreSQL server error during deletion!");
161                     logInfo("S --- C | PostgreSQL server error: %s", e.toString);
162                     err = DeleteError.unknownError;
163             }
164         }
165     }
166 
167     void executeUpdate(immutable(Table) table, Pk, Row)(Pk pk, Row record){
168         enum sql = updateFromParameters(table);
169         auto cmd = new PGCommand(conn, sql);
170         addParameters!(table)(cmd, record);
171         /+static if( record.tupleof.length >= 1 ){ cmd.parameters.add(i++, table.columns[0].type.toPGType).value = record.tupleof[0]; }
172         static if( record.tupleof.length >= 2 ){ cmd.parameters.add(i++, table.columns[1].type.toPGType).value = record.tupleof[1]; }
173         static if( record.tupleof.length >= 3 ){ cmd.parameters.add(i++, table.columns[2].type.toPGType).value = record.tupleof[2]; }
174         static if( record.tupleof.length >= 4 ){ cmd.parameters.add(i++, table.columns[3].type.toPGType).value = record.tupleof[3]; }
175         static if( record.tupleof.length >= 5 ){ cmd.parameters.add(i++, table.columns[4].type.toPGType).value = record.tupleof[4]; }
176         static if( record.tupleof.length >= 6 ){ cmd.parameters.add(i++, table.columns[5].type.toPGType).value = record.tupleof[5]; }
177         static if( record.tupleof.length >= 7 ){ cmd.parameters.add(i++, table.columns[6].type.toPGType).value = record.tupleof[6]; }
178         static if( record.tupleof.length >= 8 ){ cmd.parameters.add(i++, table.columns[7].type.toPGType).value = record.tupleof[7]; }
179         static if( record.tupleof.length >= 9 ){ cmd.parameters.add(i++, table.columns[8].type.toPGType).value = record.tupleof[8]; }
180         static if( record.tupleof.length >= 10 ) static assert(false, record.tupleof.length);+/
181         short i = record.tupleof.length +1;
182         addParameters!table(cmd, pk, i);
183         /+static if( pk.tupleof.length >= 1 ){ cmd.parameters.add(i++, table.pkCols[0].type.toPGType).value = pk.tupleof[0]; }
184         static if( pk.tupleof.length >= 2 ){ cmd.parameters.add(i++, table.pkCols[1].type.toPGType).value = pk.tupleof[1]; }
185         static if( pk.tupleof.length >= 3 ){ cmd.parameters.add(i++, table.pkCols[2].type.toPGType).value = pk.tupleof[2]; }
186         static if( pk.tupleof.length >= 4 ){ cmd.parameters.add(i++, table.pkCols[3].type.toPGType).value = pk.tupleof[3]; }
187         static if( pk.tupleof.length >= 5 ){ cmd.parameters.add(i++, table.pkCols[4].type.toPGType).value = pk.tupleof[4]; }
188         static if( pk.tupleof.length >= 6 ){ cmd.parameters.add(i++, table.pkCols[5].type.toPGType).value = pk.tupleof[5]; }
189         static if( pk.tupleof.length >= 7 ){ cmd.parameters.add(i++, table.pkCols[6].type.toPGType).value = pk.tupleof[6]; }
190         static if( pk.tupleof.length >= 8 ){ cmd.parameters.add(i++, table.pkCols[7].type.toPGType).value = pk.tupleof[7]; }
191         static if( pk.tupleof.length >= 9 ){ cmd.parameters.add(i++, table.pkCols[8].type.toPGType).value = pk.tupleof[8]; }
192         static if( pk.tupleof.length >= 10){ static assert(false, pk.tupleof.length); }+/
193         cmd.executeNonQuery();
194     }
195 
196     private {
197         PostgresDB db;
198         PGConnection conn;
199     }
200 }
201 
202 
203 private {
204     import mars.lexer;
205     import mars.sqldb;
206 
207 
208     PGType toPGType(Type t){
209         final switch(t) with(Type) {
210             case boolean: return PGType.BOOLEAN;
211             case integer: return PGType.INT4; // XXX check
212             case bigint: return PGType.INT8;
213             case smallint: return PGType.INT2; // XXX check 
214             case text: return PGType.TEXT;
215             case real_: return PGType.FLOAT4;
216             case doublePrecision: return PGType.FLOAT8;
217             case bytea: return PGType.BYTEA;
218             case smallserial: return PGType.INT2; // XXX check
219 
220             case unknown:
221             case date:
222             case serial:
223             case varchar: // varchar(n), tbd as column
224                               assert(false, t.to!string); // not implemented right now, catch at CT
225         }
226     }
227 
228     void addParameters(immutable(Table) table, Struct, short tupleofIndex =0)(PGCommand cmd, Struct s, short paramIndex =1){
229         static if( is(Struct : asStruct!table) || Struct.tupleof.length == asStruct!(table).tupleof.length ){
230             cmd.parameters.add(paramIndex, table.columns[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
231         }
232         else static if( is(Struct : asPkStruct!table) || Struct.tupleof.length == asPkStruct!(table).tupleof.length ){
233             cmd.parameters.add(paramIndex, table.pkCols[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex];
234         }
235         else static assert(false);
236 
237         static if( s.tupleof.length > tupleofIndex+1 ) addParameters!(table, Struct, tupleofIndex +1)(cmd, s, ++paramIndex);
238     }
239 
240     version(unittest){
241         /+auto starwarSchema() pure {
242             return immutable(Schema)("sw", [
243                 immutable(Table)("people", [Col("name", Type.text), Col("gender", Type.text)], [0], []),
244                 immutable(Table)("species", [Col("name", Type.text)], [0], []),
245         ]);
246         }+/
247         import mars.starwars;
248     }
249     string select(const(Select) stat){
250         return `select %s from %s`.format(
251             stat.cols.map!((c) => c.name).join(", "),
252             stat.tables.map!( (t) => t.name ).join(", "), /// XXX ho bisogno del nome dello schema QUA... refactory necessario
253             );
254     }
255     unittest {
256         auto s = starwarsSchema();
257         const sql = cast(Select)Parser([s], scan("select name from sw.people")).parse();
258         assert(select(sql) == "select name from people", select(sql));
259     }
260 
261     unittest {
262         enum pub = starwarsSchema();
263         enum tokens = scan("select name from sw.people");
264         static const stat = Parser([pub], tokens).parse();
265         auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
266         db.execute(cast(Select)stat);
267     }
268     unittest {
269         // check bigint select
270         enum pub = starwarsSchema();
271         enum tokens = scan("select population from sw.planets");
272         static const stat = Parser([pub], tokens).parse();
273         auto db = new Database("127.0.0.1", "starwars", "jedi", "force");
274         db.execute(cast(Select)stat);
275     }
276 }
277