In SSL socket, use timeout handler of underlying socket. Throw exception when reading from/writing to disconnected SSL socket.

This commit is contained in:
Vincent Richard 2014-01-19 16:36:45 +01:00
parent b6469f68a8
commit 2afe340b7b
18 changed files with 153 additions and 55 deletions

View File

@ -327,6 +327,19 @@ exception* socket_exception::clone() const { return new socket_exception(*this);
const char* socket_exception::name() const throw() { return "socket_exception"; } const char* socket_exception::name() const throw() { return "socket_exception"; }
//
// socket_not_connected_exception
//
socket_not_connected_exception::~socket_not_connected_exception() throw() {}
socket_not_connected_exception::socket_not_connected_exception(const string& what, const exception& other)
: socket_exception(what.empty()
? "Socket is not connected." : what, other) {}
exception* socket_not_connected_exception::clone() const { return new socket_not_connected_exception(*this); }
const char* socket_not_connected_exception::name() const throw() { return "socket_not_connected_exception"; }
// //
// connection_error // connection_error
// //

View File

@ -370,6 +370,23 @@ public:
}; };
/** Socket not connected: you are trying to write to/read from a socket which
* is not connected to a peer.
*/
class VMIME_EXPORT socket_not_connected_exception : public socket_exception
{
public:
socket_not_connected_exception(const string& what = "", const exception& other = NO_EXCEPTION);
~socket_not_connected_exception() throw();
exception* clone() const;
const char* name() const throw();
};
/** Error while connecting to the server: this may be a DNS resolution error /** Error while connecting to the server: this may be a DNS resolution error
* or a connection error (for example, time-out while connecting). * or a connection error (for example, time-out while connecting).
*/ */

View File

@ -482,7 +482,7 @@ void IMAPConnection::startTLS()
shared_ptr <tls::TLSSocket> tlsSocket = shared_ptr <tls::TLSSocket> tlsSocket =
tlsSession->getSocket(m_socket); tlsSession->getSocket(m_socket);
tlsSocket->handshake(m_timeoutHandler); tlsSocket->handshake();
m_socket = tlsSocket; m_socket = tlsSocket;
m_parser->setSocket(m_socket); m_parser->setSocket(m_socket);

View File

@ -552,7 +552,7 @@ void POP3Connection::startTLS()
shared_ptr <tls::TLSSocket> tlsSocket = shared_ptr <tls::TLSSocket> tlsSocket =
tlsSession->getSocket(m_socket); tlsSession->getSocket(m_socket);
tlsSocket->handshake(m_timeoutHandler); tlsSocket->handshake();
m_socket = tlsSocket; m_socket = tlsSocket;

View File

@ -141,6 +141,12 @@ public:
*/ */
virtual const string getPeerAddress() const = 0; virtual const string getPeerAddress() const = 0;
/** Return the timeout handler associated with this socket.
*
* @return timeout handler, or NULL if none is set
*/
virtual shared_ptr <timeoutHandler> getTimeoutHandler() = 0;
protected: protected:
socket() { } socket() { }

View File

@ -67,7 +67,7 @@ public:
* during the negociation process, exceptions::operation_timed_out * during the negociation process, exceptions::operation_timed_out
* if a time-out occurs * if a time-out occurs
*/ */
virtual void handshake(shared_ptr <timeoutHandler> toHandler = null) = 0; virtual void handshake() = 0;
/** Return the peer's certificate (chain) as sent by the peer. /** Return the peer's certificate (chain) as sent by the peer.
* *

View File

@ -87,11 +87,17 @@ TLSSocket_GnuTLS::~TLSSocket_GnuTLS()
void TLSSocket_GnuTLS::connect(const string& address, const port_t port) void TLSSocket_GnuTLS::connect(const string& address, const port_t port)
{ {
try
{
m_wrapped->connect(address, port); m_wrapped->connect(address, port);
handshake(null); handshake();
}
m_connected = true; catch (...)
{
disconnect();
throw;
}
} }
@ -132,6 +138,12 @@ const string TLSSocket_GnuTLS::getPeerAddress() const
} }
shared_ptr <timeoutHandler> TLSSocket_GnuTLS::getTimeoutHandler()
{
return m_wrapped->getTimeoutHandler();
}
void TLSSocket_GnuTLS::receive(string& buffer) void TLSSocket_GnuTLS::receive(string& buffer)
{ {
const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); const size_t size = receiveRaw(m_buffer, sizeof(m_buffer));
@ -239,14 +251,15 @@ unsigned int TLSSocket_GnuTLS::getStatus() const
} }
void TLSSocket_GnuTLS::handshake(shared_ptr <timeoutHandler> toHandler) void TLSSocket_GnuTLS::handshake()
{ {
shared_ptr <timeoutHandler> toHandler = m_wrapped->getTimeoutHandler();
if (toHandler) if (toHandler)
toHandler->resetTimeOut(); toHandler->resetTimeOut();
// Start handshaking process // Start handshaking process
m_handshaking = true; m_handshaking = true;
m_toHandler = toHandler;
try try
{ {
@ -280,13 +293,10 @@ void TLSSocket_GnuTLS::handshake(shared_ptr <timeoutHandler> toHandler)
catch (...) catch (...)
{ {
m_handshaking = false; m_handshaking = false;
m_toHandler = null;
throw; throw;
} }
m_handshaking = false; m_handshaking = false;
m_toHandler = null;
// Verify server's certificate(s) // Verify server's certificate(s)
shared_ptr <security::cert::certificateChain> certs = getPeerCertificates(); shared_ptr <security::cert::certificateChain> certs = getPeerCertificates();
@ -338,6 +348,8 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc
// returns -1 and errno is set to EGAIN... // returns -1 and errno is set to EGAIN...
if (sok->m_handshaking) if (sok->m_handshaking)
{ {
shared_ptr <timeoutHandler> toHandler = sok->m_wrapped->getTimeoutHandler();
while (true) while (true)
{ {
const ssize_t ret = static_cast <ssize_t> const ssize_t ret = static_cast <ssize_t>
@ -355,12 +367,12 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc
} }
// Check whether the time-out delay is elapsed // Check whether the time-out delay is elapsed
if (sok->m_toHandler && sok->m_toHandler->isTimeOut()) if (toHandler && toHandler->isTimeOut())
{ {
if (!sok->m_toHandler->handleTimeOut()) if (!toHandler->handleTimeOut())
throw exceptions::operation_timed_out(); throw exceptions::operation_timed_out();
sok->m_toHandler->resetTimeOut(); toHandler->resetTimeOut();
} }
} }
} }

View File

@ -54,7 +54,7 @@ public:
~TLSSocket_GnuTLS(); ~TLSSocket_GnuTLS();
void handshake(shared_ptr <timeoutHandler> toHandler = null); void handshake();
shared_ptr <security::cert::certificateChain> getPeerCertificates() const; shared_ptr <security::cert::certificateChain> getPeerCertificates() const;
@ -78,6 +78,8 @@ public:
const string getPeerName() const; const string getPeerName() const;
const string getPeerAddress() const; const string getPeerAddress() const;
shared_ptr <timeoutHandler> getTimeoutHandler();
private: private:
void internalThrow(); void internalThrow();
@ -99,7 +101,6 @@ private:
byte_t m_buffer[65536]; byte_t m_buffer[65536];
bool m_handshaking; bool m_handshaking;
shared_ptr <timeoutHandler> m_toHandler;
exception* m_ex; exception* m_ex;

View File

@ -87,12 +87,6 @@ TLSSocket_OpenSSL::~TLSSocket_OpenSSL()
try try
{ {
disconnect(); disconnect();
if (m_ssl)
{
SSL_free(m_ssl);
m_ssl = 0;
}
} }
catch (...) catch (...)
{ {
@ -130,20 +124,24 @@ void TLSSocket_OpenSSL::createSSLHandle()
void TLSSocket_OpenSSL::connect(const string& address, const port_t port) void TLSSocket_OpenSSL::connect(const string& address, const port_t port)
{ {
try
{
m_wrapped->connect(address, port); m_wrapped->connect(address, port);
createSSLHandle(); createSSLHandle();
handshake(null); handshake();
}
m_connected = true; catch (...)
{
disconnect();
throw;
}
} }
void TLSSocket_OpenSSL::disconnect() void TLSSocket_OpenSSL::disconnect()
{ {
if (m_connected)
{
if (m_ssl) if (m_ssl)
{ {
// Don't shut down the socket more than once. // Don't shut down the socket more than once.
@ -152,10 +150,15 @@ void TLSSocket_OpenSSL::disconnect()
if (!shutdownSent) if (!shutdownSent)
SSL_shutdown(m_ssl); SSL_shutdown(m_ssl);
SSL_free(m_ssl);
m_ssl = 0;
} }
m_wrapped->disconnect(); if (m_connected)
{
m_connected = false; m_connected = false;
m_wrapped->disconnect();
} }
} }
@ -184,6 +187,12 @@ const string TLSSocket_OpenSSL::getPeerAddress() const
} }
shared_ptr <timeoutHandler> TLSSocket_OpenSSL::getTimeoutHandler()
{
return m_wrapped->getTimeoutHandler();
}
void TLSSocket_OpenSSL::receive(string& buffer) void TLSSocket_OpenSSL::receive(string& buffer)
{ {
const size_t size = receiveRaw(m_buffer, sizeof(m_buffer)); const size_t size = receiveRaw(m_buffer, sizeof(m_buffer));
@ -209,6 +218,9 @@ void TLSSocket_OpenSSL::send(const char* str)
size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count)
{ {
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
m_status &= ~STATUS_WOULDBLOCK; m_status &= ~STATUS_WOULDBLOCK;
int rc = SSL_read(m_ssl, buffer, static_cast <int>(count)); int rc = SSL_read(m_ssl, buffer, static_cast <int>(count));
@ -235,6 +247,9 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count)
void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count)
{ {
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
m_status &= ~STATUS_WOULDBLOCK; m_status &= ~STATUS_WOULDBLOCK;
for (size_t size = count ; size > 0 ; ) for (size_t size = count ; size > 0 ; )
@ -264,6 +279,9 @@ void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count)
size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count) size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count)
{ {
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
m_status &= ~STATUS_WOULDBLOCK; m_status &= ~STATUS_WOULDBLOCK;
int rc = SSL_write(m_ssl, buffer, static_cast <int>(count)); int rc = SSL_write(m_ssl, buffer, static_cast <int>(count));
@ -288,14 +306,17 @@ size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t
} }
void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler) void TLSSocket_OpenSSL::handshake()
{ {
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
shared_ptr <timeoutHandler> toHandler = m_wrapped->getTimeoutHandler();
if (toHandler) if (toHandler)
toHandler->resetTimeOut(); toHandler->resetTimeOut();
// Start handshaking process // Start handshaking process
m_toHandler = toHandler;
if (!m_ssl) if (!m_ssl)
createSSLHandle(); createSSLHandle();
@ -318,25 +339,20 @@ void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler)
} }
// Check whether the time-out delay is elapsed // Check whether the time-out delay is elapsed
if (m_toHandler && m_toHandler->isTimeOut()) if (toHandler && toHandler->isTimeOut())
{ {
if (!m_toHandler->handleTimeOut()) if (!toHandler->handleTimeOut())
throw exceptions::operation_timed_out(); throw exceptions::operation_timed_out();
m_toHandler->resetTimeOut(); toHandler->resetTimeOut();
} }
} }
} }
catch (...) catch (...)
{ {
SSL_free(m_ssl);
m_ssl = 0;
m_toHandler = null;
throw; throw;
} }
m_toHandler = null;
// Verify server's certificate(s) // Verify server's certificate(s)
shared_ptr <security::cert::certificateChain> certs = getPeerCertificates(); shared_ptr <security::cert::certificateChain> certs = getPeerCertificates();
@ -401,6 +417,8 @@ void TLSSocket_OpenSSL::handleError(int rc)
switch (sslError) switch (sslError)
{ {
case SSL_ERROR_ZERO_RETURN: case SSL_ERROR_ZERO_RETURN:
disconnect();
return; return;
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
@ -413,8 +431,7 @@ void TLSSocket_OpenSSL::handleError(int rc)
} }
else else
{ {
vmime::string msg; std::ostringstream oss;
std::ostringstream oss(msg);
oss << "The BIO reported an error: " << rc; oss << "The BIO reported an error: " << rc;
oss.flush(); oss.flush();
throw exceptions::tls_exception(oss.str()); throw exceptions::tls_exception(oss.str());

View File

@ -58,7 +58,7 @@ public:
~TLSSocket_OpenSSL(); ~TLSSocket_OpenSSL();
void handshake(shared_ptr <timeoutHandler> toHandler = null); void handshake();
shared_ptr <security::cert::certificateChain> getPeerCertificates() const; shared_ptr <security::cert::certificateChain> getPeerCertificates() const;
@ -82,6 +82,8 @@ public:
const string getPeerName() const; const string getPeerName() const;
const string getPeerAddress() const; const string getPeerAddress() const;
shared_ptr <timeoutHandler> getTimeoutHandler();
private: private:
static BIO_METHOD sm_customBIOMethod; static BIO_METHOD sm_customBIOMethod;
@ -108,8 +110,6 @@ private:
byte_t m_buffer[65536]; byte_t m_buffer[65536];
shared_ptr <timeoutHandler> m_toHandler;
SSL* m_ssl; SSL* m_ssl;
unsigned long m_status; unsigned long m_status;

View File

@ -654,6 +654,12 @@ unsigned int posixSocket::getStatus() const
} }
shared_ptr <net::timeoutHandler> posixSocket::getTimeoutHandler()
{
return m_timeoutHandler;
}
// //
// posixSocketFactory // posixSocketFactory

View File

@ -65,6 +65,8 @@ public:
const string getPeerName() const; const string getPeerName() const;
const string getPeerAddress() const; const string getPeerAddress() const;
shared_ptr <net::timeoutHandler> getTimeoutHandler();
protected: protected:
static void throwSocketError(const int err); static void throwSocketError(const int err);

View File

@ -460,6 +460,12 @@ void windowsSocket::waitForData(const WaitOpType t, bool& timedOut)
} }
shared_ptr <net::timeoutHandler> windowsSocket::getTimeoutHandler()
{
return m_timeoutHandler;
}
// //
// posixSocketFactory // posixSocketFactory

View File

@ -69,6 +69,8 @@ public:
const string getPeerName() const; const string getPeerName() const;
const string getPeerAddress() const; const string getPeerAddress() const;
shared_ptr <net::timeoutHandler> getTimeoutHandler();
protected: protected:
void throwSocketError(const int err); void throwSocketError(const int err);

View File

@ -96,6 +96,12 @@ const string SASLSocket::getPeerAddress() const
} }
shared_ptr <net::timeoutHandler> SASLSocket::getTimeoutHandler()
{
return m_wrapped->getTimeoutHandler();
}
void SASLSocket::receive(string& buffer) void SASLSocket::receive(string& buffer)
{ {
const size_t n = receiveRaw(m_recvBuffer, sizeof(m_recvBuffer)); const size_t n = receiveRaw(m_recvBuffer, sizeof(m_recvBuffer));

View File

@ -73,6 +73,8 @@ public:
const string getPeerName() const; const string getPeerName() const;
const string getPeerAddress() const; const string getPeerAddress() const;
shared_ptr <net::timeoutHandler> getTimeoutHandler();
private: private:
shared_ptr <SASLSession> m_session; shared_ptr <SASLSession> m_session;

View File

@ -79,6 +79,12 @@ const vmime::string testSocket::getPeerAddress() const
} }
vmime::shared_ptr <vmime::net::timeoutHandler> testSocket::getTimeoutHandler()
{
return vmime::null;
}
void testSocket::receive(vmime::string& buffer) void testSocket::receive(vmime::string& buffer)
{ {
buffer = m_inBuffer; buffer = m_inBuffer;

View File

@ -263,6 +263,8 @@ public:
const vmime::string getPeerName() const; const vmime::string getPeerName() const;
const vmime::string getPeerAddress() const; const vmime::string getPeerAddress() const;
vmime::shared_ptr <vmime::net::timeoutHandler> getTimeoutHandler();
/** Send data to client. /** Send data to client.
* *
* @param buffer data to send * @param buffer data to send