(ns clojure-sql.compiler (:refer-clojure :exclude [compile sequence]) (:require [clojure.set :as set] [clojure.string :as string] [clojure.walk :as walk] [clojure-sql.query :refer [query?]] [clojure-sql.util :as u :refer [named?]] [clojure-sql.writer :as w :refer [return lift p-lift sequence do-m tell >>]])) (defn no-args-operator-error [op] (throw (ex-info (str "Argument called with no args doesn't have identity value: " op) {:operator op}))) (defn add-parentheses [s] (str \( s \))) ;; ============================================================== ;; DB specific escaping methods ;; ============================================================== (defmulti field-name (fn [db _] db) :default :postgres) (defmethod field-name :postgres [_ field] (str \" (name field) \")) (defmulti table-name (fn [db _] db) :default :postgres) (defmethod table-name :postgres [_ table] (str \" (name table) \")) (defmulti function-name (fn [db _] db) :default :postgres) (defmethod function-name :postgres [_ function] (str \" (name function) \")) ;; we use the $ prefix to denote a lifted function (def $add-parentheses (lift add-parentheses)) (def $str (lift str)) ;; ============================================================== ;; Utility functions for the compile-* functions ;; ============================================================== (def ^:private boolean? (some-fn true? false?)) (def ^:private regex? (partial instance? (class #""))) (def ^:private operator-name (some-fn (comp {"$" "~"} name) name)) (def quote? (comp boolean '#{"quote"} name)) (def unary? (comp boolean '#{"not" "exists" "-"} name)) (def binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in" "like" "~"} name)) (def n-ary? (comp boolean '#{"and" "or" "+" "-" "/" "*"} name)) (def operator-identity (comp {"and" true "or" false "+" 0 "*" 1} name)) ;; ============================================================== ;; compile-* functions (turning a map into a query string) ;; ============================================================== ;; compile-* multimethods are of the signature: ;; (db, expr) -> [args] -> [sql & args] (declare compile-query compile-expression) (defmulti compile-expression-list (fn [db _] db) :default :postgres) (defmethod compile-expression-list :postgres [db ex] (->> (map (partial compile-expression db) ex) (apply sequence) ((p-lift string/join ",")) $add-parentheses)) (defmulti compile-expression-sequential (fn [db _] db) :default :postgres) (defmethod compile-expression-sequential :postgres [db ex] (let [compile-exprs #(map (partial compile-expression db) %) op-name (operator-name (first ex)) num-args (dec (count ex))] (-> (condp u/funcall op-name quote? (do (assert (= num-args 1) "`quote` must only take one argument") (if (sequential? (second ex)) (compile-expression-list db (second ex)) (>> (tell (second ex)) (return "?")))) n-ary? (do-m :let [[op & exprs] (compile-exprs ex)] vals <- (apply sequence (interpose op exprs)) (condp = (count vals) 0 (if-let [id (operator-identity op-name)] (compile-expression db id) (no-args-operator-error (name (first ex)))) 1 (if (unary? op-name) (do-m compiled-op <- op (return (str compiled-op (first vals)))) (return (first vals))) (return (string/join " " vals)))) unary? (do (assert (= num-args 1) (str "Unary operator `" (first ex) "` must take one argument")) (do-m :let [exprs (compile-exprs ex)] vals <- (apply sequence exprs) (return (string/join "" vals)))) binary? (do (assert (= num-args 2) (str "Binary operator `" (first ex) "` must take two arguments")) (do-m :let [[op left right] (compile-exprs ex)] vals <- (sequence left op right) (return (string/join " " vals)))) (do-m :let [fn-name (function-name db (first ex)) exprs (compile-exprs (rest ex))] vals <- (apply sequence exprs) (return (str fn-name (add-parentheses (string/join "," vals)))))) $add-parentheses))) (defmulti compile-expression (fn [db _] db) :default :postgres) (defmethod compile-expression :postgres [db ex] (condp u/funcall ex boolean? (return (string/upper-case (str ex))) query? ($add-parentheses (compile-query db ex)) nil? (return "NULL") vector? (return (str (table-name db (first ex)) \. (field-name db (second ex)))) keyword? (return (field-name db ex)) regex? (>> (tell (str ex)) (return "?")) string? (>> (tell ex) (return "?")) ;;(sql-string db ex) symbol? (return (string/upper-case (operator-name ex))) sequential? (compile-expression-sequential db ex) (return ex))) (defn ^:private make-table-name [db table & [alias]] (if (or (= table alias) (nil? alias)) (return (table-name db table)) ($str (condp #(%1 %2) table query? ($add-parentheses (compile-query db table)) named? (return (table-name db table)) (compile-expression db table)) (return " AS ") (return (table-name db alias))))) (defn ^:private make-field-name [db field & [alias]] (if (and (vector? field) (or (= field alias) (nil? alias))) (compile-expression db field) ($str (compile-expression db field) (return " AS ") (return (field-name db alias))))) (defmulti compile-fields (fn [db _] db) :default :postgres) (defmethod compile-fields :postgres [db fields-map] (if (seq fields-map) (->> (for [[alias field] (sort-by first fields-map)] (make-field-name db field alias)) (apply sequence) ((p-lift string/join ", "))) (return "*"))) (def ^:private join-type-names {:inner "INNER" :outer "LEFT OUTER" :full-outer "FULL OUTER" :cross "CROSS"}) (defmulti compile-tables (fn [db _ _] db) :default :postgres) (defmethod compile-tables :postgres [db join tables-map] (if (vector? join) (->> (for [table-alias join] (make-table-name db (get tables-map table-alias) table-alias)) (apply sequence) ((p-lift string/join ", "))) (let [{:keys [left right type on]} join] (if (= type :cross) ($str (return "(") (compile-tables db left tables-map) (return " CROSS JOIN ") (compile-tables db right tables-map) (return ")")) ($str (return "(") (compile-tables db left tables-map) (return (str " " (get join-type-names type (name type)) " JOIN ")) (compile-tables db right tables-map) (return " ON ") (if on (compile-expression db on) (return "TRUE")) (return ")")))))) (defmulti compile-where (fn [db _] db) :default :postgres) (defmethod compile-where :postgres [db expr] (if expr ($str (return " WHERE ") (compile-expression db expr)) (return nil))) (defmulti compile-sort (fn [db _] db) :default :postgres) (defmethod compile-sort :postgres [db fields] (if fields (->> (for [[[table field] dir] fields] ($str (make-field-name db [table field]) (return (str \space (string/upper-case (name dir)))))) (apply sequence) ((p-lift string/join ",")) ($str (return " ORDER BY "))) (return nil))) (defmulti compile-group (fn [db _] db) :default :postgres) (defmethod compile-group :postgres [db fields] (if fields (->> (for [[table field] fields] (make-field-name db [table field])) (apply sequence) ((p-lift string/join ",")) ($str (return " GROUP BY "))) (return nil))) (defmulti compile-having (fn [db _] db) :default :postgres) (defmethod compile-having :postgres [db expr] (if expr ($str (return " HAVING ") (compile-expression db expr)) (return nil))) (defmulti compile-limit (fn [db _ _] db) :default :postgres) (defmethod compile-limit :postgres [db take drop] (return (str (if take (str " LIMIT " take)) (if drop (str " OFFSET " drop))))) (def ^:private set-operations {:union "UNION", :intersect "INTERSECT"}) (defmulti compile-query (fn [db _] db) :default :postgres) (defmethod compile-query :postgres [db {:keys [tables fields joins where sort group having take drop set-operation queries]}] (or (if set-operation (let [op-str (str ") " (get set-operations set-operation) " (")] (->> queries (map (partial compile-query db)) (apply sequence) ((p-lift string/join op-str)) $add-parentheses))) ($str (return "SELECT ") (compile-fields db fields) (if tables ($str (return " FROM ") (compile-tables db joins tables)) ($str "")) (compile-where db where) (compile-group db group) (compile-having db having) (compile-sort db sort) (compile-limit db take drop)))) (defn compile-select [db query] (let [[sql vars] ((compile-query db query) [])] (vec (cons sql vars)))) (defn compile-insert [db {:keys [fields tables joins]} & records] (assert (= (count tables) 1) "Cannot insert into a multiple-table query") (let [fields-order (map key fields) wrap #(str "INSERT INTO " (table-name db (val (first tables))) " (" (->> fields (map (comp (partial field-name db) second val)) (string/join ",")) ") VALUES (" (string/join "),(" %) ")") build-insertion #(->> (for [field fields-order] (compile-expression db (get % field))) (apply sequence) ((p-lift string/join ","))) insertions (->> (for [record records] (build-insertion record)) (apply sequence)) [sql vars] (-> ((p-lift wrap) insertions) (u/funcall []))] (vec (cons sql vars)))) ;; (insert! nil ;; (-> (table :users) ;; (project {:id :uid, :username :name})) ;; {:uid 10, :name :carlo, :other :stuff} ;; {:uid 1, :name :carl, :other :stuf}) (defn compile-update [db {:keys [tables fields where joins]} partial-record] (assert (= (count tables) 1) "Cannot delete from a multiple-table query") (assert (seq (set/intersection (set (keys partial-record)) (set (keys fields)))) "At least one field must be being updated") (let [fix-expr (partial walk/prewalk (fn [x] (if (vector? x) (second x) x))) where-expression (if where ($str (return " WHERE ") (compile-expression db (fix-expr where))) (return nil)) updates (->> (for [[alias value] partial-record :when (fields alias)] ($str (compile-expression db (second (fields alias))) (return " = ") (compile-expression db (fix-expr value)))) (apply sequence)) table-name (-> tables first val name) combined ($str (return (str "UPDATE " table-name " SET ")) ((p-lift string/join ", ") updates) where-expression) [sql vars] (combined [])] (vec (cons sql vars)))) ;; (update! nil ;; (-> (table :users) ;; (project {:username :blah}) ;; (select `(= :id 10))) ;; {:username "carlozancnaro" ;; :blah "blah"}) (defn compile-delete [db {:keys [tables where joins]}] (assert (= (count tables) 1) "Cannot delete from a multiple-table query") (let [table (compile-tables db joins tables) where-expression (if where ($str (return " WHERE ") (compile-expression db where)) (return nil)) combined ($str (return "DELETE FROM ") table where-expression) [sql vars] (combined [])] (vec (cons sql vars)))) ;; (delete! nil (-> (table :users) ;; (select `(= :id 10)))) ;; ============================================================== ;; A few DB specific overrides ;; ============================================================== ;; SQL SERVER (defmethod field-name :sql-server [_ field] (str \[ (name field) \])) (defmethod table-name :sql-server [_ table] (str \[ (name table) \])) ;; mySQL (defmethod field-name :mysql [_ field] (str \` (name field) \`)) (defmethod table-name :mysql [_ table] (str \` (name table) \`))