diff options
Diffstat (limited to 'src/vmime/net/tls')
-rw-r--r-- | src/vmime/net/tls/TLSSocket.hpp | 2 | ||||
-rw-r--r-- | src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp | 36 | ||||
-rw-r--r-- | src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp | 5 | ||||
-rw-r--r-- | src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp | 85 | ||||
-rw-r--r-- | src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp | 6 |
5 files changed, 82 insertions, 52 deletions
diff --git a/src/vmime/net/tls/TLSSocket.hpp b/src/vmime/net/tls/TLSSocket.hpp index e2668ad4..ec3a83ef 100644 --- a/src/vmime/net/tls/TLSSocket.hpp +++ b/src/vmime/net/tls/TLSSocket.hpp @@ -67,7 +67,7 @@ public: * during the negociation process, exceptions::operation_timed_out * if a time-out occurs */ - virtual void handshake(shared_ptr <timeoutHandler> toHandler = null) = 0; + virtual void handshake() = 0; /** Return the peer's certificate (chain) as sent by the peer. * diff --git a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp index edc8811a..13b7eb24 100644 --- a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp +++ b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp @@ -87,11 +87,17 @@ TLSSocket_GnuTLS::~TLSSocket_GnuTLS() void TLSSocket_GnuTLS::connect(const string& address, const port_t port) { - m_wrapped->connect(address, port); - - handshake(null); + try + { + m_wrapped->connect(address, port); - m_connected = true; + handshake(); + } + catch (...) + { + disconnect(); + throw; + } } @@ -132,6 +138,12 @@ const string TLSSocket_GnuTLS::getPeerAddress() const } +shared_ptr <timeoutHandler> TLSSocket_GnuTLS::getTimeoutHandler() +{ + return m_wrapped->getTimeoutHandler(); +} + + void TLSSocket_GnuTLS::receive(string& buffer) { const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); @@ -239,14 +251,15 @@ unsigned int TLSSocket_GnuTLS::getStatus() const } -void TLSSocket_GnuTLS::handshake(shared_ptr <timeoutHandler> toHandler) +void TLSSocket_GnuTLS::handshake() { + shared_ptr <timeoutHandler> toHandler = m_wrapped->getTimeoutHandler(); + if (toHandler) toHandler->resetTimeOut(); // Start handshaking process m_handshaking = true; - m_toHandler = toHandler; try { @@ -280,13 +293,10 @@ void TLSSocket_GnuTLS::handshake(shared_ptr <timeoutHandler> toHandler) catch (...) { m_handshaking = false; - m_toHandler = null; - throw; } m_handshaking = false; - m_toHandler = null; // Verify server's certificate(s) shared_ptr <security::cert::certificateChain> certs = getPeerCertificates(); @@ -338,6 +348,8 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc // returns -1 and errno is set to EGAIN... if (sok->m_handshaking) { + shared_ptr <timeoutHandler> toHandler = sok->m_wrapped->getTimeoutHandler(); + while (true) { const ssize_t ret = static_cast <ssize_t> @@ -355,12 +367,12 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc } // Check whether the time-out delay is elapsed - if (sok->m_toHandler && sok->m_toHandler->isTimeOut()) + if (toHandler && toHandler->isTimeOut()) { - if (!sok->m_toHandler->handleTimeOut()) + if (!toHandler->handleTimeOut()) throw exceptions::operation_timed_out(); - sok->m_toHandler->resetTimeOut(); + toHandler->resetTimeOut(); } } } diff --git a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp index 885fac13..ddba9d0e 100644 --- a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp +++ b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp @@ -54,7 +54,7 @@ public: ~TLSSocket_GnuTLS(); - void handshake(shared_ptr <timeoutHandler> toHandler = null); + void handshake(); shared_ptr <security::cert::certificateChain> getPeerCertificates() const; @@ -78,6 +78,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr <timeoutHandler> getTimeoutHandler(); + private: void internalThrow(); @@ -99,7 +101,6 @@ private: byte_t m_buffer[65536]; bool m_handshaking; - shared_ptr <timeoutHandler> m_toHandler; exception* m_ex; diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp index 9857b8fb..595a0091 100644 --- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp +++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp @@ -87,12 +87,6 @@ TLSSocket_OpenSSL::~TLSSocket_OpenSSL() try { disconnect(); - - if (m_ssl) - { - SSL_free(m_ssl); - m_ssl = 0; - } } catch (...) { @@ -130,32 +124,41 @@ void TLSSocket_OpenSSL::createSSLHandle() void TLSSocket_OpenSSL::connect(const string& address, const port_t port) { - m_wrapped->connect(address, port); - - createSSLHandle(); + try + { + m_wrapped->connect(address, port); - handshake(null); + createSSLHandle(); - m_connected = true; + handshake(); + } + catch (...) + { + disconnect(); + throw; + } } void TLSSocket_OpenSSL::disconnect() { - if (m_connected) + if (m_ssl) { - if (m_ssl) - { - // Don't shut down the socket more than once. - int shutdownState = SSL_get_shutdown(m_ssl); - bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; + // Don't shut down the socket more than once. + int shutdownState = SSL_get_shutdown(m_ssl); + bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; - if (!shutdownSent) - SSL_shutdown(m_ssl); - } + if (!shutdownSent) + SSL_shutdown(m_ssl); - m_wrapped->disconnect(); + SSL_free(m_ssl); + m_ssl = 0; + } + + if (m_connected) + { m_connected = false; + m_wrapped->disconnect(); } } @@ -184,6 +187,12 @@ const string TLSSocket_OpenSSL::getPeerAddress() const } +shared_ptr <timeoutHandler> TLSSocket_OpenSSL::getTimeoutHandler() +{ + return m_wrapped->getTimeoutHandler(); +} + + void TLSSocket_OpenSSL::receive(string& buffer) { const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); @@ -209,6 +218,9 @@ void TLSSocket_OpenSSL::send(const char* str) size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; int rc = SSL_read(m_ssl, buffer, static_cast <int>(count)); @@ -235,6 +247,9 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; for (size_t size = count ; size > 0 ; ) @@ -264,6 +279,9 @@ void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; int rc = SSL_write(m_ssl, buffer, static_cast <int>(count)); @@ -288,14 +306,17 @@ size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t } -void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler) +void TLSSocket_OpenSSL::handshake() { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + + shared_ptr <timeoutHandler> toHandler = m_wrapped->getTimeoutHandler(); + if (toHandler) toHandler->resetTimeOut(); // Start handshaking process - m_toHandler = toHandler; - if (!m_ssl) createSSLHandle(); @@ -318,25 +339,20 @@ void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler) } // Check whether the time-out delay is elapsed - if (m_toHandler && m_toHandler->isTimeOut()) + if (toHandler && toHandler->isTimeOut()) { - if (!m_toHandler->handleTimeOut()) + if (!toHandler->handleTimeOut()) throw exceptions::operation_timed_out(); - m_toHandler->resetTimeOut(); + toHandler->resetTimeOut(); } } } catch (...) { - SSL_free(m_ssl); - m_ssl = 0; - m_toHandler = null; throw; } - m_toHandler = null; - // Verify server's certificate(s) shared_ptr <security::cert::certificateChain> certs = getPeerCertificates(); @@ -401,6 +417,8 @@ void TLSSocket_OpenSSL::handleError(int rc) switch (sslError) { case SSL_ERROR_ZERO_RETURN: + + disconnect(); return; case SSL_ERROR_SYSCALL: @@ -413,8 +431,7 @@ void TLSSocket_OpenSSL::handleError(int rc) } else { - vmime::string msg; - std::ostringstream oss(msg); + std::ostringstream oss; oss << "The BIO reported an error: " << rc; oss.flush(); throw exceptions::tls_exception(oss.str()); diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp index 410fffcf..5fbed19d 100644 --- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp +++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp @@ -58,7 +58,7 @@ public: ~TLSSocket_OpenSSL(); - void handshake(shared_ptr <timeoutHandler> toHandler = null); + void handshake(); shared_ptr <security::cert::certificateChain> getPeerCertificates() const; @@ -82,6 +82,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr <timeoutHandler> getTimeoutHandler(); + private: static BIO_METHOD sm_customBIOMethod; @@ -108,8 +110,6 @@ private: byte_t m_buffer[65536]; - shared_ptr <timeoutHandler> m_toHandler; - SSL* m_ssl; unsigned long m_status; |