(ns clojure-sql.core (:refer-clojure :exclude [sort-by]) (:require [clojure.set :as set] [clojure.string :as string] [clojure-sql.util :as u] [clojure.walk])) (declare compile-query) (def ^:private ^:dynamic *database-type* nil) (defn set-database-type! [new-type] (alter-var-root #'*database-type* (constantly new-type))) (def ^:private ^:dynamic *query-deref-behaviour* #(compile-query *database-type* %)) (defn set-query-deref-behaviour! [f] (alter-var-root #'*query-deref-behaviour* (constantly f))) (defrecord ^:private Query [] clojure.lang.IDeref (deref [this] (*query-deref-behaviour* this))) (defmethod print-method Query [query writer] (binding [*out* writer] (pr (compile-query nil query)))) (defn add-parentheses [s] (str \( s \))) ;; ============================================================== ;; DB specific escaping methods ;; ============================================================== (defmulti field-name (fn [db _] db)) (defmethod field-name :default [_ field] (str \" (name field) \")) (defmulti table-name (fn [db _] db)) (defmethod table-name :default [_ table] (str \" (name table) \")) (defmulti function-name (fn [db _] db)) (defmethod function-name :default [_ function] (str \" (name function) \")) ;; ============================================================== ;; Utility functions for the compile-* functions ;; ============================================================== ;; compile-* multimethods are of the signature: ;; (db, expr) -> [SQL & replacements] (def is-unary? (comp boolean '#{"not"} name)) (def is-binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in" "+" "-" "/" "*"} name)) (def is-operator? (comp boolean '#{"and" "or"} name)) (defn c-return [val] [val]) (defn c-lift [f & orig-args] (fn [& c-args] (apply vector (apply f (concat orig-args (map first c-args))) (apply concat (map rest c-args))))) (def c-str (c-lift str)) (def c-join (fn [sep args] (apply vector (string/join sep (map first args)) (apply concat (map rest args))))) (def c-add-parentheses (c-lift add-parentheses)) ;; ============================================================== ;; compile-* functions (turning a map into a query string) ;; ============================================================== (defmulti compile-expression (fn [db _] db)) (defmethod compile-expression :default [db ex] (condp #(%1 %2) ex nil? (c-return "NULL") vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex)))) keyword? (c-return (field-name db ex)) string? ["?" ex] ;;(sql-string db ex) symbol? (c-return (string/upper-case (name ex))) sequential? (-> (condp #(%1 %2) (first ex) is-unary? (if (= (count ex) 2) (->> ex (map (partial compile-expression db)) (c-join " ")) (throw (ex-info "Unary operators can only take one argument." {:operator (first ex) :arguments (rest ex)}))) is-binary? (if (= (count ex) 3) (->> (rest ex) (map (partial compile-expression db)) (interpose (compile-expression db (first ex))) (c-join " ")) (throw (ex-info "Binary operators must take two arguments." {:operator (first ex) :arguments (rest ex)}))) is-operator? (->> (rest ex) (map (partial compile-expression db)) (interpose (compile-expression db (first ex))) (c-join " ")) (->> (rest ex) (map (partial compile-expression db)) (c-join ", ") c-add-parentheses (c-str (c-return (function-name db (first ex)))))) c-add-parentheses) (c-return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) (table-name db table) (str (table-name db table) " AS " (field-name db alias))) ) (defn ^:private make-field-name [db table field & [alias]] (if (or (= field alias) (nil? alias)) (str (table-name db table) \. (field-name db field)) (str (table-name db table) \. (field-name db field) " AS " (field-name db alias)))) (defmulti compile-fields (fn [db _] db)) (defmethod compile-fields :default [db fields-map] (if (seq fields-map) (->> (for [[[table field] alias] fields-map] (make-field-name db table field alias)) (string/join ", ") c-return) (c-return "*"))) (defmulti compile-tables (fn [db _] db)) (defmethod compile-tables :default [db tables-map] (->> (for [[table alias] tables-map] (make-table-name db table alias)) (string/join ", ") c-return)) (defmulti compile-join (fn [db _] db)) (defmethod compile-join :default [db [type table-map on]] (c-str (c-return (case type :left " LEFT OUTER" :right " RIGHT OUTER" " INNER")) (c-return " JOIN ") (compile-tables db table-map) (c-return " ON ") (compile-expression db on))) (defmulti compile-joins (fn [db _] db)) (defmethod compile-joins :default [db joins] (->> joins (map (partial compile-join db)) (c-join " "))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] (if expr (c-str (c-return " WHERE ") (compile-expression db expr)))) (defmulti compile-sort-by (fn [db _] db)) (defmethod compile-sort-by :default [db fields] (if fields (->> (for [[[table field] dir] fields] (str (make-field-name db table field) \space (string/upper-case (name dir)))) (string/join ",") (apply str " ORDER BY ") c-return))) (defmulti compile-query (fn [db _] db)) (defmethod compile-query :default [db {:keys [table fields joins where sort-by]}] (c-str (c-return "SELECT ") (compile-fields db fields) (if table (c-return " FROM ")) (compile-tables db table) (compile-joins db joins) (compile-where db where) (compile-sort-by db sort-by))) ;; SQL SERVER (defmethod field-name :sql-server [_ field] (str \[ (name field) \])) (defmethod table-name :sql-server [_ table] (str \[ (name table) \])) ;; mySQL (defmethod field-name :mysql [_ field] (str \` (name field) \`)) (defmethod table-name :mysql [_ table] (str \` (name table) \`)) ;; ============================================================== ;; The DSL for making query maps ;; ============================================================== ;; important sections: ;; -PROJECTION- ;; -TABLES- ;; -FILTERS- ;; GROUPING ;; GROUPED FILTERS ;; -SORTING- ;; { ;; :table => tablename -> table_alias, ;; :fields => (table_alias, fieldname) -> field_alias ;; :joins => [tablename -> (type, table_alias, on)] ;; :where => expression ;; :sort-by => [(field, direction)] ;; } (defn table [arg] (into (->Query) (if (map? arg) {:table arg} {:table {arg arg}}))) (defn ambiguous-error [field & [query]] (throw (ex-info (str "Ambiguous field " field " in query with more than one table") {:field field :query query}))) (defn project [query fields] (let [table (if-not (:joins query) (-> query :table first val)) alias-lookup (u/flip-map (:fields query)) original-name (fn [field] (if (vector? field) field (or (get alias-lookup field nil) (if table [table field] (ambiguous-error field query)))))] (assoc query :fields (->> (for [[key val] (if (map? fields) fields (zipmap fields fields))] [(original-name key) val]) (into {}))))) (defn rename [query field-renames] {:pre [(map? field-renames) ;; the intersection of the new aliases with the old aliases NOT renamed by this operation (empty? (set/intersection (set (vals field-renames)) (set/difference (set (vals (:fields query))) (set (keys field-renames)))))]} (let [fields (:fields query) alias-lookup (u/flip-map (:fields query)) original-name (fn [field] (cond (vector? field) field (contains? alias-lookup field) (get alias-lookup field) :else (throw (ex-info (str "Cannot rename field " (pr-str field) ". Field does not exist in query.") {:field field :query query :renames field-renames}))) (get alias-lookup field))] (update-in query [:fields] #(->> (for [[key val] field-renames] ;(if (contains? val (:fields query))) [(original-name key) val]) (into %))))) (defn ^:private resolve-field [table aliases field] (let [field-alias-lookup (u/flip-map aliases)] (or (field-alias-lookup field) (if table [table field] (ambiguous-error field))))) (defn ^:private resolve-fields [table aliases expression] (clojure.walk/postwalk (fn [expr] (cond (keyword? expr) (resolve-field table aliases expr) :else expr)) expression)) (defn ^:private combine-wheres [& wheres] (reduce (fn [acc where] (cond (nil? acc) where (nil? where) acc :else (or (if (and (sequential? where) (= (name (first where)) "and")) `(and ~acc ~@(next where))) (if (and (sequential? acc) (= (name (first acc)) "and")) `(and ~@(next acc) ~where)) `(and ~acc ~where)))) nil wheres)) (defn join [left right & {:keys [on type]}] (let [joins-vector (or (:join left) []) common-fields (set/intersection (set (vals (:fields left))) (set (vals (:fields right)))) joined-fields (if (:type :right) (merge (->> (:fields left) (filter (comp not common-fields val)) (into {})) (:fields right)) (merge (:fields left) (->> (:fields right) (filter (comp not common-fields val)) (into {})))) implicit-on (if (seq common-fields) (map (fn [field] `(= ~(resolve-field (:table left) (:fields left) field) ~(resolve-field (:table right) (:fields right) field))) common-fields)) on (if on [(resolve-fields nil joined-fields on)]) join-condition (if-let [condition (seq (concat implicit-on on))] `(and ~@condition) true)] (-> left (assoc :fields joined-fields) (assoc :joins (into (conj joins-vector [(or type :inner) (:table right) join-condition]) (:joins right))) (assoc :where (combine-wheres (:where left) (:where right)))))) (defn select [query expression] (let [table-name (if-not (:joins query) (-> query :table first val)) old-where (:where query) resolved-expression (resolve-fields table-name (:fields query) expression) new-where (combine-wheres old-where resolved-expression)] (assoc query :where new-where))) (defn sort-by [query fields] (let [table-name (if-not (:joins query) (-> query :table first val)) fields-seq (if (sequential? fields) fields [fields])] (assoc query :sort-by (for [field fields-seq] (if (vector? field) (resolve-field table-name (:fields query) field) [(resolve-field table-name (:fields query) field) :asc]))))) (defn insert! [query & records] {:pre [(empty? (:joins query))]} ;; some code here ) (defn update! [query & partial-records] ;; some code here ) (defn delete! [query] ;; some code here ) (comment (binding [*database-type* :mysql] (let [users (-> (table :users) (project [:id :username :password]) (select '(= :deleted false))) people (-> (table :people) (project [:id :fname :sname]) (select '(= :deleted false))) uid-pid-match '(= :uid :pid) is-carlo `(= :fname "Carlo'; SELECT * FROM users --")] (-> (join (-> users (rename {:id :uid})) (join (-> people (rename {:id :pid})) (-> (table {:others :o}) (project {:id :oid})) '(= :pid :oid)) uid-pid-match) (select is-carlo) (project [:fname :sname :oid])))) (-> (table :users) (join (table :something-else-with-a-username) true) (select '(or (= :username "john") (not (= :username "carlo")))) (project [:username])) (-> (table {:nodes :child}) (project [:parent-id, :name]) (rename {:name :child.name}) (join (-> (table {:nodes :parent}) (project [:id, :name]) (rename {:name :parent.name})) '(= :parent-id :id)) (project [:child.name :parent.name])) (-> (table :users) (project [:id]) (join (-> (table :people) (project [:id])) true)) (-> (table :users) (project [:id :name]) (rename {:id :name :name :id})) (-> (table :users) (project {:id :name :name :id})) (-> (table :anotherStack) (project [:anotherNumber]) (join (-> (table :collection) (project [:number])))) (-> (table :users) (select '(= (left :username 1) "bloo"))))