1 2 module mars.pgsql; 3 4 import std.algorithm; 5 import std.conv; 6 import std.format; 7 import std.string; 8 import std.range; 9 import std.typecons; 10 import std.variant; 11 12 import mars.defs; 13 import mars.msg : AuthoriseError, InsertError, DeleteError, RequestState; 14 version(unittest) import mars.starwars; 15 16 import ddb.postgres; 17 import ddb.db; 18 import vibe.core.log; 19 20 string insertIntoReturningParameter(const(Table) table) 21 { 22 int i = 1; 23 return "insert into %s values (%s) returning *".format( 24 table.name, 25 table.columns.map!( (c) => c.type == Type.serial? "default" : "$" ~ (i++).to!string).join(", ")); 26 } 27 unittest { 28 auto sql = Table("bar", [Col("foo", Type.text, false), Col("baz", Type.text, false)],[],[]).insertIntoReturningParameter(); 29 assert( sql == "insert into bar values ($1, $2) returning *", sql ); 30 auto sql2 = Table("bar", [Col("w_id", Type.serial), Col("w", Type.text)], [0], []).insertIntoReturningParameter(); 31 assert( sql2 == "insert into bar values (default, $1) returning *", sql2); 32 } 33 34 string deleteFromParameter(const(Table) table) 35 { 36 return "delete from %s where %s".format( 37 table.name, 38 zip(iota(0, table.pkCols.length), table.pkCols) 39 .map!( (t) => t[1].name ~ " = $" ~ (t[0]+1).to!string) 40 .join(" and ")); 41 } 42 unittest { 43 auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0, 1], []).deleteFromParameter(); 44 assert( sql == "delete from bar where foo = $1 and bar = $2", sql); 45 } 46 47 string updateFromParameters(const(Table) table) 48 { 49 immutable(Col)[] whereCols = table.pkCols.length >0? table.pkCols : table.columns; 50 int dollarIndex =1; 51 return "update %s set %s where %s".format( 52 table.name, 53 table.columns.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(", "), 54 whereCols.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(" and ")); 55 } 56 unittest { 57 auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0], []).updateFromParameters(); 58 assert( sql == "update bar set foo = $1, bar = $2, baz = $3 where foo = $4", sql ); 59 } 60 61 struct DatabaseService { 62 string host; 63 ushort port; 64 string database; 65 66 /** 67 * Returns: an instance of `Database` of null if can't connect or authenticate. Errors details in 'err' */ 68 Database connect(string user, string password, ref AuthoriseError err) in { 69 assert(user && password); 70 } body { 71 Database db; 72 try { 73 db = new Database(host, database, user, password); 74 err = AuthoriseError.authorised; 75 } 76 catch(ServerErrorException e){ 77 switch(e.code){ 78 case "28000": // role "user" does not exist 79 logInfo("PostgreSQL role does not exist"); 80 err = AuthoriseError.wrongUsernameOrPassword; 81 break; 82 case "28P01": // password authentication failed for user "user" 83 logInfo("PostgreSQL password authentication failed for user"); 84 err = AuthoriseError.wrongUsernameOrPassword; 85 break; 86 default: 87 logWarn("S --- C | Unhandled PostgreSQL server error during connection!"); 88 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 89 err = AuthoriseError.unknownError; 90 } 91 } 92 catch(Exception e){ 93 logWarn("S --- C | exception connecting to the PostgreSQL!"); 94 logWarn("S --- C | %s", e); 95 err = AuthoriseError.unknownError; 96 } 97 assert( err != AuthoriseError.assertCheck); 98 return db; 99 } 100 } 101 102 class Database 103 { 104 private this(string host, string database, string user, string password){ 105 if( db is null ){ 106 db = new PostgresDB(["host": host, "database": database, "user": user, "password": password]); 107 } 108 conn = db.lockConnection(); 109 } 110 111 void execute(const Select select) 112 { 113 string s = `select %s from %s`.format(select.cols[0].name, select.tables[0].name); 114 auto q = conn.executeQuery(s); 115 } 116 117 void executeUnsafe(string sql){ 118 auto q = conn.executeQuery(sql); 119 foreach(v; q){ 120 import std.stdio; writeln("-->", v); 121 } 122 } 123 T executeScalarUnsafe(T)(string sql){ 124 return conn.executeScalar!T(sql); 125 } 126 127 // usato da sync per la sottoscrizione di query complesse 128 auto executeQueryUnsafe(string sql){ 129 return conn.executeQuery(sql); 130 } 131 132 // usato da sync per la sottoscrizione di query complesse, con parametri 133 auto executeQueryUnsafe(string sql, Variant[string] parameters){ 134 // ... sort param names, transform names into a sequence of $1, $2 135 auto pgargs = xxx(sql, parameters); 136 // ... prepare the statement 137 auto cmd = new PGCommand(conn, pgargs[0]); 138 foreach(j, param; pgargs[1]){ 139 // ... try to guess the PGType from the Variant typeinfo ... 140 auto pgType = toPGType(param.type); 141 switch(pgType) with (PGType){ 142 case TEXT: 143 cmd.parameters.add((j+1).to!short, pgType).value = param.get!string; 144 break; 145 case INT2: 146 cmd.parameters.add((j+1).to!short, pgType).value = param.get!short; 147 break; 148 case INT4: 149 cmd.parameters.add((j+1).to!short, pgType).value = param.get!int; 150 break; 151 default: 152 assert(false, pgType.to!string); 153 } 154 } 155 return cmd.executeQuery(); 156 } 157 version(unittest_starwars){ unittest { 158 auto db = new Database("127.0.0.1", "starwars", "jedi", "force"); 159 auto recordSet = db.executeQueryUnsafe("select * from planets where name = $name", ["name": Variant("Tatooine")]); 160 scope(exit) recordSet.close(); 161 assert(recordSet.front[1].get!long == 120_000); 162 }} 163 164 auto executeQueryUnsafe(Row)(string sql){ 165 return conn.executeQuery!Row(sql); 166 } 167 168 auto executeInsert(immutable(Table) table, Row, )(Row record, ref InsertError err){ 169 enum sql = insertIntoReturningParameter(table); 170 auto cmd = new PGCommand(conn, sql); 171 addParameters!(table, Row, true)(cmd, record); // skip serial parameters 172 Row result; 173 try { 174 auto querySet = cmd.executeQuery!Row(); 175 scope(exit) querySet.close(); 176 result = querySet.front; 177 err = InsertError.inserted; 178 } 179 catch(ServerErrorException e){ 180 switch(e.code){ 181 case "23505": // duplicate key value violates unique constraint "<constraintname>" (for example in primary keys) 182 err = InsertError.duplicateKeyViolations; 183 break; 184 default: 185 logWarn("S --- C | Unhandled PostgreSQL server error during insertion!"); 186 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 187 err = InsertError.unknownError; 188 } 189 } 190 return result; 191 } 192 193 void executeDelete(immutable(Table) table, Pk)(Pk pk, ref DeleteError err){ 194 enum sql = deleteFromParameter(table); 195 auto cmd = new PGCommand(conn, sql); 196 197 addParameters!table(cmd, pk); 198 try { 199 cmd.executeNonQuery(); 200 err = DeleteError.deleted; 201 } 202 catch(ServerErrorException e){ 203 switch(e.code){ 204 default: 205 logWarn("S --- C | Unhandled PostgreSQL server error during deletion!"); 206 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 207 err = DeleteError.unknownError; 208 } 209 } 210 } 211 212 void executeUpdate(immutable(Table) table, Pk, Row)(Pk pk, Row record, ref RequestState state){ 213 enum sql = updateFromParameters(table); 214 auto cmd = new PGCommand(conn, sql); 215 addParameters!(table)(cmd, record); 216 short i = record.tupleof.length +1; 217 addParameters!table(cmd, pk, i); 218 try { 219 cmd.executeNonQuery(); 220 } 221 catch(ServerErrorException e){ 222 switch(e.code){ 223 case "23503": 224 logInfo("S --- C | PostgreSQL can't update the primary key as still referenced (maybe add an update cascade?)."); 225 state = RequestState.rejectedAsForeignKeyViolation; 226 break; 227 default: 228 logWarn("S --- C | Unhandled PostgreSQL server error during update!"); 229 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 230 state = RequestState.rejectedAsPGSqlError; 231 } 232 } 233 } 234 235 //private { 236 private PostgresDB db; 237 public PGConnection conn; 238 //} 239 } 240 241 242 private { 243 import mars.lexer; 244 import mars.sqldb; 245 246 PGType toPGType(TypeInfo t){ 247 if(t == typeid(bool)) return PGType.BOOLEAN; 248 if(t == typeid(int)) return PGType.INT4; 249 if(t == typeid(short)) return PGType.INT2; 250 if(t == typeid(string)) return PGType.TEXT; 251 if(t == typeid(float)) return PGType.FLOAT4; 252 if(t == typeid(double)) return PGType.FLOAT8; 253 if(t == typeid(ubyte[])) return PGType.BYTEA; 254 assert(false, t.to!string); 255 } 256 257 PGType toPGType(Type t){ 258 final switch(t) with(Type) { 259 case boolean: return PGType.BOOLEAN; 260 case integer: return PGType.INT4; // XXX check 261 case bigint: return PGType.INT8; 262 case smallint: return PGType.INT2; // XXX check 263 case text: return PGType.TEXT; 264 case real_: return PGType.FLOAT4; 265 case doublePrecision: return PGType.FLOAT8; 266 case bytea: return PGType.BYTEA; 267 case smallserial: return PGType.INT2; // XXX check 268 case serial: return PGType.INT4; // there's not really a serial type in postgres 269 270 case unknown: 271 case date: 272 case varchar: // varchar(n), tbd as column 273 assert(false, t.to!string); // not implemented right now, catch at CT 274 } 275 } 276 277 void addParameters(immutable(Table) table, Struct, bool noSerials = false, short tupleofIndex =0)(PGCommand cmd, Struct s, short paramIndex =1){ 278 static if( is(Struct : asStruct!table) || Struct.tupleof.length == asStruct!(table).tupleof.length ) 279 { 280 auto type = table.columns[tupleofIndex].type; 281 static if( noSerials ) auto mustAdd = type != Type.serial && type != Type.smallserial; 282 else bool mustAdd = true; 283 if( mustAdd ) cmd.parameters.add(paramIndex, table.columns[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex]; 284 } 285 else static if( is(Struct : asPkStruct!table) || Struct.tupleof.length == asPkStruct!(table).tupleof.length ) 286 { 287 auto type = table.columns[tupleofIndex].type; 288 static if( noSerials ) auto mustAdd = type != Type.serial && type != Type.smallserial; 289 else bool mustAdd = true; 290 if( mustAdd ) cmd.parameters.add(paramIndex, table.pkCols[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex]; 291 } 292 else static assert(false); 293 294 static if( s.tupleof.length > tupleofIndex+1 ) addParameters!(table, Struct, noSerials, tupleofIndex +1)(cmd, s, ++paramIndex); 295 } 296 297 298 string select(const(Select) stat){ 299 return `select %s from %s`.format( 300 stat.cols.map!((c) => c.name).join(", "), 301 stat.tables.map!( (t) => t.name ).join(", "), /// XXX ho bisogno del nome dello schema QUA... refactory necessario 302 ); 303 } 304 unittest { 305 auto s = starwarsSchema(); 306 const sql = cast(Select)Parser([s], scan("select name from sw.people")).parse(); 307 assert(select(sql) == "select name from people", select(sql)); 308 } 309 version(unittest_starwars){ 310 unittest { 311 enum pub = starwarsSchema(); 312 enum tokens = scan("select name from sw.people"); 313 static const stat = Parser([pub], tokens).parse(); 314 auto db = new Database("127.0.0.1", "starwars", "jedi", "force"); 315 db.execute(cast(Select)stat); 316 } 317 unittest { 318 // check bigint select 319 enum pub = starwarsSchema(); 320 enum tokens = scan("select population from sw.planets"); 321 static const stat = Parser([pub], tokens).parse(); 322 auto db = new Database("127.0.0.1", "starwars", "jedi", "force"); 323 db.execute(cast(Select)stat); 324 } 325 } 326 else { 327 version(unittest){ 328 pragma(msg, "compile with version 'unittest_starwars' to activate postgresql starwars tests."); 329 } 330 } 331 332 auto xxx(string sql, Variant[string] parameters){ 333 auto names = sort(parameters.keys); 334 Variant[] pgparam; 335 foreach(name; names){ 336 pgparam ~= parameters[name]; 337 sql = sql.replace("$"~name, "$"~(pgparam.length).to!string); // they are starting from $1, and not from $0 338 } 339 return tuple(sql, pgparam); 340 } 341 unittest { 342 auto r = xxx("select * from planets where name=$name", ["name": Variant("Tatooine")]); 343 } 344 } 345