diff --git a/src/net/tls/openssl/TLSSocket_OpenSSL.cpp b/src/net/tls/openssl/TLSSocket_OpenSSL.cpp index e08041bf..ef6647d6 100644 --- a/src/net/tls/openssl/TLSSocket_OpenSSL.cpp +++ b/src/net/tls/openssl/TLSSocket_OpenSSL.cpp @@ -60,7 +60,7 @@ BIO_METHOD TLSSocket_OpenSSL::sm_customBIOMethod = TLSSocket_OpenSSL::bio_write, TLSSocket_OpenSSL::bio_read, TLSSocket_OpenSSL::bio_puts, - TLSSocket_OpenSSL::bio_gets, + NULL, // gets TLSSocket_OpenSSL::bio_ctrl, TLSSocket_OpenSSL::bio_create, TLSSocket_OpenSSL::bio_destroy, @@ -77,7 +77,7 @@ shared_ptr TLSSocket::wrap(shared_ptr session, shared_p TLSSocket_OpenSSL::TLSSocket_OpenSSL(shared_ptr session, shared_ptr sok) - : m_session(session), m_wrapped(sok), m_connected(false), m_ssl(0), m_ex(NULL) + : m_session(session), m_wrapped(sok), m_connected(false), m_ssl(0), m_status(0), m_ex(NULL) { } @@ -107,6 +107,7 @@ void TLSSocket_OpenSSL::createSSLHandle() { BIO* sockBio = BIO_new(&sm_customBIOMethod); sockBio->ptr = this; + sockBio->init = 1; m_ssl = SSL_new(m_session->getContext()); @@ -207,11 +208,25 @@ void TLSSocket_OpenSSL::send(const char* str) size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) { - int rc = SSL_read(m_ssl, buffer, static_cast (count)); - handleError(rc); + m_status &= ~STATUS_WOULDBLOCK; - if (rc < 0) - return 0; + int rc = SSL_read(m_ssl, buffer, static_cast (count)); + + if (m_ex.get()) + internalThrow(); + + if (rc <= 0) + { + int error = SSL_get_error(m_ssl, rc); + + if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ) + { + m_status |= STATUS_WOULDBLOCK; + return 0; + } + + handleError(rc); + } return rc; } @@ -219,18 +234,31 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count) void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count) { - int rc = SSL_write(m_ssl, buffer, static_cast (count)); - handleError(rc); + sendRawNonBlocking(buffer, count); } size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count) { - int rc = SSL_write(m_ssl, buffer, static_cast (count)); - handleError(rc); + m_status &= ~STATUS_WOULDBLOCK; - if (rc < 0) - rc = 0; + int rc = SSL_write(m_ssl, buffer, static_cast (count)); + + if (m_ex.get()) + internalThrow(); + + if (rc <= 0) + { + int error = SSL_get_error(m_ssl, rc); + + if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ) + { + m_status |= STATUS_WOULDBLOCK; + return 0; + } + + handleError(rc); + } return rc; } @@ -249,9 +277,31 @@ void TLSSocket_OpenSSL::handshake(shared_ptr toHandler) try { - // int ret = SSL_connect(m_ssl); - int ret = SSL_do_handshake(m_ssl); - handleError(ret); + int rc; + + while ((rc = SSL_do_handshake(m_ssl)) <= 0) + { + const int err = SSL_get_error(m_ssl, rc); + + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + // No data available yet + platform::getHandler()->wait(); + } + else + { + handleError(rc); + } + + // Check whether the time-out delay is elapsed + if (m_toHandler && m_toHandler->isTimeOut()) + { + if (!m_toHandler->handleTimeOut()) + throw exceptions::operation_timed_out(); + + m_toHandler->resetTimeOut(); + } + } } catch (...) { @@ -386,7 +436,7 @@ void TLSSocket_OpenSSL::handleError(int rc) unsigned int TLSSocket_OpenSSL::getStatus() const { - return m_wrapped->getStatus(); + return m_status; } @@ -396,23 +446,28 @@ unsigned int TLSSocket_OpenSSL::getStatus() const // static int TLSSocket_OpenSSL::bio_write(BIO* bio, const char* buf, int len) { + BIO_clear_retry_flags(bio); + if (buf == NULL || len <= 0) - return 0; + return -1; TLSSocket_OpenSSL *sok = reinterpret_cast (bio->ptr); + if (!bio->init || !sok) + return -1; + try { - while (true) + const size_t n = sok->m_wrapped->sendRawNonBlocking + (reinterpret_cast (buf), len); + + if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK) { - const size_t n = sok->m_wrapped->sendRawNonBlocking - (reinterpret_cast (buf), len); - - if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK) - continue; - - return static_cast (len); + BIO_set_retry_write(bio); + return -1; } + + return static_cast (len); } catch (exception& e) { @@ -426,23 +481,28 @@ int TLSSocket_OpenSSL::bio_write(BIO* bio, const char* buf, int len) // static int TLSSocket_OpenSSL::bio_read(BIO* bio, char* buf, int len) { + BIO_clear_retry_flags(bio); + if (buf == NULL || len <= 0) - return 0; + return -1; TLSSocket_OpenSSL *sok = reinterpret_cast (bio->ptr); + if (!bio->init || !sok) + return -1; + try { - while (true) + const size_t n = sok->m_wrapped->receiveRaw + (reinterpret_cast (buf), len); + + if (n == 0 || sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK) { - const size_t n = sok->m_wrapped->receiveRaw - (reinterpret_cast (buf), len); - - if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK) - continue; - - return static_cast (n); + BIO_set_retry_read(bio); + return -1; } + + return static_cast (n); } catch (exception& e) { @@ -461,29 +521,53 @@ int TLSSocket_OpenSSL::bio_puts(BIO* bio, const char* str) // static -int TLSSocket_OpenSSL::bio_gets(BIO* /* bio */, char* /* buf */, int /* len */) +long TLSSocket_OpenSSL::bio_ctrl(BIO* bio, int cmd, long num, void* ptr) { - return -1; -} + long ret = 1; - -// static -long TLSSocket_OpenSSL::bio_ctrl(BIO* /* bio */, int cmd, long /* num */, void* /* ptr */) -{ - if (cmd == BIO_CTRL_FLUSH) + switch (cmd) { - // OpenSSL library needs this - return 1; + case BIO_CTRL_INFO: + + ret = 0; + break; + + case BIO_CTRL_GET_CLOSE: + + ret = bio->shutdown; + break; + + case BIO_CTRL_SET_CLOSE: + + bio->shutdown = static_cast (num); + break; + + case BIO_CTRL_PENDING: + case BIO_CTRL_WPENDING: + + ret = 0; + break; + + case BIO_CTRL_DUP: + case BIO_CTRL_FLUSH: + + ret = 1; + break; + + default: + + ret = 0; + break; } - return 0; + return ret; } // static int TLSSocket_OpenSSL::bio_create(BIO* bio) { - bio->init = 1; + bio->init = 0; bio->num = 0; bio->ptr = NULL; bio->flags = 0; @@ -498,9 +582,12 @@ int TLSSocket_OpenSSL::bio_destroy(BIO* bio) if (bio == NULL) return 0; - bio->ptr = NULL; - bio->init = 0; - bio->flags = 0; + if (bio->shutdown) + { + bio->ptr = NULL; + bio->init = 0; + bio->flags = 0; + } return 1; } diff --git a/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp b/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp index 3dda9fa5..410fffcf 100644 --- a/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp +++ b/vmime/net/tls/openssl/TLSSocket_OpenSSL.hpp @@ -112,6 +112,8 @@ private: SSL* m_ssl; + unsigned long m_status; + // Last exception thrown from C BIO functions std::auto_ptr m_ex; };