(ns clojure-sql.compiler (:refer-clojure :exclude [compile sequence]) (:require [clojure.string :as string] [clojure-sql.query :refer [query?]] [clojure-sql.util :as u :refer [named?]] [clojure-sql.writer :as w :refer [return lift p-lift sequence do-m tell >>]])) (defn add-parentheses [s] (str \( s \))) ;; ============================================================== ;; DB specific escaping methods ;; ============================================================== (defmulti field-name (fn [db _] db)) (defmethod field-name :default [_ field] (str \" (name field) \")) (defmulti table-name (fn [db _] db)) (defmethod table-name :default [_ table] (str \" (name table) \")) (defmulti function-name (fn [db _] db)) (defmethod function-name :default [_ function] (str \" (name function) \")) ;; we use the $ prefix to denote a lifted function (def $add-parentheses (lift add-parentheses)) (def $str (lift str)) ;; ============================================================== ;; Utility functions for the compile-* functions ;; ============================================================== (def quote? (comp boolean '#{"quote"} name)) (def unary? (comp boolean '#{"not" "exists"} name)) (def binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in"} name)) (def operator? (comp boolean '#{"and" "or" "+" "-" "/" "*"} name)) ;; ============================================================== ;; compile-* functions (turning a map into a query string) ;; ============================================================== ;; compile-* multimethods are of the signature: ;; (db, expr) -> [args] -> [sql & args] (declare compile-query compile-expression) (defmulti compile-expression-list (fn [db _] db)) (defmethod compile-expression-list :default [db ex] (->> (map (partial compile-expression db) ex) (apply sequence) ((p-lift string/join ",")) $add-parentheses)) (defmulti compile-expression-sequential (fn [db _] db)) (defmethod compile-expression-sequential :default [db ex] (let [compile-exprs #(map (partial compile-expression db) %) op (name (first ex)) num-args (dec (count ex))] (-> (condp u/funcall (first ex) quote? (do (assert (= num-args 1) "`quote` must only take one argument") (if (sequential? (second ex)) (compile-expression-list db (second ex)) (>> (tell (second ex)) (return "?")))) unary? (do (assert (= num-args 1) (str "Unary operator `" op "` must take one argument")) (do-m :let [exprs (compile-exprs ex)] vals <- (apply sequence exprs) (return (string/join "" vals)))) binary? (do (assert (= num-args 2) (str "Binary operator `" op "` must take two arguments")) (do-m :let [[op left right] (compile-exprs ex)] vals <- (sequence left op right) (return (string/join " " vals)))) operator? (do-m :let [[op & exprs] (compile-exprs ex)] vals <- (apply sequence (interpose op exprs)) (return (string/join " " vals))) (do-m :let [fn-name (function-name db (first ex)) exprs (compile-exprs (rest ex))] vals <- (apply sequence exprs) (return (str fn-name (add-parentheses (string/join "," vals)))))) $add-parentheses))) (defmulti compile-expression (fn [db _] db)) (defmethod compile-expression :default [db ex] (condp u/funcall ex query? ($add-parentheses (compile-query db ex)) nil? (return "NULL") vector? (return (str (table-name db (first ex)) \. (field-name db (second ex)))) keyword? (return (field-name db ex)) string? (>> (tell ex) (return "?")) ;;(sql-string db ex) symbol? (return (string/upper-case (name ex))) sequential? (compile-expression-sequential db ex) (return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) (return (table-name db table)) ($str (condp #(%1 %2) table query? ($add-parentheses (compile-query db table)) named? (return (table-name db table)) (compile-expression db table)) (return " AS ") (return (table-name db alias))))) (defn ^:private make-field-name [db field & [alias]] (if (and (vector? field) (or (= field alias) (nil? alias))) (compile-expression db field) ($str (compile-expression db field) (return " AS ") (return (field-name db alias))))) (defmulti compile-fields (fn [db _] db)) (defmethod compile-fields :default [db fields-map] (if (seq fields-map) (->> (for [[alias field] fields-map] (make-field-name db field alias)) (apply sequence) ((p-lift string/join ", "))) (return "*"))) (def ^:private join-type-names {:inner "INNER" :outer "LEFT OUTER" :full-outer "FULL OUTER" :cross "CROSS"}) (defmulti compile-tables (fn [db _ _] db)) (defmethod compile-tables :default [db join tables-map] (if (vector? join) (->> (for [table-alias join] (make-table-name db (get tables-map table-alias) table-alias)) (apply sequence) ((p-lift string/join ", "))) (let [{:keys [left right type on]} join] (if (= type :cross) ($str (return "(") (compile-tables db left tables-map) (return " CROSS JOIN ") (compile-tables db right tables-map) (return ")")) ($str (return "(") (compile-tables db left tables-map) (return (str " " (get join-type-names type (name type)) " JOIN ")) (compile-tables db right tables-map) (return " ON ") (if on (compile-expression db on) (return "TRUE")) (return ")")))))) (defmulti compile-where (fn [db _] db)) (defmethod compile-where :default [db expr] (if expr ($str (return " WHERE ") (compile-expression db expr)) (return nil))) (defmulti compile-sort (fn [db _] db)) (defmethod compile-sort :default [db fields] (if fields (->> (for [[[table field] dir] fields] ($str (make-field-name db [table field]) (return (str \space (string/upper-case (name dir)))))) (apply sequence) ((p-lift string/join ",")) ($str (return " ORDER BY "))) (return nil))) (defmulti compile-group (fn [db _] db)) (defmethod compile-group :default [db fields] (if fields (->> (for [[table field] fields] (make-field-name db [table field])) (apply sequence) ((p-lift string/join ",")) ($str (return " GROUP BY "))) (return nil))) (defmulti compile-having (fn [db _] db)) (defmethod compile-having :default [db expr] (if expr ($str (return " HAVING ") (compile-expression db expr)) (return nil))) (defmulti compile-query (fn [db _] db)) (defmethod compile-query :default [db {:keys [tables fields joins where sort group having union]}] (if union (->> union (map (partial compile-query db)) (apply sequence) ((p-lift string/join " UNION "))) ($str (return "SELECT ") (compile-fields db fields) (if tables ($str (return " FROM ") (compile-tables db joins tables)) ($str "")) (compile-where db where) (compile-group db group) (compile-having db having) (compile-sort db sort)))) (defn compile [db query] (let [[sql vars] ((compile-query db query) [])] (vec (cons sql vars)))) ;; ============================================================== ;; A few DB specific overrides ;; ============================================================== ;; SQL SERVER (defmethod field-name :sql-server [_ field] (str \[ (name field) \])) (defmethod table-name :sql-server [_ table] (str \[ (name table) \])) ;; mySQL (defmethod field-name :mysql [_ field] (str \` (name field) \`)) (defmethod table-name :mysql [_ table] (str \` (name table) \`)) ;;(compile nil {:table {:u :u}, :fields {[:v :x] :w}}) ;; Utility functions (defn insert! [db query & records] {:pre [(empty? (:joins query))]} ;; some code here ) (defn update! [db query & partial-records] {:pre [(empty? (:joins query))]} ;; some code here ) (defn delete! [db query] {:pre [(empty? (:joins query))]} ;; some code here )