summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/clojure_sql/compiler.clj63
-rw-r--r--src/clojure_sql/dsl.clj176
-rw-r--r--src/clojure_sql/util.clj3
3 files changed, 133 insertions, 109 deletions
diff --git a/src/clojure_sql/compiler.clj b/src/clojure_sql/compiler.clj
index ecd7bd6..0022256 100644
--- a/src/clojure_sql/compiler.clj
+++ b/src/clojure_sql/compiler.clj
@@ -39,9 +39,6 @@
;; Utility functions for the compile-* functions
;; ==============================================================
-
-(def ^:private named? (some-fn string? symbol? keyword?))
-
(def quote? (comp boolean '#{"quote"} name))
(def unary? (comp boolean '#{"not" "exists"} name))
(def binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in"} name))
@@ -115,36 +112,27 @@
(defmulti compile-fields (fn [db _] db))
(defmethod compile-fields :default [db fields-map]
(if (seq fields-map)
- (->> (for [[field alias] fields-map]
+ (->> (for [[alias field] fields-map]
(make-field-name db field alias))
(apply sequence)
((p-lift string/join ", ")))
(return "*")))
-(defmulti compile-tables (fn [db _] db))
-(defmethod compile-tables :default [db tables-map]
- (->> (for [[table alias] tables-map]
- (make-table-name db table alias))
- (apply sequence)
- ((p-lift string/join ", "))))
-
-(defmulti compile-join (fn [db _] db))
-(defmethod compile-join :default [db [type table-map on]]
- ($str (return (case type
- :left " LEFT OUTER"
- :right " RIGHT OUTER"
- " INNER"))
- (return " JOIN ")
- (compile-tables db table-map)
- (return " ON ")
- (compile-expression db on)))
-
-(defmulti compile-joins (fn [db _] db))
-(defmethod compile-joins :default [db joins]
- (->> joins
- (map (partial compile-join db))
- (apply sequence)
- ((p-lift string/join ""))))
+(defmulti compile-tables (fn [db _ _] db))
+(defmethod compile-tables :default [db join tables-map]
+ (if (vector? join)
+ (->> (for [table-alias join]
+ (make-table-name db (get tables-map table-alias) table-alias))
+ (apply sequence)
+ ((p-lift string/join ", ")))
+ (let [{:keys [left right type on]} join]
+ ($str (return "(")
+ (compile-tables db left tables-map)
+ (return (str " " (name type) " JOIN "))
+ (compile-tables db right tables-map)
+ (return " ON ")
+ (compile-expression db on)
+ (return ")")))))
(defmulti compile-where (fn [db _] db))
(defmethod compile-where :default [db expr]
@@ -156,20 +144,21 @@
(defmethod compile-sort-by :default [db fields]
(if fields
(->> (for [[[table field] dir] fields]
- (str (make-field-name db table field) \space (string/upper-case (name dir))))
- (string/join ",")
- (apply $str (return " ORDER BY "))
- return)
+ ($str (make-field-name db [table field])
+ (return (str \space (string/upper-case (name dir))))))
+ (apply sequence)
+ ((p-lift string/join ","))
+ ($str (return " ORDER BY ")))
(return nil)))
(defmulti compile-query (fn [db _] db))
-(defmethod compile-query :default [db {:keys [table fields joins where sort-by]}]
+(defmethod compile-query :default [db {:keys [tables fields joins where sort-by]}]
($str (return "SELECT ")
(compile-fields db fields)
- (if table
- (return " FROM "))
- (compile-tables db table)
- (compile-joins db joins)
+ (if tables
+ ($str (return " FROM ")
+ (compile-tables db joins tables))
+ ($str ""))
(compile-where db where)
(compile-sort-by db sort-by)))
diff --git a/src/clojure_sql/dsl.clj b/src/clojure_sql/dsl.clj
index 68446d1..5cf7f59 100644
--- a/src/clojure_sql/dsl.clj
+++ b/src/clojure_sql/dsl.clj
@@ -1,7 +1,7 @@
(ns clojure-sql.dsl
(:refer-clojure :exclude [sort-by])
(:require [clojure.set :as set]
- [clojure.walk]
+ [clojure.walk :as walk]
[clojure-sql.query :as q]
[clojure-sql.util :as u]))
@@ -29,8 +29,10 @@
(defn table [arg]
(into (q/->Query)
(if (map? arg)
- {:tables (u/flip-map arg)}
- {:tables {arg arg}})))
+ {:tables (u/flip-map arg)
+ :joins (vec (vals arg))}
+ {:tables {arg arg}
+ :joins [arg]})))
(defn ^:private ambiguous-error [field & [query]]
(throw (ex-info (str "Ambiguous field " field " in query with more than one table")
@@ -45,8 +47,8 @@
(ambiguous-error field)))))
(defn ^:private resolve-fields [table aliases expression]
- (cond (list? expression) (map (partial resolve-fields table aliases) expression)
- (vector? expression) (mapv (partial resolve-fields table aliases) expression)
+ (cond ;;(vector? expression) (mapv (partial resolve-fields table aliases) expression)
+ (sequential? expression) (map (partial resolve-fields table aliases) expression)
(keyword? expression) (resolve-field table aliases expression)
:else expression))
@@ -54,12 +56,12 @@
(let [table (if (= (count (:tables query)) 1)
(-> query :tables first key))
alias-lookup (or (:fields query) {})
- original-name #(resolve-fields table alias-lookup %)]
+ get-real-name #(resolve-fields table alias-lookup %)]
(assoc query
- :fields (->> (for [[key val] (if (map? fields)
- fields
- (zipmap fields fields))]
- [val (original-name key)])
+ :fields (->> (for [[old-name new-name] (if (map? fields)
+ fields
+ (zipmap fields fields))]
+ [new-name (get-real-name old-name)])
(into {})))))
(defn rename [query field-renames]
@@ -89,65 +91,95 @@
(into {}))]
(assoc query :fields fields)))
-;; (defn ^:private combine-wheres [& wheres]
-;; (reduce (fn [acc where]
-;; (cond (nil? acc) where
-;; (nil? where) acc
-;; :else (or (if (and (sequential? where)
-;; (= (name (first where)) "and"))
-;; `(and ~acc ~@(next where)))
-;; (if (and (sequential? acc)
-;; (= (name (first acc)) "and"))
-;; `(and ~@(next acc) ~where))
-;; `(and ~acc ~where))))
-;; nil wheres))
-
-;; (defn join [left right & {:keys [on type]}]
-;; (let [joins-vector (or (:join left) [])
-;; common-fields (set/intersection (set (vals (:fields left)))
-;; (set (vals (:fields right))))
-;; joined-fields (if (= type :right)
-;; (merge (->> (:fields left)
-;; (filter (comp not common-fields val))
-;; (into {}))
-;; (:fields right))
-;; (merge (:fields left)
-;; (->> (:fields right)
-;; (filter (comp not common-fields val))
-;; (into {}))))
-;; implicit-on (if (seq common-fields)
-;; (map (fn [field]
-;; `(= ~(resolve-field (:table left) (:fields left) field)
-;; ~(resolve-field (:table right) (:fields right) field)))
-;; common-fields))
-;; on (if on
-;; [(resolve-fields nil joined-fields on)])
-;; join-condition (if-let [condition (seq (concat implicit-on on))]
-;; `(and ~@condition)
-;; true)]
-;; (-> left
-;; (assoc :fields joined-fields)
-;; (assoc :joins (into (conj joins-vector
-;; [(or type :inner) (:table right) join-condition])
-;; (:joins left)))
-;; (assoc :where (combine-wheres (:where left) (:where right))))))
-
-;; (defn select [query expression]
-;; (let [table-name (if-not (:joins query)
-;; (-> query :table first val))
-;; old-where (:where query)
-;; resolved-expression (resolve-fields table-name (:fields query) expression)
-;; new-where (combine-wheres old-where resolved-expression)]
-;; (assoc query :where new-where)))
-
-;; (defn sort-by [query fields]
-;; (let [table-name (if-not (:joins query)
-;; (-> query :table first val))
-;; fields-seq (if (sequential? fields)
-;; fields
-;; [fields])]
-;; (assoc query
-;; :sort-by (for [field fields-seq]
-;; (if (vector? field)
-;; (resolve-field table-name (:fields query) field)
-;; [(resolve-field table-name (:fields query) field) :asc])))))
+
+(defn ^:private combine-wheres [& wheres]
+ (reduce (fn [acc where]
+ (cond (nil? acc) where
+ (nil? where) acc
+ :else (or (if (and (sequential? where)
+ (= (name (first where)) "and"))
+ `(and ~acc ~@(next where)))
+ (if (and (sequential? acc)
+ (= (name (first acc)) "and"))
+ `(and ~@(next acc) ~where))
+ `(and ~acc ~where))))
+ nil wheres))
+
+(defn join [left right & {:keys [on type]}]
+ (let [common-tables (set/intersection (set (keys (:tables left)))
+ (set (keys (:tables right))))
+ ;;_ (assert (empty? common-tables) "Cannot join two tables with the same name")
+ merged-tables (merge (:tables left) (:tables right))
+ common-fields (set/intersection (set (keys (:fields left)))
+ (set (keys (:fields right))))
+ merged-fields (merge (:fields left) (:fields right))
+ join-condition (cond
+ (nil? on) (let [implicit (map (fn [field]
+ `(~'=
+ ~(resolve-field (:table left) (:fields left) field)
+ ~(resolve-field (:table right) (:fields right) field)))
+ common-fields)]
+ (if (next implicit)
+ (seq (cons `and implicit)) ;; more than one, so add an "and" around them
+ (first implicit)))
+ (seq common-fields) (throw (ex-info "Cannot join with common fields unless natural join"
+ {:left left
+ :right right
+ :common-fields common-fields}))
+ :else (resolve-fields nil merged-fields on))
+ type (or type
+ (if join-condition :inner)
+ :cross)]
+ (-> left
+ (assoc :fields merged-fields
+ :tables merged-tables
+ :joins {:left (:joins left)
+ :right (:joins right)
+ :type type
+ :on join-condition}
+ :where (combine-wheres (:where left)
+ (:where right))))))
+
+(defn select [query expression]
+ (let [table-name (if (= (count (:tables query)) 1)
+ (-> query :tables first key))
+ old-where (:where query)
+ resolved-expression (resolve-fields table-name (:fields query) expression)
+ new-where (combine-wheres old-where resolved-expression)]
+ (assoc query :where new-where)))
+
+(defn sort-by [query fields]
+ (let [table-name (if (= (count (:tables query)) 1)
+ (-> query :tables first key))
+ fields-seq (if (sequential? fields)
+ fields
+ [fields])]
+ (assoc query
+ :sort-by (for [field fields-seq]
+ (if (vector? field)
+ (resolve-field table-name (:fields query) field)
+ [(resolve-field table-name (:fields query) field) :asc])))))
+
+
+(let [id 10]
+ (->> (-> (table :x)
+ (project [:x])
+ (select `(and (in :x [1 2 3 :y])
+ (= :x ~id)))
+ (join (-> (table :y)
+ (project [:y]))
+ :on `(= :x :y))
+ (join (-> (table :z)
+ (project [:x])))
+ (sort-by [:x]))
+ (into {}))
+ (-> (table :x)
+ (project [:x])
+ (select `(and (in :x [1 2 3 :y])
+ (= :x ~id)))
+ (join (-> (table :y)
+ (project [:y]))
+ :on `(= :x :y))
+ (join (-> (table :z)
+ (project [:x])))
+ (sort-by [:x])))
diff --git a/src/clojure_sql/util.clj b/src/clojure_sql/util.clj
index f2edf30..3dfe948 100644
--- a/src/clojure_sql/util.clj
+++ b/src/clojure_sql/util.clj
@@ -24,3 +24,6 @@
(apply map-kv (fn [k & vs]
[k (apply f vs)])
maps))
+
+(defn named? [x]
+ (some #(% x) [keyword? string? symbol?]))