aboutsummaryrefslogtreecommitdiffstats
path: root/src/vmime/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'src/vmime/net/tls')
-rw-r--r--src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp123
-rw-r--r--src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp5
-rw-r--r--src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp62
-rw-r--r--src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp5
4 files changed, 115 insertions, 80 deletions
diff --git a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp
index 13b7eb24..3832326c 100644
--- a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp
+++ b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.cpp
@@ -30,6 +30,8 @@
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
+#include <errno.h>
+
#include "vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp"
#include "vmime/net/tls/gnutls/TLSSession_GnuTLS.hpp"
@@ -57,7 +59,7 @@ shared_ptr <TLSSocket> TLSSocket::wrap(shared_ptr <TLSSession> session, shared_p
TLSSocket_GnuTLS::TLSSocket_GnuTLS(shared_ptr <TLSSession_GnuTLS> session, shared_ptr <socket> sok)
: m_session(session), m_wrapped(sok), m_connected(false),
- m_handshaking(false), m_ex(NULL), m_status(0)
+ m_ex(NULL), m_status(0)
{
gnutls_transport_set_ptr(*m_session->m_gnutlsSession, this);
@@ -144,6 +146,18 @@ shared_ptr <timeoutHandler> TLSSocket_GnuTLS::getTimeoutHandler()
}
+bool TLSSocket_GnuTLS::waitForRead(const int msecs)
+{
+ return m_wrapped->waitForRead(msecs);
+}
+
+
+bool TLSSocket_GnuTLS::waitForWrite(const int msecs)
+{
+ return m_wrapped->waitForWrite(msecs);
+}
+
+
void TLSSocket_GnuTLS::receive(string& buffer)
{
const size_t size = receiveRaw(m_buffer, sizeof(m_buffer));
@@ -165,7 +179,7 @@ void TLSSocket_GnuTLS::send(const char* str)
size_t TLSSocket_GnuTLS::receiveRaw(byte_t* buffer, const size_t count)
{
- m_status &= ~STATUS_WOULDBLOCK;
+ m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
const ssize_t ret = gnutls_record_recv
(*m_session->m_gnutlsSession,
@@ -178,7 +192,11 @@ size_t TLSSocket_GnuTLS::receiveRaw(byte_t* buffer, const size_t count)
{
if (ret == GNUTLS_E_AGAIN)
{
- m_status |= STATUS_WOULDBLOCK;
+ if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0)
+ m_status |= STATUS_WANT_READ;
+ else
+ m_status |= STATUS_WANT_WRITE;
+
return 0;
}
@@ -191,7 +209,7 @@ size_t TLSSocket_GnuTLS::receiveRaw(byte_t* buffer, const size_t count)
void TLSSocket_GnuTLS::sendRaw(const byte_t* buffer, const size_t count)
{
- m_status &= ~STATUS_WOULDBLOCK;
+ m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
for (size_t size = count ; size > 0 ; )
{
@@ -206,7 +224,11 @@ void TLSSocket_GnuTLS::sendRaw(const byte_t* buffer, const size_t count)
{
if (ret == GNUTLS_E_AGAIN)
{
- platform::getHandler()->wait();
+ if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0)
+ m_wrapped->waitForRead();
+ else
+ m_wrapped->waitForWrite();
+
continue;
}
@@ -223,6 +245,8 @@ void TLSSocket_GnuTLS::sendRaw(const byte_t* buffer, const size_t count)
size_t TLSSocket_GnuTLS::sendRawNonBlocking(const byte_t* buffer, const size_t count)
{
+ m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
+
ssize_t ret = gnutls_record_send
(*m_session->m_gnutlsSession,
buffer, static_cast <size_t>(count));
@@ -234,7 +258,11 @@ size_t TLSSocket_GnuTLS::sendRawNonBlocking(const byte_t* buffer, const size_t c
{
if (ret == GNUTLS_E_AGAIN)
{
- m_status |= STATUS_WOULDBLOCK;
+ if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0)
+ m_status |= STATUS_WANT_READ;
+ else
+ m_status |= STATUS_WANT_WRITE;
+
return 0;
}
@@ -259,8 +287,6 @@ void TLSSocket_GnuTLS::handshake()
toHandler->resetTimeOut();
// Start handshaking process
- m_handshaking = true;
-
try
{
while (true)
@@ -272,11 +298,17 @@ void TLSSocket_GnuTLS::handshake()
if (ret < 0)
{
- if (ret == GNUTLS_E_AGAIN ||
- ret == GNUTLS_E_INTERRUPTED)
+ if (ret == GNUTLS_E_AGAIN)
+ {
+ if (gnutls_record_get_direction(*m_session->m_gnutlsSession) == 0)
+ m_wrapped->waitForRead();
+ else
+ m_wrapped->waitForWrite();
+ }
+ else if (ret == GNUTLS_E_INTERRUPTED)
{
// Non-fatal error
- platform::getHandler()->wait();
+ m_wrapped->waitForRead();
}
else
{
@@ -292,12 +324,9 @@ void TLSSocket_GnuTLS::handshake()
}
catch (...)
{
- m_handshaking = false;
throw;
}
- m_handshaking = false;
-
// Verify server's certificate(s)
shared_ptr <security::cert::certificateChain> certs = getPeerCertificates();
@@ -321,14 +350,21 @@ ssize_t TLSSocket_GnuTLS::gnutlsPushFunc
(sok->m_wrapped->sendRawNonBlocking
(reinterpret_cast <const byte_t*>(data), len));
- if (ret == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
- return GNUTLS_E_AGAIN;
+ if (ret == 0)
+ {
+ if (sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
+ gnutls_transport_set_errno(*sok->m_session->m_gnutlsSession, EAGAIN);
+ else
+ gnutls_transport_set_errno(*sok->m_session->m_gnutlsSession, 0);
+
+ return -1;
+ }
return ret;
}
catch (exception& e)
{
- // Workaround for bad behaviour when throwing C++ exceptions
+ // Workaround for non-portable behaviour when throwing C++ exceptions
// from C functions (GNU TLS)
sok->m_ex = e.clone();
return -1;
@@ -343,54 +379,25 @@ ssize_t TLSSocket_GnuTLS::gnutlsPullFunc
try
{
- // Workaround for cross-platform asynchronous handshaking:
- // gnutls_handshake() only returns GNUTLS_E_AGAIN if recv()
- // returns -1 and errno is set to EGAIN...
- if (sok->m_handshaking)
- {
- shared_ptr <timeoutHandler> toHandler = sok->m_wrapped->getTimeoutHandler();
-
- while (true)
- {
- const ssize_t ret = static_cast <ssize_t>
- (sok->m_wrapped->receiveRaw
- (reinterpret_cast <byte_t*>(data), len));
+ const ssize_t n = static_cast <ssize_t>
+ (sok->m_wrapped->receiveRaw
+ (reinterpret_cast <byte_t*>(data), len));
- if (ret == 0)
- {
- // No data available yet
- platform::getHandler()->wait();
- }
- else
- {
- return ret;
- }
-
- // Check whether the time-out delay is elapsed
- if (toHandler && toHandler->isTimeOut())
- {
- if (!toHandler->handleTimeOut())
- throw exceptions::operation_timed_out();
-
- toHandler->resetTimeOut();
- }
- }
- }
- else
+ if (n == 0)
{
- const ssize_t n = static_cast <ssize_t>
- (sok->m_wrapped->receiveRaw
- (reinterpret_cast <byte_t*>(data), len));
-
- if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
- return GNUTLS_E_AGAIN;
+ if (sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
+ gnutls_transport_set_errno(*sok->m_session->m_gnutlsSession, EAGAIN);
+ else
+ gnutls_transport_set_errno(*sok->m_session->m_gnutlsSession, 0);
- return n;
+ return -1;
}
+
+ return n;
}
catch (exception& e)
{
- // Workaround for bad behaviour when throwing C++ exceptions
+ // Workaround for non-portable behaviour when throwing C++ exceptions
// from C functions (GNU TLS)
sok->m_ex = e.clone();
return -1;
diff --git a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp
index ddba9d0e..faa3a423 100644
--- a/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp
+++ b/src/vmime/net/tls/gnutls/TLSSocket_GnuTLS.hpp
@@ -63,6 +63,9 @@ public:
void disconnect();
bool isConnected() const;
+ bool waitForRead(const int msecs = 30000);
+ bool waitForWrite(const int msecs = 30000);
+
void receive(string& buffer);
size_t receiveRaw(byte_t* buffer, const size_t count);
@@ -100,8 +103,6 @@ private:
byte_t m_buffer[65536];
- bool m_handshaking;
-
exception* m_ex;
unsigned int m_status;
diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp
index 595a0091..bec41612 100644
--- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp
+++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.cpp
@@ -113,7 +113,7 @@ void TLSSocket_OpenSSL::createSSLHandle()
SSL_set_bio(m_ssl, sockBio, sockBio);
SSL_set_connect_state(m_ssl);
- SSL_set_mode(m_ssl, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
+ SSL_set_mode(m_ssl, SSL_MODE_AUTO_RETRY | SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
}
else
{
@@ -193,6 +193,18 @@ shared_ptr <timeoutHandler> TLSSocket_OpenSSL::getTimeoutHandler()
}
+bool TLSSocket_OpenSSL::waitForRead(const int msecs)
+{
+ return m_wrapped->waitForRead(msecs);
+}
+
+
+bool TLSSocket_OpenSSL::waitForWrite(const int msecs)
+{
+ return m_wrapped->waitForWrite(msecs);
+}
+
+
void TLSSocket_OpenSSL::receive(string& buffer)
{
const size_t size = receiveRaw(m_buffer, sizeof(m_buffer));
@@ -221,7 +233,7 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count)
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
- m_status &= ~STATUS_WOULDBLOCK;
+ m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
int rc = SSL_read(m_ssl, buffer, static_cast <int>(count));
@@ -232,9 +244,14 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count)
{
int error = SSL_get_error(m_ssl, rc);
- if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ)
+ if (error == SSL_ERROR_WANT_WRITE)
+ {
+ m_status |= STATUS_WANT_WRITE;
+ return 0;
+ }
+ else if (error == SSL_ERROR_WANT_READ)
{
- m_status |= STATUS_WOULDBLOCK;
+ m_status |= STATUS_WANT_READ;
return 0;
}
@@ -250,7 +267,7 @@ void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count)
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
- m_status &= ~STATUS_WOULDBLOCK;
+ m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
for (size_t size = count ; size > 0 ; )
{
@@ -260,9 +277,14 @@ void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count)
{
int error = SSL_get_error(m_ssl, rc);
- if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ)
+ if (error == SSL_ERROR_WANT_READ)
+ {
+ m_wrapped->waitForRead();
+ continue;
+ }
+ else if (error == SSL_ERROR_WANT_WRITE)
{
- platform::getHandler()->wait();
+ m_wrapped->waitForWrite();
continue;
}
@@ -282,7 +304,7 @@ size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t
if (!m_ssl)
throw exceptions::socket_not_connected_exception();
- m_status &= ~STATUS_WOULDBLOCK;
+ m_status &= ~(STATUS_WANT_WRITE | STATUS_WANT_READ);
int rc = SSL_write(m_ssl, buffer, static_cast <int>(count));
@@ -293,9 +315,14 @@ size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t
{
int error = SSL_get_error(m_ssl, rc);
- if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ)
+ if (error == SSL_ERROR_WANT_WRITE)
{
- m_status |= STATUS_WOULDBLOCK;
+ m_status |= STATUS_WANT_WRITE;
+ return 0;
+ }
+ else if (error == SSL_ERROR_WANT_READ)
+ {
+ m_status |= STATUS_WANT_READ;
return 0;
}
@@ -328,15 +355,12 @@ void TLSSocket_OpenSSL::handshake()
{
const int err = SSL_get_error(m_ssl, rc);
- if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
- {
- // No data available yet
- platform::getHandler()->wait();
- }
- else
- {
- handleError(rc);
- }
+ if (err == SSL_ERROR_WANT_READ)
+ m_wrapped->waitForRead();
+ else if (err == SSL_ERROR_WANT_WRITE)
+ m_wrapped->waitForWrite();
+ else
+ handleError(rc);
// Check whether the time-out delay is elapsed
if (toHandler && toHandler->isTimeOut())
diff --git a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp
index 5fbed19d..20712263 100644
--- a/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp
+++ b/src/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp
@@ -67,6 +67,9 @@ public:
void disconnect();
bool isConnected() const;
+ bool waitForRead(const int msecs = 30000);
+ bool waitForWrite(const int msecs = 30000);
+
void receive(string& buffer);
size_t receiveRaw(byte_t* buffer, const size_t count);
@@ -115,7 +118,7 @@ private:
unsigned long m_status;
// Last exception thrown from C BIO functions
- std::auto_ptr <std::exception> m_ex;
+ std::auto_ptr <exception> m_ex;
};