diff options
Diffstat (limited to 'src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp')
-rw-r--r-- | src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp | 85 |
1 files changed, 51 insertions, 34 deletions
diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp index 9857b8fb..595a0091 100644 --- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp +++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp @@ -87,12 +87,6 @@ TLSSocket_OpenSSL::~TLSSocket_OpenSSL() try { disconnect(); - - if (m_ssl) - { - SSL_free(m_ssl); - m_ssl = 0; - } } catch (...) { @@ -130,32 +124,41 @@ void TLSSocket_OpenSSL::createSSLHandle() void TLSSocket_OpenSSL::connect(const string& address, const port_t port) { - m_wrapped->connect(address, port); - - createSSLHandle(); + try + { + m_wrapped->connect(address, port); - handshake(null); + createSSLHandle(); - m_connected = true; + handshake(); + } + catch (...) + { + disconnect(); + throw; + } } void TLSSocket_OpenSSL::disconnect() { - if (m_connected) + if (m_ssl) { - if (m_ssl) - { - // Don't shut down the socket more than once. - int shutdownState = SSL_get_shutdown(m_ssl); - bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; + // Don't shut down the socket more than once. + int shutdownState = SSL_get_shutdown(m_ssl); + bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; - if (!shutdownSent) - SSL_shutdown(m_ssl); - } + if (!shutdownSent) + SSL_shutdown(m_ssl); - m_wrapped->disconnect(); + SSL_free(m_ssl); + m_ssl = 0; + } + + if (m_connected) + { m_connected = false; + m_wrapped->disconnect(); } } @@ -184,6 +187,12 @@ const string TLSSocket_OpenSSL::getPeerAddress() const } +shared_ptr <timeoutHandler> TLSSocket_OpenSSL::getTimeoutHandler() +{ + return m_wrapped->getTimeoutHandler(); +} + + void TLSSocket_OpenSSL::receive(string& buffer) { const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); @@ -209,6 +218,9 @@ void TLSSocket_OpenSSL::send(const char* str) size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; int rc = SSL_read(m_ssl, buffer, static_cast <int>(count)); @@ -235,6 +247,9 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; for (size_t size = count ; size > 0 ; ) @@ -264,6 +279,9 @@ void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; int rc = SSL_write(m_ssl, buffer, static_cast <int>(count)); @@ -288,14 +306,17 @@ size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t } -void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler) +void TLSSocket_OpenSSL::handshake() { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + + shared_ptr <timeoutHandler> toHandler = m_wrapped->getTimeoutHandler(); + if (toHandler) toHandler->resetTimeOut(); // Start handshaking process - m_toHandler = toHandler; - if (!m_ssl) createSSLHandle(); @@ -318,25 +339,20 @@ void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler) } // Check whether the time-out delay is elapsed - if (m_toHandler && m_toHandler->isTimeOut()) + if (toHandler && toHandler->isTimeOut()) { - if (!m_toHandler->handleTimeOut()) + if (!toHandler->handleTimeOut()) throw exceptions::operation_timed_out(); - m_toHandler->resetTimeOut(); + toHandler->resetTimeOut(); } } } catch (...) { - SSL_free(m_ssl); - m_ssl = 0; - m_toHandler = null; throw; } - m_toHandler = null; - // Verify server's certificate(s) shared_ptr <security::cert::certificateChain> certs = getPeerCertificates(); @@ -401,6 +417,8 @@ void TLSSocket_OpenSSL::handleError(int rc) switch (sslError) { case SSL_ERROR_ZERO_RETURN: + + disconnect(); return; case SSL_ERROR_SYSCALL: @@ -413,8 +431,7 @@ void TLSSocket_OpenSSL::handleError(int rc) } else { - vmime::string msg; - std::ostringstream oss(msg); + std::ostringstream oss; oss << "The BIO reported an error: " << rc; oss.flush(); throw exceptions::tls_exception(oss.str()); |