(ns jester.match "A simple pattern matcher, to avoid pulling in core.match. It is also designed to specifically match against types, so it has some special symbol handling." (:require [clojure.walk :as walk])) (def ^:dynamic ^:private *seen-vars*) (defn ^:private special-type-symbol? [symbol] (and (symbol? symbol) (contains? #{"string" "number" "boolean" "any" "none" "_" "&"} (name symbol)))) (defmacro ^:private ensuring-unique-vars {:style/indent 0} [& body] `(binding [*seen-vars* {}] (let [result# (do ~@body) multiple-vars# (into #{} (comp (filter (fn [[_# v#]] (> v# 1))) (map key)) *seen-vars*)] (when-not (empty? multiple-vars#) (throw (ex-info "Variables are bound multiple times" {:names (mapv symbol multiple-vars#)}))) result#))) (declare pattern->condition) (defn seq-pattern->condition [pattern value] (if-let [[pat-head & pat-tail] (seq pattern)] (if (= pat-head '&) (pattern->condition (first pat-tail) value) (let [head-s (gensym "head") tail-s (gensym "tail")] `(let [[~head-s & ~tail-s] ~value] (and ~(pattern->condition pat-head head-s) ~(seq-pattern->condition pat-tail tail-s))))) true)) (defn ^:private pattern->condition [pattern value] (cond (vector? pattern) (let [value-s (gensym "value")] `(let [~value-s ~value] (and (vector? ~value-s) ~@(->> pattern (take-while #(not= % '&)) (map-indexed (fn [i p] (pattern->condition p `(get ~value-s ~i)))) doall) ~@(when (seq (drop-while #(not= % '&) pattern)) [(pattern->condition (second (drop-while #(not= % '&) pattern)) `(subvec ~value-s ~(count (take-while #(not= % '&) pattern))))])))) (seq? pattern) (let [[head & tail] pattern] (assert (symbol? head) (str "Head of a seq pattern must be a symbol, not " (pr-str head))) (let [head-s (gensym "head") tail-s (gensym "tail")] `(let [v# ~value] (when (seq? v#) (let [[~head-s & ~tail-s] v#] (and (symbol? ~head-s) (= ~(name head) (name ~head-s)) ~(seq-pattern->condition tail tail-s))))))) (special-type-symbol? pattern) (if (= (name pattern) "_") true `(let [v# ~value] (and (symbol? v#) (= ~(name pattern) (name v#))))) (symbol? pattern) (do (when-not (= (name pattern) "_") (set! *seen-vars* (update *seen-vars* (name pattern) (fnil inc 0)))) true) :else `(= ~pattern ~value))) (defn ^:private pattern->destructuring-form [pattern] (cond (vector? pattern) (into [] (map pattern->destructuring-form) pattern) (seq? pattern) (into ['_] (map pattern->destructuring-form) (rest pattern)) (symbol? pattern) (if (= pattern '&) '& (if (special-type-symbol? pattern) '_ pattern)) :else '_)) (defmacro type-match-let {:style/indent 1} [[pattern value] & body] (ensuring-unique-vars (let [value-s (gensym "value")] `(let [~value-s ~value] (if ~(pattern->condition pattern value-s) (let [~(pattern->destructuring-form pattern) ~value-s] ~@body) (throw (ex-info "Value did not match pattern" {:pattern '~pattern :value ~value-s}))))))) (defmacro type-match {:style/indent 1} [value & clauses] (let [value-s (gensym "value")] `(let [~value-s ~value] (cond ~@(mapcat (fn [[pattern result]] (ensuring-unique-vars [(pattern->condition pattern value-s) `(let [~(pattern->destructuring-form pattern) ~value-s] ~result)])) (partition 2 clauses)) :else (throw (ex-info "Value did not match any pattern" {:patterns '~(mapv first (partition 2 clauses)) :value ~value-s})))))) (defmacro defn-type-match [fn-name & clauses] (let [[doc clauses] (if (and (string? (first clauses)) (odd? (count clauses))) [(first clauses) (rest clauses)] [nil clauses]) patterns (map first (partition 2 clauses)) all-symbols? (fn [pattern] (every? (fn [s] (and (symbol? s) (not (special-type-symbol? s)))) pattern)) argument-symbols (or (seq (first (filter all-symbols? patterns))) (map #(do % (gensym "arg")) (first patterns)))] (assert (every? vector? patterns) "All patterns in a defn-type-match must be a vector") (assert (apply = (map count patterns)) "All patterns in a defn-type-match must have the same number of items") `(defn ~fn-name ~@(when doc [doc]) [~@argument-symbols] (type-match [~@argument-symbols] ~@clauses))))