summaryrefslogtreecommitdiff
path: root/src/clojure_sql/core.clj
blob: 319d281fe3e2d88936b9b298111c9f09bd3f0295 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
(ns clojure-sql.core
  (:refer-clojure :exclude [sort-by])
  (:require [clojure.set :as set]
            [clojure.string :as string]
            [clojure-sql.util :as u]
            [clojure.walk]))

(defn add-parentheses [s]
  (str \( s \)))

(defmulti field-name (fn [db _] db))
(defmethod field-name :default [_ field]
  (str \" (name field) \"))

(defmulti table-name (fn [db _] db))
(defmethod table-name :default [_ table]
  (str \" (name table) \"))

(defmulti sql-string (fn [db _] db))
(defmethod sql-string :default [_ string]
  (str \' (string/replace string "'" "''") \'))



;; compile-* multimethods are of the signature:
;;  (db, expr) -> (SQL, replacements)

(def is-unary? (comp boolean '#{not}))
(def is-predicate? (comp boolean '#{= < > <= >= is in}))

(defn c-return [val] [val nil])
(defn c-lift [f & orig-args]
  (fn [& c-args] 
    [(apply f (concat orig-args
                      (map first c-args)))
     (apply concat (map second c-args))]))
(def c-str (c-lift str))
(def c-join (fn [sep args]
              [(string/join sep (map first args))
               (apply concat (map second args))]))
(def c-add-parentheses (c-lift add-parentheses))

(defmulti compile-expression (fn [db _] db))
(defmethod compile-expression :default [db ex]
  (condp #(%1 %2) ex
    nil? (c-return "NULL")
    vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex))))
    keyword? (c-return (field-name db ex))
    string? ["?" [ex]] ;;(sql-string db ex)
    symbol? (c-return (string/upper-case (name ex)))
    sequential? (-> (if (= (count ex) 2)
                      (->> ex
                           (map (partial compile-expression db))
                           (c-join " "))
                      (->> (rest ex)
                           (map (partial compile-expression db))
                           (interpose (compile-expression db (first ex)))
                           (c-join " ")))
                    c-add-parentheses)
    (c-return ex)))

(defn ^:private make-table-name [db table & [alias]]
  (if (or (= table alias) (nil? alias))
    (table-name db table)
    (str (table-name db table) " AS " (field-name db alias)))  )

(defn ^:private make-field-name [db table field & [alias]]
  (if (or (= field alias) (nil? alias))
    (str (table-name db table) \. (field-name db field))
    (str (table-name db table) \. (field-name db field) " AS " (field-name db alias))))

(defmulti compile-fields (fn [db _] db))
(defmethod compile-fields :default [db fields-map]
  (if (seq fields-map)
    (->> (for [[[table field] alias] fields-map]
           (make-field-name db table field alias))
         (string/join ", ")
         c-return)
    (c-return "*")))

(defmulti compile-tables (fn [db _] db))
(defmethod compile-tables :default [db tables-map]
  (->> (for [[table alias] tables-map]
         (make-table-name db table alias))
       (string/join ", ")
       c-return))

(defmulti compile-join (fn [db _] db))
(defmethod compile-join :default [db [type table-map on]]
  (c-str (c-return (case type
                     :left " LEFT OUTER "
                     :right " RIGHT OUTER "
                     ""))
         (c-return " JOIN ")
         (compile-tables db table-map)
         (c-return " ON ")
         (compile-expression db on)))

(defmulti compile-joins (fn [db _] db))
(defmethod compile-joins :default [db joins]
  (->> joins
       (map (partial compile-join db))
       (c-join " ")))

(defmulti compile-where (fn [db _] db))
(defmethod compile-where :default [db expr]
  (if expr
    (c-str (c-return " WHERE ") (compile-expression db expr))))

(defmulti compile-sort-by (fn [db _] db))
(defmethod compile-sort-by :default [db fields]
  (if fields
    (->> (for [[[table field] dir] fields]
           (str (make-field-name db table field) \space (string/upper-case (name dir))))
         (string/join ",")
         (apply str " ORDER BY ")
         c-return)))

(defmulti compile-query (fn [db _] db))
(defmethod compile-query :default [db {:keys [table fields joins where sort-by]}]
  (c-str (c-return "SELECT ")
         (compile-fields db fields)
         (c-return " FROM ")
         (compile-tables db table)
         (compile-joins db joins)
         (compile-where db where)
         (compile-sort-by db sort-by)))




;; SQL SERVER
(defmethod field-name :sql-server [_ field]
  (str \[ (name field) \]))

(defmethod table-name :sql-server [_ table]
  (str \[ (name table) \]))

;; mySQL
(defmethod field-name :mysql [_ field]
  (str \` (name field) \`))

(defmethod table-name :mysql [_ table]
  (str \` (name table) \`))

(defmethod sql-string :mysql [_ string]
  (str \"  (string/replace string "\"" "\\\"") \"))





;; important sections:
;;   -PROJECTION-
;;   -TABLES-
;;   -FILTERS-
;;   GROUPING
;;   GROUPED FILTERS
;;   -SORTING-


;; table: tablename -> table_alias
;; fields: (table_alias, fieldname) -> field_alias
;; joins: [(tablename -> table_alias, type, on)]
;; where: expression
;; group-by: [field]
;; having: expression

(def ^:dynamic *database-type* nil)
(defrecord Table []
  clojure.lang.IDeref
  (deref [this] (compile-query *database-type* this)))

(defn table [arg]
  (into (->Table)
        (if (map? arg)
          {:table arg}
          {:table {arg arg}})))

(defn project [query fields]
  (let [table (if-not (:joins query)
                (-> query :table first val))
        alias-lookup (->> query :fields
                          (map (comp vec reverse))
                          (into {}))
        original-name (fn [field]
                        (if (vector? field)
                          field
                          (or (get alias-lookup field nil)
                              (if table
                                [table field]
                                (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved")))))))]
    (assoc query
      :fields (->> (for [[key val] (if (map? fields)
                                     fields
                                     (zipmap fields fields))]
                     [(original-name key) val])
                   (into {})))))

(defn rename [query field-renames]
  {:pre [(map? field-renames)]}
  (let [alias-lookup (->> query :fields
                          (map (comp vec reverse))
                          (into {}))
        original-name (fn [field]
                        (cond (vector? field) field
                              (contains? alias-lookup field) (get alias-lookup field)
                              :else (throw (ex-info (str "Invalid field in rename: " (pr-str field))
                                                    {:field field
                                                     :query query
                                                     :renames field-renames})))
                        (get alias-lookup field))]
    (update-in query
               [:fields] #(->> (for [[key val] field-renames]
                                 [(original-name key) val])
                               (into %)))))

(defn ^:private resolve-field [table aliases field]
  (let [field-alias-lookup (u/flip-map aliases)]
    (or (field-alias-lookup field)
        (if table
          [table field]
          (throw (RuntimeException. (str "Ambiguous field " field " cannot be resolved")))))))

(defn ^:private resolve-fields [table aliases expression]
  (clojure.walk/postwalk (fn [expr]
                           (cond
                            (keyword? expr) (resolve-field table aliases expr)
                            :else expr))
                         expression))

(defn ^:private combine-wheres [& wheres]
  (reduce (fn [acc where]
            (cond (nil? acc) where
                  (nil? where) acc
                  :else (or (if (and (sequential? where)
                                     (= (name (first where)) "and"))
                              `(and ~acc ~@(next where)))
                            (if (and (sequential? acc)
                                     (= (name (first acc)) "and"))
                              `(and ~@(next acc) ~where))
                            `(and ~acc ~where))))
          nil wheres))

(defn join [left right on & [type]]
  (let [joins-vector (or (:join left) [])
        joined-fields (merge (:fields left) (:fields right))]
    (-> left
        (assoc :fields joined-fields)
        (assoc :joins (into (conj joins-vector
                                  [(or type :inner) (:table right) (resolve-fields nil joined-fields on)])
                            (:joins right)))
        (assoc :where (combine-wheres (:where left) (:where right))))))

(defn select [query expression]
  (let [table-name (if-not (:joins query)
                     (-> query :table first val))
        old-where (:where query)
        resolved-expression (resolve-fields table-name (:fields query) expression)
        new-where (combine-wheres old-where resolved-expression)]
    (assoc query :where new-where)))

(defn sort-by [query fields]
  #_{:pre [(let [flipped-query-fields (u/flip-map (:fields query))
               field-names-seq (map (fn [x] (if (vector? x) (first x) x)) 
                                    (if (sequential? fields) fields [fields]))]
           (every? flipped-query-fields field-names-seq))]}
  (let [table-name (if-not (:joins query)
                     (-> query :table first val))
        fields-seq (if (sequential? fields)
                     fields
                     [fields])]
    (assoc query
      :sort-by (for [field fields-seq]
                 (if (vector? field)
                   (resolve-field table-name (:fields query) field)
                   [(resolve-field table-name (:fields query) field) :asc])))))



(comment

  (binding [*database-type* :mysql]
    (let [users (-> (table :users)
                    (project [:id :username :password])
                    (select '(= :deleted false)))
          people (-> (table :people)
                     (project [:id :fname :sname])
                     (select '(= :deleted false)))
          uid-pid-match '(= :uid :pid)
          is-carlo `(= :fname "Carlo")
          query (-> (join (-> users
                              (rename {:id :uid}))
                          (join (-> people
                                    (rename {:id :pid}))
                                (-> (table {:others :o})
                                    (project {:id :oid}))
                                '(= :pid :oid))
                          uid-pid-match)
                    (select is-carlo)
                    (project [:fname :sname :oid]))]
      @query))

  (-> (table :users)
      (project [:username])
      (join (table :something-else-with-a-username)
            true)
      (select '(or (= :username "john")
                   (not (= :username "carlo"))))
      deref)


  (-> (table {:nodes :child})
      (project [:parent-id, :name])
      (rename {:name :child.name})
      (join (-> (table {:nodes :parent})
                (project [:id, :name])
                (rename {:name :parent.name}))
            '(= :parent-id :id))
      (project [:child.name :parent.name])
      deref #_println)


  (deref (-> (table :anotherStack)
             (project [:anotherNumber])
             (join (-> (table :collection)
                       (project [:number]))
                   true))))