summaryrefslogtreecommitdiff
path: root/src/main/java/au/id/zancanaro/javacheck/ShrinkTree.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/au/id/zancanaro/javacheck/ShrinkTree.java')
-rw-r--r--src/main/java/au/id/zancanaro/javacheck/ShrinkTree.java128
1 files changed, 128 insertions, 0 deletions
diff --git a/src/main/java/au/id/zancanaro/javacheck/ShrinkTree.java b/src/main/java/au/id/zancanaro/javacheck/ShrinkTree.java
new file mode 100644
index 0000000..a424806
--- /dev/null
+++ b/src/main/java/au/id/zancanaro/javacheck/ShrinkTree.java
@@ -0,0 +1,128 @@
+package au.id.zancanaro.javacheck;
+
+import java.io.IOException;
+import java.io.Writer;
+import java.util.*;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+@SuppressWarnings("unused")
+public class ShrinkTree<T> {
+ private final T value;
+ private final Iterable<ShrinkTree<T>> children;
+
+ public ShrinkTree(T value, Iterable<ShrinkTree<T>> children) {
+ this.value = value;
+ this.children = children;
+ }
+
+ public T getValue() {
+ return value;
+ }
+
+ public Iterator<ShrinkTree<T>> getChildren() {
+ return children.iterator();
+ }
+
+ public static <T> ShrinkTree<T> pure(T value) {
+ return new ShrinkTree<>(value, Collections.emptyList());
+ }
+
+ public static <T> ShrinkTree<T> join(ShrinkTree<ShrinkTree<T>> tree) {
+ return new ShrinkTree<>(
+ tree.getValue().getValue(),
+ () -> Iterators.concat(
+ Iterators.mappingIterator(ShrinkTree::join, tree.children.iterator()),
+ tree.getValue().children.iterator()));
+ }
+
+ private static <T> Iterator<ShrinkTree<T>[]> permutations(ShrinkTree<T>[] trees) {
+ return Iterators.flatten(
+ Iterators.rangeIterator(trees.length,
+ index -> Iterators.mappingIterator(child -> {
+ @SuppressWarnings("unchecked")
+ ShrinkTree<T>[] result = (ShrinkTree<T>[]) new ShrinkTree[trees.length];
+ for (int i = 0; i < trees.length; ++i) {
+ result[i] = (i == index ? child : trees[i]);
+ }
+ return result;
+ }, trees[index].getChildren())
+ ));
+ }
+
+ private static <T> List<T> makeHeadList(ShrinkTree<T>[] trees) {
+ List<T> heads = new ArrayList<>(trees.length);
+ for (ShrinkTree<T> tree : trees) {
+ heads.add(tree.getValue());
+ }
+ return heads;
+ }
+
+ public static <T, R> ShrinkTree<R> zip(Function<List<T>, R> fn, ShrinkTree<T>[] trees) {
+ return new ShrinkTree<>(
+ fn.apply(makeHeadList(trees)),
+ () -> Iterators.mappingIterator(
+ shrinks -> ShrinkTree.zip(fn, shrinks),
+ ShrinkTree.permutations(trees)));
+ }
+
+ private static <T> Iterator<ShrinkTree<T>[]> removeEach(ShrinkTree<T>[] trees) {
+ return Iterators.concat(
+ Iterators.rangeIterator(trees.length, index -> {
+ @SuppressWarnings("unchecked")
+ ShrinkTree<T>[] result = (ShrinkTree<T>[]) new ShrinkTree[trees.length - 1];
+ for (int i = 0; i < trees.length - 1; ++i) {
+ result[i] = trees[(i >= index ? i + 1 : i)];
+ }
+ return result;
+ }),
+ permutations(trees));
+ }
+
+ public static <T, R> ShrinkTree<R> shrink(Function<List<T>, R> fn, ShrinkTree<T>[] trees) {
+ return new ShrinkTree<>(
+ fn.apply(makeHeadList(trees)),
+ () -> Iterators.mappingIterator(
+ shrinks -> ShrinkTree.shrink(fn, shrinks),
+ ShrinkTree.removeEach(trees)));
+ }
+
+ public <R> ShrinkTree<R> map(Function<T, R> f) {
+ return new ShrinkTree<>(
+ f.apply(this.value),
+ () -> Iterators.mappingIterator(tree -> tree.map(f), this.children.iterator()));
+ }
+
+ public <R> ShrinkTree<R> flatMap(Function<T, ShrinkTree<R>> f) {
+ return ShrinkTree.join(this.map(f));
+ }
+
+ public ShrinkTree<T> filter(Predicate<T> predicate) {
+ if (predicate.test(this.getValue())) {
+ return new ShrinkTree<>(
+ this.getValue(),
+ () -> Iterators.mappingIterator(tree -> tree.filter(predicate),
+ Iterators.filteringIterator(
+ tree -> predicate.test(tree.getValue()),
+ this.getChildren())));
+ } else {
+ throw new IllegalArgumentException("Current value doesn't match predicate: whoops!");
+ }
+ }
+
+ @SuppressWarnings("unused")
+ public void print(Writer output) throws IOException {
+ print(output, Object::toString);
+ }
+
+ @SuppressWarnings("unused")
+ public void print(Writer output, Function<T, String> toString) throws IOException {
+ output.write(toString.apply(this.getValue()));
+ output.write('[');
+ for (ShrinkTree<T> child : children) {
+ child.print(output, toString);
+ }
+ output.write(']');
+ output.flush();
+ }
+}