From 2afe340b7b4621b81c3bad10ead23ce50ad83baa Mon Sep 17 00:00:00 2001 From: Vincent Richard Date: Sun, 19 Jan 2014 16:36:45 +0100 Subject: [PATCH] In SSL socket, use timeout handler of underlying socket. Throw exception when reading from/writing to disconnected SSL socket. --- src/vmime/exception.cpp | 13 +++ src/vmime/exception.hpp | 17 ++++ src/vmime/net/imap/IMAPConnection.cpp | 2 +- src/vmime/net/pop3/POP3Connection.cpp | 2 +- src/vmime/net/socket.hpp | 6 ++ src/vmime/net/tls/TLSSocket.hpp | 2 +- src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp | 36 +++++--- src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp | 5 +- .../net/tls/openssl/TLSSocket_OpenSSL.cpp | 87 +++++++++++-------- .../net/tls/openssl/TLSSocket_OpenSSL.hpp | 6 +- src/vmime/platforms/posix/posixSocket.cpp | 6 ++ src/vmime/platforms/posix/posixSocket.hpp | 2 + src/vmime/platforms/windows/windowsSocket.cpp | 6 ++ src/vmime/platforms/windows/windowsSocket.hpp | 2 + src/vmime/security/sasl/SASLSocket.cpp | 6 ++ src/vmime/security/sasl/SASLSocket.hpp | 2 + tests/testUtils.cpp | 6 ++ tests/testUtils.hpp | 2 + 18 files changed, 153 insertions(+), 55 deletions(-) diff --git a/src/vmime/exception.cpp b/src/vmime/exception.cpp index 042ac4f4..dff7d497 100644 --- a/src/vmime/exception.cpp +++ b/src/vmime/exception.cpp @@ -327,6 +327,19 @@ exception* socket_exception::clone() const { return new socket_exception(*this); const char* socket_exception::name() const throw() { return "socket_exception"; } +// +// socket_not_connected_exception +// + +socket_not_connected_exception::~socket_not_connected_exception() throw() {} +socket_not_connected_exception::socket_not_connected_exception(const string& what, const exception& other) + : socket_exception(what.empty() + ? "Socket is not connected." : what, other) {} + +exception* socket_not_connected_exception::clone() const { return new socket_not_connected_exception(*this); } +const char* socket_not_connected_exception::name() const throw() { return "socket_not_connected_exception"; } + + // // connection_error // diff --git a/src/vmime/exception.hpp b/src/vmime/exception.hpp index e2afcc62..3c547756 100644 --- a/src/vmime/exception.hpp +++ b/src/vmime/exception.hpp @@ -370,6 +370,23 @@ public: }; +/** Socket not connected: you are trying to write to/read from a socket which + * is not connected to a peer. + */ + +class VMIME_EXPORT socket_not_connected_exception : public socket_exception +{ +public: + + socket_not_connected_exception(const string& what = "", const exception& other = NO_EXCEPTION); + ~socket_not_connected_exception() throw(); + + exception* clone() const; + const char* name() const throw(); + +}; + + /** Error while connecting to the server: this may be a DNS resolution error * or a connection error (for example, time-out while connecting). */ diff --git a/src/vmime/net/imap/IMAPConnection.cpp b/src/vmime/net/imap/IMAPConnection.cpp index bab4e58b..7a90b142 100644 --- a/src/vmime/net/imap/IMAPConnection.cpp +++ b/src/vmime/net/imap/IMAPConnection.cpp @@ -482,7 +482,7 @@ void IMAPConnection::startTLS() shared_ptr tlsSocket = tlsSession->getSocket(m_socket); - tlsSocket->handshake(m_timeoutHandler); + tlsSocket->handshake(); m_socket = tlsSocket; m_parser->setSocket(m_socket); diff --git a/src/vmime/net/pop3/POP3Connection.cpp b/src/vmime/net/pop3/POP3Connection.cpp index 5fa923f4..283cc91b 100644 --- a/src/vmime/net/pop3/POP3Connection.cpp +++ b/src/vmime/net/pop3/POP3Connection.cpp @@ -552,7 +552,7 @@ void POP3Connection::startTLS() shared_ptr tlsSocket = tlsSession->getSocket(m_socket); - tlsSocket->handshake(m_timeoutHandler); + tlsSocket->handshake(); m_socket = tlsSocket; diff --git a/src/vmime/net/socket.hpp b/src/vmime/net/socket.hpp index 537c34bb..7f878a73 100644 --- a/src/vmime/net/socket.hpp +++ b/src/vmime/net/socket.hpp @@ -141,6 +141,12 @@ public: */ virtual const string getPeerAddress() const = 0; + /** Return the timeout handler associated with this socket. + * + * @return timeout handler, or NULL if none is set + */ + virtual shared_ptr getTimeoutHandler() = 0; + protected: socket() { } diff --git a/src/vmime/net/tls/TLSSocket.hpp b/src/vmime/net/tls/TLSSocket.hpp index e2668ad4..ec3a83ef 100644 --- a/src/vmime/net/tls/TLSSocket.hpp +++ b/src/vmime/net/tls/TLSSocket.hpp @@ -67,7 +67,7 @@ public: * during the negociation process, exceptions::operation_timed_out * if a time-out occurs */ - virtual void handshake(shared_ptr toHandler = null) = 0; + virtual void handshake() = 0; /** Return the peer's certificate (chain) as sent by the peer. * diff --git a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp index edc8811a..13b7eb24 100644 --- a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp +++ b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp @@ -87,11 +87,17 @@ TLSSocket_GnuTLS::~TLSSocket_GnuTLS() void TLSSocket_GnuTLS::connect(const string& address, const port_t port) { - m_wrapped->connect(address, port); + try + { + m_wrapped->connect(address, port); - handshake(null); - - m_connected = true; + handshake(); + } + catch (...) + { + disconnect(); + throw; + } } @@ -132,6 +138,12 @@ const string TLSSocket_GnuTLS::getPeerAddress() const } +shared_ptr TLSSocket_GnuTLS::getTimeoutHandler() +{ + return m_wrapped->getTimeoutHandler(); +} + + void TLSSocket_GnuTLS::receive(string& buffer) { const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); @@ -239,14 +251,15 @@ unsigned int TLSSocket_GnuTLS::getStatus() const } -void TLSSocket_GnuTLS::handshake(shared_ptr toHandler) +void TLSSocket_GnuTLS::handshake() { + shared_ptr toHandler = m_wrapped->getTimeoutHandler(); + if (toHandler) toHandler->resetTimeOut(); // Start handshaking process m_handshaking = true; - m_toHandler = toHandler; try { @@ -280,13 +293,10 @@ void TLSSocket_GnuTLS::handshake(shared_ptr toHandler) catch (...) { m_handshaking = false; - m_toHandler = null; - throw; } m_handshaking = false; - m_toHandler = null; // Verify server's certificate(s) shared_ptr certs = getPeerCertificates(); @@ -338,6 +348,8 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc // returns -1 and errno is set to EGAIN... if (sok->m_handshaking) { + shared_ptr toHandler = sok->m_wrapped->getTimeoutHandler(); + while (true) { const ssize_t ret = static_cast @@ -355,12 +367,12 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc } // Check whether the time-out delay is elapsed - if (sok->m_toHandler && sok->m_toHandler->isTimeOut()) + if (toHandler && toHandler->isTimeOut()) { - if (!sok->m_toHandler->handleTimeOut()) + if (!toHandler->handleTimeOut()) throw exceptions::operation_timed_out(); - sok->m_toHandler->resetTimeOut(); + toHandler->resetTimeOut(); } } } diff --git a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp index 885fac13..ddba9d0e 100644 --- a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp +++ b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp @@ -54,7 +54,7 @@ public: ~TLSSocket_GnuTLS(); - void handshake(shared_ptr toHandler = null); + void handshake(); shared_ptr getPeerCertificates() const; @@ -78,6 +78,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr getTimeoutHandler(); + private: void internalThrow(); @@ -99,7 +101,6 @@ private: byte_t m_buffer[65536]; bool m_handshaking; - shared_ptr m_toHandler; exception* m_ex; diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp index 9857b8fb..595a0091 100644 --- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp +++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp @@ -87,12 +87,6 @@ TLSSocket_OpenSSL::~TLSSocket_OpenSSL() try { disconnect(); - - if (m_ssl) - { - SSL_free(m_ssl); - m_ssl = 0; - } } catch (...) { @@ -130,32 +124,41 @@ void TLSSocket_OpenSSL::createSSLHandle() void TLSSocket_OpenSSL::connect(const string& address, const port_t port) { - m_wrapped->connect(address, port); + try + { + m_wrapped->connect(address, port); - createSSLHandle(); + createSSLHandle(); - handshake(null); - - m_connected = true; + handshake(); + } + catch (...) + { + disconnect(); + throw; + } } void TLSSocket_OpenSSL::disconnect() { + if (m_ssl) + { + // Don't shut down the socket more than once. + int shutdownState = SSL_get_shutdown(m_ssl); + bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; + + if (!shutdownSent) + SSL_shutdown(m_ssl); + + SSL_free(m_ssl); + m_ssl = 0; + } + if (m_connected) { - if (m_ssl) - { - // Don't shut down the socket more than once. - int shutdownState = SSL_get_shutdown(m_ssl); - bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; - - if (!shutdownSent) - SSL_shutdown(m_ssl); - } - - m_wrapped->disconnect(); m_connected = false; + m_wrapped->disconnect(); } } @@ -184,6 +187,12 @@ const string TLSSocket_OpenSSL::getPeerAddress() const } +shared_ptr TLSSocket_OpenSSL::getTimeoutHandler() +{ + return m_wrapped->getTimeoutHandler(); +} + + void TLSSocket_OpenSSL::receive(string& buffer) { const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); @@ -209,6 +218,9 @@ void TLSSocket_OpenSSL::send(const char* str) size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; int rc = SSL_read(m_ssl, buffer, static_cast (count)); @@ -235,6 +247,9 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; for (size_t size = count ; size > 0 ; ) @@ -264,6 +279,9 @@ void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count) { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + m_status &= ~STATUS_WOULDBLOCK; int rc = SSL_write(m_ssl, buffer, static_cast (count)); @@ -288,14 +306,17 @@ size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t } -void TLSSocket_OpenSSL::handshake(shared_ptr toHandler) +void TLSSocket_OpenSSL::handshake() { + if (!m_ssl) + throw exceptions::socket_not_connected_exception(); + + shared_ptr toHandler = m_wrapped->getTimeoutHandler(); + if (toHandler) toHandler->resetTimeOut(); // Start handshaking process - m_toHandler = toHandler; - if (!m_ssl) createSSLHandle(); @@ -318,25 +339,20 @@ void TLSSocket_OpenSSL::handshake(shared_ptr toHandler) } // Check whether the time-out delay is elapsed - if (m_toHandler && m_toHandler->isTimeOut()) + if (toHandler && toHandler->isTimeOut()) { - if (!m_toHandler->handleTimeOut()) + if (!toHandler->handleTimeOut()) throw exceptions::operation_timed_out(); - m_toHandler->resetTimeOut(); + toHandler->resetTimeOut(); } } } catch (...) { - SSL_free(m_ssl); - m_ssl = 0; - m_toHandler = null; throw; } - m_toHandler = null; - // Verify server's certificate(s) shared_ptr certs = getPeerCertificates(); @@ -401,6 +417,8 @@ void TLSSocket_OpenSSL::handleError(int rc) switch (sslError) { case SSL_ERROR_ZERO_RETURN: + + disconnect(); return; case SSL_ERROR_SYSCALL: @@ -413,8 +431,7 @@ void TLSSocket_OpenSSL::handleError(int rc) } else { - vmime::string msg; - std::ostringstream oss(msg); + std::ostringstream oss; oss << "The BIO reported an error: " << rc; oss.flush(); throw exceptions::tls_exception(oss.str()); diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp index 410fffcf..5fbed19d 100644 --- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp +++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp @@ -58,7 +58,7 @@ public: ~TLSSocket_OpenSSL(); - void handshake(shared_ptr toHandler = null); + void handshake(); shared_ptr getPeerCertificates() const; @@ -82,6 +82,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr getTimeoutHandler(); + private: static BIO_METHOD sm_customBIOMethod; @@ -108,8 +110,6 @@ private: byte_t m_buffer[65536]; - shared_ptr m_toHandler; - SSL* m_ssl; unsigned long m_status; diff --git a/src/vmime/platforms/posix/posixSocket.cpp b/src/vmime/platforms/posix/posixSocket.cpp index ab434116..e7eba9f1 100644 --- a/src/vmime/platforms/posix/posixSocket.cpp +++ b/src/vmime/platforms/posix/posixSocket.cpp @@ -654,6 +654,12 @@ unsigned int posixSocket::getStatus() const } +shared_ptr posixSocket::getTimeoutHandler() +{ + return m_timeoutHandler; +} + + // // posixSocketFactory diff --git a/src/vmime/platforms/posix/posixSocket.hpp b/src/vmime/platforms/posix/posixSocket.hpp index 4ec3edec..5d29d710 100644 --- a/src/vmime/platforms/posix/posixSocket.hpp +++ b/src/vmime/platforms/posix/posixSocket.hpp @@ -65,6 +65,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr getTimeoutHandler(); + protected: static void throwSocketError(const int err); diff --git a/src/vmime/platforms/windows/windowsSocket.cpp b/src/vmime/platforms/windows/windowsSocket.cpp index cb96481a..bd20e5d4 100644 --- a/src/vmime/platforms/windows/windowsSocket.cpp +++ b/src/vmime/platforms/windows/windowsSocket.cpp @@ -460,6 +460,12 @@ void windowsSocket::waitForData(const WaitOpType t, bool& timedOut) } +shared_ptr windowsSocket::getTimeoutHandler() +{ + return m_timeoutHandler; +} + + // // posixSocketFactory diff --git a/src/vmime/platforms/windows/windowsSocket.hpp b/src/vmime/platforms/windows/windowsSocket.hpp index cb8a6e67..31e1488b 100644 --- a/src/vmime/platforms/windows/windowsSocket.hpp +++ b/src/vmime/platforms/windows/windowsSocket.hpp @@ -69,6 +69,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr getTimeoutHandler(); + protected: void throwSocketError(const int err); diff --git a/src/vmime/security/sasl/SASLSocket.cpp b/src/vmime/security/sasl/SASLSocket.cpp index 12d634c2..541fc904 100644 --- a/src/vmime/security/sasl/SASLSocket.cpp +++ b/src/vmime/security/sasl/SASLSocket.cpp @@ -96,6 +96,12 @@ const string SASLSocket::getPeerAddress() const } +shared_ptr SASLSocket::getTimeoutHandler() +{ + return m_wrapped->getTimeoutHandler(); +} + + void SASLSocket::receive(string& buffer) { const size_t n = receiveRaw(m_recvBuffer, sizeof(m_recvBuffer)); diff --git a/src/vmime/security/sasl/SASLSocket.hpp b/src/vmime/security/sasl/SASLSocket.hpp index e52911b4..d2d82411 100644 --- a/src/vmime/security/sasl/SASLSocket.hpp +++ b/src/vmime/security/sasl/SASLSocket.hpp @@ -73,6 +73,8 @@ public: const string getPeerName() const; const string getPeerAddress() const; + shared_ptr getTimeoutHandler(); + private: shared_ptr m_session; diff --git a/tests/testUtils.cpp b/tests/testUtils.cpp index 437b476b..ee642bea 100644 --- a/tests/testUtils.cpp +++ b/tests/testUtils.cpp @@ -79,6 +79,12 @@ const vmime::string testSocket::getPeerAddress() const } +vmime::shared_ptr testSocket::getTimeoutHandler() +{ + return vmime::null; +} + + void testSocket::receive(vmime::string& buffer) { buffer = m_inBuffer; diff --git a/tests/testUtils.hpp b/tests/testUtils.hpp index f0ecf454..9e72158a 100644 --- a/tests/testUtils.hpp +++ b/tests/testUtils.hpp @@ -263,6 +263,8 @@ public: const vmime::string getPeerName() const; const vmime::string getPeerAddress() const; + vmime::shared_ptr getTimeoutHandler(); + /** Send data to client. * * @param buffer data to send