summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/jester/comparisons.clj36
-rw-r--r--src/jester/expansion.clj460
-rw-r--r--src/jester/match.clj142
-rw-r--r--src/jester/types.clj633
4 files changed, 1271 insertions, 0 deletions
diff --git a/src/jester/comparisons.clj b/src/jester/comparisons.clj
new file mode 100644
index 0000000..c47469e
--- /dev/null
+++ b/src/jester/comparisons.clj
@@ -0,0 +1,36 @@
+(ns jester.comparisons
+ "Comparison functions to specify how Jester comparisons work.")
+
+(defn string=
+ "Compare two (potentially-nil) strings. If either is nil, return nil.
+ Otherwise, return true iff a = b, ignoring case, and false
+ otherwise."
+ [^String a, ^String b]
+ (boolean (and a b (.equalsIgnoreCase a b))))
+
+(defn string<
+ "Compare two (potentially-nil) strings. If either is nil, return nil.
+ Otherwise, return true iff a < b, ignoring case, and false
+ otherwise."
+ [^String a, ^String b]
+ (and a b (< (.compareToIgnoreCase a b) 0)))
+
+(defn number=
+ "Compare two (potentially-nil) numbers. If either is nil, return nil.
+ If either is NaN, return false. Otherwise, return true iff a = b,
+ and false otherwise."
+ [a b]
+ (boolean (and a b (== a b))))
+
+(defn number<
+ "Compare two (potentially-nil) numbers. If either is nil, return nil.
+ If either is NaN, return false. Otherwise, return true iff a < b,
+ and false otherwise."
+ [a b]
+ (boolean (and a b (< a b))))
+
+(defn boolean=
+ "Compare two (potentially-nil) booleans. If either is nil, return nil.
+ Otherwise, return true iff a = b, and false otherwise."
+ [a b]
+ (boolean (and (some? a) (some? b) (= a b))))
diff --git a/src/jester/expansion.clj b/src/jester/expansion.clj
new file mode 100644
index 0000000..b6c915b
--- /dev/null
+++ b/src/jester/expansion.clj
@@ -0,0 +1,460 @@
+(ns jester.expansion
+ (:require [clojure.string :as str]
+ [clojure.walk :as walk]
+ [jester.comparisons :refer [string= string<
+ number= number<
+ boolean=]]
+ [jester.match :refer [type-match]]
+ [jester.types :refer [in-constraint-environment
+ constrain
+ ground-type
+ type-variable]]))
+
+(defn ^:private type-of
+ "Return the type of an expanded form. This relies on the metadata,
+ except for primitive types where metadata can't be stored."
+ [expanded-form]
+ (cond (number? expanded-form) `(enum ~expanded-form)
+ (string? expanded-form) `(enum ~expanded-form)
+ (boolean? expanded-form) `(enum ~expanded-form)
+ (nil? expanded-form) `(optional ~(type-variable))
+ :else (::type (meta expanded-form))))
+
+(defn ^:private with-type [object type]
+ (with-meta object {::type type}))
+
+(def ^:private ^:dynamic *parameters*)
+
+(declare expand-form*)
+
+(defn regex-start-index
+ "Return the index of the first character in `string` which matches the
+ start of an instance of `regex`."
+ [regex string]
+ (let [matcher (re-matcher #"\??\." string)]
+ (when (.find matcher)
+ (.start matcher))))
+
+(defn expand-symbol [form]
+ (let [string (name form)]
+ (if-let [sep-index (regex-start-index #"\??\." string)]
+ (let [root (.substring string 0 sep-index)
+ rest (.substring string sep-index)
+ root-sym (symbol root)]
+ (expand-form* `(~(symbol rest)
+ ~root-sym)))
+ (with-type form
+ (or (get *parameters* form)
+ (let [type (type-variable)]
+ (set! *parameters* (assoc *parameters* form type))
+ type))))))
+
+(defn getter-expander [form]
+ (let [[head arg] form
+ string (name head)]
+ (when-not (or (str/starts-with? string ".")
+ (str/starts-with? string "?."))
+ (throw (ex-info "Unknown value in head of form"
+ {:form form})))
+ (let [;; This complicated regex splits the string into parts
+ ;; starting with either "?" or "?."
+ parts (str/split string #"(?=\?\.|(?<!\?)\.)")
+ part ^String (last parts)
+ rest (butlast parts)]
+ (cond
+ (str/starts-with? part ".")
+ (let [type (type-variable)
+ part-sym (symbol (.substring part 1))
+ expanded (expand-form*
+ (if (empty? rest)
+ arg
+ `(~(symbol (str/join rest)) ~arg)))]
+ (constrain (type-of expanded) {part-sym type})
+ (with-type `(get ~expanded '~part-sym)
+ type))
+ (str/starts-with? part "?.")
+ (let [type (type-variable)
+ part-sym (symbol (.substring part 2))
+ expanded (expand-form*
+ (if (empty? rest)
+ arg
+ `(~(symbol (str/join rest)) ~arg)))]
+ (constrain (type-of expanded)
+ `(optional {~part-sym (optional ~type)}))
+ (with-type `(get ~expanded '~part-sym)
+ `(optional ~type)))
+ :else
+ (throw (ex-info "Invalid head for getter"
+ {:head head
+ :part part
+ :rest rest}))))))
+
+(defn expand-map [form]
+ (let [expressions (reduce-kv (fn [acc k value-form]
+ (assoc acc (list 'quote k) (expand-form* value-form)))
+ {} form)]
+ (with-type expressions
+ (reduce-kv (fn [acc k expanded]
+ (assoc acc (second k) (type-of expanded)))
+ {} expressions))))
+
+(defn expand-vector [form]
+ (let [terms (mapv expand-form* form)]
+ (with-type terms
+ (mapv #(type-of %) terms))))
+
+(def expanders {})
+(def operators {})
+
+(defn expand-operator-form [operator form]
+ (let [result (type-variable)
+ args (map expand-form* (rest form))]
+ (constrain ((:type operator)) `(-> ~(mapv type-of args) ~result))
+ (with-type `(~(:name operator) ~@args)
+ result)))
+
+(defn expand-form* [form]
+ (cond
+ (number? form) form
+ (string? form) form
+ (boolean? form) form
+ (nil? form) form
+ (symbol? form) (expand-symbol form)
+ (map? form) (expand-map form)
+ (vector? form) (expand-vector form)
+ (seq? form) (if (symbol? (first form))
+ (if-let [operator (get operators (first form))]
+ (if (get expanders (first form))
+ (throw (ex-info (format "Found both an operator and an expander for %s"
+ (pr-str (first form)))
+ {:head (first form)
+ :form form}))
+ (expand-operator-form operator form))
+ (if-let [expander (get expanders (first form))]
+ (expander form)
+ (getter-expander form)))
+ (throw (ex-info "Invalid form"
+ {:form form})))
+ :else (throw (ex-info "Invalid form"
+ {:form form}))))
+
+(defn ^:private force-all
+ "Walk over a structure, iterating over child objects and forcing
+ sequences. The result is identical the input object.
+
+ (let [x ...]
+ (identical? x (force-all x)))
+ ;;=> true "
+ [obj]
+ (cond
+ ;; Reduce over sequences, vectors, and sets
+ (or (seq? obj)
+ (vector? obj)
+ (set? obj))
+ (reduce (fn [result item]
+ (force-all item)
+ result)
+ obj obj)
+ ;; Reduce over the keys and values in maps
+ (map? obj)
+ (reduce-kv (fn [result key value]
+ (force-all key)
+ (force-all value)
+ result)
+ obj obj)
+ ;; Return everything else unchanged.
+ :else
+ obj))
+
+(defn expand-form [form]
+ (binding [*parameters* {}]
+ (in-constraint-environment
+ (let [expanded (force-all (expand-form* form))
+ grounded (ground-type *parameters* :argument)]
+ ;; Once we've solved the input types, add constraints to
+ ;; ensure that information is captured when we solve the
+ ;; output type. We only need to add them as lower bounds,
+ ;; because :return grounding solves for the greatest lower
+ ;; bound.
+ (doseq [[param type] grounded]
+ (constrain type (get *parameters* param)))
+ (if (instance? clojure.lang.IMeta expanded)
+ (with-meta expanded
+ {:input grounded
+ :output (ground-type (type-of expanded) :return)})
+ expanded)))))
+
+(defmacro define-expander
+ {:style/indent 2}
+ [operator-name [& args] & body]
+ `(alter-var-root #'expanders
+ assoc '~operator-name (fn [[_# ~@args]]
+ ~@body)))
+
+(defn ^:private resolve-symbol [sym]
+ (let [resolved (meta (ns-resolve *ns* sym))
+ ns (:ns resolved)]
+ (when resolved
+ (symbol (-> ns ns-name name)
+ (-> resolved :name name)))))
+
+(defmacro define-simple-operator
+ {:style/indent :defn}
+ ([name type]
+ (if-let [resolved (resolve-symbol name)]
+ `(define-simple-operator ~name ~resolved ~type)
+ (throw (ex-info (format "Cannot resolve symbol for operator: %s" (pr-str name))
+ {:namespace *ns*
+ :symbol name}))))
+ ([name operator-symbol type]
+ (let [type (type-match type
+ (∀ args value)
+ `(let [~@(mapcat #(list % `(list 'var (gensym '~%))) args)]
+ ~(quote-type value (set args)))
+ _
+ (list 'quote type))]
+ (if (namespace operator-symbol)
+ `(alter-var-root #'operators assoc '~name
+ {:type (fn [] ~type)
+ :name '~operator-symbol})
+ (throw (ex-info (format "Operator symbol must have a namespace" (pr-str name))
+ {:symbol operator-symbol}))))))
+
+(defn ^:private e-use [form type]
+ (let [result (expand-form* form)]
+ (constrain (type-of result) type)
+ result))
+
+(defmacro ^:private with-temporary-parameter-type [[parameter type] & body]
+ `(try
+ (let [old# (get *parameters* ~parameter)]
+ (try
+ (set! *parameters* (assoc *parameters* ~parameter ~type))
+ (force-all (do ~@body))
+ (finally
+ (if old#
+ (set! *parameters* (assoc *parameters* ~parameter old#))
+ (set! *parameters* (dissoc *parameters* ~parameter))))))))
+
+(defn ? [optional default]
+ (if (nil? optional) default optional))
+
+(define-simple-operator ?
+ (∀ [x] (-> [(optional x) x] x)))
+
+(define-simple-operator string= (-> [string string] boolean))
+(define-simple-operator string=? (-> [(optional string) (optional string)] (optional boolean)))
+
+(define-expander find [[var & {list :in condition :when}] value]
+ (let [type (type-variable)
+ result (with-temporary-parameter-type [var type]
+ (expand-form* value))]
+ (with-type `(reduce (fn [_# ~var]
+ (when ~(with-temporary-parameter-type [var type]
+ (e-use (or condition true) '(optional boolean)))
+ (reduced ~result)))
+ nil ~(e-use list `(list ~type)))
+ `(optional ~(type-of result)))))
+
+(define-expander for [[var & {list :in condition :when}] value]
+ (let [type (type-variable)
+ result (with-temporary-parameter-type [var type]
+ (expand-form* value))]
+ (with-type `(reduce (fn [acc# ~var]
+ (if ~(with-temporary-parameter-type [var type]
+ (e-use (or condition true) '(optional boolean)))
+ (conj acc# ~result)
+ acc#))
+ [] ~(e-use list `(list ~type)))
+ `(list ~(type-of result)))))
+
+(define-expander all-of [[var & {list :in condition :when}] value]
+ (let [type (type-variable)]
+ (with-type `(reduce (fn [_# ~var]
+ (if ~(with-temporary-parameter-type [var type]
+ `(or (not ~(e-use (or condition true) '(optional boolean)))
+ ~(e-use value 'boolean)))
+ true
+ (reduced false)))
+ true ~(e-use list `(list ~type)))
+ 'boolean)))
+
+(define-expander one-of [[var & {list :in condition :when}] value]
+ (let [type (type-variable)]
+ (with-type `(reduce (fn [_# ~var]
+ (if ~(with-temporary-parameter-type [var type]
+ `(or (not ~(e-use (or condition true) '(optional boolean)))
+ ~(e-use value 'boolean)))
+ true
+ (reduced false)))
+ true ~(e-use list `(list ~type)))
+ 'boolean)))
+
+(define-expander with-optional [[var optional] value]
+ (assert (symbol? var) "Binding in with-optional must be a symbol")
+ (let [type (type-variable)
+ result (with-temporary-parameter-type [var type]
+ (expand-form* value))]
+ (with-type `(let [~var ~(e-use optional `(optional ~type))]
+ (when (some? ~var)
+ ~result))
+ `(optional ~(type-of result)))))
+
+(defn ^:private quote-type [type escape?]
+ (cond
+ (escape? type)
+ type
+ (seq? type)
+ (cons 'list (map #(quote-type % escape?) type))
+ ;; Reduce over sequences, vectors, and sets
+ (or (vector? type))
+ (reduce (fn [result item]
+ (conj result (quote-type item escape?)))
+ [] type)
+ ;; Reduce over the values in maps
+ (map? type)
+ (reduce-kv (fn [result key value]
+ (assoc result key (quote-type value)))
+ {} type)
+ ;; Return everything else quoted.
+ :else
+ (list 'quote type)))
+
+(define-simple-operator number= (-> [number number] boolean))
+(def number=? number=)
+(define-simple-operator number=? (-> [(optional number) (optional number)] (optional boolean)))
+
+;; Working out how exactly these should behave in the presence of
+;; optionals is a bit tricky, mostly because of short-circuiting. The
+;; easiest option is to not short circuit and define it as: "and? is
+;; nil if any argument is nil, false if any argument is false, and
+;; true otherwise". The obvious question then is "why does nil win
+;; over false? Why not have it false, then nil, then true?" Working
+;; out which makes the most sense might require some thinking.
+;; TODO: work this out.
+(define-simple-operator and
+ (-> [& (list boolean)] boolean))
+(define-simple-operator and? clojure.core/and
+ (-> [& (list (optional boolean))] (optional boolean)))
+
+(define-simple-operator or
+ (-> [& (list boolean)] boolean))
+(defmacro or? [& values]
+ (if-let [[x & xs] (seq values)]
+ `(let [val# ~x]
+ (cond (nil? val#) nil
+ (true? val#) true
+ :else `(or? ~@xs)))
+ false))
+(define-simple-operator or?
+ (-> [& (list (optional boolean))] (optional boolean)))
+
+(define-simple-operator not
+ (-> [boolean] boolean))
+(defn not? [value]
+ (cond (nil? value) nil
+ (true? value) false
+ :else true))
+(define-simple-operator not?
+ (-> [(optional boolean)] (optional boolean)))
+
+(define-simple-operator get
+ (∀ [k v] (-> [(map k v) k] (optional v))))
+
+(define-simple-operator str
+ (-> [& (list string)] string))
+
+(def str? str)
+(define-simple-operator str?
+ (-> [& (list (optional string))] string))
+
+(define-simple-operator count
+ (-> [(list any)] number))
+
+(def has? some?)
+(define-simple-operator has?
+ (-> [(optional any)] boolean))
+
+
+
+(import java.time.OffsetDateTime
+ java.time.ZonedDateTime
+ java.time.temporal.ChronoField
+ java.time.format.DateTimeFormatterBuilder)
+
+(def date-time-format
+ (-> (DateTimeFormatterBuilder.)
+ (.parseCaseInsensitive)
+ (.appendValue ChronoField/YEAR)
+ (.appendLiteral "-")
+ (.appendValue ChronoField/MONTH_OF_YEAR)
+ (.appendLiteral "-")
+ (.appendValue ChronoField/DAY_OF_MONTH)
+ (-> .optionalStart
+ (.appendLiteral "T")
+ (.appendValue ChronoField/HOUR_OF_DAY)
+ (.appendLiteral ":")
+ (.appendValue ChronoField/MINUTE_OF_HOUR)
+ (-> .optionalStart
+ (.appendLiteral ":")
+ (.appendValue ChronoField/SECOND_OF_MINUTE)
+ .optionalEnd)
+ .optionalEnd)
+ (-> .optionalStart
+ (.appendOffset "+HH:mm" "Z")
+ .optionalEnd)
+ (.parseDefaulting ChronoField/HOUR_OF_DAY 0)
+ (.parseDefaulting ChronoField/MINUTE_OF_HOUR 0)
+ (.parseDefaulting ChronoField/SECOND_OF_MINUTE 0)
+ (.toFormatter)
+ (.withZone java.time.ZoneOffset/UTC)))
+
+(defn parse-timestamp [string]
+ (.withOffsetSameInstant (OffsetDateTime/parse string date-time-format)
+ (java.time.ZoneOffset/UTC)))
+
+(parse-timestamp "2021-02-23")
+
+(define-simple-operator parse-timestamp
+ (-> [string] timestamp))
+
+(defn time-unit->chrono-unit [time-unit]
+ (condp #(.equalsIgnoreCase ^String %1 ^String %2)
+ time-unit
+ "second" java.time.temporal.ChronoUnit/SECONDS
+ "minute" java.time.temporal.ChronoUnit/MINUTES
+ "hour" java.time.temporal.ChronoUnit/HOURS
+ "day" java.time.temporal.ChronoUnit/DAYS
+ "month" java.time.temporal.ChronoUnit/MONTHS
+ "year" java.time.temporal.ChronoUnit/YEARS))
+(defn truncate-timestamp [^OffsetDateTime time time-unit]
+ (.truncatedTo time (time-unit->chrono-unit time-unit)))
+(defn truncate-timestamp? [time time-unit]
+ (and time time-unit (truncate-timestamp time time-unit)))
+
+(define-simple-operator truncate-timestamp
+ (-> [timestamp (enum "second" "minute" "hour" "day" "month" "year")] timestamp))
+(define-simple-operator truncate-timestamp?
+ (-> [(optional timestamp)
+ (optional (enum "second" "minute" "hour" "day" "month" "year"))]
+ (optional timestamp)))
+
+(define-simple-operator then jester.core/then
+ (∀ [x y] (-> [x y] y)))
+
+;; (alter-var-root #'operators
+;; assoc 'str {:type #(quote (-> [& (list string)] string))
+;; :name 'str})
+
+;; ;; (alter-var-root #'operators
+;; ;; assoc 'get {:type #(let [k (type-variable)
+;; ;; v (type-variable)]
+;; ;; `(-> [(map ~k ~v) ~k] (optional ~v)))})
+
+;; (alter-var-root #'operators
+;; assoc 'str? {:type #(quote (-> [& (list (optional string))] string))
+;; :name 'str})
+
+;; (alter-var-root #'operators
+;; assoc 'count {:type #(quote (-> [(list any)] number))
+;; :name 'count})
diff --git a/src/jester/match.clj b/src/jester/match.clj
new file mode 100644
index 0000000..5d18aec
--- /dev/null
+++ b/src/jester/match.clj
@@ -0,0 +1,142 @@
+(ns jester.match
+ "A simple pattern matcher, to avoid pulling in core.match. It is also
+ designed to specifically match against types, so it has some special
+ symbol handling."
+ (:require [clojure.walk :as walk]))
+
+(def ^:dynamic ^:private *seen-vars*)
+
+(defn ^:private special-type-symbol? [symbol]
+ (and (symbol? symbol)
+ (contains? #{"string" "number" "boolean" "any" "none" "_" "&"}
+ (name symbol))))
+
+(defmacro ^:private ensuring-unique-vars
+ {:style/indent 0}
+ [& body]
+ `(binding [*seen-vars* {}]
+ (let [result# (do ~@body)
+ multiple-vars# (into #{}
+ (comp (filter (fn [[_# v#]] (> v# 1)))
+ (map key))
+ *seen-vars*)]
+ (when-not (empty? multiple-vars#)
+ (throw (ex-info "Variables are bound multiple times"
+ {:names (mapv symbol multiple-vars#)})))
+ result#)))
+
+(declare pattern->condition)
+
+(defn seq-pattern->condition [pattern value]
+ (if-let [[pat-head & pat-tail] (seq pattern)]
+ (if (= pat-head '&)
+ (pattern->condition (first pat-tail) value)
+ (let [head-s (gensym "head")
+ tail-s (gensym "tail")]
+ `(let [[~head-s & ~tail-s] ~value]
+ (and ~(pattern->condition pat-head head-s)
+ ~(seq-pattern->condition pat-tail tail-s)))))
+ true))
+
+(defn ^:private pattern->condition [pattern value]
+ (cond
+ (vector? pattern)
+ (let [value-s (gensym "value")]
+ `(let [~value-s ~value]
+ (and (vector? ~value-s)
+ ~@(->> pattern
+ (take-while #(not= % '&))
+ (map-indexed (fn [i p]
+ (pattern->condition p `(get ~value-s ~i))))
+ doall)
+ ~@(when (seq (drop-while #(not= % '&) pattern))
+ [(pattern->condition (second (drop-while #(not= % '&) pattern))
+ `(subvec ~value-s
+ ~(count
+ (take-while #(not= % '&)
+ pattern))))]))))
+ (seq? pattern)
+ (let [[head & tail] pattern]
+ (assert (symbol? head) (str "Head of a seq pattern must be a symbol, not " (pr-str head)))
+ (let [head-s (gensym "head")
+ tail-s (gensym "tail")]
+ `(let [v# ~value]
+ (when (seq? v#)
+ (let [[~head-s & ~tail-s] v#]
+ (and (symbol? ~head-s)
+ (= ~(name head) (name ~head-s))
+ ~(seq-pattern->condition tail tail-s)))))))
+ (special-type-symbol? pattern)
+ (if (= (name pattern) "_")
+ true
+ `(let [v# ~value]
+ (and (symbol? v#)
+ (= ~(name pattern) (name v#)))))
+ (symbol? pattern)
+ (do (when-not (= (name pattern) "_")
+ (set! *seen-vars* (update *seen-vars* (name pattern) (fnil inc 0))))
+ true)
+ :else
+ `(= ~pattern ~value)))
+
+(defn ^:private pattern->destructuring-form [pattern]
+ (cond
+ (vector? pattern) (into [] (map pattern->destructuring-form) pattern)
+ (seq? pattern) (into ['_] (map pattern->destructuring-form) (rest pattern))
+ (symbol? pattern) (if (= pattern '&)
+ '&
+ (if (special-type-symbol? pattern)
+ '_
+ pattern))
+ :else '_))
+
+(defmacro type-match-let
+ {:style/indent 1}
+ [[pattern value] & body]
+ (ensuring-unique-vars
+ (let [value-s (gensym "value")]
+ `(let [~value-s ~value]
+ (if ~(pattern->condition pattern value-s)
+ (let [~(pattern->destructuring-form pattern) ~value-s]
+ ~@body)
+ (throw (ex-info "Value did not match pattern"
+ {:pattern '~pattern
+ :value ~value-s})))))))
+
+(defmacro type-match
+ {:style/indent 1}
+ [value & clauses]
+ (let [value-s (gensym "value")]
+ `(let [~value-s ~value]
+ (cond
+ ~@(mapcat (fn [[pattern result]]
+ (ensuring-unique-vars
+ [(pattern->condition pattern value-s)
+ `(let [~(pattern->destructuring-form pattern)
+ ~value-s]
+ ~result)]))
+ (partition 2 clauses))
+ :else
+ (throw (ex-info "Value did not match any pattern"
+ {:patterns '~(mapv first (partition 2 clauses))
+ :value ~value-s}))))))
+
+(defmacro defn-type-match [fn-name & clauses]
+ (let [[doc clauses] (if (and (string? (first clauses))
+ (odd? (count clauses)))
+ [(first clauses) (rest clauses)]
+ [nil clauses])
+ patterns (map first (partition 2 clauses))
+ all-symbols? (fn [pattern]
+ (every? (fn [s]
+ (and (symbol? s) (not (special-type-symbol? s))))
+ pattern))
+ argument-symbols (or (seq (first (filter all-symbols? patterns)))
+ (map #(do % (gensym "arg")) (first patterns)))]
+ (assert (every? vector? patterns) "All patterns in a defn-type-match must be a vector")
+ (assert (apply = (map count patterns)) "All patterns in a defn-type-match must have the same number of items")
+ `(defn ~fn-name
+ ~@(when doc [doc])
+ [~@argument-symbols]
+ (type-match [~@argument-symbols]
+ ~@clauses))))
diff --git a/src/jester/types.clj b/src/jester/types.clj
new file mode 100644
index 0000000..9451ea5
--- /dev/null
+++ b/src/jester/types.clj
@@ -0,0 +1,633 @@
+(ns jester.types
+ (:require [clojure.string :as str]
+ [jester.comparisons :refer [string= number= boolean=]]
+ [jester.match :refer [type-match-let type-match defn-type-match]]))
+
+;; Valid jester types are:
+;;
+;; These types are all "atomic":
+;; - any (supertype of all types)
+;; - none (subtype of all types)
+;; - number, string, boolean
+;; - (enum "a" ...), subtype of string
+;; - (enum 0 ...), subtype of number
+;; - (enum true ...), subtype of boolean
+
+;; These types are all "compound":
+;; - (optional α)
+;; - (list α), a list of values satisfying α
+;; - [α β], a tuple of α and β
+;; - [α β & γ], a tuple of α and β, and an additional sequence which matches γ
+;; - (map α β), a map of αs to βs
+;; - (-> [α β & γ] δ), a function taking an α, a β, and a sequence matching γ to δ
+;; - records, denoted by a map of property names to types
+
+;; Subtyping relationships are:
+;;
+;; - α ≤ any
+;; - (enum "a" ...) ≤ string
+;; - (enum 0 ...) ≤ number
+;; - (enum true ...) ≤ boolean
+;; - α ≤ (optional α)
+;; - [α α] ≤ (list α)
+;; - [α α & (list α)] ≤ (list α)
+;; - {l₁ α, l₂ β} ≤ {l₁ α}
+;; ... and more!
+
+(defn ^:private split-rest [coll]
+ (let [[main rest] (split-with #(not= % '&) coll)]
+ (if (seq rest)
+ (do (assert (= (count rest) 2) "Tail must consist of only one element")
+ [main (second rest)])
+ [main nil])))
+
+(defn ^:private assert-type*
+ [object type path root-object root-type]
+ (letfn [(type-error [message & format-args]
+ (throw (ex-info (apply format message format-args)
+ {:object object, :root-object root-object
+ :type type, :root-type root-type
+ :path path})))]
+ (type-match type
+ any true
+ none (type-error "No value can satisfy none, but got %s" (pr-str object))
+
+ string (when-not (string? object)
+ (type-error "Expected a string, but got %s" (pr-str object)))
+ number (when-not (number? object)
+ (type-error "Expected a number, but got %s" (pr-str object)))
+ boolean (when-not (boolean? object)
+ (type-error "Expected a boolean, but got %s" (pr-str object)))
+
+ (enum & values)
+ (when-not (some #(cond
+ (string? %) (and (string? object) (string= % object))
+ (number? %) (and (number? object) (number= % object))
+ (boolean? %) (and (boolean? object) (boolean= % object))
+ :else (assert false "Invalid enum value"))
+ values)
+ (if (empty? values)
+ (type-error "Empty enumeration type cannot have a value, but got %s"
+ (pr-str object))
+ (type-error "Expected %s, but got %s"
+ (cond
+ (= (count values) 1) (pr-str (first values))
+ (= (count values) 2) (str (pr-str (first values)) " or " (pr-str (second values)))
+ :else (str (str/join ", " (map pr-str (butlast values))) ", or " (pr-str (last values))))
+ (pr-str object))))
+
+ (optional t)
+ (when-not (nil? object)
+ (assert-type* object t path root-object root-type))
+
+ [& types]
+ (if (sequential? object)
+ (let [[main rest] (split-rest types)]
+ (if (< (count object) (count main))
+ (type-error "Expected at least %s item%s, but got %s" (count main) (if (> (count main) 1) "s" "") (count object))
+ (do (reduce (fn [index [item-type item]]
+ (assert-type* item item-type (conj path index) root-object root-type)
+ (inc index))
+ 0 (map vector main object))
+ (if rest
+ (assert-type* (drop (count main) object) rest (conj path (list 'drop (count main))) root-object root-type)
+ (when-not (= (count object) (count main))
+ (type-error "Expected at most %s item%s, but got %s" (count main) (if (> (count main) 1) "s" "") (count object)))))))
+ (type-error "Expected a sequential object, but got %s" (pr-str object)))
+
+ (list item-type)
+ (if (sequential? object)
+ (reduce (fn [index item]
+ (assert-type* item item-type (conj path index) root-object root-type)
+ (inc index))
+ 0 object)
+ (type-error "Expected a sequential object, but got %s" (pr-str object)))
+
+ (map k v)
+ (if (map? object)
+ (reduce-kv (fn [_ key value]
+ (assert-type* key k type root-object root-type)
+ (assert-type* value v (conj path key) root-object root-type))
+ nil object)
+ (type-error "Expected a map, but got %s" (pr-str object)))
+
+ (-> args result)
+ (type-error "Cannot verify function type of %s for %s" (pr-str type) (pr-str object))
+
+ _
+ (if (map? type)
+ (if (map? object)
+ (reduce-kv (fn [_ key key-type]
+ (assert-type* (get object key) key-type (conj path key) root-object root-type))
+ nil type)
+ (type-error "Expected an object (map), but got %s" (pr-str object)))
+ (type-error "Invalid type: %s" (pr-str type))))))
+
+(defn assert-type
+ "Assert that `object` is of type `type`. Throws an exception with
+ failure details if this is not the case. Returns nil otherwise."
+ [object type]
+ (assert-type* object type [] object type)
+ nil)
+
+(defn-type-match subtype?
+ "Returns true if every object of type `sub` also satisfies the type
+ `super`. Returns false otherwise."
+
+ [_ any] true
+ [none _] true
+
+ [(enum & values) string] (every? string? values)
+ [(enum & values) number] (every? number? values)
+ [(enum & values) boolean] (every? boolean? values)
+
+ [(enum & sub-values) (enum & super-values)]
+ (every? (fn [sub-value]
+ (cond (string? sub-value) (some #(and (string? %) (string= sub-value %))
+ super-values)
+ (number? sub-value) (some #(and (number? %) (number= sub-value %))
+ super-values)
+ (boolean? sub-value) (some #(and (boolean? %) (boolean= sub-value %))
+ super-values)
+ :else (assert false "We should not reach here.")))
+ sub-values)
+
+ [(optional subtype) (optional supertype)]
+ (subtype? subtype supertype)
+
+ [subtype (optional supertype)]
+ (subtype? subtype supertype)
+
+ ;; This is unfortunately quite complicated. :(
+ [[& subtypes] [& supertypes]]
+ (let [[sub-main sub-rest] (split-rest subtypes)
+ [super-main super-rest] (split-rest supertypes)]
+ (and
+ ;; The subtype must guarantree at least as many elements as the
+ ;; supertype.
+ (>= (count sub-main) (count super-main))
+ ;; Any compulsory types from both arrays are subtypes
+ (every? true? (map subtype? sub-main super-main))
+ (or
+ ;; If there are no excess compulsory subtypes then we're good.
+ (= (count sub-main) (count super-main))
+ ;; Otherwise, we need to check them against the super-rest.
+ (if super-rest
+ (subtype? (->> sub-main
+ (drop (count super-main))
+ vec)
+ super-rest)
+ ;; If there's no super-rest, then we just have excess types,
+ ;; which means there's no subtype relationship here.
+ false))
+ ;; If the subtype has a rest, then the supertype does too and
+ ;; they have an appropriate subtyping relationship
+ (if sub-rest
+ (and super-rest
+ (subtype? sub-rest super-rest))
+ true)))
+
+ [[& subtypes] (list supertype)]
+ (let [[sub-main sub-rest] (split-rest subtypes)]
+ (and (every? #(subtype? % supertype) sub-main)
+ (if sub-rest
+ (subtype? sub-rest `(list ~supertype))
+ true)))
+
+ [(list subtype) (list supertype)]
+ (subtype? subtype supertype)
+
+ [(map sub-key sub-val) (map super-key super-val)]
+ (and (subtype? super-key sub-key)
+ (subtype? sub-val super-val))
+
+ [(-> sub-args sub-result) (-> super-args super-result)]
+ (and (subtype? super-args sub-args)
+ (subtype? sub-result super-result))
+
+ [subtype supertype]
+ (cond
+ (and (map? subtype) (map? supertype))
+ (reduce-kv (fn [_ prop prop-type]
+ (or (subtype? (get subtype prop) prop-type)
+ (reduced false)))
+ nil supertype)
+
+ (and (symbol? subtype) (symbol? supertype))
+ (= (name subtype) (name supertype))
+
+ :else
+ false))
+
+(def ^:private ^:dynamic *upper-bounds*)
+(def ^:private ^:dynamic *lower-bounds*)
+
+(defn call-in-constraint-environment [thunk]
+ (binding [*upper-bounds* {}
+ *lower-bounds* {}]
+ (thunk)))
+
+(defmacro in-constraint-environment
+ {:style/indent 0}
+ [& body]
+ `(call-in-constraint-environment #(do ~@body)))
+
+(defn-type-match constrain
+ "Add a constraint that `sub` is a subtype of `super` within the
+ current constraint environment (see `in-constraint-environment`).
+ The return value of this function is meaningless, and it will throw
+ an exception if the provided constraint leads to a contradiction."
+
+ ;; No point doing anything for the constraints that aren't really constraints.
+ [_ any] nil
+ [none _] nil
+
+ [subtype (var n)]
+ (when-not (contains? (get *lower-bounds* n) subtype)
+ (set! *lower-bounds* (update *lower-bounds* n (fnil conj #{}) subtype))
+ (doseq [upper (get *upper-bounds* n)]
+ (constrain subtype upper))
+ ;; The type hierarchy splits for all of the "atomic" types, so if
+ ;; we're constrained to be a supertype of one of the divisions
+ ;; then add a constraint on the top as well. While there is still
+ ;; the `any` type at the top of the hierachy, we only want it to
+ ;; show up for truly unconstrained types. This makes things like
+ ;; (? a 10) infer `a` to be (optional number) rather than the less
+ ;; accurate/helpful (optional any).
+ (type-match subtype
+ string (constrain `(var ~n) `(optional string))
+ number (constrain `(var ~n) `(optional number))
+ boolean (constrain `(var ~n) `(optional boolean))
+ (enum & values) (do (when (some string? values)
+ (constrain `(var ~n) `(optional string)))
+ (when (some number? values)
+ (constrain `(var ~n) `(optional number)))
+ (when (some boolean? values)
+ (constrain `(var ~n) `(optional boolean))))
+ _ nil))
+
+ [(var n) supertype]
+ (when-not (contains? (get *upper-bounds* n) supertype)
+ (set! *upper-bounds* (update *upper-bounds* n (fnil conj #{}) supertype))
+ (doseq [lower (get *lower-bounds* n)]
+ (constrain lower supertype))
+ ;; If we're a subtype of an optional type then add an artificial
+ ;; "unoptional" constraint to track the constraint against the
+ ;; wrapped type variable.
+ (type-match supertype
+ (optional x) (constrain `(unoptional (var ~n))
+ x)
+ _ nil))
+
+ [subtype (optional super)]
+ (constrain (loop [subtype subtype]
+ (type-match subtype
+ (optional sub) (recur sub)
+ _ subtype))
+ super)
+
+ [(unoptional _) _]
+ nil ;; This is a weird extra constraint we add to propagate the
+ ;; restrictions imposed by having an optional on the RHS. It
+ ;; doesn't impose any other constraint, so just ingore it.
+
+ [[& subtypes] [& supertypes]]
+ (let [[sub-main sub-rest] (split-rest subtypes)
+ [super-main super-rest] (split-rest supertypes)]
+ (when (< (count sub-main) (count super-main))
+ (throw (ex-info "Constraint error: cannot solve type relation"
+ {:subtype (vec subtypes)
+ :supertype (vec supertypes)})))
+ (doall (map constrain sub-main super-main))
+ (when-not (= (count sub-main) (count super-main))
+ (if super-rest
+ (constrain (->> sub-main
+ (drop (count super-main))
+ vec)
+ super-rest)
+ (throw (ex-info "Constraint error: cannot solve type relation"
+ {:subtype (vec subtypes)
+ :supertype (vec supertypes)}))))
+ (when sub-rest
+ (if super-rest
+ (constrain sub-rest super-rest)
+ (throw (ex-info "Constraint error: cannot solve type relation"
+ {:subtype (vec subtypes)
+ :supertype (vec supertypes)})))))
+
+ [[& subtypes] (list supertype)]
+ (let [[sub-main sub-rest] (split-rest subtypes)]
+ (doall (map #(constrain % supertype) sub-main))
+ (when sub-rest
+ (constrain sub-rest `(list ~supertype))))
+
+ [(list sub) (list super)]
+ (constrain sub super)
+
+ [(map subk subv) (map superk superv)]
+ (do (constrain superk subk)
+ (constrain subv superv))
+
+ [(-> sub-args sub-result) (-> super-args super-result)]
+ (do (constrain super-args sub-args)
+ (constrain sub-result super-result))
+
+ [subtype supertype]
+ (cond
+ (and (map? subtype) (map? supertype))
+ (doseq [[property super-value] supertype
+ :let [sub-value (get subtype property)]]
+ (if (nil? sub-value)
+ (throw (ex-info "Object type does not meet subtype constraint"
+ {:subtype subtype
+ :supertype supertype}))
+ (constrain sub-value super-value)))
+
+ (not (subtype? subtype supertype))
+ (throw (ex-info "Constraint error: cannot solve type relation"
+ {:subtype subtype
+ :supertype supertype}))))
+
+;; ;; This helper just prints out all the constraints that get added.
+;; (def constrain (let [c constrain]
+;; (fn [a b]
+;; (prn a '≤ b)
+;; (c a b))))
+
+(declare union-types)
+
+(defn-type-match ^:private intersect-types
+ "Return a type that is a subtype of both `t1` and `t2`. The
+ resulting type should be the greatest such type (that is, for all x
+ that are subtypes of both t1 and t2, x is a subtype of, or equal
+ to, (intersect-types t1 t2))."
+
+ [any t] t
+ [t any] t
+ [none _] 'none
+ [_ none] 'none
+
+ [(enum & t1values) (enum & t2values)]
+ (if-let [common-values (seq (filter (fn [t1value]
+ (cond (string? t1value) (some #(and (string? %) (string= t1value %))
+ t2values)
+ (number? t1value) (some #(and (number? %) (number= t1value %))
+ t2values)
+ (boolean? t1value) (some #(and (boolean? %) (boolean= t1value %))
+ t2values)
+ :else (assert false "We should not reach here.")))
+ t1values))]
+ (cons 'enum common-values)
+ 'none)
+
+ [(optional t1) (optional t2)]
+ (list 'optional (intersect-types t1 t2))
+
+ [(optional t1) t2]
+ (intersect-types t1 t2)
+ [t1 (optional t2)]
+ (intersect-types t1 t2)
+
+ [[& t1s] [& t2s]]
+ (let [[t1-main t1-rest] (split-rest t1s)
+ [t2-main t2-rest] (split-rest t2s)]
+ (cond
+ (empty? t1-main)
+ (intersect-types t1-rest t2s)
+ (empty? t2-main)
+ (intersect-types t1s t2-rest)
+ (= (count t1-main) (count t2-main))
+ (into (mapv intersect-types t1-main t2-main)
+ (when (and t1-rest t2-rest)
+ ['& (intersect-types t1-rest t2-rest)]))
+ (< (count t1-main) (count t2-main))
+ (if t1-rest
+ (into (mapv intersect-types t1-main t2-main)
+ (concat (intersect-types t1-rest
+ (->> t2-main
+ (drop (count t1-main))
+ vec))
+ (when t2-rest
+ ['& (intersect-types t1-rest t2-rest)])))
+ (throw (ex-info "No value can possible satisfy type"
+ {:type (list 'intersect t1 t2)})))
+ (> (count t1-main) (count t2-main))
+ (if t2-rest
+ (into (mapv intersect-types t1-main t2-main)
+ (concat (intersect-types (->> t1-main
+ (drop (count t2-main))
+ vec)
+ t2-rest)
+ (when t1-rest
+ ['& (intersect-types t1-rest t2-rest)])))
+ (throw (ex-info "No value can possible satisfy type"
+ {:type (list 'intersect t1 t2)})))
+ :else
+ (throw (ex-info "No value can possible satisfy type"
+ {:type (list 'intersect t1 t2)}))))
+
+ [[& t1s] (list t2)]
+ (mapv #(intersect-types % t2) t1s)
+
+ [(list t1) [& t2s]]
+ (mapv #(intersect-types t1 %) t2s)
+
+ [(list t1) (list t2)]
+ (list 'list (intersect-types t1 t2))
+
+ [(map t1k t1v) (map t2k t2v)]
+ (list 'map
+ (union-types t1k t2k)
+ (intersect-types t1v t2v))
+
+ [(-> t1k t1v) (-> t2k t2v)]
+ (list '->
+ (union-types t1k t2k)
+ (intersect-types t1v t2v))
+
+ [t1 t2]
+ (cond
+ (and (map? t1) (map? t2)) (merge-with intersect-types t1 t2)
+ (subtype? t1 t2) t1
+ (subtype? t2 t1) t2
+ :else (throw (ex-info "No value can possible satisfy type"
+ {:type (list 'intersect t1 t2)}))))
+
+(defn-type-match ^:private union-types
+ "Return a type that is a supertype of both `t1` and `t2`. The
+ resulting type should be the least such type (that is, for all x
+ that are supertypes of both t1 and t2, x is a supertype of, or equal
+ to, (union-types t1 t2))."
+
+ [any _] 'any
+ [_ any] 'any
+ [none t] t
+ [t none] t
+
+ [(enum & t1values) (enum & t2values)]
+ (cons 'enum
+ (concat (remove (fn [t1value]
+ (cond (string? t1value) (some #(and (string? %) (string= t1value %))
+ t2values)
+ (number? t1value) (some #(and (number? %) (number= t1value %))
+ t2values)
+ (boolean? t1value) (some #(and (boolean? %) (boolean= t1value %))
+ t2values)
+ :else (assert false "We should not reach here.")))
+ t1values)
+ t2values))
+
+ [(optional t1) (optional t2)]
+ (list 'optional (union-types t1 t2))
+
+ [(optional t1) t2]
+ (list 'optional (union-types t1 t2))
+ [t1 (optional t2)]
+ (list 'optional (union-types t1 t2))
+
+ [[& t1s] [& t2s]]
+ (let [[t1-main t1-rest] (split-rest t1s)
+ [t2-main t2-rest] (split-rest t2s)]
+ (cond
+ (empty? t1-main)
+ (union-types t1-rest t2s)
+ (empty? t2-main)
+ (union-types t1s t2-rest)
+ (= (count t1-main) (count t2-main))
+ (into (mapv union-types t1-main t2-main)
+ (when (and t1-rest t2-rest)
+ ['& (union-types t1-rest t2-rest)]))
+ (< (count t1-main) (count t2-main))
+ (if t1-rest
+ (into (mapv union-types t1-main t2-main)
+ (concat (union-types t1-rest
+ (->> t2-main
+ (drop (count t1-main))
+ vec))
+ (when t2-rest
+ ['& (union-types t1-rest t2-rest)])))
+ (throw (ex-info "Cannot construct type union"
+ {:type (list 'union t1 t2)})))
+ (> (count t1-main) (count t2-main))
+ (if t2-rest
+ (into (mapv union-types t1-main t2-main)
+ (concat (union-types (->> t1-main
+ (drop (count t2-main))
+ vec)
+ t2-rest)
+ (when t1-rest
+ ['& (union-types t1-rest t2-rest)])))
+ (throw (ex-info "Cannot construct type union"
+ {:type (list 'union t1 t2)})))
+ :else
+ (throw (ex-info "Cannot construct type union"
+ {:type (list 'union t1 t2)}))))
+
+ [[& t1s] (list t2)]
+ (let [[t1-main t1-rest] (split-rest t1s)]
+ (into (mapv #(union-types % t2) t1-main)
+ (if t1-rest
+ ['& (union-types t1-rest t2)]
+ ['& `(list ~t2)])))
+
+ [(list t1) [& t2s]]
+ (let [[t2-main t2-rest] (split-rest t2s)]
+ (into (mapv #(union-types t1 %) t2-main)
+ (if t2-rest
+ ['& (union-types t1 t2-rest)]
+ ['& `(list ~t1)])))
+
+ [(list t1) (list t2)]
+ (list 'list (union-types t1 t2))
+
+ [(map t1k t1v) (map t2k t2v)]
+ (list 'map
+ (intersect-types t1k t2k)
+ (union-types t1v t2v))
+
+ [(-> t1k t1v) (-> t2k t2v)]
+ (list '->
+ (intersect-types t1k t2k)
+ (union-types t1v t2v))
+
+ [t1 t2]
+ (cond
+ (subtype? t1 t2) t2
+ (subtype? t2 t1) t1
+ :else (throw (ex-info "Cannot construct type union"
+ {:type (list 'union t1 t2)}))))
+
+(defn-type-match ground-type
+ "Attempt to eliminate type-variables from `type` by solving any
+ constraints that can be solved. If `position` is :argument then
+ variables will be solved for their least upper bound, if `position`
+ is :return
+ bound."
+
+ [(var n) position]
+ (do (when-not (or (= position :argument) (= position :return))
+ (throw (ex-info (str "Invalid position: " (pr-str position))
+ {:position position})))
+ (if-let [type-bounds (->> (get (case position
+ :argument *upper-bounds*
+ :return *lower-bounds*)
+ n)
+ (map #(ground-type % position))
+ seq)]
+ (reduce (case position
+ :argument intersect-types
+ :return union-types)
+ type-bounds)
+ 'any))
+
+ [(enum & values) _]
+ (cons 'enum values)
+
+ [(optional t) position]
+ (list 'optional
+ (type-match (ground-type t position)
+ (optional t2) t2
+ t2 t2))
+
+ [(unoptional t) position]
+ (type-match (ground-type t position)
+ (optional t2) t2
+ t2 t2)
+
+ [[& ts] position]
+ (let [[t-main t-rest] (split-rest ts)]
+ (into (mapv #(ground-type % position) t-main)
+ (when t-rest
+ ['& (ground-type t-rest position)])))
+
+ [(list t) position]
+ (list 'list (ground-type t position))
+
+ [(map k v) position]
+ (list 'map
+ (ground-type k (case position
+ :argument :return
+ :return :argument))
+ (ground-type v position))
+
+ [(-> k v) position]
+ (list '->
+ (ground-type k (case position
+ :argument :return
+ :return :argument))
+ (ground-type v position))
+
+ [type position]
+ (cond
+ (map? type) (persistent!
+ (reduce-kv (fn [acc prop prop-type]
+ (assoc! acc prop (ground-type prop-type position)))
+ (transient {}) type))
+ (symbol? type) (symbol (name type))
+ :else (throw (ex-info "Cannot ground invalid type" {:type type}))))
+
+(defn type-variable
+ "Create a new type variable, distinct from all existing type
+ variables."
+ []
+ (list 'var (gensym "type")))