summaryrefslogtreecommitdiff
path: root/src/clojure_sql/compiler.clj
blob: 50c14af68be031a2a7196692d133d133bcfe73c7 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
(ns clojure-sql.compiler
  (:refer-clojure :exclude [compile sequence])
  (:require [clojure.set :as set]
            [clojure.string :as string]
            [clojure.walk :as walk]
            [clojure-sql.query :refer [query?]]
            [clojure-sql.util :as u :refer [named?]]
            [clojure-sql.writer :as w :refer [return lift p-lift sequence do-m tell >>]]))

(defn no-args-operator-error [op]
  (throw (ex-info (str "Argument called with no args doesn't have identity value: " op)
                  {:operator op})))



(defn add-parentheses [s]
  (str \( s \)))


;; ==============================================================
;; DB specific escaping methods
;; ==============================================================

(defmulti field-name (fn [db _] db) :default :postgres)
(defmethod field-name :postgres [_ field]
  (str \" (name field) \"))

(defmulti table-name (fn [db _] db) :default :postgres)
(defmethod table-name :postgres [_ table]
  (str \" (name table) \"))

(defmulti function-name (fn [db _] db) :default :postgres)
(defmethod function-name :postgres [_ function]
  (str \" (name function) \"))




;; we use the $ prefix to denote a lifted function
(def $add-parentheses (lift add-parentheses))
(def $str (lift str))




;; ==============================================================
;; Utility functions for the compile-* functions
;; ==============================================================

(def ^:private boolean? (some-fn true? false?))
(def ^:private regex? (partial instance? (class #"")))
(def ^:private operator-name (some-fn (comp {"$" "~"} name) name))

(def quote? (comp boolean '#{"quote"} name))
(def unary? (comp boolean '#{"not" "exists" "-"} name))
(def binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in" "like" "~"} name))
(def n-ary? (comp boolean '#{"and" "or" "+" "-" "/" "*"} name))
(def operator-identity (comp {"and" true
                              "or" false
                              "+" 0
                              "*" 1} name))


;; ==============================================================
;; compile-* functions (turning a map into a query string)
;; ==============================================================

;; compile-* multimethods are of the signature:
;;  (db, expr) -> [args] -> [sql & args]

(declare compile-query compile-expression)

(defmulti compile-expression-list (fn [db _] db) :default :postgres)
(defmethod compile-expression-list :postgres [db ex]
  (->> (map (partial compile-expression db) ex)
       (apply sequence)
       ((p-lift string/join ","))
       $add-parentheses))

(defmulti compile-expression-sequential (fn [db _] db) :default :postgres)
(defmethod compile-expression-sequential :postgres [db ex]
  (let [compile-exprs #(map (partial compile-expression db) %)
        op-name (operator-name (first ex))
        num-args (dec (count ex))]
    (-> (condp u/funcall op-name
          quote?    (do (assert (= num-args 1) "`quote` must only take one argument")
                        (if (sequential? (second ex))
                          (compile-expression-list db (second ex))
                          (>> (tell (second ex)) (return "?"))))
          n-ary?    (do-m :let [[op & exprs] (compile-exprs ex)]
                          vals <- (apply sequence (interpose op exprs))
                          (condp = (count vals)
                            0 (if-let [id (operator-identity op-name)]
                                (compile-expression db id)
                                (no-args-operator-error (name (first ex))))
                            1 (if (unary? op-name)
                                (do-m compiled-op <- op
                                      (return (str compiled-op (first vals))))
                                (return (first vals))) 
                            (return (string/join " " vals))))
          unary?    (do (assert (= num-args 1) (str "Unary operator `" (first ex) "` must take one argument"))
                        (do-m :let [exprs (compile-exprs ex)]
                              vals <- (apply sequence exprs)
                              (return (string/join "" vals))))
          binary?   (do (assert (= num-args 2) (str "Binary operator `" (first ex) "` must take two arguments"))
                        (do-m :let [[op left right] (compile-exprs ex)]
                              vals <- (sequence left op right)
                              (return (string/join " " vals))))
          (do-m :let [fn-name (function-name db (first ex))
                      exprs (compile-exprs (rest ex))]
                vals <- (apply sequence exprs)
                (return (str fn-name
                             (add-parentheses (string/join "," vals))))))
        $add-parentheses)))

(defmulti compile-expression (fn [db _] db) :default :postgres)
(defmethod compile-expression :postgres [db ex]
  (condp u/funcall ex
    boolean?     (return (string/upper-case (str ex)))
    query?       ($add-parentheses (compile-query db ex))
    nil?         (return "NULL")
    vector?      (return (str (table-name db (first ex)) \. (field-name db (second ex))))
    keyword?     (return (field-name db ex))
    regex?       (>> (tell (str ex)) (return "?"))
    symbol?      (return (string/upper-case (operator-name ex)))
    sequential?  (compile-expression-sequential db ex)
    (>> (tell ex) (return "?"))))

(defn ^:private make-table-name [db table & [alias]]
  (if (or (= table alias) (nil? alias))
    (return (table-name db table))
    ($str (condp #(%1 %2) table
            query? ($add-parentheses (compile-query db table))
            named? (return (table-name db table))
            (compile-expression db table))
          (return " AS ")
          (return (table-name db alias)))))

(defn ^:private make-field-name [db field & [alias]]
  (if (and (vector? field) (or (= field alias) (nil? alias)))
    (compile-expression db field)
    ($str (compile-expression db field)
          (return " AS ")
          (return (field-name db alias)))))

(defmulti compile-fields (fn [db _] db) :default :postgres)
(defmethod compile-fields :postgres [db fields-map]
  (if (seq fields-map)
    (->> (for [[alias field] (sort-by first fields-map)]
           (make-field-name db field alias))
         (apply sequence)
         ((p-lift string/join ", ")))
    (return "*")))

(def ^:private join-type-names
  {:inner "INNER"
   :outer "LEFT OUTER"
   :full-outer "FULL OUTER"
   :cross "CROSS"})

(defmulti compile-tables (fn [db _ _] db) :default :postgres)
(defmethod compile-tables :postgres [db join tables-map] 
  (if (vector? join) 
    (->> (for [table-alias join]
           (make-table-name db (get tables-map table-alias) table-alias))
         (apply sequence)
         ((p-lift string/join ", ")))
    (let [{:keys [left right type on]} join]
      (if (= type :cross)
        ($str (return "(")
              (compile-tables db left tables-map)
              (return " CROSS JOIN ")
              (compile-tables db right tables-map)
              (return ")"))
        ($str (return "(") 
              (compile-tables db left tables-map)
              (return (str " " (get join-type-names type (name type)) " JOIN "))
              (compile-tables db right tables-map) 
              (return " ON ")
              (if on
                (compile-expression db on)
                (return "TRUE"))
              (return ")"))))))

(defmulti compile-where (fn [db _] db) :default :postgres)
(defmethod compile-where :postgres [db expr]
  (if expr
    ($str (return " WHERE ") (compile-expression db expr))
    (return nil)))

(defmulti compile-sort (fn [db _] db) :default :postgres)
(defmethod compile-sort :postgres [db fields]
  (if fields
    (->> (for [[expr dir] fields]
           ($str (compile-expression db expr) ;;(make-field-name db [table field])
                 (return (str \space (string/upper-case (name dir))))))
         (apply sequence)
         ((p-lift string/join ","))
         ($str (return " ORDER BY ")))
    (return nil)))

(defmulti compile-group (fn [db _] db) :default :postgres)
(defmethod compile-group :postgres [db fields]
  (if fields
    (->> (for [[table field] fields]
           (make-field-name db [table field]))
         (apply sequence)
         ((p-lift string/join ","))
         ($str (return " GROUP BY ")))
    (return nil)))

(defmulti compile-having (fn [db _] db) :default :postgres)
(defmethod compile-having :postgres [db expr]
  (if expr
    ($str (return " HAVING ") (compile-expression db expr))
    (return nil)))

(defmulti compile-limit (fn [db _ _] db) :default :postgres)
(defmethod compile-limit :postgres [db take drop]
  (return (str (if take
                 (str " LIMIT " take))
               (if drop
                 (str " OFFSET " drop)))))

(def ^:private set-operations {:union "UNION", :intersect "INTERSECT"})

(defmulti compile-query (fn [db _] db) :default :postgres)
(defmethod compile-query :postgres [db {:keys [tables fields joins where sort group having take drop set-operation queries]}]
  (or (if set-operation
        (let [op-str (str ") " (get set-operations set-operation) " (")]
          (->> queries
               (map (partial compile-query db))
               (apply sequence)
               ((p-lift string/join op-str))
               $add-parentheses)))
      ($str (return "SELECT ")
            (compile-fields db fields)
            (if tables
              ($str (return " FROM ")
                    (compile-tables db joins tables))
              ($str ""))
            (compile-where db where)
            (compile-group db group)
            (compile-having db having)
            (compile-sort db sort)
            (compile-limit db take drop))))




(defn compile-select [db query]
  (let [[sql vars] ((compile-query db query) [])]
    (vec (cons sql vars))))



(defn compile-insert [db {:keys [fields tables joins]} records]
  (assert (= (count tables) 1) "Cannot insert into a multiple-table query")
  (let [fields-order (map key fields)
        wrap #(str "INSERT INTO "
                   (table-name db (val (first tables)))
                   " ("
                   (->> fields
                        (map (comp (partial field-name db) second val))
                        (string/join ","))
                   ") VALUES (" (string/join "),(" %) ")")
        build-insertion #(->> (for [field fields-order]
                                (compile-expression db (get % field)))
                              (apply sequence)
                              ((p-lift string/join ",")))
        insertions (->> (for [record records]
                          (build-insertion record))
                        (apply sequence))
        [sql vars] (-> ((p-lift wrap) insertions)
                       (u/funcall []))]
    (vec (cons sql vars))))

;; (insert! nil
;;          (-> (table :users)
;;              (project {:id :uid, :username :name}))
;;          {:uid 10, :name :carlo, :other :stuff}
;;          {:uid 1, :name :carl, :other :stuf})

(defn compile-update [db {:keys [tables fields where joins]} partial-record]
  (assert (= (count tables) 1) "Cannot delete from a multiple-table query")
  (assert (seq (set/intersection (set (keys partial-record))
                                 (set (keys fields)))) "At least one field must be being updated")
  (let [fix-expr (partial walk/prewalk (fn [x]
                                         (if (vector? x)
                                           (second x)
                                           x)))
        where-expression (if where
                           ($str (return " WHERE ")
                                 (compile-expression db (fix-expr where)))
                           (return nil))
        updates (->> (for [[alias value] partial-record
                           :when (fields alias)]
                       ($str (compile-expression db (second (fields alias)))
                             (return " = ")
                             (compile-expression db (fix-expr value))))
                     (apply sequence))
        table-name (-> tables first val name)
        combined ($str (return (str "UPDATE " table-name " SET "))
                       ((p-lift string/join ", ") updates)
                       where-expression)
        [sql vars] (combined [])]
    (vec (cons sql vars))))

;; (update! nil
;;          (-> (table :users)
;;              (project {:username :blah})
;;              (select `(= :id 10)))
;;          {:username "carlozancnaro"
;;           :blah "blah"})

(defn compile-delete [db {:keys [tables where joins]}]
  (assert (= (count tables) 1) "Cannot delete from a multiple-table query")
  (let [table (compile-tables db joins tables)
        where-expression (if where
                           ($str (return " WHERE ")
                                 (compile-expression db where))
                           (return nil))
        combined ($str (return "DELETE FROM ")
                       table
                       where-expression)
        [sql vars] (combined [])]
    (vec (cons sql vars))))

;; (delete! nil (-> (table :users)
;;                  (select `(= :id 10))))




;; ==============================================================
;; A few DB specific overrides
;; ==============================================================


;; SQL SERVER
(defmethod field-name :sql-server [_ field]
  (str \[ (name field) \]))

(defmethod table-name :sql-server [_ table]
  (str \[ (name table) \]))

;; mySQL
(defmethod field-name :mysql [_ field]
  (str \` (name field) \`))

(defmethod table-name :mysql [_ table]
  (str \` (name table) \`))