NIOでソケットプログラミング

個人的にはJava7の目玉機能の一つはNIO2だと思います。

ただ、NIO自体を体系的に勉強したわけじゃないから実はあまり分かっていない部分も多い今日この頃、あえて、NIOを利用したソケットプログラムをやっつけで書いてみました。(暇つぶしに。)

リクエストを受け取ると乱数を生成してクライアントに返すHTTPサーバですが、NIOの詳細がいまいち理解不足なためあっているかどうかわかりません。

とりあえず、リクエストの受付を行うハンドラと、リクエストの読み込みと、クライアントへのレスポンスの返信を行うハンドラを作成し、チャネルに対してセレクタとセレクションキー及び、作成したハンドラを登録する仕組みを採用。
例外処理とか、終了処理とかはいろいろ適当です。

個人的なメモとしてプログラムをとりあえず記録しておくことにしました。
(正しい使い方かどうか確証がないのでホントにメモです。)

【Handlerインターフェース】

import java.nio.channels.SelectionKey;

public interface Handler {

    public void handle(SelectionKey selectionKey);
}

【AcceptHandler】

import java.io.IOException;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;

public class AcceptHandler implements Handler {
    
    @Override
    public void handle(SelectionKey selectionKey) {
        try {
            ServerSocketChannel serverSocketChannel = (ServerSocketChannel) selectionKey.channel();
            SocketChannel socketChannel = serverSocketChannel.accept();
            socketChannel.configureBlocking(false);
            // ReadWriteHandlerをアタッチメント
            socketChannel.register(selectionKey.selector(), SelectionKey.OP_READ, new ReadWriteHandler()); 
        } catch (IOException e) {
            e.printStackTrace();
        }
        
    }

}

【ReadWriteHandler】

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class ReadWriteHandler implements Handler {
    
    private static final int READ_BUFFER_SIZE = 8190;
    
    /** あえてバッファサイズを少なくして実験 */
    private static final int WRITE_BUFFER_SIZE = 10;
    
    private List<ByteBuffer> buffers = new ArrayList<ByteBuffer>();
       
    @Override
    public void handle(SelectionKey selectionKey) {
        if (selectionKey.isReadable()) {
            handleRequest(selectionKey);
        }
        if (selectionKey.isWritable()) {
            handleResponse(selectionKey);
        }
    }
    
    private void handleRequest(final SelectionKey selectionKey) {
        SocketChannel channel = (SocketChannel)selectionKey.channel();
        ByteBuffer buf = ByteBuffer.allocate(READ_BUFFER_SIZE);
        try {
            if(channel.read(buf) >= 0) {
                buf.flip();
                byte[] bytes = new byte[buf.limit()];
                buf.get(bytes);
                String request = new String(bytes);
                // 標準出力にリクエスト内容を表示
                System.out.println("[server] The received request is " + request);
                createResponse();
                selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE);
            }
        } catch (IOException e) {
            e.printStackTrace();
            closeChannel(channel);
        }
    }
    
    private void createResponse() {
        final long seed = System.nanoTime();
        final Random random = new Random(seed);
        final String responseText = "HTTP/1.1 200 OK\n\n" + Long.toString(random.nextLong());
        System.out.println("[server] The response message is " + responseText);
        final byte[] responseBytes = responseText.getBytes();
        if (responseBytes.length > WRITE_BUFFER_SIZE) {
            // バッファに収まるようにブロックサイズを算出
            int blockSize = (responseBytes.length % WRITE_BUFFER_SIZE) == 0 ? responseBytes.length / WRITE_BUFFER_SIZE : (responseBytes.length / WRITE_BUFFER_SIZE) + 1;
            for (int i = 0, j = 0; i < blockSize; i++, j += WRITE_BUFFER_SIZE) {
                ByteBuffer buf = ByteBuffer.allocate(WRITE_BUFFER_SIZE);
                byte[] chunk;
                if (i == (blockSize - 1)) {
                    int lastPos = (j + WRITE_BUFFER_SIZE) == responseBytes.length ? j + WRITE_BUFFER_SIZE : j + (responseBytes.length - j);
                    chunk = Arrays.copyOfRange(responseBytes, j, lastPos);
                } else {
                    chunk = Arrays.copyOfRange(responseBytes, j, j + WRITE_BUFFER_SIZE);
                }
                buf.put(chunk);
                buf.flip();
                buffers.add(buf);
            }
        } else {
            ByteBuffer buf = ByteBuffer.allocate(WRITE_BUFFER_SIZE);
            buf.put(responseText.getBytes());
            buf.flip();
            buffers.add(buf);
        }
    }
    
    private void handleResponse(final SelectionKey selectionKey) {
        if (!buffers.isEmpty()) {
            SocketChannel channel = (SocketChannel) selectionKey.channel();
            try {
                for (ByteBuffer buf : buffers) {
                    channel.write(buf);
                }
                
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                buffers.clear();
                closeChannel(channel);
            }

        } else {
            // ここに到達することはあり得ない?
            // ??? とりあえずChannelをクローズ?
            closeChannel((SocketChannel) selectionKey.channel());
        }
    }
    
    private void closeChannel(SocketChannel channel) {
        try {
            if (channel != null && channel.isOpen()) {
                channel.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

}

【SimpleHttpServer】

import java.io.IOException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.util.Iterator;

public class SimpleHttpServer {
       
    private static final long TIMEOUT = 10000;

    private ServerSocketChannel server;
    private Selector selector;
    private volatile boolean stop = false;
    
    public SimpleHttpServer(ServerSocketChannel server, Selector selector) {
        this.server = server;
        this.selector = selector;
    }

    public void start() {
        System.out.println("Start " + this.getClass().getSimpleName());
        while (!stop) {
            try {
                int ret = selector.select(TIMEOUT);
                if (ret > 0) {
                    for (Iterator<SelectionKey> it = selector.selectedKeys().iterator(); it.hasNext();) {
                        SelectionKey key = it.next();
                        it.remove();
                        if (key.isValid()) {
                            Handler handler = (Handler) key.attachment();
                            if (handler != null) {
                                handler.handle(key);
                            }
                        }
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
                break;
            }
        }
        for (SelectionKey key : selector.keys()) {
            try {
                if (key.channel().isOpen()) {
                    key.channel().close();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        
        if (selector != null && selector.isOpen()) {
            try {
                selector.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        if (server != null && server.isOpen()) {
            try {
                server.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        System.out.println("See you.");
    }
    
    public void stop() {
        stop = true;
        selector.wakeup();
        System.out.println("Require to stop " + this.getClass().getSimpleName());
    }
}

【Main】

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;

public class Main {

    public static final String HOST_NAME = "localhost";
       
    public static final int PORT = 8081;
    
    public static void main(String[] args) {
        System.out.println("@@@ Start the simple http server.");

        ServerSocketChannel serverSocketChannel = null;
        Selector selector = null;
        
        try {
            // チャネルオープン
            serverSocketChannel = ServerSocketChannel.open();
            serverSocketChannel.socket().bind(new InetSocketAddress(HOST_NAME, PORT), 10);
            serverSocketChannel.configureBlocking(false);
            selector = Selector.open();
            // セレクタと選択キーの登録
            serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT, new AcceptHandler());
        } catch (IOException e) {
            e.printStackTrace();
            if (serverSocketChannel != null && serverSocketChannel.isOpen()) {
                try {
                    serverSocketChannel.close();
                } catch (IOException e1) {
                    e1.printStackTrace();
                }
            }
            if (selector != null && selector.isOpen()) {
                try {
                    selector.close();
                } catch (IOException e1) {
                    e1.printStackTrace();
                }
            }
            System.exit(1);
        }
        
        final SimpleHttpServer httpServer = new SimpleHttpServer(serverSocketChannel, selector);
        Runtime.getRuntime().addShutdownHook(new Thread() {
            @Override
            public void run() {
                httpServer.stop();
                System.out.println("@@@ Stop the simple http server.");
            }
        });
        
        httpServer.start();
        
    }

}

とりあえずブラウザからlocalhost:8081にアクセスすると生成された乱数が表示されます。

改めて、NIOでソケットを用いたプログラムを作成する場合、この書き方で合ってるのだろうか・・・。

「そうじゃねー」という突っ込みがあればよろしくお願いします。