summaryrefslogtreecommitdiff
path: root/src/clojure_sql/core.clj
blob: ae0eb755cafa11a7c3f73cf3cb56a6e94d7e364b (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
(ns clojure-sql.core
  (:refer-clojure :exclude [sort-by])
  (:require [clojure.set :as set]
            [clojure.string :as string]
            [clojure-sql.util :as u]
            [clojure.walk]))

(declare compile-query)


(def ^:private ^:dynamic *database-type* nil)
(defn set-database-type! [new-type]
  (alter-var-root #'*database-type* (constantly new-type)))

(def ^:private ^:dynamic *query-deref-behaviour* #(compile-query *database-type* %))
(defn set-query-deref-behaviour! [f]
  (alter-var-root #'*query-deref-behaviour* (constantly f)))

(defrecord ^:private Query []
  clojure.lang.IDeref
  (deref [this] (*query-deref-behaviour* this)))
(defmethod print-method Query [query writer]
  (binding [*out* writer]
    (pr (compile-query nil query))))




(defn add-parentheses [s]
  (str \( s \)))

;; ==============================================================
;; DB specific escaping methods
;; ==============================================================

(defmulti field-name (fn [db _] db))
(defmethod field-name :default [_ field]
  (str \" (name field) \"))

(defmulti table-name (fn [db _] db))
(defmethod table-name :default [_ table]
  (str \" (name table) \"))

(defmulti function-name (fn [db _] db))
(defmethod function-name :default [_ function]
  (str \" (name function) \"))





;; ==============================================================
;; Utility functions for the compile-* functions
;; ==============================================================

;; compile-* multimethods are of the signature:
;;  (db, expr) -> [SQL & replacements]

(def is-unary? (comp boolean '#{"not"} name))
(def is-binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in" "+" "-" "/" "*"} name))
(def is-operator? (comp boolean '#{"and" "or"} name))

(defn c-return [val] [val])
(defn c-lift [f & orig-args]
  (fn [& c-args] 
    (apply vector 
           (apply f (concat orig-args
                            (map first c-args)))
           (apply concat (map rest c-args)))))
(def c-str (c-lift str))
(def c-join (fn [sep args]
              (apply vector
                     (string/join sep (map first args))
                     (apply concat (map rest args)))))
(def c-add-parentheses (c-lift add-parentheses))



;; ==============================================================
;; compile-* functions (turning a map into a query string)
;; ==============================================================

(defmulti compile-expression (fn [db _] db))
(defmethod compile-expression :default [db ex]
  (condp #(%1 %2) ex
    nil? (c-return "NULL")
    vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex))))
    keyword? (c-return (field-name db ex))
    string? ["?" ex] ;;(sql-string db ex)
    symbol? (c-return (string/upper-case (name ex)))
    sequential? (-> (condp #(%1 %2) (first  ex)
                      is-unary? (if (= (count ex) 2)
                                  (->> ex
                                       (map (partial compile-expression db))
                                       (c-join " "))
                                  (throw (ex-info "Unary operators can only take one argument."
                                                  {:operator (first ex)
                                                   :arguments (rest ex)})))
                      is-binary? (if (= (count ex) 3)
                                   (->> (rest ex)
                                        (map (partial compile-expression db))
                                        (interpose (compile-expression db (first ex)))
                                        (c-join " "))
                                   (throw (ex-info "Binary operators must take two arguments."
                                                   {:operator (first ex)
                                                    :arguments (rest ex)})))
                      is-operator? (->> (rest ex)
                                        (map (partial compile-expression db))
                                        (interpose (compile-expression db (first ex)))
                                        (c-join " "))
                      (->> (rest ex)
                           (map (partial compile-expression db))
                           (c-join ", ")
                           c-add-parentheses
                           (c-str (c-return (function-name db (first ex))))))
                    c-add-parentheses)
    (c-return ex)))

(defn ^:private make-table-name [db table & [alias]]
  (if (or (= table alias) (nil? alias))
    (table-name db table)
    (str (table-name db table) " AS " (field-name db alias)))  )

(defn ^:private make-field-name [db table field & [alias]]
  (if (or (= field alias) (nil? alias))
    (str (table-name db table) \. (field-name db field))
    (str (table-name db table) \. (field-name db field) " AS " (field-name db alias))))

(defmulti compile-fields (fn [db _] db))
(defmethod compile-fields :default [db fields-map]
  (if (seq fields-map)
    (->> (for [[[table field] alias] fields-map]
           (make-field-name db table field alias))
         (string/join ", ")
         c-return)
    (c-return "*")))

(defmulti compile-tables (fn [db _] db))
(defmethod compile-tables :default [db tables-map]
  (->> (for [[table alias] tables-map]
         (make-table-name db table alias))
       (string/join ", ")
       c-return))

(defmulti compile-join (fn [db _] db))
(defmethod compile-join :default [db [type table-map on]]
  (c-str (c-return (case type
                     :left " LEFT OUTER"
                     :right " RIGHT OUTER"
                     " INNER"))
         (c-return " JOIN ")
         (compile-tables db table-map)
         (c-return " ON ")
         (compile-expression db on)))

(defmulti compile-joins (fn [db _] db))
(defmethod compile-joins :default [db joins]
  (->> joins
       (map (partial compile-join db))
       (c-join " ")))

(defmulti compile-where (fn [db _] db))
(defmethod compile-where :default [db expr]
  (if expr
    (c-str (c-return " WHERE ") (compile-expression db expr))))

(defmulti compile-sort-by (fn [db _] db))
(defmethod compile-sort-by :default [db fields]
  (if fields
    (->> (for [[[table field] dir] fields]
           (str (make-field-name db table field) \space (string/upper-case (name dir))))
         (string/join ",")
         (apply str " ORDER BY ")
         c-return)))

(defmulti compile-query (fn [db _] db))
(defmethod compile-query :default [db {:keys [table fields joins where sort-by]}]
  (c-str (c-return "SELECT ")
         (compile-fields db fields)
         (if table
           (c-return " FROM "))
         (compile-tables db table)
         (compile-joins db joins)
         (compile-where db where)
         (compile-sort-by db sort-by)))



;; SQL SERVER
(defmethod field-name :sql-server [_ field]
  (str \[ (name field) \]))

(defmethod table-name :sql-server [_ table]
  (str \[ (name table) \]))

;; mySQL
(defmethod field-name :mysql [_ field]
  (str \` (name field) \`))

(defmethod table-name :mysql [_ table]
  (str \` (name table) \`))
















;; ==============================================================
;; The DSL for making query maps
;; ==============================================================

;; important sections:
;;   -PROJECTION-
;;   -TABLES-
;;   -FILTERS-
;;   GROUPING
;;   GROUPED FILTERS
;;   -SORTING-

;; {
;;  :table => tablename -> table_alias,
;;  :fields => (table_alias, fieldname) -> field_alias
;;  :joins => [tablename -> (type, table_alias, on)]
;;  :where => expression
;;  :sort-by => [(field, direction)]
;; }

(defn table [arg]
  (into (->Query)
        (if (map? arg)
          {:table arg}
          {:table {arg arg}})))

(defn ambiguous-error [field & [query]]
  (throw (ex-info (str "Ambiguous field " field " in query with more than one table")
                  {:field field
                   :query query})))

(defn project [query fields]
  (let [table (if-not (:joins query)
                (-> query :table first val))
        alias-lookup (u/flip-map (:fields query))
        original-name (fn [field]
                        (if (vector? field)
                          field
                          (or (get alias-lookup field nil)
                              (if table
                                [table field] 
                                (ambiguous-error field query)))))]
    (assoc query
      :fields (->> (for [[key val] (if (map? fields)
                                     fields
                                     (zipmap fields fields))]
                     [(original-name key) val])
                   (into {})))))

(defn rename [query field-renames]
  {:pre [(map? field-renames)
         ;; the intersection of the new aliases with the old aliases NOT renamed by this operation
         (empty? (set/intersection (set (vals field-renames))
                                   (set/difference (set (vals (:fields query)))
                                                   (set (keys field-renames)))))]}
  (let [fields (:fields query)
        alias-lookup (u/flip-map (:fields query)) 
        original-name (fn [field]
                        (cond (vector? field) field
                              (contains? alias-lookup field) (get alias-lookup field)
                              :else (throw (ex-info (str "Cannot rename field " (pr-str field) ". Field does not exist in query.")
                                                    {:field field
                                                     :query query
                                                     :renames field-renames})))
                        (get alias-lookup field))]
    (update-in query
               [:fields] #(->> (for [[key val] field-renames]
                                 ;(if (contains? val (:fields query)))
                                 [(original-name key) val])
                               (into %)))))

(defn ^:private resolve-field [table aliases field]
  (let [field-alias-lookup (u/flip-map aliases)]
    (or (field-alias-lookup field)
        (if table
          [table field]
          (ambiguous-error field)))))

(defn ^:private resolve-fields [table aliases expression]
  (clojure.walk/postwalk (fn [expr]
                           (cond
                            (keyword? expr) (resolve-field table aliases expr)
                            :else expr))
                         expression))

(defn ^:private combine-wheres [& wheres]
  (reduce (fn [acc where]
            (cond (nil? acc) where
                  (nil? where) acc
                  :else (or (if (and (sequential? where)
                                     (= (name (first where)) "and"))
                              `(and ~acc ~@(next where)))
                            (if (and (sequential? acc)
                                     (= (name (first acc)) "and"))
                              `(and ~@(next acc) ~where))
                            `(and ~acc ~where))))
          nil wheres))

(defn join [left right & {:keys [on type]}]
  (let [joins-vector (or (:join left) []) 
        common-fields (set/intersection (set (vals (:fields left)))
                                        (set (vals (:fields right))))
        joined-fields (if (:type :right) 
                        (merge (->> (:fields left)
                                    (filter (comp not common-fields val))
                                    (into {}))
                               (:fields right))
                        (merge (:fields left)
                               (->> (:fields right)
                                    (filter (comp not common-fields val))
                                    (into {}))))
        implicit-on (if (seq common-fields)
                      (map (fn [field]
                             `(= ~(resolve-field (:table left) (:fields left) field)
                                 ~(resolve-field (:table right) (:fields right) field)))
                           common-fields))
        on (if on
             [(resolve-fields nil joined-fields on)])
        join-condition (if-let [condition (seq (concat implicit-on on))]
                         `(and ~@condition)
                         true)]
    (-> left
        (assoc :fields joined-fields)
        (assoc :joins (into (conj joins-vector
                                  [(or type :inner) (:table right) join-condition])
                            (:joins right)))
        (assoc :where (combine-wheres (:where left) (:where right))))))

(defn select [query expression]
  (let [table-name (if-not (:joins query)
                     (-> query :table first val))
        old-where (:where query)
        resolved-expression (resolve-fields table-name (:fields query) expression)
        new-where (combine-wheres old-where resolved-expression)]
    (assoc query :where new-where)))

(defn sort-by [query fields]
  (let [table-name (if-not (:joins query)
                     (-> query :table first val))
        fields-seq (if (sequential? fields)
                     fields
                     [fields])]
    (assoc query
      :sort-by (for [field fields-seq]
                 (if (vector? field)
                   (resolve-field table-name (:fields query) field)
                   [(resolve-field table-name (:fields query) field) :asc])))))


(defn insert! [query & records]
  {:pre [(empty? (:joins query))]}
  ;; some code here
  )

(defn update! [query & partial-records]
  ;; some code here
  )

(defn delete! [query]
  ;; some code here
  )


(comment

  (binding [*database-type* :mysql]
    (let [users (-> (table :users)
                    (project [:id :username :password])
                    (select '(= :deleted false)))
          people (-> (table :people)
                     (project [:id :fname :sname])
                     (select '(= :deleted false)))
          uid-pid-match '(= :uid :pid)
          is-carlo `(= :fname "Carlo'; SELECT * FROM users --")]
      (-> (join (-> users
                    (rename {:id :uid}))
                (join (-> people
                          (rename {:id :pid}))
                      (-> (table {:others :o})
                          (project {:id :oid}))
                      '(= :pid :oid))
                uid-pid-match)
          (select is-carlo)
          (project [:fname :sname :oid]))))

  (-> (table :users)
      (join (table :something-else-with-a-username)
            true)
      (select '(or (= :username "john")
                   (not (= :username "carlo"))))
      (project [:username]))

  (-> (table {:nodes :child})
      (project [:parent-id, :name])
      (rename {:name :child.name})
      (join (-> (table {:nodes :parent})
                (project [:id, :name])
                (rename {:name :parent.name}))
            '(= :parent-id :id))
      (project [:child.name :parent.name]))

  (-> (table :users)
      (project [:id])
      (join (-> (table :people)
                (project [:id]))
            true))
  (-> (table :users)
      (project [:id :name])
      (rename {:id :name
               :name :id}))
  (-> (table :users)
      (project {:id :name
                :name :id}))

  (-> (table :anotherStack)
      (project [:anotherNumber])
      (join (-> (table :collection)
                (project [:number]))))

  (-> (table :users)
      (select '(= (left :username 1) "bloo"))))