1 2 module mars.pgsql; 3 4 import std.algorithm; 5 import std.conv; 6 import std.format; 7 import std.string; 8 import std.range; 9 10 import mars.defs; 11 import mars.msg : AuthoriseError, InsertError, DeleteError; 12 13 import ddb.postgres; 14 import ddb.db; 15 import vibe.core.log; 16 17 string insertIntoReturningParameter(const(Table) table) 18 { 19 return "insert into %s values (%s) returning *" 20 .format(table.name, iota(0, table.columns.length).map!( (c) => "$" ~ (c+1).to!string).join(", ")); 21 } 22 unittest { 23 auto sql = Table("bar", [Col("foo", Type.text, false), Col("baz", Type.text, false)],[],[]).insertIntoReturningParameter(); 24 assert( sql == "insert into bar values ($1, $2) returning *", sql ); 25 } 26 27 string deleteFromParameter(const(Table) table) 28 { 29 return "delete from %s where %s".format( 30 table.name, 31 zip(iota(0, table.pkCols.length), table.pkCols) 32 .map!( (t) => t[1].name ~ " = $" ~ (t[0]+1).to!string) 33 .join(" and ")); 34 } 35 unittest { 36 auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0, 1], []).deleteFromParameter(); 37 assert( sql == "delete from bar where foo = $1 and bar = $2", sql); 38 } 39 40 string updateFromParameters(const(Table) table) 41 { 42 immutable(Col)[] whereCols = table.pkCols.length >0? table.pkCols : table.columns; 43 int dollarIndex =1; 44 return "update %s set %s where %s".format( 45 table.name, 46 table.columns.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(", "), 47 whereCols.map!( (t) => t.name ~ " = $" ~ (dollarIndex++).to!string).join(" and ")); 48 } 49 unittest { 50 auto sql = Table("bar", [Col("foo", Type.text, false), Col("bar", Type.text, false), Col("baz", Type.text, false)], [0], []).updateFromParameters(); 51 assert( sql == "update bar set foo = $1, bar = $2, baz = $3 where foo = $4", sql ); 52 } 53 54 struct DatabaseService { 55 string host; 56 ushort port; 57 string database; 58 59 /** 60 * Returns: an instance of `Database` of null if can't connect or authenticate. Errors details in 'err' */ 61 Database connect(string user, string password, ref AuthoriseError err) in { 62 assert(user && password); 63 } body { 64 Database db; 65 try { 66 db = new Database(host, database, user, password); 67 err = AuthoriseError.authorised; 68 } 69 catch(ServerErrorException e){ 70 switch(e.code){ 71 case "28P01": // password authentication failed for user "user" 72 logInfo("PostgreSQL password authentication failed for user"); 73 err = AuthoriseError.wrongUsernameOrPassword; 74 break; 75 default: 76 logWarn("S -- C | Unhandled PostgreSQL server error during connection!"); 77 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 78 err = AuthoriseError.unknownError; 79 } 80 } 81 catch(Exception e){ 82 logWarn("S --- C | exception connecting to the PostgreSQL!"); 83 logWarn("S --- C | %s", e); 84 err = AuthoriseError.unknownError; 85 } 86 assert( err != AuthoriseError.assertCheck); 87 return db; 88 } 89 } 90 91 class Database 92 { 93 private this(string host, string database, string user, string password){ 94 if( db is null ){ 95 db = new PostgresDB(["host": host, "database": database, "user": user, "password": password]); 96 } 97 conn = db.lockConnection(); 98 } 99 100 void execute(const Select select) 101 { 102 string s = `select %s from %s`.format(select.cols[0].name, select.tables[0].name); 103 auto q = conn.executeQuery(s); 104 } 105 106 void executeUnsafe(string sql){ 107 auto q = conn.executeQuery(sql); 108 foreach(v; q){ 109 import std.stdio; writeln("-->", v); 110 } 111 } 112 T executeScalarUnsafe(T)(string sql){ 113 return conn.executeScalar!T(sql); 114 } 115 auto executeQueryUnsafe(string sql){ 116 return conn.executeQuery(sql); 117 } 118 auto executeQueryUnsafe(Row)(string sql){ 119 return conn.executeQuery!Row(sql); 120 } 121 122 auto executeInsert(immutable(Table) table, Row, )(Row record, ref InsertError err){ 123 enum sql = insertIntoReturningParameter(table); 124 auto cmd = new PGCommand(conn, sql); 125 126 addParameters!table(cmd, record); 127 Row result; 128 try { 129 auto querySet = cmd.executeQuery!Row(); 130 scope(exit) querySet.close(); 131 result = querySet.front; 132 err = InsertError.inserted; 133 } 134 catch(ServerErrorException e){ 135 switch(e.code){ 136 case "23505": // duplicate key value violates unique constraint "<constraintname>" (for example in primary keys) 137 err = InsertError.duplicateKeyViolations; 138 break; 139 default: 140 logWarn("S -- C | Unhandled PostgreSQL server error during insertion!"); 141 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 142 err = InsertError.unknownError; 143 } 144 } 145 return result; 146 } 147 148 void executeDelete(immutable(Table) table, Pk)(Pk pk, ref DeleteError err){ 149 enum sql = deleteFromParameter(table); 150 auto cmd = new PGCommand(conn, sql); 151 152 addParameters!table(cmd, pk); 153 try { 154 cmd.executeNonQuery(); 155 err = DeleteError.deleted; 156 } 157 catch(ServerErrorException e){ 158 switch(e.code){ 159 default: 160 logWarn("S -- C | Unhandled PostgreSQL server error during deletion!"); 161 logInfo("S --- C | PostgreSQL server error: %s", e.toString); 162 err = DeleteError.unknownError; 163 } 164 } 165 } 166 167 void executeUpdate(immutable(Table) table, Pk, Row)(Pk pk, Row record){ 168 enum sql = updateFromParameters(table); 169 auto cmd = new PGCommand(conn, sql); 170 addParameters!(table)(cmd, record); 171 /+static if( record.tupleof.length >= 1 ){ cmd.parameters.add(i++, table.columns[0].type.toPGType).value = record.tupleof[0]; } 172 static if( record.tupleof.length >= 2 ){ cmd.parameters.add(i++, table.columns[1].type.toPGType).value = record.tupleof[1]; } 173 static if( record.tupleof.length >= 3 ){ cmd.parameters.add(i++, table.columns[2].type.toPGType).value = record.tupleof[2]; } 174 static if( record.tupleof.length >= 4 ){ cmd.parameters.add(i++, table.columns[3].type.toPGType).value = record.tupleof[3]; } 175 static if( record.tupleof.length >= 5 ){ cmd.parameters.add(i++, table.columns[4].type.toPGType).value = record.tupleof[4]; } 176 static if( record.tupleof.length >= 6 ){ cmd.parameters.add(i++, table.columns[5].type.toPGType).value = record.tupleof[5]; } 177 static if( record.tupleof.length >= 7 ){ cmd.parameters.add(i++, table.columns[6].type.toPGType).value = record.tupleof[6]; } 178 static if( record.tupleof.length >= 8 ){ cmd.parameters.add(i++, table.columns[7].type.toPGType).value = record.tupleof[7]; } 179 static if( record.tupleof.length >= 9 ){ cmd.parameters.add(i++, table.columns[8].type.toPGType).value = record.tupleof[8]; } 180 static if( record.tupleof.length >= 10 ) static assert(false, record.tupleof.length);+/ 181 short i = record.tupleof.length +1; 182 addParameters!table(cmd, pk, i); 183 /+static if( pk.tupleof.length >= 1 ){ cmd.parameters.add(i++, table.pkCols[0].type.toPGType).value = pk.tupleof[0]; } 184 static if( pk.tupleof.length >= 2 ){ cmd.parameters.add(i++, table.pkCols[1].type.toPGType).value = pk.tupleof[1]; } 185 static if( pk.tupleof.length >= 3 ){ cmd.parameters.add(i++, table.pkCols[2].type.toPGType).value = pk.tupleof[2]; } 186 static if( pk.tupleof.length >= 4 ){ cmd.parameters.add(i++, table.pkCols[3].type.toPGType).value = pk.tupleof[3]; } 187 static if( pk.tupleof.length >= 5 ){ cmd.parameters.add(i++, table.pkCols[4].type.toPGType).value = pk.tupleof[4]; } 188 static if( pk.tupleof.length >= 6 ){ cmd.parameters.add(i++, table.pkCols[5].type.toPGType).value = pk.tupleof[5]; } 189 static if( pk.tupleof.length >= 7 ){ cmd.parameters.add(i++, table.pkCols[6].type.toPGType).value = pk.tupleof[6]; } 190 static if( pk.tupleof.length >= 8 ){ cmd.parameters.add(i++, table.pkCols[7].type.toPGType).value = pk.tupleof[7]; } 191 static if( pk.tupleof.length >= 9 ){ cmd.parameters.add(i++, table.pkCols[8].type.toPGType).value = pk.tupleof[8]; } 192 static if( pk.tupleof.length >= 10){ static assert(false, pk.tupleof.length); }+/ 193 cmd.executeNonQuery(); 194 } 195 196 private { 197 PostgresDB db; 198 PGConnection conn; 199 } 200 } 201 202 203 private { 204 import mars.lexer; 205 import mars.sqldb; 206 207 208 PGType toPGType(Type t){ 209 final switch(t) with(Type) { 210 case boolean: return PGType.BOOLEAN; 211 case integer: return PGType.INT4; // XXX check 212 case bigint: return PGType.INT8; 213 case smallint: return PGType.INT2; // XXX check 214 case text: return PGType.TEXT; 215 case real_: return PGType.FLOAT4; 216 case doublePrecision: return PGType.FLOAT8; 217 case bytea: return PGType.BYTEA; 218 case smallserial: return PGType.INT2; // XXX check 219 220 case unknown: 221 case date: 222 case serial: 223 case varchar: // varchar(n), tbd as column 224 assert(false, t.to!string); // not implemented right now, catch at CT 225 } 226 } 227 228 void addParameters(immutable(Table) table, Struct, short tupleofIndex =0)(PGCommand cmd, Struct s, short paramIndex =1){ 229 static if( is(Struct : asStruct!table) || Struct.tupleof.length == asStruct!(table).tupleof.length ){ 230 cmd.parameters.add(paramIndex, table.columns[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex]; 231 } 232 else static if( is(Struct : asPkStruct!table) || Struct.tupleof.length == asPkStruct!(table).tupleof.length ){ 233 cmd.parameters.add(paramIndex, table.pkCols[tupleofIndex].type.toPGType).value = s.tupleof[tupleofIndex]; 234 } 235 else static assert(false); 236 237 static if( s.tupleof.length > tupleofIndex+1 ) addParameters!(table, Struct, tupleofIndex +1)(cmd, s, ++paramIndex); 238 } 239 240 version(unittest){ 241 /+auto starwarSchema() pure { 242 return immutable(Schema)("sw", [ 243 immutable(Table)("people", [Col("name", Type.text), Col("gender", Type.text)], [0], []), 244 immutable(Table)("species", [Col("name", Type.text)], [0], []), 245 ]); 246 }+/ 247 import mars.starwars; 248 } 249 string select(const(Select) stat){ 250 return `select %s from %s`.format( 251 stat.cols.map!((c) => c.name).join(", "), 252 stat.tables.map!( (t) => t.name ).join(", "), /// XXX ho bisogno del nome dello schema QUA... refactory necessario 253 ); 254 } 255 unittest { 256 auto s = starwarsSchema(); 257 const sql = cast(Select)Parser([s], scan("select name from sw.people")).parse(); 258 assert(select(sql) == "select name from people", select(sql)); 259 } 260 261 unittest { 262 enum pub = starwarsSchema(); 263 enum tokens = scan("select name from sw.people"); 264 static const stat = Parser([pub], tokens).parse(); 265 auto db = new Database("127.0.0.1", "starwars", "jedi", "force"); 266 db.execute(cast(Select)stat); 267 } 268 unittest { 269 // check bigint select 270 enum pub = starwarsSchema(); 271 enum tokens = scan("select population from sw.planets"); 272 static const stat = Parser([pub], tokens).parse(); 273 auto db = new Database("127.0.0.1", "starwars", "jedi", "force"); 274 db.execute(cast(Select)stat); 275 } 276 } 277