diff options
-rw-r--r-- | src/clojure_sql/core.clj | 217 |
1 files changed, 123 insertions, 94 deletions
diff --git a/src/clojure_sql/core.clj b/src/clojure_sql/core.clj index 2000595..319d281 100644 --- a/src/clojure_sql/core.clj +++ b/src/clojure_sql/core.clj @@ -2,7 +2,8 @@ (:refer-clojure :exclude [sort-by]) (:require [clojure.set :as set] [clojure.string :as string] - [clojure-sql.util :as u])) + [clojure-sql.util :as u] + [clojure.walk])) (defn add-parentheses [s] (str \( s \))) @@ -19,55 +20,44 @@ (defmethod sql-string :default [_ string] (str \' (string/replace string "'" "''") \')) + + ;; compile-* multimethods are of the signature: ;; (db, expr) -> (SQL, replacements) (def is-unary? (comp boolean '#{not})) (def is-predicate? (comp boolean '#{= < > <= >= is in})) -(defn c-lift [f] - (fn [& args] - (let [seconds (map second args)]) - (apply f (map first args)))) -(defn c-str [elements] - (reduce (fn [[string args] [new-string new-args]] - [(str string new-string) - (vec concat args new-args)]) - ["" nil] elements)) +(defn c-return [val] [val nil]) +(defn c-lift [f & orig-args] + (fn [& c-args] + [(apply f (concat orig-args + (map first c-args))) + (apply concat (map second c-args))])) +(def c-str (c-lift str)) +(def c-join (fn [sep args] + [(string/join sep (map first args)) + (apply concat (map second args))])) +(def c-add-parentheses (c-lift add-parentheses)) (defmulti compile-expression (fn [db _] db)) (defmethod compile-expression :default [db ex] (condp #(%1 %2) ex - nil? "NULL" - vector? (str (table-name db (first ex)) \. (field-name db (second ex))) - keyword? (field-name db ex) - string? (sql-string db ex) - symbol? (string/upper-case (name ex)) - sequential? (if (= (count ex) 2) - (->> (second ex) - (compile-expression db) - (str (compile-expression db (first ex)) " ") - add-parentheses) - (->> (rest ex) - (map (partial compile-expression db)) - (interpose (compile-expression db (first ex))) - (string/join " ") - add-parentheses)) - ex)) - -(defmulti compile-join (fn [db _] db)) -(defmethod compile-join :default [db [type table-map on]] - (str (case type - :left " LEFT OUTER " - :right " RIGHT OUTER " - "") - " JOIN " (compile-tables db table-map) " ON " (compile-expression db on))) - -(defmulti compile-joins (fn [db _] db)) -(defmethod compile-joins :default [db joins] - (->> joins - (map (partial compile-join db)) - (string/join " "))) + nil? (c-return "NULL") + vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex)))) + keyword? (c-return (field-name db ex)) + string? ["?" [ex]] ;;(sql-string db ex) + symbol? (c-return (string/upper-case (name ex))) + sequential? (-> (if (= (count ex) 2) + (->> ex + (map (partial compile-expression db)) + (c-join " ")) + (->> (rest ex) + (map (partial compile-expression db)) + (interpose (compile-expression db (first ex))) + (c-join " "))) + c-add-parentheses) + (c-return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) @@ -84,19 +74,38 @@ (if (seq fields-map) (->> (for [[[table field] alias] fields-map] (make-field-name db table field alias)) - (string/join ", ")) - "*")) + (string/join ", ") + c-return) + (c-return "*"))) (defmulti compile-tables (fn [db _] db)) (defmethod compile-tables :default [db tables-map] (->> (for [[table alias] tables-map] (make-table-name db table alias)) - (string/join ", "))) + (string/join ", ") + c-return)) + +(defmulti compile-join (fn [db _] db)) +(defmethod compile-join :default [db [type table-map on]] + (c-str (c-return (case type + :left " LEFT OUTER " + :right " RIGHT OUTER " + "")) + (c-return " JOIN ") + (compile-tables db table-map) + (c-return " ON ") + (compile-expression db on))) + +(defmulti compile-joins (fn [db _] db)) +(defmethod compile-joins :default [db joins] + (->> joins + (map (partial compile-join db)) + (c-join " "))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] (if expr - (str " WHERE " (compile-expression db expr)))) + (c-str (c-return " WHERE ") (compile-expression db expr)))) (defmulti compile-sort-by (fn [db _] db)) (defmethod compile-sort-by :default [db fields] @@ -104,17 +113,18 @@ (->> (for [[[table field] dir] fields] (str (make-field-name db table field) \space (string/upper-case (name dir)))) (string/join ",") - (apply str " ORDER BY ")))) + (apply str " ORDER BY ") + c-return))) (defmulti compile-query (fn [db _] db)) (defmethod compile-query :default [db {:keys [table fields joins where sort-by]}] - (str "SELECT " - (compile-fields db fields) - " FROM " - (compile-tables db table) - (compile-joins db joins) - (compile-where db where) - (compile-sort-by db sort-by))) + (c-str (c-return "SELECT ") + (compile-fields db fields) + (c-return " FROM ") + (compile-tables db table) + (compile-joins db joins) + (compile-where db where) + (compile-sort-by db sort-by))) @@ -168,14 +178,18 @@ {:table {arg arg}}))) (defn project [query fields] - (let [table (-> query :table first val) + (let [table (if-not (:joins query) + (-> query :table first val)) alias-lookup (->> query :fields (map (comp vec reverse)) (into {})) original-name (fn [field] (if (vector? field) field - (get alias-lookup field [table field])))] + (or (get alias-lookup field nil) + (if table + [table field] + (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved")))))))] (assoc query :fields (->> (for [[key val] (if (map? fields) fields @@ -205,7 +219,8 @@ (let [field-alias-lookup (u/flip-map aliases)] (or (field-alias-lookup field) (if table - [table field])))) + [table field] + (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved"))))))) (defn ^:private resolve-fields [table aliases expression] (clojure.walk/postwalk (fn [expr] @@ -238,7 +253,8 @@ (assoc :where (combine-wheres (:where left) (:where right)))))) (defn select [query expression] - (let [table-name (-> query :table first val) + (let [table-name (if-not (:joins query) + (-> query :table first val)) old-where (:where query) resolved-expression (resolve-fields table-name (:fields query) expression) new-where (combine-wheres old-where resolved-expression)] @@ -249,7 +265,8 @@ field-names-seq (map (fn [x] (if (vector? x) (first x) x)) (if (sequential? fields) fields [fields]))] (every? flipped-query-fields field-names-seq))]} - (let [table-name (-> query :table first val) + (let [table-name (if-not (:joins query) + (-> query :table first val)) fields-seq (if (sequential? fields) fields [fields])] @@ -260,40 +277,52 @@ [(resolve-field table-name (:fields query) field) :asc]))))) -(binding [*database-type* :mysql] - (let [users (-> (table :users) - (project [:id :username :password]) - (select '(= :deleted false))) - people (-> (table :people) - (project [:id :fname :sname]) - (select '(= :deleted false))) - uid-pid-match '(= :uid :pid) - is-carlo `(= :username "Carlo") - query (-> (join (-> users - (rename {:id :uid})) - (join (-> people - (rename {:id :pid})) - (-> (table {:others :o}) - (project {:id :oid})) - '(= :pid :oid)) - uid-pid-match) - (select is-carlo) - (project [:fname :sname :oid]))] - @query)) - -(-> (table {:nodes :child}) - (project [:parent-id, :name]) - (rename {:name :child.name}) - (join (-> (table {:nodes :parent}) - (project [:id, :name]) - (rename {:name :parent.name})) - '(= :parent-id :id)) - (project [:child.name :parent.name]) - deref #_println) - - -(deref (-> (table :anotherStack) - (project [:anotherNumber]) - (join (-> (table :collection) - (project [:number])) - true))) + +(comment + + (binding [*database-type* :mysql] + (let [users (-> (table :users) + (project [:id :username :password]) + (select '(= :deleted false))) + people (-> (table :people) + (project [:id :fname :sname]) + (select '(= :deleted false))) + uid-pid-match '(= :uid :pid) + is-carlo `(= :fname "Carlo") + query (-> (join (-> users + (rename {:id :uid})) + (join (-> people + (rename {:id :pid})) + (-> (table {:others :o}) + (project {:id :oid})) + '(= :pid :oid)) + uid-pid-match) + (select is-carlo) + (project [:fname :sname :oid]))] + @query)) + + (-> (table :users) + (project [:username]) + (join (table :something-else-with-a-username) + true) + (select '(or (= :username "john") + (not (= :username "carlo")))) + deref) + + + (-> (table {:nodes :child}) + (project [:parent-id, :name]) + (rename {:name :child.name}) + (join (-> (table {:nodes :parent}) + (project [:id, :name]) + (rename {:name :parent.name})) + '(= :parent-id :id)) + (project [:child.name :parent.name]) + deref #_println) + + + (deref (-> (table :anotherStack) + (project [:anotherNumber]) + (join (-> (table :collection) + (project [:number])) + true)))) |