diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/clojure_sql/compiler.clj | 63 | ||||
-rw-r--r-- | src/clojure_sql/dsl.clj | 176 | ||||
-rw-r--r-- | src/clojure_sql/util.clj | 3 |
3 files changed, 133 insertions, 109 deletions
diff --git a/src/clojure_sql/compiler.clj b/src/clojure_sql/compiler.clj index ecd7bd6..0022256 100644 --- a/src/clojure_sql/compiler.clj +++ b/src/clojure_sql/compiler.clj @@ -39,9 +39,6 @@ ;; Utility functions for the compile-* functions ;; ============================================================== - -(def ^:private named? (some-fn string? symbol? keyword?)) - (def quote? (comp boolean '#{"quote"} name)) (def unary? (comp boolean '#{"not" "exists"} name)) (def binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in"} name)) @@ -115,36 +112,27 @@ (defmulti compile-fields (fn [db _] db)) (defmethod compile-fields :default [db fields-map] (if (seq fields-map) - (->> (for [[field alias] fields-map] + (->> (for [[alias field] fields-map] (make-field-name db field alias)) (apply sequence) ((p-lift string/join ", "))) (return "*"))) -(defmulti compile-tables (fn [db _] db)) -(defmethod compile-tables :default [db tables-map] - (->> (for [[table alias] tables-map] - (make-table-name db table alias)) - (apply sequence) - ((p-lift string/join ", ")))) - -(defmulti compile-join (fn [db _] db)) -(defmethod compile-join :default [db [type table-map on]] - ($str (return (case type - :left " LEFT OUTER" - :right " RIGHT OUTER" - " INNER")) - (return " JOIN ") - (compile-tables db table-map) - (return " ON ") - (compile-expression db on))) - -(defmulti compile-joins (fn [db _] db)) -(defmethod compile-joins :default [db joins] - (->> joins - (map (partial compile-join db)) - (apply sequence) - ((p-lift string/join "")))) +(defmulti compile-tables (fn [db _ _] db)) +(defmethod compile-tables :default [db join tables-map] + (if (vector? join) + (->> (for [table-alias join] + (make-table-name db (get tables-map table-alias) table-alias)) + (apply sequence) + ((p-lift string/join ", "))) + (let [{:keys [left right type on]} join] + ($str (return "(") + (compile-tables db left tables-map) + (return (str " " (name type) " JOIN ")) + (compile-tables db right tables-map) + (return " ON ") + (compile-expression db on) + (return ")"))))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] @@ -156,20 +144,21 @@ (defmethod compile-sort-by :default [db fields] (if fields (->> (for [[[table field] dir] fields] - (str (make-field-name db table field) \space (string/upper-case (name dir)))) - (string/join ",") - (apply $str (return " ORDER BY ")) - return) + ($str (make-field-name db [table field]) + (return (str \space (string/upper-case (name dir)))))) + (apply sequence) + ((p-lift string/join ",")) + ($str (return " ORDER BY "))) (return nil))) (defmulti compile-query (fn [db _] db)) -(defmethod compile-query :default [db {:keys [table fields joins where sort-by]}] +(defmethod compile-query :default [db {:keys [tables fields joins where sort-by]}] ($str (return "SELECT ") (compile-fields db fields) - (if table - (return " FROM ")) - (compile-tables db table) - (compile-joins db joins) + (if tables + ($str (return " FROM ") + (compile-tables db joins tables)) + ($str "")) (compile-where db where) (compile-sort-by db sort-by))) diff --git a/src/clojure_sql/dsl.clj b/src/clojure_sql/dsl.clj index 68446d1..5cf7f59 100644 --- a/src/clojure_sql/dsl.clj +++ b/src/clojure_sql/dsl.clj @@ -1,7 +1,7 @@ (ns clojure-sql.dsl (:refer-clojure :exclude [sort-by]) (:require [clojure.set :as set] - [clojure.walk] + [clojure.walk :as walk] [clojure-sql.query :as q] [clojure-sql.util :as u])) @@ -29,8 +29,10 @@ (defn table [arg] (into (q/->Query) (if (map? arg) - {:tables (u/flip-map arg)} - {:tables {arg arg}}))) + {:tables (u/flip-map arg) + :joins (vec (vals arg))} + {:tables {arg arg} + :joins [arg]}))) (defn ^:private ambiguous-error [field & [query]] (throw (ex-info (str "Ambiguous field " field " in query with more than one table") @@ -45,8 +47,8 @@ (ambiguous-error field))))) (defn ^:private resolve-fields [table aliases expression] - (cond (list? expression) (map (partial resolve-fields table aliases) expression) - (vector? expression) (mapv (partial resolve-fields table aliases) expression) + (cond ;;(vector? expression) (mapv (partial resolve-fields table aliases) expression) + (sequential? expression) (map (partial resolve-fields table aliases) expression) (keyword? expression) (resolve-field table aliases expression) :else expression)) @@ -54,12 +56,12 @@ (let [table (if (= (count (:tables query)) 1) (-> query :tables first key)) alias-lookup (or (:fields query) {}) - original-name #(resolve-fields table alias-lookup %)] + get-real-name #(resolve-fields table alias-lookup %)] (assoc query - :fields (->> (for [[key val] (if (map? fields) - fields - (zipmap fields fields))] - [val (original-name key)]) + :fields (->> (for [[old-name new-name] (if (map? fields) + fields + (zipmap fields fields))] + [new-name (get-real-name old-name)]) (into {}))))) (defn rename [query field-renames] @@ -89,65 +91,95 @@ (into {}))] (assoc query :fields fields))) -;; (defn ^:private combine-wheres [& wheres] -;; (reduce (fn [acc where] -;; (cond (nil? acc) where -;; (nil? where) acc -;; :else (or (if (and (sequential? where) -;; (= (name (first where)) "and")) -;; `(and ~acc ~@(next where))) -;; (if (and (sequential? acc) -;; (= (name (first acc)) "and")) -;; `(and ~@(next acc) ~where)) -;; `(and ~acc ~where)))) -;; nil wheres)) - -;; (defn join [left right & {:keys [on type]}] -;; (let [joins-vector (or (:join left) []) -;; common-fields (set/intersection (set (vals (:fields left))) -;; (set (vals (:fields right)))) -;; joined-fields (if (= type :right) -;; (merge (->> (:fields left) -;; (filter (comp not common-fields val)) -;; (into {})) -;; (:fields right)) -;; (merge (:fields left) -;; (->> (:fields right) -;; (filter (comp not common-fields val)) -;; (into {})))) -;; implicit-on (if (seq common-fields) -;; (map (fn [field] -;; `(= ~(resolve-field (:table left) (:fields left) field) -;; ~(resolve-field (:table right) (:fields right) field))) -;; common-fields)) -;; on (if on -;; [(resolve-fields nil joined-fields on)]) -;; join-condition (if-let [condition (seq (concat implicit-on on))] -;; `(and ~@condition) -;; true)] -;; (-> left -;; (assoc :fields joined-fields) -;; (assoc :joins (into (conj joins-vector -;; [(or type :inner) (:table right) join-condition]) -;; (:joins left))) -;; (assoc :where (combine-wheres (:where left) (:where right)))))) - -;; (defn select [query expression] -;; (let [table-name (if-not (:joins query) -;; (-> query :table first val)) -;; old-where (:where query) -;; resolved-expression (resolve-fields table-name (:fields query) expression) -;; new-where (combine-wheres old-where resolved-expression)] -;; (assoc query :where new-where))) - -;; (defn sort-by [query fields] -;; (let [table-name (if-not (:joins query) -;; (-> query :table first val)) -;; fields-seq (if (sequential? fields) -;; fields -;; [fields])] -;; (assoc query -;; :sort-by (for [field fields-seq] -;; (if (vector? field) -;; (resolve-field table-name (:fields query) field) -;; [(resolve-field table-name (:fields query) field) :asc]))))) + +(defn ^:private combine-wheres [& wheres] + (reduce (fn [acc where] + (cond (nil? acc) where + (nil? where) acc + :else (or (if (and (sequential? where) + (= (name (first where)) "and")) + `(and ~acc ~@(next where))) + (if (and (sequential? acc) + (= (name (first acc)) "and")) + `(and ~@(next acc) ~where)) + `(and ~acc ~where)))) + nil wheres)) + +(defn join [left right & {:keys [on type]}] + (let [common-tables (set/intersection (set (keys (:tables left))) + (set (keys (:tables right)))) + ;;_ (assert (empty? common-tables) "Cannot join two tables with the same name") + merged-tables (merge (:tables left) (:tables right)) + common-fields (set/intersection (set (keys (:fields left))) + (set (keys (:fields right)))) + merged-fields (merge (:fields left) (:fields right)) + join-condition (cond + (nil? on) (let [implicit (map (fn [field] + `(~'= + ~(resolve-field (:table left) (:fields left) field) + ~(resolve-field (:table right) (:fields right) field))) + common-fields)] + (if (next implicit) + (seq (cons `and implicit)) ;; more than one, so add an "and" around them + (first implicit))) + (seq common-fields) (throw (ex-info "Cannot join with common fields unless natural join" + {:left left + :right right + :common-fields common-fields})) + :else (resolve-fields nil merged-fields on)) + type (or type + (if join-condition :inner) + :cross)] + (-> left + (assoc :fields merged-fields + :tables merged-tables + :joins {:left (:joins left) + :right (:joins right) + :type type + :on join-condition} + :where (combine-wheres (:where left) + (:where right)))))) + +(defn select [query expression] + (let [table-name (if (= (count (:tables query)) 1) + (-> query :tables first key)) + old-where (:where query) + resolved-expression (resolve-fields table-name (:fields query) expression) + new-where (combine-wheres old-where resolved-expression)] + (assoc query :where new-where))) + +(defn sort-by [query fields] + (let [table-name (if (= (count (:tables query)) 1) + (-> query :tables first key)) + fields-seq (if (sequential? fields) + fields + [fields])] + (assoc query + :sort-by (for [field fields-seq] + (if (vector? field) + (resolve-field table-name (:fields query) field) + [(resolve-field table-name (:fields query) field) :asc]))))) + + +(let [id 10] + (->> (-> (table :x) + (project [:x]) + (select `(and (in :x [1 2 3 :y]) + (= :x ~id))) + (join (-> (table :y) + (project [:y])) + :on `(= :x :y)) + (join (-> (table :z) + (project [:x]))) + (sort-by [:x])) + (into {})) + (-> (table :x) + (project [:x]) + (select `(and (in :x [1 2 3 :y]) + (= :x ~id))) + (join (-> (table :y) + (project [:y])) + :on `(= :x :y)) + (join (-> (table :z) + (project [:x]))) + (sort-by [:x]))) diff --git a/src/clojure_sql/util.clj b/src/clojure_sql/util.clj index f2edf30..3dfe948 100644 --- a/src/clojure_sql/util.clj +++ b/src/clojure_sql/util.clj @@ -24,3 +24,6 @@ (apply map-kv (fn [k & vs] [k (apply f vs)]) maps)) + +(defn named? [x] + (some #(% x) [keyword? string? symbol?])) |