(ns clojure-sql.compiler (:refer-clojure :exclude [compile sequence]) (:require [clojure.string :as string] [clojure-sql.query :refer [query?]] [clojure-sql.util :as u :refer [named?]] [clojure-sql.writer :as w :refer [return lift p-lift sequence do-m tell >>]])) (defn no-args-operator-error [op] (throw (ex-info (str "Argument called with no args doesn't have identity value: " op) {:operator op}))) (defn add-parentheses [s] (str \( s \))) ;; ============================================================== ;; DB specific escaping methods ;; ============================================================== (defmulti field-name (fn [db _] db)) (defmethod field-name :default [_ field] (str \" (name field) \")) (defmulti table-name (fn [db _] db)) (defmethod table-name :default [_ table] (str \" (name table) \")) (defmulti function-name (fn [db _] db)) (defmethod function-name :default [_ function] (str \" (name function) \")) ;; we use the $ prefix to denote a lifted function (def $add-parentheses (lift add-parentheses)) (def $str (lift str)) ;; ============================================================== ;; Utility functions for the compile-* functions ;; ============================================================== (def ^:private boolean? (some-fn true? false?)) (def ^:private regex? (partial instance? (class #""))) (def ^:private operator-name (some-fn (comp {"$" "~"} name) name)) (def quote? (comp boolean '#{"quote"} name)) (def unary? (comp boolean '#{"not" "exists" "-"} name)) (def binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in" "like" "~"} name)) (def n-ary? (comp boolean '#{"and" "or" "+" "-" "/" "*"} name)) (def operator-identity (comp {"and" true "or" false "+" 0 "*" 1} name)) ;; ============================================================== ;; compile-* functions (turning a map into a query string) ;; ============================================================== ;; compile-* multimethods are of the signature: ;; (db, expr) -> [args] -> [sql & args] (declare compile-query compile-expression) (defmulti compile-expression-list (fn [db _] db)) (defmethod compile-expression-list :default [db ex] (->> (map (partial compile-expression db) ex) (apply sequence) ((p-lift string/join ",")) $add-parentheses)) (defmulti compile-expression-sequential (fn [db _] db)) (defmethod compile-expression-sequential :default [db ex] (let [compile-exprs #(map (partial compile-expression db) %) op-name (operator-name (first ex)) num-args (dec (count ex))] (-> (condp u/funcall op-name quote? (do (assert (= num-args 1) "`quote` must only take one argument") (if (sequential? (second ex)) (compile-expression-list db (second ex)) (>> (tell (second ex)) (return "?")))) n-ary? (do-m :let [[op & exprs] (compile-exprs ex)] vals <- (apply sequence (interpose op exprs)) (condp = (count vals) 0 (if-let [id (operator-identity op-name)] (compile-expression db id) (no-args-operator-error (name (first ex)))) 1 (if (unary? op-name) (do-m compiled-op <- op (return (str compiled-op (first vals))))) (return (string/join " " vals)))) unary? (do (assert (= num-args 1) (str "Unary operator `" (first ex) "` must take one argument")) (do-m :let [exprs (compile-exprs ex)] vals <- (apply sequence exprs) (return (string/join "" vals)))) binary? (do (assert (= num-args 2) (str "Binary operator `" (first ex) "` must take two arguments")) (do-m :let [[op left right] (compile-exprs ex)] vals <- (sequence left op right) (return (string/join " " vals)))) (do-m :let [fn-name (function-name db (first ex)) exprs (compile-exprs (rest ex))] vals <- (apply sequence exprs) (return (str fn-name (add-parentheses (string/join "," vals)))))) $add-parentheses))) (defmulti compile-expression (fn [db _] db)) (defmethod compile-expression :default [db ex] (condp u/funcall ex boolean? (return (string/upper-case (str ex))) query? ($add-parentheses (compile-query db ex)) nil? (return "NULL") vector? (return (str (table-name db (first ex)) \. (field-name db (second ex)))) keyword? (return (field-name db ex)) regex? (>> (tell (str ex)) (return "?")) string? (>> (tell ex) (return "?")) ;;(sql-string db ex) symbol? (return (string/upper-case (operator-name ex))) sequential? (compile-expression-sequential db ex) (return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) (return (table-name db table)) ($str (condp #(%1 %2) table query? ($add-parentheses (compile-query db table)) named? (return (table-name db table)) (compile-expression db table)) (return " AS ") (return (table-name db alias))))) (defn ^:private make-field-name [db field & [alias]] (if (and (vector? field) (or (= field alias) (nil? alias))) (compile-expression db field) ($str (compile-expression db field) (return " AS ") (return (field-name db alias))))) (defmulti compile-fields (fn [db _] db)) (defmethod compile-fields :default [db fields-map] (if (seq fields-map) (->> (for [[alias field] fields-map] (make-field-name db field alias)) (apply sequence) ((p-lift string/join ", "))) (return "*"))) (def ^:private join-type-names {:inner "INNER" :outer "LEFT OUTER" :full-outer "FULL OUTER" :cross "CROSS"}) (defmulti compile-tables (fn [db _ _] db)) (defmethod compile-tables :default [db join tables-map] (if (vector? join) (->> (for [table-alias join] (make-table-name db (get tables-map table-alias) table-alias)) (apply sequence) ((p-lift string/join ", "))) (let [{:keys [left right type on]} join] (if (= type :cross) ($str (return "(") (compile-tables db left tables-map) (return " CROSS JOIN ") (compile-tables db right tables-map) (return ")")) ($str (return "(") (compile-tables db left tables-map) (return (str " " (get join-type-names type (name type)) " JOIN ")) (compile-tables db right tables-map) (return " ON ") (if on (compile-expression db on) (return "TRUE")) (return ")")))))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] (if expr ($str (return " WHERE ") (compile-expression db expr)) (return nil))) (defmulti compile-sort (fn [db _] db)) (defmethod compile-sort :default [db fields] (if fields (->> (for [[[table field] dir] fields] ($str (make-field-name db [table field]) (return (str \space (string/upper-case (name dir)))))) (apply sequence) ((p-lift string/join ",")) ($str (return " ORDER BY "))) (return nil))) (defmulti compile-group (fn [db _] db)) (defmethod compile-group :default [db fields] (if fields (->> (for [[table field] fields] (make-field-name db [table field])) (apply sequence) ((p-lift string/join ",")) ($str (return " GROUP BY "))) (return nil))) (defmulti compile-having (fn [db _] db)) (defmethod compile-having :default [db expr] (if expr ($str (return " HAVING ") (compile-expression db expr)) (return nil))) (defmulti compile-query (fn [db _] db)) (defmethod compile-query :default [db {:keys [tables fields joins where sort group having union]}] (if union (->> union (map (partial compile-query db)) (apply sequence) ((p-lift string/join " UNION "))) ($str (return "SELECT ") (compile-fields db fields) (if tables ($str (return " FROM ") (compile-tables db joins tables)) ($str "")) (compile-where db where) (compile-group db group) (compile-having db having) (compile-sort db sort)))) (defn compile [db query] (let [[sql vars] ((compile-query db query) [])] (vec (cons sql vars)))) ;; ============================================================== ;; A few DB specific overrides ;; ============================================================== ;; SQL SERVER (defmethod field-name :sql-server [_ field] (str \[ (name field) \])) (defmethod table-name :sql-server [_ table] (str \[ (name table) \])) ;; mySQL (defmethod field-name :mysql [_ field] (str \` (name field) \`)) (defmethod table-name :mysql [_ table] (str \` (name table) \`)) ;;(compile nil {:table {:u :u}, :fields {[:v :x] :w}}) ;; Utility functions (defn insert! [db query & records] {:pre [(empty? (:joins query))]} ;; some code here ) (defn update! [db query & partial-records] {:pre [(empty? (:joins query))]} ;; some code here ) (defn delete! [db query] {:pre [(empty? (:joins query))]} ;; some code here )