summaryrefslogtreecommitdiff
path: root/src/clojure_sql/core.clj
diff options
context:
space:
mode:
authorCarlo Zancanaro <carlo@clearboxsystems.com.au>2013-05-14 12:21:50 +1000
committerCarlo Zancanaro <carlo@clearboxsystems.com.au>2013-05-14 12:21:50 +1000
commitd37dc87a15767fc48a251539875ef28df372a8cd (patch)
tree7ea76847d4cb22886ba2ed4f77b95990da77a2e0 /src/clojure_sql/core.clj
parentd70e99185025eeef545248321c04d885aa6a38c2 (diff)
Fix ordering issue, split out string parameters for jdbc stuff
The string parameters are now put in the query as a '?' and the string which should go in their place is now placed in an auxiliary list when the query is constructed. This should make it easier to avoid SQL injection stuff. (Although table/column names are still vulnerable to SQL injection, they should not be dynamic so the issue should be minimal.) There was also another issue where some things were used before they were declared (as a result of repl development) which has now been corrected.
Diffstat (limited to 'src/clojure_sql/core.clj')
-rw-r--r--src/clojure_sql/core.clj217
1 files changed, 123 insertions, 94 deletions
diff --git a/src/clojure_sql/core.clj b/src/clojure_sql/core.clj
index 2000595..319d281 100644
--- a/src/clojure_sql/core.clj
+++ b/src/clojure_sql/core.clj
@@ -2,7 +2,8 @@
(:refer-clojure :exclude [sort-by])
(:require [clojure.set :as set]
[clojure.string :as string]
- [clojure-sql.util :as u]))
+ [clojure-sql.util :as u]
+ [clojure.walk]))
(defn add-parentheses [s]
(str \( s \)))
@@ -19,55 +20,44 @@
(defmethod sql-string :default [_ string]
(str \' (string/replace string "'" "''") \'))
+
+
;; compile-* multimethods are of the signature:
;; (db, expr) -> (SQL, replacements)
(def is-unary? (comp boolean '#{not}))
(def is-predicate? (comp boolean '#{= < > <= >= is in}))
-(defn c-lift [f]
- (fn [& args]
- (let [seconds (map second args)])
- (apply f (map first args))))
-(defn c-str [elements]
- (reduce (fn [[string args] [new-string new-args]]
- [(str string new-string)
- (vec concat args new-args)])
- ["" nil] elements))
+(defn c-return [val] [val nil])
+(defn c-lift [f & orig-args]
+ (fn [& c-args]
+ [(apply f (concat orig-args
+ (map first c-args)))
+ (apply concat (map second c-args))]))
+(def c-str (c-lift str))
+(def c-join (fn [sep args]
+ [(string/join sep (map first args))
+ (apply concat (map second args))]))
+(def c-add-parentheses (c-lift add-parentheses))
(defmulti compile-expression (fn [db _] db))
(defmethod compile-expression :default [db ex]
(condp #(%1 %2) ex
- nil? "NULL"
- vector? (str (table-name db (first ex)) \. (field-name db (second ex)))
- keyword? (field-name db ex)
- string? (sql-string db ex)
- symbol? (string/upper-case (name ex))
- sequential? (if (= (count ex) 2)
- (->> (second ex)
- (compile-expression db)
- (str (compile-expression db (first ex)) " ")
- add-parentheses)
- (->> (rest ex)
- (map (partial compile-expression db))
- (interpose (compile-expression db (first ex)))
- (string/join " ")
- add-parentheses))
- ex))
-
-(defmulti compile-join (fn [db _] db))
-(defmethod compile-join :default [db [type table-map on]]
- (str (case type
- :left " LEFT OUTER "
- :right " RIGHT OUTER "
- "")
- " JOIN " (compile-tables db table-map) " ON " (compile-expression db on)))
-
-(defmulti compile-joins (fn [db _] db))
-(defmethod compile-joins :default [db joins]
- (->> joins
- (map (partial compile-join db))
- (string/join " ")))
+ nil? (c-return "NULL")
+ vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex))))
+ keyword? (c-return (field-name db ex))
+ string? ["?" [ex]] ;;(sql-string db ex)
+ symbol? (c-return (string/upper-case (name ex)))
+ sequential? (-> (if (= (count ex) 2)
+ (->> ex
+ (map (partial compile-expression db))
+ (c-join " "))
+ (->> (rest ex)
+ (map (partial compile-expression db))
+ (interpose (compile-expression db (first ex)))
+ (c-join " ")))
+ c-add-parentheses)
+ (c-return ex)))
(defn ^:private make-table-name [db table & [alias]]
(if (or (= table alias) (nil? alias))
@@ -84,19 +74,38 @@
(if (seq fields-map)
(->> (for [[[table field] alias] fields-map]
(make-field-name db table field alias))
- (string/join ", "))
- "*"))
+ (string/join ", ")
+ c-return)
+ (c-return "*")))
(defmulti compile-tables (fn [db _] db))
(defmethod compile-tables :default [db tables-map]
(->> (for [[table alias] tables-map]
(make-table-name db table alias))
- (string/join ", ")))
+ (string/join ", ")
+ c-return))
+
+(defmulti compile-join (fn [db _] db))
+(defmethod compile-join :default [db [type table-map on]]
+ (c-str (c-return (case type
+ :left " LEFT OUTER "
+ :right " RIGHT OUTER "
+ ""))
+ (c-return " JOIN ")
+ (compile-tables db table-map)
+ (c-return " ON ")
+ (compile-expression db on)))
+
+(defmulti compile-joins (fn [db _] db))
+(defmethod compile-joins :default [db joins]
+ (->> joins
+ (map (partial compile-join db))
+ (c-join " ")))
(defmulti compile-where (fn [db _] db))
(defmethod compile-where :default [db expr]
(if expr
- (str " WHERE " (compile-expression db expr))))
+ (c-str (c-return " WHERE ") (compile-expression db expr))))
(defmulti compile-sort-by (fn [db _] db))
(defmethod compile-sort-by :default [db fields]
@@ -104,17 +113,18 @@
(->> (for [[[table field] dir] fields]
(str (make-field-name db table field) \space (string/upper-case (name dir))))
(string/join ",")
- (apply str " ORDER BY "))))
+ (apply str " ORDER BY ")
+ c-return)))
(defmulti compile-query (fn [db _] db))
(defmethod compile-query :default [db {:keys [table fields joins where sort-by]}]
- (str "SELECT "
- (compile-fields db fields)
- " FROM "
- (compile-tables db table)
- (compile-joins db joins)
- (compile-where db where)
- (compile-sort-by db sort-by)))
+ (c-str (c-return "SELECT ")
+ (compile-fields db fields)
+ (c-return " FROM ")
+ (compile-tables db table)
+ (compile-joins db joins)
+ (compile-where db where)
+ (compile-sort-by db sort-by)))
@@ -168,14 +178,18 @@
{:table {arg arg}})))
(defn project [query fields]
- (let [table (-> query :table first val)
+ (let [table (if-not (:joins query)
+ (-> query :table first val))
alias-lookup (->> query :fields
(map (comp vec reverse))
(into {}))
original-name (fn [field]
(if (vector? field)
field
- (get alias-lookup field [table field])))]
+ (or (get alias-lookup field nil)
+ (if table
+ [table field]
+ (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved")))))))]
(assoc query
:fields (->> (for [[key val] (if (map? fields)
fields
@@ -205,7 +219,8 @@
(let [field-alias-lookup (u/flip-map aliases)]
(or (field-alias-lookup field)
(if table
- [table field]))))
+ [table field]
+ (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved")))))))
(defn ^:private resolve-fields [table aliases expression]
(clojure.walk/postwalk (fn [expr]
@@ -238,7 +253,8 @@
(assoc :where (combine-wheres (:where left) (:where right))))))
(defn select [query expression]
- (let [table-name (-> query :table first val)
+ (let [table-name (if-not (:joins query)
+ (-> query :table first val))
old-where (:where query)
resolved-expression (resolve-fields table-name (:fields query) expression)
new-where (combine-wheres old-where resolved-expression)]
@@ -249,7 +265,8 @@
field-names-seq (map (fn [x] (if (vector? x) (first x) x))
(if (sequential? fields) fields [fields]))]
(every? flipped-query-fields field-names-seq))]}
- (let [table-name (-> query :table first val)
+ (let [table-name (if-not (:joins query)
+ (-> query :table first val))
fields-seq (if (sequential? fields)
fields
[fields])]
@@ -260,40 +277,52 @@
[(resolve-field table-name (:fields query) field) :asc])))))
-(binding [*database-type* :mysql]
- (let [users (-> (table :users)
- (project [:id :username :password])
- (select '(= :deleted false)))
- people (-> (table :people)
- (project [:id :fname :sname])
- (select '(= :deleted false)))
- uid-pid-match '(= :uid :pid)
- is-carlo `(= :username "Carlo")
- query (-> (join (-> users
- (rename {:id :uid}))
- (join (-> people
- (rename {:id :pid}))
- (-> (table {:others :o})
- (project {:id :oid}))
- '(= :pid :oid))
- uid-pid-match)
- (select is-carlo)
- (project [:fname :sname :oid]))]
- @query))
-
-(-> (table {:nodes :child})
- (project [:parent-id, :name])
- (rename {:name :child.name})
- (join (-> (table {:nodes :parent})
- (project [:id, :name])
- (rename {:name :parent.name}))
- '(= :parent-id :id))
- (project [:child.name :parent.name])
- deref #_println)
-
-
-(deref (-> (table :anotherStack)
- (project [:anotherNumber])
- (join (-> (table :collection)
- (project [:number]))
- true)))
+
+(comment
+
+ (binding [*database-type* :mysql]
+ (let [users (-> (table :users)
+ (project [:id :username :password])
+ (select '(= :deleted false)))
+ people (-> (table :people)
+ (project [:id :fname :sname])
+ (select '(= :deleted false)))
+ uid-pid-match '(= :uid :pid)
+ is-carlo `(= :fname "Carlo")
+ query (-> (join (-> users
+ (rename {:id :uid}))
+ (join (-> people
+ (rename {:id :pid}))
+ (-> (table {:others :o})
+ (project {:id :oid}))
+ '(= :pid :oid))
+ uid-pid-match)
+ (select is-carlo)
+ (project [:fname :sname :oid]))]
+ @query))
+
+ (-> (table :users)
+ (project [:username])
+ (join (table :something-else-with-a-username)
+ true)
+ (select '(or (= :username "john")
+ (not (= :username "carlo"))))
+ deref)
+
+
+ (-> (table {:nodes :child})
+ (project [:parent-id, :name])
+ (rename {:name :child.name})
+ (join (-> (table {:nodes :parent})
+ (project [:id, :name])
+ (rename {:name :parent.name}))
+ '(= :parent-id :id))
+ (project [:child.name :parent.name])
+ deref #_println)
+
+
+ (deref (-> (table :anotherStack)
+ (project [:anotherNumber])
+ (join (-> (table :collection)
+ (project [:number]))
+ true))))