Fork/Joinフレームワークを用いたマージソート

Java7ではFork/Joinフレームワークが新たに追加になっているおかげでマージソートなどでデータをソートするアルゴリズムを効率的に処理するプログラミングが比較的簡単に作成することが可能です。

今回は自分用のメモとしてFork/Joinフレームワークを利用したマージソート処理を記載します。

まず最初にRecursiveTaskを継承したマージソートを行うMergeSortTaskクラスを作成します。

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveTask;

public abstract class MergeSortTask<T> extends RecursiveTask<List<T>> {

    private static final long serialVersionUID = 1L;

    private final List<T> values;
    
    protected MergeSortTask(List<T> values) {
        this.values = values;
    }

    @Override
    protected List<T> compute() {
        if (values.size() > 1) {
            // ソート対象リストサイズの中間値を算出
            int mid = values.size() / 2;
            final MergeSortTask<T> task1 = createMergeSortTask(values.subList(0, mid));
            final MergeSortTask<T> task2 = createMergeSortTask(values.subList(mid, values.size()));
            // 分割した最初のタスクをfork
            task1.fork();
            // 次タスクの処理を実行
            List<T> sortedValues2 = task2.compute();
            // forkした最初の処理を待ち合わせる
            List<T> sortedValues1 = task1.join();
            return merge(sortedValues1, sortedValues2);
        } else {
            return values;
        }
    }
    
    protected abstract MergeSortTask<T> createMergeSortTask(List<T> values);
    
    /*
     * マージメソッド
     */
    protected List<T> merge(List<T> arg1, List<T> arg2) {
        final List<T> mergedValues = new ArrayList<>(arg1.size() + arg2.size());
        int i = 0, j =0,  k = 0;
        while (i < arg1.size() || j < arg2.size()) {
            // 値の入れ替え
            if (j >= arg2.size() || ((i < arg1.size()) && compare(arg1.get(i), arg2.get(j)))) {
                mergedValues.add(k, arg1.get(i++));
            } else {
                mergedValues.add(k, arg2.get(j++));
            }
            k++;
        }
        return mergedValues;
    }
    
    /*
     * 大小比較メソッド
     */
    protected abstract boolean compare(T arg1, T arg2);

}

ポイントはcomputeメソッド内でソート対象となるリストを分割して、最初のタスクをforkし、次のタスクを処理後にforkしたタスクの待ち合わせを行いmergeメソッドでマージしているという点です。

抽象メソッドとなっている部分は昇順と降順用のソートクラスを作成するためにあえてテンプレートメソッドパターンを利用しています。

それぞれ昇順用と降順用のサブクラスは以下のようになります。

import java.util.List;

/**
 * 昇順マージソートタスク
 */
public class ComparableAscMergeSortTask<T extends Comparable<? super T>> extends MergeSortTask<T> {
    
    private static final long serialVersionUID = 1L;

    public ComparableAscMergeSortTask(List<T> values) {
        super(values);
    }

    @Override
    protected MergeSortTask<T> createMergeSortTask(List<T> values) {
        return new ComparableAscMergeSortTask<>(values);
    }

    @Override
    protected boolean compare(T arg1, T arg2) {
        return arg1.compareTo(arg2) < 0 ? true : false;
    }
}
import java.util.List;

/**
 * 降順マージソートタスク
 */
public class ComparableDescMergeSortTask<T extends Comparable<? super T>> extends MergeSortTask<T> {

    private static final long serialVersionUID = 1L;

    public ComparableDescMergeSortTask(List<T> values) {
        super(values);
    }
    
    @Override
    protected MergeSortTask<T> createMergeSortTask(List<T> values) {
        return new ComparableDescMergeSortTask<>(values);
    }

    @Override
    protected boolean compare(T arg1, T arg2) {
        return arg1.compareTo(arg2) > 0 ? true : false;
    }

}

それぞれの違いはcompareメソッド内の比較演算子の向きぐらいです。

メインクラスは以下の通り。

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;

public class MergeSortMain {

    public static void main(String[] args) {
        List<Integer> values = new ArrayList<>();
        for (int i= 0; i < 10; i++) {
            long seed = System.nanoTime();
            Random random = new Random(seed);
            values.add(Math.abs(random.nextInt()));
        }

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        System.out.println("========= ASC =========");
        List<Integer> results = forkJoinPool.invoke(new ComparableAscMergeSortTask<Integer>(values));
        for (Integer r : results) {
            System.out.println(r);
        }

        System.out.println("========= DESC =========");
        results = forkJoinPool.invoke(new ComparableDescMergeSortTask<Integer>(values));
        for (Integer r : results) {
            System.out.println(r);
        }
    }

}

ForkJoinPoolクラスを生成し、invokeメソッドにそれぞれのタスクを引数で渡し、ソート結果をリストで受け取ります。

実行結果は以下のような感じです。


========= ASC =========
220809008
313991963
599017177
882423875
946292192
1064163273
1156976803
1363513242
2090228414
2138086498
========= DESC =========
2138086498
2090228414
1363513242
1156976803
1064163273
946292192
882423875
599017177
313991963
220809008

メモリにやさしくない感はありますがサンプルなので。