From 86538b891332936412c1af7e2ecac52a046b4bce Mon Sep 17 00:00:00 2001 From: DanyLE Date: Tue, 31 Jan 2023 13:34:10 +0100 Subject: [PATCH] add sqlite query generator to the core API --- silkmvc/core/sqlite.lua | 246 ++++++++++++++++++++++++++++++++++++++-- test/test_core.lua | 36 ++++++ 2 files changed, 275 insertions(+), 7 deletions(-) diff --git a/silkmvc/core/sqlite.lua b/silkmvc/core/sqlite.lua index c0408c6..c324e69 100644 --- a/silkmvc/core/sqlite.lua +++ b/silkmvc/core/sqlite.lua @@ -17,9 +17,241 @@ sqlite.getdb = function(name) end end --- create class +--- SQL generator +SQLQueryGenerator = Object:extends{} + +function SQLQueryGenerator:initialize() +end + +function SQLQueryGenerator:parse() + local j, w, o, f + j = self:sql_joins() + if self.where then + w = self:sql_where("$and", self.where) + end + f = self:sql_fields() + o = self:sql_order() + return f, w, j, o +end + +function SQLQueryGenerator:sql_select() + local v, f, w, j, o = pcall(SQLQueryGenerator.parse, self) + if not v then + return v, f + end + local segments = {"SELECT"} + if f then + table.insert(segments, f) + else + table.insert(segments, "*") + end + table.insert(segments, "FROM") + table.insert(segments, self.table_name) + if j then + table.insert(segments, j) + end + if w then + table.insert(segments, "WHERE") + table.insert(segments, w) + end + + if o then + table.insert(segments, "ORDER BY") + table.insert(segments, o) + end + + return true, table.concat(segments, " ") +end + +function SQLQueryGenerator:sql_delete() + local v, f, w, j, o = pcall(SQLQueryGenerator.parse, self) + if not v then + return v, f + end + local segments = {"DELETE"} + table.insert(segments, "FROM") + table.insert(segments, self.table_name) + if j then + table.insert(segments, j) + end + if w then + table.insert(segments, "WHERE") + table.insert(segments, w) + end + return true, table.concat(segments, " ") +end + +function SQLQueryGenerator:error(msg, ...) + local emsg = string.format(msg or "ERROR", ...) + LOG_ERROR(msg, ...) + error(emsg) +end + +function SQLQueryGenerator:infer_field(k) + if not self.table_name then + self:error("Unknown input table (specified by `table_name` field)") + end + if not self.joins then + return k + end + if k:match("%.") then + return k + end + return string.format("%s.%s", self.table_name, k) +end + +function SQLQueryGenerator:sql_joins() + if not self.joins then + return nil + end + local joins = {} + for k, v in pairs(self.joins) do + local arr = explode(v, ".") + if not arr[2] then + self:error("SQL JOIN: Other table name parsing error: " .. v) + end + table.insert(joins, string.format("INNER JOIN %s ON %s = %s", arr[1], self:infer_field(k), v)) + end + return table.concat(joins, " ") +end + +function SQLQueryGenerator:sql_fields() + if not self.fields then + return nil + end + local arr = {} + for k, v in ipairs(self.fields) do + arr[k] = self:infer_field(v) + end + return string.format("(%s)", table.concat(arr, ",")) +end + +function SQLQueryGenerator:sql_order() + local tb = {} + for k, v in ipairs(self.order) do + local arr = explode(v, "$") + if #arr ~= 2 then + self:error("Invalid field order format %s", v) + end + if arr[2] == "asc" then + table.insert(tb, self:infer_field(arr[1]) .. " ASC") + elseif arr[2] == "desc" then + table.insert(tb, self:infer_field(arr[1]) .. " DESC") + else + self:error("Unknown order %s", arr[2]) + end + end + return table.concat(tb, ",") +end + +function SQLQueryGenerator:sql_where(cond, obj) + if not obj then + self:error("%s condition is nil", cond) + end + local conds = {} + local op = " AND " + if cond == "$or" then + op = " OR " + end + if type(obj) ~= 'table' then + self:error("Invalid input data for operator " .. cond) + end + for k, v in pairs(obj) do + if k == "$and" or k == "$or" then + table.insert(conds, self:sql_where(k, v)) + else + table.insert(conds, self:binary(k, v)) + end + end + + return string.format("(%s)", table.concat(conds, op)) +end + +function SQLQueryGenerator:parse_value(v, types) + if not types[type(v)] then + self:error("Type error: unexpected type %d", type(v)) + end + if type(v) == "number" then + return tostring(v) + end + if type(v) == "string" then + return string.format("'%s'", v:gsub("'", "''")) + end +end +function SQLQueryGenerator:binary(k, v) + local arr = explode(k, "$"); + if #arr > 2 then + self:error("Invalid left hand side format: %s", k) + end + if #arr == 2 then + if arr[2] == "gt" then + return string.format("(%s > %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['number'] = true + })) + elseif arr[2] == "gte" then + return string.format("(%s >= %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['number'] = true + })) + elseif arr[2] == "lt" then + return string.format("(%s < %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['number'] = true + })) + elseif arr[2] == "lte" then + return string.format("(%s <= %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['number'] = true + })) + elseif arr[2] == "ne" then + return string.format("(%s != %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['number'] = true, + ['string'] = true + })) + elseif arr[2] == "between" then + return string.format("(%s BETWEEN %s AND %s)", self:infer_field(arr[1]), self:parse_value(v[1], { + ['number'] = true + }), self:parse_value(v[2], { + ['number'] = true + })) + elseif arr[2] == "not_between" then + return string.format("(%s NOT BETWEEN %s AND %s)", self:infer_field(arr[1]), self:parse_value(v[1], { + ['number'] = true + }), self:parse_value(v[2], { + ['number'] = true + })) + elseif arr[2] == "in" then + return string.format("(%s IN [%s,%s])", self:infer_field(arr[1]), self:parse_value(v[1], { + ['number'] = true + }), self:parse_value(v[2], { + ['number'] = true + })) + elseif arr[2] == "not_in" then + return string.format("(%s NOT IN [%s,%s])", self:infer_field(arr[1]), self:parse_value(v[1], { + ['number'] = true + }), self:parse_value(v[2], { + ['number'] = true + })) + elseif arr[2] == "like" then + return string.format("(%s LIKE %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['string'] = true + })) + elseif arr[2] == "not_like" then + return string.format("(%s NOT LIKE %s)", self:infer_field(arr[1]), self:parse_value(v, { + ['string'] = true + })) + else + self:error("Unsupported operator `%s`", arr[2]) + end + else + return string.format("(%s=%s)", self:infer_field(arr[1]), self:parse_value(v, { + ['number'] = true, + ['string'] = true + })) + end +end + +--- create class DBModel +--- TODO: This class shall use the SQLQueryGenerator to create the query DBModel = Object:inherit{ - db = nil, + db = nil } function DBModel:createTable(name, m) @@ -58,7 +290,7 @@ function DBModel:insert(name, m) end function DBModel:get(name, id) - local records = self:query( string.format("SELECT * FROM %s WHERE id=%d", name, id)) + local records = self:query(string.format("SELECT * FROM %s WHERE id=%d", name, id)) if records and #records == 1 then return records[1] end @@ -66,7 +298,7 @@ function DBModel:get(name, id) end function DBModel:getAll(name) - local data = self:query( "SELECT * FROM " .. name) + local data = self:query("SELECT * FROM " .. name) if not data then return nil end @@ -102,7 +334,7 @@ function DBModel:find(name, cond) -- print(sel) end -- print(cnd) - local data = self:query( string.format("SELECT %s FROM %s WHERE %s", sel, name, cnd)) + local data = self:query(string.format("SELECT %s FROM %s WHERE %s", sel, name, cnd)) if data == nil then return nil end @@ -116,7 +348,7 @@ end function DBModel:query(sql) local data, error = sqlite.query(self.db, sql) - --LOG_DEBUG(sql) + -- LOG_DEBUG(sql) if not data then LOG_ERROR("Error querying recorda SQL[%s]: %s", sql, error or "") return nil @@ -125,7 +357,7 @@ function DBModel:query(sql) end function DBModel:exec(sql) - --LOG_DEBUG(sql) + -- LOG_DEBUG(sql) local ret, err = sqlite.exec(self.db, sql) if not ret then LOG_ERROR("Error execute [%s]: %s", sql, err or "") diff --git a/test/test_core.lua b/test/test_core.lua index d73365a..305e0d3 100644 --- a/test/test_core.lua +++ b/test/test_core.lua @@ -392,5 +392,41 @@ end) test("sha1 encode", function() expect(enc.sha1("this is a test"), "fa26be19de6bff93f70bc2308434e4a440bbad02") end) + +test("SQL Generator", function() + local o = { + table_name = "database", + where = { + ["id$gte"] = 10, + user = "dany'", + ["$or"] = { + email = "test@mail.com", + ["age$ne"] = 30, + ["$and"] = { + ["birth$ne"] = 1986, + ["age$between"] = {10,20}, + ["age$not_in"] = {20,30}, + ["name$like"] = "%LE" + } + } + }, + fields = {'user.name', 'id', 'email'}, + order = {'name$asc', "id$desc"}, + joins = { + cid = 'Category.id', + did ='Country.id' + } + } + local generator = SQLQueryGenerator:new(o) + + local r,sql = generator:sql_select() + assert(r == true, sql) + print(sql) + + r,sql = generator:sql_delete() + assert(r == true, sql) + print(sql) + +end) --- run all unit tests run() \ No newline at end of file