diff --git a/distrib/distrib.cc b/distrib/distrib.cc index 3b9b8ab..d2a12c9 100644 --- a/distrib/distrib.cc +++ b/distrib/distrib.cc @@ -12,6 +12,7 @@ #include #include #include +#include /* TCP_NODELAY */ using namespace std; #define MSG_WANT_DATA 1 @@ -20,14 +21,17 @@ using namespace std; distrib::distrib() { pthread_cond_init(&m_listen_cond, NULL); + pthread_cond_init(&m_tasks_complete_cond, NULL); pthread_mutex_init(&m_listen_mutex, NULL); pthread_mutex_init(&m_task_mutex, NULL); pthread_mutex_init(&m_tasks_in_progress_mutex, NULL); + pthread_mutex_init(&m_tasks_complete_mutex, NULL); m_num_clients = 0; m_data = NULL; m_server = true; m_next_task = 0; m_client_socket = -1; + m_tasks_complete = 0; } distrib::~distrib() @@ -123,40 +127,61 @@ void connection_thread(distrib::connection_thread_arg_t * arg) int client_socket = arg->client_socket; delete arg; + int flag = 1; + if (setsockopt(client_socket, + IPPROTO_TCP, + TCP_NODELAY, + (char *) &flag, + sizeof(flag)) < 0) + { + cerr << "Failed to set TCP_NODELAY on client socket in connection thread!" << endl; + } + bool done = false; /* loop listening for messages from the client */ while (!done) { int msg_type; - int nread = read(client_socket, &msg_type, sizeof(msg_type)); - if (nread < 0) - break; - switch (msg_type) + size_t nread = read(client_socket, &msg_type, sizeof(msg_type)); + if (nread == sizeof(msg_type)) { - case MSG_WANT_DATA: - { - int task = the_distrib->getTask(); - if (write(client_socket, &task, sizeof(task)) < 0) - done = true; - } - break; - case MSG_SEND_DATA: - { - unsigned char data[3 * UNIT_TASK_SIZE]; - int task; - if (read(client_socket, &task, sizeof(task)) < 0) - done = true; - else if (read(client_socket, &data[0], sizeof(data)) < 0) - done = true; - else - the_distrib->send_data(task, &data[0], sizeof(data)); - } - break; - default: - break; + switch (msg_type) + { + case MSG_WANT_DATA: + { + int task = the_distrib->getTask(); + if (write(client_socket, &task, sizeof(task)) < 0) + done = true; + } + break; + case MSG_SEND_DATA: + { + unsigned char data[3 * UNIT_TASK_SIZE]; + int task; + if (read(client_socket, &task, sizeof(task)) < 0) + done = true; + else if (read(client_socket, &data[0], sizeof(data)) < 0) + done = true; + else + the_distrib->send_data(task, &data[0], sizeof(data)); + } + break; + default: + break; + } + } + else if (nread > 0 && nread < sizeof(msg_type)) + { + cerr << "Error: nread = " << nread << "!" << endl; + } + else if (nread < 0) + { + break; } } close(client_socket); + cerr << "Closing connection thread!" << endl; + pthread_exit(NULL); } void distrib_server(distrib * the_distrib) @@ -218,6 +243,20 @@ void distrib_server(distrib * the_distrib) if (client_socket < 0) break; + int cip = ntohl(client_addr.sin_addr.s_addr); + + cout << "Connection from " + << (unsigned int) ((cip >> 24) & 0xFF) + << '.' + << (unsigned int) ((cip >> 16) & 0xFF) + << '.' + << (unsigned int) ((cip >> 8) & 0xFF) + << '.' + << (unsigned int) (cip & 0xFF) + << ':' + << client_addr.sin_port + << endl; + distrib::connection_thread_arg_t * arg = new distrib::connection_thread_arg_t; arg->the_distrib = the_distrib; @@ -229,6 +268,8 @@ void distrib_server(distrib * the_distrib) (void * (*)(void *)) &connection_thread, arg); } + + cout << "Listen thread exiting!" << endl; } int distrib::startServer() @@ -262,6 +303,16 @@ int distrib::startClient(const char * server, int port) cerr << "Error creating client socket: " << errno << endl; return 1; } + int flag = 1; + if (setsockopt(m_client_socket, + IPPROTO_TCP, + TCP_NODELAY, + (char *) &flag, + sizeof(flag)) < 0) + { + cerr << "Failed to set TCP_NODELAY on client socket!" << endl; + return 2; + } struct addrinfo hint; memset(&hint, 0, sizeof(hint)); @@ -324,7 +375,7 @@ int distrib::getTask() } pthread_mutex_unlock(&m_task_mutex); if (task > -1) - recordTask(task); + startTask(task); } else { @@ -339,7 +390,7 @@ int distrib::getTask() return task; } -void distrib::recordTask(int task) +void distrib::startTask(int task) { pthread_mutex_lock(&m_tasks_in_progress_mutex); m_tasks_in_progress[task] = 1; @@ -350,5 +401,27 @@ void distrib::taskDone(int task) { pthread_mutex_lock(&m_tasks_in_progress_mutex); m_tasks_in_progress.erase(task); + m_tasks_complete++; + if (m_tasks_complete == m_num_tasks) + { + pthread_mutex_lock(&m_tasks_complete_mutex); + pthread_cond_signal(&m_tasks_complete_cond); + pthread_mutex_unlock(&m_tasks_complete_mutex); + } pthread_mutex_unlock(&m_tasks_in_progress_mutex); } + +void distrib::waitAllTasks() +{ + int done; + pthread_mutex_lock(&m_tasks_in_progress_mutex); + done = m_tasks_complete; + if (done < m_num_tasks) + pthread_mutex_lock(&m_tasks_complete_mutex); + pthread_mutex_unlock(&m_tasks_in_progress_mutex); + if (done < m_num_tasks) + { + pthread_cond_wait(&m_tasks_complete_cond, &m_tasks_complete_mutex); + pthread_mutex_unlock(&m_tasks_complete_mutex); + } +} diff --git a/distrib/distrib.h b/distrib/distrib.h index 3afbea2..f87ef97 100644 --- a/distrib/distrib.h +++ b/distrib/distrib.h @@ -7,7 +7,7 @@ #include #include -#define UNIT_TASK_SIZE 50 +#define UNIT_TASK_SIZE 100 class distrib { @@ -28,6 +28,7 @@ class distrib int getTask(); int send_data(int task, unsigned char * data, int num_bytes); int getNumTasksInProgress() { return m_tasks_in_progress.size(); } + void waitAllTasks(); typedef struct { @@ -41,7 +42,7 @@ class distrib protected: int clientConnect(const std::string & host, const std::vector & client_options); - void recordTask(int task); + void startTask(int task); void taskDone(int task); std::vector m_hosts; @@ -55,12 +56,15 @@ class distrib int m_num_clients; unsigned char * m_data; int m_data_size; - int m_num_tasks; + int m_num_tasks; + int m_tasks_complete; int m_next_task; bool m_server; pthread_mutex_t m_task_mutex; std::map m_tasks_in_progress; pthread_mutex_t m_tasks_in_progress_mutex; + pthread_mutex_t m_tasks_complete_mutex; + pthread_cond_t m_tasks_complete_cond; }; #endif diff --git a/main/fart.cc b/main/fart.cc index 05ab7b9..a9571cf 100644 --- a/main/fart.cc +++ b/main/fart.cc @@ -139,37 +139,6 @@ int main(int argc, char * argv[]) struct timeval before, after; gettimeofday(&before, NULL); /* start timing */ -#if 0 -void Scene::taskLoop() -{ - unsigned char data[3 * UNIT_TASK_SIZE]; - for (;;) - { - int task_id = m_distrib.getTask(); - if (task_id < 0) - break; - int pixel = task_id * UNIT_TASK_SIZE; - int i = pixel / m_width; - int j = pixel % m_width; - for (int t = 0; t < UNIT_TASK_SIZE; t++) - { - renderPixel(j, i, &data[3 * t]); - j++; - if (j >= m_width) - { - j = 0; - i++; - if (i >= m_height) - break; - } - } - int ret = m_distrib.send_data(task_id, data, 3 * UNIT_TASK_SIZE); - if (ret != 0) - break; - } -} -#endif - if (distributed) { /* start the distribution infrastructure */ @@ -186,11 +155,41 @@ void Scene::taskLoop() the_distrib.startServer(); the_distrib.startClients(client_options); - /* TODO: wait until all tasks are complete */ + + /* wait until all tasks are complete */ + the_distrib.waitAllTasks(); } else { the_distrib.startClient(server_name, server_port); + + unsigned char data[3 * UNIT_TASK_SIZE]; + for (;;) + { + int task_id = the_distrib.getTask(); + if (task_id < 0) + break; + int pixel = task_id * UNIT_TASK_SIZE; + int i = pixel / width; + int j = pixel % width; + for (int t = 0; t < UNIT_TASK_SIZE; t++) + { + scene.renderPixel(j, i, &data[3 * t]); + j++; + if (j >= width) + { + j = 0; + i++; + if (i >= height) + break; + } + } + int ret = the_distrib.send_data(task_id, + &data[0], + 3 * UNIT_TASK_SIZE); + if (ret != 0) + break; + } } } else