From d37dc87a15767fc48a251539875ef28df372a8cd Mon Sep 17 00:00:00 2001 From: Carlo Zancanaro Date: Tue, 14 May 2013 12:21:50 +1000 Subject: Fix ordering issue, split out string parameters for jdbc stuff The string parameters are now put in the query as a '?' and the string which should go in their place is now placed in an auxiliary list when the query is constructed. This should make it easier to avoid SQL injection stuff. (Although table/column names are still vulnerable to SQL injection, they should not be dynamic so the issue should be minimal.) There was also another issue where some things were used before they were declared (as a result of repl development) which has now been corrected. --- src/clojure_sql/core.clj | 217 +++++++++++++++++++++++++++-------------------- 1 file changed, 123 insertions(+), 94 deletions(-) (limited to 'src/clojure_sql/core.clj') diff --git a/src/clojure_sql/core.clj b/src/clojure_sql/core.clj index 2000595..319d281 100644 --- a/src/clojure_sql/core.clj +++ b/src/clojure_sql/core.clj @@ -2,7 +2,8 @@ (:refer-clojure :exclude [sort-by]) (:require [clojure.set :as set] [clojure.string :as string] - [clojure-sql.util :as u])) + [clojure-sql.util :as u] + [clojure.walk])) (defn add-parentheses [s] (str \( s \))) @@ -19,55 +20,44 @@ (defmethod sql-string :default [_ string] (str \' (string/replace string "'" "''") \')) + + ;; compile-* multimethods are of the signature: ;; (db, expr) -> (SQL, replacements) (def is-unary? (comp boolean '#{not})) (def is-predicate? (comp boolean '#{= < > <= >= is in})) -(defn c-lift [f] - (fn [& args] - (let [seconds (map second args)]) - (apply f (map first args)))) -(defn c-str [elements] - (reduce (fn [[string args] [new-string new-args]] - [(str string new-string) - (vec concat args new-args)]) - ["" nil] elements)) +(defn c-return [val] [val nil]) +(defn c-lift [f & orig-args] + (fn [& c-args] + [(apply f (concat orig-args + (map first c-args))) + (apply concat (map second c-args))])) +(def c-str (c-lift str)) +(def c-join (fn [sep args] + [(string/join sep (map first args)) + (apply concat (map second args))])) +(def c-add-parentheses (c-lift add-parentheses)) (defmulti compile-expression (fn [db _] db)) (defmethod compile-expression :default [db ex] (condp #(%1 %2) ex - nil? "NULL" - vector? (str (table-name db (first ex)) \. (field-name db (second ex))) - keyword? (field-name db ex) - string? (sql-string db ex) - symbol? (string/upper-case (name ex)) - sequential? (if (= (count ex) 2) - (->> (second ex) - (compile-expression db) - (str (compile-expression db (first ex)) " ") - add-parentheses) - (->> (rest ex) - (map (partial compile-expression db)) - (interpose (compile-expression db (first ex))) - (string/join " ") - add-parentheses)) - ex)) - -(defmulti compile-join (fn [db _] db)) -(defmethod compile-join :default [db [type table-map on]] - (str (case type - :left " LEFT OUTER " - :right " RIGHT OUTER " - "") - " JOIN " (compile-tables db table-map) " ON " (compile-expression db on))) - -(defmulti compile-joins (fn [db _] db)) -(defmethod compile-joins :default [db joins] - (->> joins - (map (partial compile-join db)) - (string/join " "))) + nil? (c-return "NULL") + vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex)))) + keyword? (c-return (field-name db ex)) + string? ["?" [ex]] ;;(sql-string db ex) + symbol? (c-return (string/upper-case (name ex))) + sequential? (-> (if (= (count ex) 2) + (->> ex + (map (partial compile-expression db)) + (c-join " ")) + (->> (rest ex) + (map (partial compile-expression db)) + (interpose (compile-expression db (first ex))) + (c-join " "))) + c-add-parentheses) + (c-return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) @@ -84,19 +74,38 @@ (if (seq fields-map) (->> (for [[[table field] alias] fields-map] (make-field-name db table field alias)) - (string/join ", ")) - "*")) + (string/join ", ") + c-return) + (c-return "*"))) (defmulti compile-tables (fn [db _] db)) (defmethod compile-tables :default [db tables-map] (->> (for [[table alias] tables-map] (make-table-name db table alias)) - (string/join ", "))) + (string/join ", ") + c-return)) + +(defmulti compile-join (fn [db _] db)) +(defmethod compile-join :default [db [type table-map on]] + (c-str (c-return (case type + :left " LEFT OUTER " + :right " RIGHT OUTER " + "")) + (c-return " JOIN ") + (compile-tables db table-map) + (c-return " ON ") + (compile-expression db on))) + +(defmulti compile-joins (fn [db _] db)) +(defmethod compile-joins :default [db joins] + (->> joins + (map (partial compile-join db)) + (c-join " "))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] (if expr - (str " WHERE " (compile-expression db expr)))) + (c-str (c-return " WHERE ") (compile-expression db expr)))) (defmulti compile-sort-by (fn [db _] db)) (defmethod compile-sort-by :default [db fields] @@ -104,17 +113,18 @@ (->> (for [[[table field] dir] fields] (str (make-field-name db table field) \space (string/upper-case (name dir)))) (string/join ",") - (apply str " ORDER BY ")))) + (apply str " ORDER BY ") + c-return))) (defmulti compile-query (fn [db _] db)) (defmethod compile-query :default [db {:keys [table fields joins where sort-by]}] - (str "SELECT " - (compile-fields db fields) - " FROM " - (compile-tables db table) - (compile-joins db joins) - (compile-where db where) - (compile-sort-by db sort-by))) + (c-str (c-return "SELECT ") + (compile-fields db fields) + (c-return " FROM ") + (compile-tables db table) + (compile-joins db joins) + (compile-where db where) + (compile-sort-by db sort-by))) @@ -168,14 +178,18 @@ {:table {arg arg}}))) (defn project [query fields] - (let [table (-> query :table first val) + (let [table (if-not (:joins query) + (-> query :table first val)) alias-lookup (->> query :fields (map (comp vec reverse)) (into {})) original-name (fn [field] (if (vector? field) field - (get alias-lookup field [table field])))] + (or (get alias-lookup field nil) + (if table + [table field] + (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved")))))))] (assoc query :fields (->> (for [[key val] (if (map? fields) fields @@ -205,7 +219,8 @@ (let [field-alias-lookup (u/flip-map aliases)] (or (field-alias-lookup field) (if table - [table field])))) + [table field] + (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved"))))))) (defn ^:private resolve-fields [table aliases expression] (clojure.walk/postwalk (fn [expr] @@ -238,7 +253,8 @@ (assoc :where (combine-wheres (:where left) (:where right)))))) (defn select [query expression] - (let [table-name (-> query :table first val) + (let [table-name (if-not (:joins query) + (-> query :table first val)) old-where (:where query) resolved-expression (resolve-fields table-name (:fields query) expression) new-where (combine-wheres old-where resolved-expression)] @@ -249,7 +265,8 @@ field-names-seq (map (fn [x] (if (vector? x) (first x) x)) (if (sequential? fields) fields [fields]))] (every? flipped-query-fields field-names-seq))]} - (let [table-name (-> query :table first val) + (let [table-name (if-not (:joins query) + (-> query :table first val)) fields-seq (if (sequential? fields) fields [fields])] @@ -260,40 +277,52 @@ [(resolve-field table-name (:fields query) field) :asc]))))) -(binding [*database-type* :mysql] - (let [users (-> (table :users) - (project [:id :username :password]) - (select '(= :deleted false))) - people (-> (table :people) - (project [:id :fname :sname]) - (select '(= :deleted false))) - uid-pid-match '(= :uid :pid) - is-carlo `(= :username "Carlo") - query (-> (join (-> users - (rename {:id :uid})) - (join (-> people - (rename {:id :pid})) - (-> (table {:others :o}) - (project {:id :oid})) - '(= :pid :oid)) - uid-pid-match) - (select is-carlo) - (project [:fname :sname :oid]))] - @query)) - -(-> (table {:nodes :child}) - (project [:parent-id, :name]) - (rename {:name :child.name}) - (join (-> (table {:nodes :parent}) - (project [:id, :name]) - (rename {:name :parent.name})) - '(= :parent-id :id)) - (project [:child.name :parent.name]) - deref #_println) - - -(deref (-> (table :anotherStack) - (project [:anotherNumber]) - (join (-> (table :collection) - (project [:number])) - true))) + +(comment + + (binding [*database-type* :mysql] + (let [users (-> (table :users) + (project [:id :username :password]) + (select '(= :deleted false))) + people (-> (table :people) + (project [:id :fname :sname]) + (select '(= :deleted false))) + uid-pid-match '(= :uid :pid) + is-carlo `(= :fname "Carlo") + query (-> (join (-> users + (rename {:id :uid})) + (join (-> people + (rename {:id :pid})) + (-> (table {:others :o}) + (project {:id :oid})) + '(= :pid :oid)) + uid-pid-match) + (select is-carlo) + (project [:fname :sname :oid]))] + @query)) + + (-> (table :users) + (project [:username]) + (join (table :something-else-with-a-username) + true) + (select '(or (= :username "john") + (not (= :username "carlo")))) + deref) + + + (-> (table {:nodes :child}) + (project [:parent-id, :name]) + (rename {:name :child.name}) + (join (-> (table {:nodes :parent}) + (project [:id, :name]) + (rename {:name :parent.name})) + '(= :parent-id :id)) + (project [:child.name :parent.name]) + deref #_println) + + + (deref (-> (table :anotherStack) + (project [:anotherNumber]) + (join (-> (table :collection) + (project [:number])) + true)))) -- cgit v1.2.3