#include "distrib.h" #include #include #include #include #include #include #include /* gethostbyname() */ #include /* memset() */ #include /* fcntl(), F_GETFD, F_SETFD, FD_CLOEXEC */ #include #include #include using namespace std; #define MSG_WANT_DATA 1 #define MSG_SEND_DATA 2 distrib::distrib() { pthread_cond_init(&m_listen_cond, NULL); pthread_mutex_init(&m_listen_mutex, NULL); pthread_mutex_init(&m_task_mutex, NULL); m_num_clients = 0; m_data = NULL; m_server = true; m_next_task = 0; } int distrib::readHostFile(const char * filename) { ifstream ifs(filename); if ( ! ifs.is_open() ) return 1; string host; while ( ! ifs.eof() ) { ifs >> host; if ( ifs.eof() ) break; m_hosts.push_back(host); } ifs.close(); return 0; } int distrib::startClients(const std::vector & client_options) { int ret = 0; for (int i = 0, sz = m_hosts.size(); i < sz; i++) { ret += clientConnect(m_hosts[i], client_options); } return ret; } int distrib::clientConnect(const string & host, const std::vector & client_options) { int id = fork(); if (id < 0) /* check for fork() error */ { cerr << "Error forking: " << id << endl; return 1; } else if (id > 0) /* in the parent */ { m_children.push_back(id); m_num_clients++; } else /* in the child */ { char server_port_str[15]; sprintf(server_port_str, "%d", m_serverport); vector args; args.push_back("ssh"); args.push_back(host); args.push_back("fart"); args.push_back("--host"); args.push_back(m_servername); args.push_back("--port"); args.push_back(server_port_str); for (int i = 0, sz = client_options.size(); i < sz; i++) args.push_back(client_options[i]); const char * char_star_args[args.size() + 1]; for (int i = 0, sz = args.size(); i < sz; i++) char_star_args[i] = args[i].c_str(); char_star_args[args.size()] = (char *) NULL; #if 0 /* debug */ cout << "executing: 'ssh', "; for (int i = 0, sz = args.size(); i < sz; i++) cout << "'" << char_star_args[i] << "', "; cout << endl; #endif execvp("ssh", (char * const *) char_star_args); /* we should not get here */ cerr << "Error " << errno << " with execlp()!" << endl; exit(33); } return 0; } void connection_thread(distrib::connection_thread_arg_t * arg) { distrib * the_distrib = arg->the_distrib; int client_socket = arg->client_socket; delete arg; bool done = false; /* loop listening for messages from the client */ while (!done) { int msg_type; int nread = read(client_socket, &msg_type, sizeof(msg_type)); if (nread < 0) break; switch (msg_type) { case MSG_WANT_DATA: { int task = the_distrib->getTask(); if (write(client_socket, &task, sizeof(task)) < 0) done = true; } break; case MSG_SEND_DATA: { unsigned char data[3 * UNIT_TASK_SIZE]; int task; if (read(client_socket, &task, sizeof(task)) < 0) done = true; else if (read(client_socket, &data[0], sizeof(data)) < 0) done = true; else the_distrib->send_data(task, &data[0], sizeof(data)); } break; default: break; } } close(client_socket); } void distrib_server(distrib * the_distrib) { char hostname[1000]; gethostname(&hostname[0], 1000); the_distrib->m_servername = hostname; int listen_socket = socket(PF_INET, SOCK_STREAM, 0); if ( listen_socket == -1 ) { cerr << "Error " << errno << " creating listen socket!" << endl; exit(39); } int flags = fcntl(listen_socket, F_GETFD); flags |= FD_CLOEXEC; fcntl(listen_socket, F_SETFD, flags); if ( listen(listen_socket, 5) == -1 ) { cerr << "Error " << errno << " when trying to listen!" << endl; exit(40); } struct sockaddr_in addr; int addr_len = sizeof(struct sockaddr_in); getsockname(listen_socket, (struct sockaddr *) &addr, (socklen_t *) &addr_len); int ip_addr = ntohl(addr.sin_addr.s_addr); the_distrib->m_serverport = ntohs(addr.sin_port); cout << "Listening on " << (unsigned int) ((ip_addr >> 24) & 0xFF) << '.' << (unsigned int) ((ip_addr >> 16) & 0xFF) << '.' << (unsigned int) ((ip_addr >> 8) & 0xFF) << '.' << (unsigned int) (ip_addr & 0xFF) << ':' << the_distrib->m_serverport << endl; /* signal readiness of the listen thread */ pthread_mutex_lock(&the_distrib->m_listen_mutex); pthread_cond_signal(&the_distrib->m_listen_cond); pthread_mutex_unlock(&the_distrib->m_listen_mutex); for (;;) { struct sockaddr_in client_addr; socklen_t client_addr_len = sizeof(client_addr); int client_socket = accept(listen_socket, (struct sockaddr *) &client_addr, &client_addr_len); if (client_socket < 0) break; distrib::connection_thread_arg_t * arg = new distrib::connection_thread_arg_t; arg->the_distrib = the_distrib; arg->client_socket = client_socket; pthread_t client_thread; pthread_create(&client_thread, NULL, (void * (*)(void *)) &connection_thread, arg); } } int distrib::startServer() { m_server = true; pthread_mutex_lock(&m_listen_mutex); /* start the listen thread */ int ret = pthread_create(&m_server_thread, NULL, (void * (*)(void *)) distrib_server, this); if (ret) return ret; /* wait for the listen thread to be running */ pthread_cond_wait(&m_listen_cond, &m_listen_mutex); pthread_mutex_unlock(&m_listen_mutex); return ret; } int distrib::startClient(const char * server, int port) { m_server = false; m_client_socket = socket(PF_INET, SOCK_STREAM, 0); if (m_client_socket < 0) { cerr << "Error creating client socket: " << errno << endl; return 1; } struct addrinfo hint; memset(&hint, 0, sizeof(hint)); hint.ai_family = AF_INET; hint.ai_socktype = SOCK_STREAM; struct addrinfo * res; char portstr[15]; sprintf(portstr, "%d", port); getaddrinfo(server, portstr, &hint, &res); if (connect(m_client_socket, res->ai_addr, res->ai_addrlen) == -1) { cerr << "Error connecting from client socket: " << errno << endl; return 2; } return 0; } int distrib::send_data(int task, unsigned char * data, int num_bytes) { if (m_server) { if (m_data != NULL) { int num_to_copy = num_bytes; if (3 * task * UNIT_TASK_SIZE + num_to_copy > m_data_size) num_to_copy = m_data_size - 3 * task * UNIT_TASK_SIZE; if (num_to_copy > 0) { memcpy(m_data + 3 * task * UNIT_TASK_SIZE, data, num_to_copy); } } } else { int msg_header = MSG_SEND_DATA; /* send data */ if ( write(m_client_socket, &msg_header, sizeof(msg_header)) < 0 || write(m_client_socket, &task, sizeof(task)) < 0 || write(m_client_socket, data, num_bytes) < 0) { return -1; } } return 0; } int distrib::getTask() { if (m_server) { pthread_mutex_lock(&m_task_mutex); int task = -1; if (m_next_task < m_num_tasks) { task = m_next_task; m_next_task++; } pthread_mutex_unlock(&m_task_mutex); return task; } else { int msg_header = MSG_WANT_DATA; if (write(m_client_socket, &msg_header, sizeof(msg_header)) < 0) return -1; /* wait for a message back */ int task = 0; if (read(m_client_socket, &task, sizeof(task)) < 0) return -1; return task; } }