00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include <iostream>
00024 #include <fstream>
00025 #include <sys/types.h>
00026 #include "socket.hh"
00027
00028 namespace Network
00029 {
00030
00031 Socket::Socket(SOCKET_KIND kind, SOCKET_VERSION version) :
00032 _kind(kind), _version(version), _state_timeout(0),
00033 _socket(0), _recv_flags(kind), _proto_kind(text), _empty_lines(false),
00034 _buffer(""), _tls(false)
00035 {
00036 _delim.push_back("\0");
00037 #ifdef LIBSOCKET_WIN
00038 WSADATA wsadata;
00039 if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0)
00040 throw WSAStartupError("WSAStartup failed", HERE);
00041 #endif
00042 #ifndef IPV6_ENABLED
00043 if (version == V6)
00044 throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE);
00045 #endif
00046 }
00047
00048 Socket::Socket(SOCKET_KIND kind, PROTO_KIND pkind, SOCKET_VERSION version) :
00049 _kind(kind), _version(version), _state_timeout(0),
00050 _socket(0), _recv_flags(kind), _proto_kind(pkind), _empty_lines(false),
00051 _buffer(""), _tls(false)
00052 {
00053 _delim.push_back("\0");
00054 #ifdef LIBSOCKET_WIN
00055 WSADATA wsadata;
00056 if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0)
00057 throw WSAStartupError("WSAStartup failed", HERE);
00058 #endif
00059 #ifndef IPV6_ENABLED
00060 if (version == V6)
00061 throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE);
00062 #endif
00063 }
00064
00065 Socket::~Socket()
00066 {
00067 }
00068
00069 void Socket::enable_tls()
00070 {
00071 #ifdef TLS
00072 int ret;
00073
00074 if (_kind != TCP)
00075 throw TLSError("You need to have a TCP connection", HERE);
00076 if (!connected())
00077 throw NoConnection("You need to have a connection", HERE);
00078
00079 gnutls_transport_set_ptr(_session, (gnutls_transport_ptr)_socket);
00080 ret = gnutls_handshake(_session);
00081 if (ret < 0)
00082 {
00083 close(_socket);
00084 gnutls_deinit(_session);
00085 throw TLSError(gnutls_strerror(ret), HERE);
00086 }
00087 #else
00088 throw TLSSupportError("lib was not compiled with TLS support", HERE);
00089 #endif
00090 }
00091
00092 void Socket::init_tls(GnuTLSKind kind,
00093 unsigned size, const std::string &certfile,
00094 const std::string &keyfile,
00095 const std::string &trustfile,
00096 const std::string &crlfile)
00097 {
00098 #ifdef TLS
00099 static bool init = false;
00100 static gnutls_dh_params dh_params;
00101 const int protocol_tls[] = { GNUTLS_TLS1, 0 };
00102 const int protocol_ssl[] = { GNUTLS_SSL3, 0 };
00103 const int cert_type_priority[] = { GNUTLS_CRT_X509,
00104 GNUTLS_CRT_OPENPGP, 0 };
00105
00106 if (!init)
00107 {
00108 gnutls_global_init();
00109 init = true;
00110 }
00111 _tls = true;
00112 _tls_main = true;
00113 gnutls_certificate_allocate_credentials(&_x509_cred);
00114 if (keyfile.size() > 0 && certfile.size() > 0)
00115 {
00116 std::ifstream key(keyfile.c_str()), cert(certfile.c_str());
00117 if (!key.is_open() || !cert.is_open())
00118 throw InvalidFile("key or cert invalid", HERE);
00119 key.close();
00120 cert.close();
00121
00122 _nbbits = size;
00123 if (trustfile.size() > 0)
00124 gnutls_certificate_set_x509_trust_file(_x509_cred, trustfile.c_str(),
00125 GNUTLS_X509_FMT_PEM);
00126 if (crlfile.size() > 0)
00127 gnutls_certificate_set_x509_crl_file(_x509_cred, crlfile.c_str(),
00128 GNUTLS_X509_FMT_PEM);
00129 gnutls_certificate_set_x509_key_file(_x509_cred, certfile.c_str(),
00130 keyfile.c_str(),
00131 GNUTLS_X509_FMT_PEM);
00132 gnutls_dh_params_init(&dh_params);
00133 gnutls_dh_params_generate2(dh_params, _nbbits);
00134 gnutls_certificate_set_dh_params(_x509_cred, dh_params);
00135
00136 if (gnutls_init(&_session, GNUTLS_SERVER))
00137 throw TLSError("gnutls_init failed", HERE);
00138 }
00139 else
00140 {
00141 if (gnutls_init(&_session, GNUTLS_CLIENT))
00142 throw TLSError("gnutls_init failed", HERE);
00143 }
00144
00145 gnutls_set_default_priority(_session);
00146 if (kind == TLS)
00147 gnutls_protocol_set_priority(_session, protocol_tls);
00148 else
00149 gnutls_protocol_set_priority(_session, protocol_ssl);
00150
00151 if (keyfile.size() > 0 && certfile.size() > 0)
00152 {
00153 gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred);
00154 gnutls_certificate_server_set_request(_session, GNUTLS_CERT_REQUEST);
00155 gnutls_dh_set_prime_bits(_session, _nbbits);
00156 }
00157 else
00158 {
00159 gnutls_certificate_type_set_priority(_session, cert_type_priority);
00160 gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred);
00161 }
00162 #else
00163 throw TLSSupportError("lib was not compiled with TLS support", HERE);
00164 #endif
00165 }
00166
00167 void Socket::_close(int socket) const
00168 {
00169 #ifndef LIBSOCKET_WIN
00170 if (socket < 0 || close(socket) < 0)
00171 throw CloseError("Close Error", HERE);
00172 socket = 0;
00173 #else
00174 if (socket < 0 || closesocket(socket) < 0)
00175 throw CloseError("Close Error", HERE);
00176 socket = 0;
00177 #endif
00178 #ifdef TLS
00179 if (_tls)
00180 {
00181 std::cout << "Deletion..." << std::endl;
00182 gnutls_deinit(_session);
00183 if (_tls_main)
00184 {
00185 gnutls_certificate_free_credentials(_x509_cred);
00186 gnutls_global_deinit();
00187 }
00188 }
00189 #endif
00190 }
00191
00192 void Socket::_listen(int socket) const
00193 {
00194 if (socket < 0 || listen(socket, 5) < 0)
00195 throw ListenError("Listen Error", HERE);
00196 }
00197
00198 void Socket::_write_str(int socket, const std::string& str) const
00199 {
00200 int res = 1;
00201 unsigned int count = 0;
00202 const char *buf;
00203
00204 buf = str.c_str();
00205 if (socket < 0)
00206 throw NoConnection("No Socket", HERE);
00207 while (res && count < str.size())
00208 {
00209 #ifdef IPV6_ENABLED
00210 if (V4 == _version)
00211 #endif
00212 #ifdef TLS
00213 if (_tls)
00214 res = gnutls_record_send(_session, buf + count, str.size() - count);
00215 else
00216 #endif
00217 res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS,
00218 (const struct sockaddr*)&_addr, sizeof(_addr));
00219 #ifdef IPV6_ENABLED
00220 else
00221 res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS,
00222 (const struct sockaddr*)&_addr6, sizeof(_addr6));
00223 #endif
00224 if (res <= 0)
00225 throw ConnectionClosed("Connection Closed", HERE);
00226 count += res;
00227 }
00228 }
00229
00230 void Socket::_write_str_bin(int socket, const std::string& str) const
00231 {
00232 int res = 1;
00233 unsigned int count = 0;
00234 #ifdef LIBSOCKET_WIN
00235 char* buf = new char[str.size() + 2];
00236 #else
00237 char buf[str.size() + 2];
00238 #endif
00239 buf[0] = str.size() / 256;
00240 buf[1] = str.size() % 256;
00241 memcpy(buf + 2, str.c_str(), str.size());
00242 if (socket < 0)
00243 throw NoConnection("No Socket", HERE);
00244 while (res && count < str.size() + 2)
00245 {
00246 #ifdef IPV6_ENABLED
00247 if (V4 == _version)
00248 #endif
00249 #ifdef TLS
00250 if (_tls)
00251 res = gnutls_record_send(_session, buf + count, str.size() + 2 - count);
00252 else
00253 #endif
00254 res = sendto(socket, buf + count, str.size() + 2 - count,
00255 SENDTO_FLAGS,
00256 (const struct sockaddr*)&_addr, sizeof(_addr));
00257 #ifdef IPV6_ENABLED
00258 else
00259 res = sendto(socket, buf + count, str.size() + 2 - count,
00260 \ SENDTO_FLAGS,
00261 (const struct sockaddr*)&_addr6, sizeof(_addr6));
00262 #endif
00263 if (res <= 0)
00264 throw ConnectionClosed("Connection Closed", HERE);
00265 count += res;
00266 }
00267 #ifdef LIBSOCKET_WIN
00268 delete[] buf;
00269 #endif
00270 }
00271
00272 void Socket::_set_timeout(bool enable, int socket, int timeout)
00273 {
00274 fd_set fdset;
00275 struct timeval timetowait;
00276 int res;
00277
00278 if (enable)
00279 timetowait.tv_sec = timeout;
00280 else
00281 timetowait.tv_sec = 65535;
00282 timetowait.tv_usec = 0;
00283 FD_ZERO(&fdset);
00284 FD_SET(socket, &fdset);
00285 if (enable)
00286 res = select(socket + 1, &fdset, NULL, NULL, &timetowait);
00287 else
00288 res = select(socket + 1, &fdset, NULL, NULL, NULL);
00289 if (res < 0)
00290 throw SelectError("Select error", HERE);
00291 if (res == 0)
00292 throw Timeout("Timeout on socket", HERE);
00293 }
00294
00295 void Socket::write(const std::string& str)
00296 {
00297 if (_proto_kind == binary)
00298 _write_str_bin(_socket, str);
00299 else
00300 _write_str(_socket, str);
00301 }
00302
00303 bool Socket::connected() const
00304 {
00305 return _socket != 0;
00306 }
00307
00308 void Socket::allow_empty_lines()
00309 {
00310 _empty_lines = true;
00311 }
00312
00313 int Socket::get_socket()
00314 {
00315 return _socket;
00316 }
00317
00318 void Socket::add_delim(const std::string& delim)
00319 {
00320 _delim.push_back(delim);
00321 }
00322
00323 void Socket::del_delim(const std::string& delim)
00324 {
00325 std::list<std::string>::iterator it, it2;
00326
00327 for (it = _delim.begin(); it != _delim.end(); )
00328 {
00329 if (*it == delim)
00330 {
00331 it2 = it++;
00332 _delim.erase(it2);
00333 }
00334 else
00335 it++;
00336 }
00337 }
00338
00339 std::pair<int, int> Socket::_find_delim(const std::string& str, int start) const
00340 {
00341 int i = -1;
00342 int pos = -1, size = 0;
00343 std::list<std::string>::const_iterator it;
00344
00345
00346 if (_delim.size() > 0)
00347 {
00348 it = _delim.begin();
00349 while (it != _delim.end())
00350 {
00351 if (*it == "")
00352 i = str.find('\0', start);
00353 else
00354 i = str.find(*it, start);
00355 if ((i >= 0) && ((unsigned int)i < str.size()) &&
00356 (pos < 0 || i < pos))
00357 {
00358 pos = i;
00359 size = it->size() ? it->size() : 1;
00360 }
00361 it++;
00362 }
00363 }
00364 return std::pair<int, int>(pos, size);
00365 }
00366
00367 Socket& operator<<(Socket& s, const std::string& str)
00368 {
00369 s.write(str);
00370 return s;
00371 }
00372
00373 Socket& operator>>(Socket& s, std::string& str)
00374 {
00375 str = s.read();
00376 return s;
00377 }
00378 }