Writing a custom spliterator in Java 8

In this article I’m going to give two examples of writing a custom spliterator in Java 8. What is a spliterator and why would you need to write your own? Well, a spliterator is used by the Java streams code when you call stream() on a collection or other object. The two most important methods in the Spliterator interface are as follows:

  • boolean tryAdvance(Consumer action);
  • Spliterator trySplit();

Any custom spliterator must implement tryAdvance. It is this method which is invoked to get each element of a stream to process. The trySplit method only needs to be implemented if you are going to create a parallel stream. It is invoked to split the stream into sections which can be safely processed in parallel.

Now let’s run through a few examples. I’m going to include all of the code snippets inline, but if you want a working example, all of the source code is available on my github:

https://github.com/hedleyproctor/java8-examples

To begin with, let’s consider a case where you don’t need a parallel stream, but you do have a custom class for which you want to write a spliterator. Suppose I work on an application that processes html data, and I decide that to get test data for my application, I could just scrap random pages off the web. I could use a library like jsoup to get pages, and for each page, put any links on a list, so that if I need to get another page, I can just retrieve the next link. A simple implementation could look like this:

import org.jsoup.Jsoup;
import org.jsoup.nodes.Element;
import org.jsoup.select.Elements;
 
import java.io.IOException;
import java.util.LinkedList;
import java.util.Queue;
 
public class WebPageProvider {
 
    private Queue<String> urls = new LinkedList<String>();
 
    public WebPageProvider() {
        urls.add("https://en.wikipedia.org/wiki/Main_Page");
    }
 
    public Document getPage() {
        org.jsoup.nodes.Document doc = null;
 
        while (doc == null) {
            String nextPageURL = urls.remove();
            System.out.println("Next page: " + nextPageURL);
            try {
                doc = Jsoup.connect(nextPageURL).get();
            } catch (IOException e) {
                // we'll try the next one on our list
            }
        }
 
        // get links and put on our queue
        Elements links = doc.select("a[href]");
        for (Element link : links) {
            String newURL = link.attr("abs:href");
            // System.out.println(newURL);
            urls.add(newURL);
        }
        return new Document(doc);
    }
}

Now, what I’d really like to be able to do is to use all of the useful methods in streams to be able to provide different sorts of test data. For example, suppose I just wanted images, I could map each web page to get the list of images on the page, then call flatMap to flatten the stream of List objects back to a stream of Image objects, like this:

 StreamSupport.stream( new WebPageSpliterator(new WebPageProvider()), false)
                                        .map(Document::getImages)
                                        .flatMap(List::stream)
                                        .limit(10);

Or perhaps filter to only include documents with five or more images:

StreamSupport.stream(new WebPageSpliterator(new WebPageProvider()), false)
                                                        .filter(doc -> doc.getImages().size() >= 5)
                                                        .limit(10);

Seems useful, so how do we implement the spliterator? Well, it’s pretty trivial:

import java.util.Spliterator;
import java.util.function.Consumer;
 
public class WebPageSpliterator implements Spliterator<Document> {
    private WebPageProvider webPageProvider;
 
    public WebPageSpliterator(WebPageProvider webPageProvider) {
        this.webPageProvider = webPageProvider;
    }
 
    @Override
    public boolean tryAdvance(Consumer<? super Document> action) {
        action.accept(webPageProvider.getPage());
        return true;
    }
 
    @Override
    public Spliterator<Document> trySplit() {
        return null;
    }
 
    @Override
    public long estimateSize() {
        return 0;
    }
 
    @Override
    public int characteristics() {
        return 0;
    }
}

You can see that all we’ve had to do is implement the tryAdvance method. Since the backing provider can provide an infinite number of web pages (assuming pages keep linking to other pages) there is no complex logic needed inside this method. It simply calls the accept method of the Consumer code passed into it (Consumer is a Java 8 functional interface, allowing callers to pass in a lambda) and then returns true, to signify that more pages can be returned if required.

Now let’s consider a more complex example involving parallel processing. When would you need to write a custom spliterator for parallel processing? Well, one situation is when you have a stream of objects, but the stream has an internal ordering or structure, meaning that a naive split of the stream at a random point might not produce sections that can validly be processed in parallel. In my github repo, I’ve given two separate examples of this type of scenario. In one, you have a character stream, which actually represents a custom record format. i.e. you need to split the stream at the record boundaries. In the other, you have a stream of Payment objects, but really these are grouped into payment batches, and you must split the stream at a payment batch boundary. Let’s look at this example. The payment batch test data is created like this:

    private List<Payment> createSampleData() {
        List<Payment> paymentList = new ArrayList<>();
        for (int i=0; i<1000; i++) {
            paymentList.add(new Payment(10,"A"));
            paymentList.add(new Payment(20,"A"));
            paymentList.add(new Payment(30,"A"));
            // total = 60
 
            paymentList.add(new Payment(20,"B"));
            paymentList.add(new Payment(30,"B"));
            paymentList.add(new Payment(40,"B"));
            paymentList.add(new Payment(50,"B"));
            paymentList.add(new Payment(60,"B"));
            // total = 200
 
            paymentList.add(new Payment(30,"C"));
            paymentList.add(new Payment(30,"C"));
            paymentList.add(new Payment(20,"C"));
            // total = 80
        }
        return paymentList;
    }

We want to total each batch. You can see that if you did this in parallel, but didn’t split on the batch boundaries, you would get the wrong totals, because you would count more batches than actually exist. e.g. by splitting the second batch into two. We can verify this, and then implement a custom spliterator and check that with the custom spliterator, the totals are correct. First, let’s create a collector to count up the totals:

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
 
public class PaymentBatchTotaller 
    implements Collector<Payment,PaymentBatchTotaller.Accumulator,Map<String,Double>> {
 
    public class Total {
        public double amount;
        public int numberOfBatches;
    }
 
    public class Accumulator {
        Map<String,Total> totalsByCategory = new HashMap<>();
        String currentPaymentCategory;
    }
 
    @Override
    public Supplier<Accumulator> supplier() {
        return Accumulator::new;
    }
 
    @Override
    public BiConsumer<Accumulator,Payment> accumulator() {
        return (accumulator,payment) -> {
            // store this amount
            Total batchTotalForThisCategory = accumulator.totalsByCategory.get(payment.getCategory());
            if (batchTotalForThisCategory == null) {
                batchTotalForThisCategory = new Total();
                accumulator.totalsByCategory.put(payment.getCategory(),batchTotalForThisCategory);
            }
            batchTotalForThisCategory.amount += payment.getAmount();
 
            // if this was start of a new batch, increment the counter
            if (!payment.getCategory().equals(accumulator.currentPaymentCategory)) {
                batchTotalForThisCategory.numberOfBatches += 1;
                accumulator.currentPaymentCategory = payment.getCategory();
            }
        };
    }
 
    @Override
    public BinaryOperator<Accumulator> combiner() {
        return (accumulator1,accumulator2) -> {
            for (String category : accumulator1.totalsByCategory.keySet()) {
                Total total2 = accumulator2.totalsByCategory.get(category);
                if (total2 == null) {
                    accumulator2.totalsByCategory.put(category,accumulator1.totalsByCategory.get(category));
                } else {
                    Total total1 = accumulator1.totalsByCategory.get(category);
                    total2.amount += total1.amount;
                    total2.numberOfBatches += total1.numberOfBatches;
                }
            }
            return accumulator2;
        };
    }
 
    @Override
    public Function<Accumulator, Map<String, Double>> finisher() {
        return (accumulator) -> {
            Map<String,Double> results = new HashMap<>();
            for (Map.Entry<String,Total> entry : accumulator.totalsByCategory.entrySet()) {
                String category = entry.getKey();
                Total total = entry.getValue();
                double averageForBatchInThisCategory = total.amount / total.numberOfBatches;
                results.put(category,averageForBatchInThisCategory);
            }
            return results;
        };
    }
 
    @Override
    public Set<Characteristics> characteristics() {
        return Collections.EMPTY_SET;
    }
}

You can see that this collector keeps totals for each payment batch category, along with the number of batches in that category, then the finisher method divides each total by the number of batches in that category to get the average batch size. (If you aren’t familiar with custom collectors, you might like to read my previous article ?View Code JAVA

List<Payment> payments = createSampleData();
 
// won't work in parallel!
Map<String,Double> averageTotalsPerBatchAndCategory = payments.parallelStream().collect(new PaymentBatchTotaller());
 
Set<Map.Entry<String,Double>> entrySet = averageTotalsPerBatchAndCategory.entrySet();
assertEquals(entrySet.size(),3);
for (Map.Entry<String,Double> total : averageTotalsPerBatchAndCategory.entrySet()) {
    if (total.getKey().equals("A")) {
       assertEquals(60d,total.getValue());
    } else if (total.getKey().equals("B")) {
       assertEquals(200d,total.getValue());
    } else {
       assertEquals(80d,total.getValue());
    }
}

To begin with, our spliterator must keep hold of its backing list, and will need to keep track of its current and end positions in the list:

public class PaymentBatchSpliterator implements Spliterator<Payment> {
 
    private List<Payment> paymentList;
    private int current;
    private int last;  // inclusive
 
    public PaymentBatchSpliterator(List<Payment> payments) {
        this.paymentList = payments;
        last = paymentList.size() - 1;
    }

The implementation of tryAdvance is fairly simple. Providing we aren’t at the end of the list yet, we need to call accept on the Consumer code passed in, then increment our current counter and return true:

@Override
public boolean tryAdvance(Consumer<? super Payment> action) {
    if (current <= last) {
        action.accept(paymentList.get(current));
        current++;
        return true
    }
    return false;
}

Now we come to the real logic, the implementation of trySplit. We can implement this by saying: generate a possible split position, half way along the list, then check if it is a boundary between payment batches, if not, move forward until it is. The code looks like this:

    @Override
    public Spliterator<Payment> trySplit() {
        if ((last - current) < 100) {
            return null;
        }
 
        // first stab at finding a split position
        int splitPosition = current + (last - current) / 2;
        // if the categories are the same, we can't split here, as we are in the middle of a batch
        String categoryBeforeSplit = paymentList.get(splitPosition-1).getCategory();
        String categoryAfterSplit = paymentList.get(splitPosition).getCategory();
 
        // keep moving forward until we reach a split between categories
        while (categoryBeforeSplit.equals(categoryAfterSplit)) {
            splitPosition++;
            categoryBeforeSplit = categoryAfterSplit;
            categoryAfterSplit = paymentList.get(splitPosition).getCategory();
        }
 
        // safe to create a new spliterator
        PaymentBatchSpliterator secondHalf = new PaymentBatchSpliterator(paymentList,splitPosition,last);
        // reset our own last value
        last = splitPosition - 1;
 
        return secondHalf;
    }

Finally there is one little detail not to be missed. We must implement the estimateSize() method. Why? Well, this is called internally by the stream code to check if it needs to do any more splitting – if you don’t implement it, your stream will never be split! The implementation is trivial:

    @Override
    public long estimateSize() {
        return last - current;
    }

Finally we can test this by using the spliterator in our test code when we count the totals:

        Map<String,Double> averageTotalsPerBatchAndCategory =
                StreamSupport.stream(new PaymentBatchSpliterator(payments),true).collect(new PaymentBatchTotaller());

This will generate the correct totals. If you want to look at the character stream example, please check the github repo. You might also be interested in some of my other blog posts on Java 8:

Streams tutorial
Using Optional in Java 8

Posted in Java, Uncategorized | Leave a comment