/*
    Copyright (c) 2007-2010 iMatix Corporation

    This file is part of 0MQ.

    0MQ is free software; you can redistribute it and/or modify it under
    the terms of the Lesser GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    0MQ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    Lesser GNU General Public License for more details.

    You should have received a copy of the Lesser GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include <assert.h>

#include <zmq.h>

#include "jzmq.hpp"
#include "util.hpp"
#include "org_zeromq_ZMQ_Poller.h"


static void *fetch_socket (JNIEnv *env,
                           jobject socket);
static int fetch_socket_fd (JNIEnv *env,
                           jobject socket);


JNIEXPORT jlong JNICALL Java_org_zeromq_ZMQ_00024Poller_run_1poll (JNIEnv *env,
                                                                   jobject obj,
                                                                   jint count,
                                                                   jobjectArray socket_0mq,
                                                                   jshortArray event_0mq,
                                                                   jshortArray revent_0mq,
                                                                   jlong timeout)
{
    int ls = (int) count;
    if (ls <= 0)
        return 0;
    
    int ls_0mq = 0;
    int le_0mq = 0;
    int lr_0mq = 0;

    if (socket_0mq)
        ls_0mq = env->GetArrayLength (socket_0mq);
    if (event_0mq)
        le_0mq = env->GetArrayLength (event_0mq);
    if (revent_0mq)
        lr_0mq = env->GetArrayLength (revent_0mq);

    if (ls > ls_0mq || ls > le_0mq || ls > ls_0mq)
        return 0;

    zmq_pollitem_t *pitem = new zmq_pollitem_t [ls];
    short pc = 0;
    int rc = 0;

    // Add 0MQ sockets.  Array containing them can be "sparse": there
    // may be null elements.  The count argument has the real number
    // of valid sockets in the array.
    if (ls_0mq > 0) {
        jshort *e_0mq = env->GetShortArrayElements (event_0mq, 0);
        if (e_0mq != NULL) {
            for (int i = 0; i < ls_0mq; ++i) {
                jobject s_0mq = env->GetObjectArrayElement (socket_0mq, i);
                if (!s_0mq)
                    continue;
                void *s = NULL; 
                int fd = fetch_socket_fd (env, s_0mq);
                if (fd < 0) {
                    raise_exception (env, EINVAL);
                    continue;
                } else if (fd == 0) {
                    s = fetch_socket (env, s_0mq);
                    if ( s == NULL ) {
                        raise_exception (env, EINVAL);
                        continue;
                    }
                }

                pitem [pc].socket = s;
                pitem [pc].fd = fd;
                pitem [pc].events = e_0mq [i];
                pitem [pc].revents = 0;
                ++pc;

                env->DeleteLocalRef (s_0mq);
            }
            env->ReleaseShortArrayElements(event_0mq, e_0mq, 0);
        }
    }

    // Count of non-null sockets must be equal to passed-in arg.
    if (pc == ls) {
        pc = 0;
        long tout = (long) timeout;
        rc = zmq_poll (pitem, ls, tout);
    }

    //  Set 0MQ results.
    if (rc > 0 && ls_0mq > 0) {
        jshort *r_0mq = env->GetShortArrayElements (revent_0mq, 0);
        if (r_0mq) {
            for (int i = 0; i < ls_0mq; ++i) {
                jobject s_0mq = env->GetObjectArrayElement (socket_0mq, i);
                if (!s_0mq)
                    continue;
                r_0mq [i] = pitem [pc].revents;
                ++pc;

                env->DeleteLocalRef (s_0mq);
            }
            env->ReleaseShortArrayElements(revent_0mq, r_0mq, 0);
        }
    }

    delete [] pitem;
    return rc;
}

  
/**
 * Get the value of socketHandle for the specified Java Socket.
 */
static void *fetch_socket (JNIEnv *env,
                           jobject socket)
{
    static jmethodID get_socket_handle_mid = NULL;

    if (get_socket_handle_mid == NULL) {
        jclass cls = env->GetObjectClass (socket);
        assert (cls);
        get_socket_handle_mid = env->GetMethodID (cls,
            "getSocketHandle", "()J");
        env->DeleteLocalRef (cls);
        assert (get_socket_handle_mid);
    }
  
    void *s = (void*) env->CallLongMethod (socket, get_socket_handle_mid);
    if (env->ExceptionCheck ()) {
        s = NULL;
    }
  
    return s;
}

/**
 * Get the file descriptor id of java.net.Socket.
 * returns 0 if socket is not a SelectableChannel
 * returns the file descriptor id or -1 on an error
 */
static int fetch_socket_fd (JNIEnv *env, jobject socket){

    static jclass channel_cls = NULL;
    jclass cls;
    jfieldID fid;
    if (channel_cls == NULL) {
        cls = env->FindClass ("java/nio/channels/SelectableChannel");
        assert (cls);
        channel_cls = (jclass) env->NewGlobalRef (cls);
        env->DeleteLocalRef (cls);
        assert (channel_cls);
    }
    if (!env->IsInstanceOf (socket, channel_cls)) 
        return 0;

    cls = env->GetObjectClass (socket);
    assert (cls);

    fid = env->GetFieldID (cls, "fdVal", "I");
    env->DeleteLocalRef (cls);
    if (fid == NULL)
        return -1;

    /* return the descriptor */
    int fd = env->GetIntField (socket, fid);

    return fd;
}
