JavaのStreamで並列化できるのは最上位のStreamだけ

結論からいうと、Streamで並列化できるのは、最上位のStreamだけ。flatMapなどで,入れ子のStreamの処理については並行が効かない。(Javaのコードも読んだが,flatMapのStreamは内部コードでSequentialに変更されていることを確認した。)

基本的にはできるだけ上位の処理を並列化するのが,一般的には効率がよく並列化できるため,この実装は納得できるものではあるが,必ずしも万能なわけではない。

例えば,木構造を辿って処理するようなコードを書いていたりする場合に,木構造の偏りによって並列化が期待したほど効かないということがある。そもそも木構造を処理するのは別の方法のほうがよいかもしれないが。

確認のためのコードは以下。タスクを二重配列で持っていて,これをStreamで処理することを考える。

package parallel;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class Test1 {
    
    static class TestTask implements Runnable {
        @Override
        public void run() {
            try {
                Thread.sleep(100);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
    
    public static void main(String[] args) throws InterruptedException {
        
        List<List<TestTask>> table = new ArrayList<>();

        int sum = 0;
        int max = 0;
        Random rgen = new Random(0);
        for (int j = 0; j < 3; ++j) {
            List<TestTask> list = IntStream.range(0, 5 + rgen.nextInt(10))
                    .mapToObj(i -> new TestTask())
                    .collect(Collectors.toList());
            table.add(list);

            sum += list.size();
            max = Math.max(max, list.size());
        }

        System.out.println("table: "+ table.stream()
                .map(l -> Integer.toString(l.size()))
                .collect(Collectors.joining(", ")));
        System.out.printf("sum: %d, max: %d\n", sum, max);
        System.out.println("parallelism: "+ ForkJoinPool.commonPool().getParallelism());
        System.out.println();

        {
            // 1.base-parallel
            long t0 = System.currentTimeMillis();
            table.stream()
                    .parallel()
                    .flatMap(l -> l.stream())
                    .forEach(task -> task.run());
            System.out.printf("1.base-parallel:     %,5d msec\n", (System.currentTimeMillis() - t0));
        }

        {
            // 2.flat-map-parallel
            long t0 = System.currentTimeMillis();
            table.stream()
                    .flatMap(l -> l.parallelStream())
                    .forEach(task -> task.run());
            System.out.printf("2.flat-map-parallel: %,5d msec\n", (System.currentTimeMillis() - t0));
        }

        {
            // 3.both-parallel
            long t0 = System.currentTimeMillis();
            table.parallelStream()
            .flatMap(l -> l.parallelStream()
                    .peek(task -> task.run())
                    .collect(Collectors.toList())
                    .stream())
                    .forEach(task -> {});
            System.out.printf("3.both-parallel:     %,5d msec\n", (System.currentTimeMillis() - t0));
        }
    }
}

結果は以下。

table: 5, 13, 14
sum: 32, max: 14
parallelism: 3

1.base-parallel:     1,424 msec
2.flat-map-parallel: 3,281 msec
3.both-parallel:       926 msec

各結果について順にみていく。

1.base-parallelで上位の配列でのみ並列化される。そのため,タスクの二重配列の多少偏りの影響を受ける。table[3].lengthの14の影響を受けて14*100msecの時間がかかることが確認できる。

2.flat-map-parallelでは,冒頭で述べた通りparallelは効かない。そのため,すべてのタスクをsequentialに処理するため,タスクの合計の32*100msecの時間がかかっている。

3.both-parallelでは,flatMap内で無駄に終端処理collectを挟むことで,並列化させている。これによって,最も短い時間で終了できている。

ちなみに,確認した反映ではJavadocなどの記載でもこのあたりの振る舞いについての記載は見つけられなかった。どこかに記載があることをご存知の方は教えていただけると助かります。