(ns clojure-sql.core (:refer-clojure :exclude [sort-by]) (:require [clojure.set :as set] [clojure.string :as string] [clojure-sql.util :as u] [clojure.walk])) (declare compile-query) (def ^:dynamic *database-type* nil) (defn set-database-type! [new-type] (alter-var-root #'*database-type* (constantly new-type))) (def ^:dynamic *query-deref-behaviour* #(compile-query *database-type* %)) (defn set-query-deref-behaviour! [f] (alter-var-root #'*query-deref-behaviour* (constantly f))) (defrecord ^:private Query [] clojure.lang.IDeref (deref [this] (*query-deref-behaviour* this))) (defn add-parentheses [s] (str \( s \))) (defmulti field-name (fn [db _] db)) (defmethod field-name :default [_ field] (str \" (name field) \")) (defmulti table-name (fn [db _] db)) (defmethod table-name :default [_ table] (str \" (name table) \")) ;; compile-* multimethods are of the signature: ;; (db, expr) -> [SQL & replacements] (def is-unary? (comp boolean '#{not})) (def is-predicate? (comp boolean '#{= < > <= >= is in})) (defn c-return [val] [val]) (defn c-lift [f & orig-args] (fn [& c-args] (apply vector (apply f (concat orig-args (map first c-args))) (apply concat (map rest c-args))))) (def c-str (c-lift str)) (def c-join (fn [sep args] (apply vector (string/join sep (map first args)) (apply concat (map rest args))))) (def c-add-parentheses (c-lift add-parentheses)) (defmulti compile-expression (fn [db _] db)) (defmethod compile-expression :default [db ex] (condp #(%1 %2) ex nil? (c-return "NULL") vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex)))) keyword? (c-return (field-name db ex)) string? ["?" ex] ;;(sql-string db ex) symbol? (c-return (string/upper-case (name ex))) sequential? (-> (if (= (count ex) 2) (->> ex (map (partial compile-expression db)) (c-join " ")) (->> (rest ex) (map (partial compile-expression db)) (interpose (compile-expression db (first ex))) (c-join " "))) c-add-parentheses) (c-return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) (table-name db table) (str (table-name db table) " AS " (field-name db alias))) ) (defn ^:private make-field-name [db table field & [alias]] (if (or (= field alias) (nil? alias)) (str (table-name db table) \. (field-name db field)) (str (table-name db table) \. (field-name db field) " AS " (field-name db alias)))) (defmulti compile-fields (fn [db _] db)) (defmethod compile-fields :default [db fields-map] (if (seq fields-map) (->> (for [[[table field] alias] fields-map] (make-field-name db table field alias)) (string/join ", ") c-return) (c-return "*"))) (defmulti compile-tables (fn [db _] db)) (defmethod compile-tables :default [db tables-map] (->> (for [[table alias] tables-map] (make-table-name db table alias)) (string/join ", ") c-return)) (defmulti compile-join (fn [db _] db)) (defmethod compile-join :default [db [type table-map on]] (c-str (c-return (case type :left " LEFT OUTER " :right " RIGHT OUTER " "")) (c-return " JOIN ") (compile-tables db table-map) (c-return " ON ") (compile-expression db on))) (defmulti compile-joins (fn [db _] db)) (defmethod compile-joins :default [db joins] (->> joins (map (partial compile-join db)) (c-join " "))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] (if expr (c-str (c-return " WHERE ") (compile-expression db expr)))) (defmulti compile-sort-by (fn [db _] db)) (defmethod compile-sort-by :default [db fields] (if fields (->> (for [[[table field] dir] fields] (str (make-field-name db table field) \space (string/upper-case (name dir)))) (string/join ",") (apply str " ORDER BY ") c-return))) (defmulti compile-query (fn [db _] db)) (defmethod compile-query :default [db {:keys [table fields joins where sort-by]}] (c-str (c-return "SELECT ") (compile-fields db fields) (c-return " FROM ") (compile-tables db table) (compile-joins db joins) (compile-where db where) (compile-sort-by db sort-by))) ;; SQL SERVER (defmethod field-name :sql-server [_ field] (str \[ (name field) \])) (defmethod table-name :sql-server [_ table] (str \[ (name table) \])) ;; mySQL (defmethod field-name :mysql [_ field] (str \` (name field) \`)) (defmethod table-name :mysql [_ table] (str \` (name table) \`)) ;; important sections: ;; -PROJECTION- ;; -TABLES- ;; -FILTERS- ;; GROUPING ;; GROUPED FILTERS ;; -SORTING- ;; table: tablename -> table_alias ;; fields: (table_alias, fieldname) -> field_alias ;; joins: [(tablename -> type, table_alias, on)] ;; where: expression ;; sort-by: [[field direction]] (defn table [arg] (into (->Query) (if (map? arg) {:table arg} {:table {arg arg}}))) (defn ambiguous-error [field] (throw (ex-info (str "Ambiguous field " field) {:field field}))) (defn project [query fields] (let [table (if-not (:joins query) (-> query :table first val)) alias-lookup (u/flip-map (:fields query)) original-name (fn [field] (if (vector? field) field (or (get alias-lookup field nil) (if table [table field] (ambiguous-error field)))))] (assoc query :fields (->> (for [[key val] (if (map? fields) fields (zipmap fields fields))] [(original-name key) val]) (into {}))))) (defn rename [query field-renames] {:pre [(map? field-renames) (empty? (set/intersection (set (vals field-renames)) (set (vals (:fields query)))))]} (let [fields (:fields query) alias-lookup (u/flip-map (:fields query)) original-name (fn [field] (cond (vector? field) field (contains? alias-lookup field) (get alias-lookup field) :else (throw (ex-info (str "Invalid field in rename: " (pr-str field)) {:field field :query query :renames field-renames}))) (get alias-lookup field))] (update-in query [:fields] #(->> (for [[key val] field-renames] ;(if (contains? val (:fields query))) [(original-name key) val]) (into %))))) (defn ^:private resolve-field [table aliases field] (let [field-alias-lookup (u/flip-map aliases)] (or (field-alias-lookup field) (if table [table field] (ambiguous-error field))))) (defn ^:private resolve-fields [table aliases expression] (clojure.walk/postwalk (fn [expr] (cond (keyword? expr) (resolve-field table aliases expr) :else expr)) expression)) (defn ^:private combine-wheres [& wheres] (reduce (fn [acc where] (cond (nil? acc) where (nil? where) acc :else (or (if (and (sequential? where) (= (name (first where)) "and")) `(and ~acc ~@(next where))) (if (and (sequential? acc) (= (name (first acc)) "and")) `(and ~@(next acc) ~where)) `(and ~acc ~where)))) nil wheres)) (defn join [left right & [on type]] {:pre [(empty? (set/intersection (set (vals (:fields left))) (set (vals (:fields right)))))]} (let [joins-vector (or (:join left) []) joined-fields (merge (:fields left) (:fields right)) on (or on true)] (-> left (assoc :fields joined-fields) (assoc :joins (into (conj joins-vector [(or type :inner) (:table right) (resolve-fields nil joined-fields on)]) (:joins right))) (assoc :where (combine-wheres (:where left) (:where right)))))) (defn select [query expression] (let [table-name (if-not (:joins query) (-> query :table first val)) old-where (:where query) resolved-expression (resolve-fields table-name (:fields query) expression) new-where (combine-wheres old-where resolved-expression)] (assoc query :where new-where))) (defn sort-by [query fields] (let [table-name (if-not (:joins query) (-> query :table first val)) fields-seq (if (sequential? fields) fields [fields])] (assoc query :sort-by (for [field fields-seq] (if (vector? field) (resolve-field table-name (:fields query) field) [(resolve-field table-name (:fields query) field) :asc]))))) (comment (binding [*database-type* :mysql] (let [users (-> (table :users) (project [:id :username :password]) (select '(= :deleted false))) people (-> (table :people) (project [:id :fname :sname]) (select '(= :deleted false))) uid-pid-match '(= :uid :pid) is-carlo `(= :fname "Carlo'; SELECT * FROM users --") query (-> (join (-> users (rename {:id :uid})) (join (-> people (rename {:id :pid})) (-> (table {:others :o}) (project {:id :oid})) '(= :pid :oid)) uid-pid-match) (select is-carlo) (project [:fname :sname :oid]))] @query)) (-> (table :users) (join (table :something-else-with-a-username) true) (select '(or (= :username "john") (not (= :username "carlo")))) (project [:username]) deref) (-> (table {:nodes :child}) (project [:parent-id, :name]) (rename {:name :child.name}) (join (-> (table {:nodes :parent}) (project [:id, :name]) (rename {:name :parent.name})) '(= :parent-id :id)) (project [:child.name :parent.name]) deref) (-> (table :users) (project [:id]) (join (-> (table :people) (project [:id])) true) deref) (-> (table :users) (project [:id :name]) (rename {:id :name}) deref) (-> (table :anotherStack) (project [:anotherNumber]) (join (-> (table :collection) (project [:number]))) deref))