summaryrefslogtreecommitdiff
path: root/src/clojure_sql/compiler.clj
blob: ca8d2d9628b7d57acbac745104beeb22fb4c0e46 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
(ns clojure-sql.compiler
  (:require [clojure.string :as string]))



(defn add-parentheses [s]
  (str \( s \)))


;; ==============================================================
;; DB specific escaping methods
;; ==============================================================

(defmulti field-name (fn [db _] db))
(defmethod field-name :default [_ field]
  (str \" (name field) \"))

(defmulti table-name (fn [db _] db))
(defmethod table-name :default [_ table]
  (str \" (name table) \"))

(defmulti function-name (fn [db _] db))
(defmethod function-name :default [_ function]
  (str \" (name function) \"))





;; ==============================================================
;; Utility functions for the compile-* functions
;; ==============================================================

;; compile-* multimethods are of the signature:
;;  (db, expr) -> [SQL & replacements]

(def is-unary? (comp boolean '#{"not"} name))
(def is-binary? (comp boolean '#{"=" "<" ">" "<=" ">=" "is" "in" "+" "-" "/" "*"} name))
(def is-operator? (comp boolean '#{"and" "or"} name))

(defn c-return [val] [val])
(defn c-lift [f & orig-args]
  (fn [& c-args] 
    (apply vector 
           (apply f (concat orig-args
                            (map first c-args)))
           (apply concat (map rest c-args)))))
(def c-str (c-lift str))
(def c-join (fn [sep args]
              (apply vector
                     (string/join sep (map first args))
                     (apply concat (map rest args)))))
(def c-add-parentheses (c-lift add-parentheses))



;; ==============================================================
;; compile-* functions (turning a map into a query string)
;; ==============================================================

(defmulti compile-expression (fn [db _] db))
(defmethod compile-expression :default [db ex]
  (condp #(%1 %2) ex
    nil? (c-return "NULL")
    vector? (c-return (str (table-name db (first ex)) \. (field-name db (second ex))))
    keyword? (c-return (field-name db ex))
    string? ["?" ex] ;;(sql-string db ex)
    symbol? (c-return (string/upper-case (name ex)))
    sequential? (-> (condp #(%1 %2) (first  ex)
                      is-unary? (if (= (count ex) 2)
                                  (->> ex
                                       (map (partial compile-expression db))
                                       (c-join " "))
                                  (throw (ex-info "Unary operators can only take one argument."
                                                  {:operator (first ex)
                                                   :arguments (rest ex)})))
                      is-binary? (if (= (count ex) 3)
                                   (->> (rest ex)
                                        (map (partial compile-expression db))
                                        (interpose (compile-expression db (first ex)))
                                        (c-join " "))
                                   (throw (ex-info "Binary operators must take two arguments."
                                                   {:operator (first ex)
                                                    :arguments (rest ex)})))
                      is-operator? (->> (rest ex)
                                        (map (partial compile-expression db))
                                        (interpose (compile-expression db (first ex)))
                                        (c-join " "))
                      (->> (rest ex)
                           (map (partial compile-expression db))
                           (c-join ", ")
                           c-add-parentheses
                           (c-str (c-return (function-name db (first ex))))))
                    c-add-parentheses)
    (c-return ex)))

(defn ^:private make-table-name [db table & [alias]]
  (if (or (= table alias) (nil? alias))
    (table-name db table)
    (str (table-name db table) " AS " (field-name db alias)))  )

(defn ^:private make-field-name [db table field & [alias]]
  (if (or (= field alias) (nil? alias))
    (str (table-name db table) \. (field-name db field))
    (str (table-name db table) \. (field-name db field) " AS " (field-name db alias))))

(defmulti compile-fields (fn [db _] db))
(defmethod compile-fields :default [db fields-map]
  (if (seq fields-map)
    (->> (for [[[table field] alias] fields-map]
           (make-field-name db table field alias))
         (string/join ", ")
         c-return)
    (c-return "*")))

(defmulti compile-tables (fn [db _] db))
(defmethod compile-tables :default [db tables-map]
  (->> (for [[table alias] tables-map]
         (make-table-name db table alias))
       (string/join ", ")
       c-return))

(defmulti compile-join (fn [db _] db))
(defmethod compile-join :default [db [type table-map on]]
  (c-str (c-return (case type
                     :left " LEFT OUTER"
                     :right " RIGHT OUTER"
                     " INNER"))
         (c-return " JOIN ")
         (compile-tables db table-map)
         (c-return " ON ")
         (compile-expression db on)))

(defmulti compile-joins (fn [db _] db))
(defmethod compile-joins :default [db joins]
  (->> joins
       (map (partial compile-join db))
       (c-join "")))

(defmulti compile-where (fn [db _] db))
(defmethod compile-where :default [db expr]
  (if expr
    (c-str (c-return " WHERE ") (compile-expression db expr))))

(defmulti compile-sort-by (fn [db _] db))
(defmethod compile-sort-by :default [db fields]
  (if fields
    (->> (for [[[table field] dir] fields]
           (str (make-field-name db table field) \space (string/upper-case (name dir))))
         (string/join ",")
         (apply str " ORDER BY ")
         c-return)))

(defmulti compile-query (fn [db _] db))
(defmethod compile-query :default [db {:keys [table fields joins where sort-by]}]
  (c-str (c-return "SELECT ")
         (compile-fields db fields)
         (if table
           (c-return " FROM "))
         (compile-tables db table)
         (compile-joins db joins)
         (compile-where db where)
         (compile-sort-by db sort-by)))



;; ==============================================================
;; A few DB specific overrides
;; ==============================================================


;; SQL SERVER
(defmethod field-name :sql-server [_ field]
  (str \[ (name field) \]))

(defmethod table-name :sql-server [_ table]
  (str \[ (name table) \]))

;; mySQL
(defmethod field-name :mysql [_ field]
  (str \` (name field) \`))

(defmethod table-name :mysql [_ table]
  (str \` (name table) \`))