Better error handling. Fixed return values in custom BIO. Added support for SSL_ERROR_WANT_READ/SSL_ERROR_WANT_WRITE in handshaking.

This commit is contained in:
Vincent Richard 2013-12-18 21:21:30 +01:00
parent 645c572ab5
commit 9a4b72b47a
2 changed files with 138 additions and 49 deletions

View File

@ -60,7 +60,7 @@ BIO_METHOD TLSSocket_OpenSSL::sm_customBIOMethod =
TLSSocket_OpenSSL::bio_write,
TLSSocket_OpenSSL::bio_read,
TLSSocket_OpenSSL::bio_puts,
TLSSocket_OpenSSL::bio_gets,
NULL, // gets
TLSSocket_OpenSSL::bio_ctrl,
TLSSocket_OpenSSL::bio_create,
TLSSocket_OpenSSL::bio_destroy,
@ -77,7 +77,7 @@ shared_ptr <TLSSocket> TLSSocket::wrap(shared_ptr <TLSSession> session, shared_p
TLSSocket_OpenSSL::TLSSocket_OpenSSL(shared_ptr <TLSSession_OpenSSL> session, shared_ptr <socket> sok)
: m_session(session), m_wrapped(sok), m_connected(false), m_ssl(0), m_ex(NULL)
: m_session(session), m_wrapped(sok), m_connected(false), m_ssl(0), m_status(0), m_ex(NULL)
{
}
@ -107,6 +107,7 @@ void TLSSocket_OpenSSL::createSSLHandle()
{
BIO* sockBio = BIO_new(&sm_customBIOMethod);
sockBio->ptr = this;
sockBio->init = 1;
m_ssl = SSL_new(m_session->getContext());
@ -207,11 +208,25 @@ void TLSSocket_OpenSSL::send(const char* str)
size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count)
{
int rc = SSL_read(m_ssl, buffer, static_cast <int>(count));
handleError(rc);
m_status &= ~STATUS_WOULDBLOCK;
if (rc < 0)
return 0;
int rc = SSL_read(m_ssl, buffer, static_cast <int>(count));
if (m_ex.get())
internalThrow();
if (rc <= 0)
{
int error = SSL_get_error(m_ssl, rc);
if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ)
{
m_status |= STATUS_WOULDBLOCK;
return 0;
}
handleError(rc);
}
return rc;
}
@ -219,18 +234,31 @@ size_t TLSSocket_OpenSSL::receiveRaw(byte_t* buffer, const size_t count)
void TLSSocket_OpenSSL::sendRaw(const byte_t* buffer, const size_t count)
{
int rc = SSL_write(m_ssl, buffer, static_cast <int>(count));
handleError(rc);
sendRawNonBlocking(buffer, count);
}
size_t TLSSocket_OpenSSL::sendRawNonBlocking(const byte_t* buffer, const size_t count)
{
int rc = SSL_write(m_ssl, buffer, static_cast <int>(count));
handleError(rc);
m_status &= ~STATUS_WOULDBLOCK;
if (rc < 0)
rc = 0;
int rc = SSL_write(m_ssl, buffer, static_cast <int>(count));
if (m_ex.get())
internalThrow();
if (rc <= 0)
{
int error = SSL_get_error(m_ssl, rc);
if (error == SSL_ERROR_WANT_WRITE || error == SSL_ERROR_WANT_READ)
{
m_status |= STATUS_WOULDBLOCK;
return 0;
}
handleError(rc);
}
return rc;
}
@ -249,9 +277,31 @@ void TLSSocket_OpenSSL::handshake(shared_ptr <timeoutHandler> toHandler)
try
{
// int ret = SSL_connect(m_ssl);
int ret = SSL_do_handshake(m_ssl);
handleError(ret);
int rc;
while ((rc = SSL_do_handshake(m_ssl)) <= 0)
{
const int err = SSL_get_error(m_ssl, rc);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
{
// No data available yet
platform::getHandler()->wait();
}
else
{
handleError(rc);
}
// Check whether the time-out delay is elapsed
if (m_toHandler && m_toHandler->isTimeOut())
{
if (!m_toHandler->handleTimeOut())
throw exceptions::operation_timed_out();
m_toHandler->resetTimeOut();
}
}
}
catch (...)
{
@ -386,7 +436,7 @@ void TLSSocket_OpenSSL::handleError(int rc)
unsigned int TLSSocket_OpenSSL::getStatus() const
{
return m_wrapped->getStatus();
return m_status;
}
@ -396,23 +446,28 @@ unsigned int TLSSocket_OpenSSL::getStatus() const
// static
int TLSSocket_OpenSSL::bio_write(BIO* bio, const char* buf, int len)
{
BIO_clear_retry_flags(bio);
if (buf == NULL || len <= 0)
return 0;
return -1;
TLSSocket_OpenSSL *sok = reinterpret_cast <TLSSocket_OpenSSL*>(bio->ptr);
if (!bio->init || !sok)
return -1;
try
{
while (true)
const size_t n = sok->m_wrapped->sendRawNonBlocking
(reinterpret_cast <const byte_t*>(buf), len);
if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
{
const size_t n = sok->m_wrapped->sendRawNonBlocking
(reinterpret_cast <const byte_t*>(buf), len);
if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
continue;
return static_cast <int>(len);
BIO_set_retry_write(bio);
return -1;
}
return static_cast <int>(len);
}
catch (exception& e)
{
@ -426,23 +481,28 @@ int TLSSocket_OpenSSL::bio_write(BIO* bio, const char* buf, int len)
// static
int TLSSocket_OpenSSL::bio_read(BIO* bio, char* buf, int len)
{
BIO_clear_retry_flags(bio);
if (buf == NULL || len <= 0)
return 0;
return -1;
TLSSocket_OpenSSL *sok = reinterpret_cast <TLSSocket_OpenSSL*>(bio->ptr);
if (!bio->init || !sok)
return -1;
try
{
while (true)
const size_t n = sok->m_wrapped->receiveRaw
(reinterpret_cast <byte_t*>(buf), len);
if (n == 0 || sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
{
const size_t n = sok->m_wrapped->receiveRaw
(reinterpret_cast <byte_t*>(buf), len);
if (n == 0 && sok->m_wrapped->getStatus() & socket::STATUS_WOULDBLOCK)
continue;
return static_cast <int>(n);
BIO_set_retry_read(bio);
return -1;
}
return static_cast <int>(n);
}
catch (exception& e)
{
@ -461,29 +521,53 @@ int TLSSocket_OpenSSL::bio_puts(BIO* bio, const char* str)
// static
int TLSSocket_OpenSSL::bio_gets(BIO* /* bio */, char* /* buf */, int /* len */)
long TLSSocket_OpenSSL::bio_ctrl(BIO* bio, int cmd, long num, void* ptr)
{
return -1;
}
long ret = 1;
// static
long TLSSocket_OpenSSL::bio_ctrl(BIO* /* bio */, int cmd, long /* num */, void* /* ptr */)
{
if (cmd == BIO_CTRL_FLUSH)
switch (cmd)
{
// OpenSSL library needs this
return 1;
case BIO_CTRL_INFO:
ret = 0;
break;
case BIO_CTRL_GET_CLOSE:
ret = bio->shutdown;
break;
case BIO_CTRL_SET_CLOSE:
bio->shutdown = static_cast <int>(num);
break;
case BIO_CTRL_PENDING:
case BIO_CTRL_WPENDING:
ret = 0;
break;
case BIO_CTRL_DUP:
case BIO_CTRL_FLUSH:
ret = 1;
break;
default:
ret = 0;
break;
}
return 0;
return ret;
}
// static
int TLSSocket_OpenSSL::bio_create(BIO* bio)
{
bio->init = 1;
bio->init = 0;
bio->num = 0;
bio->ptr = NULL;
bio->flags = 0;
@ -498,9 +582,12 @@ int TLSSocket_OpenSSL::bio_destroy(BIO* bio)
if (bio == NULL)
return 0;
bio->ptr = NULL;
bio->init = 0;
bio->flags = 0;
if (bio->shutdown)
{
bio->ptr = NULL;
bio->init = 0;
bio->flags = 0;
}
return 1;
}

View File

@ -112,6 +112,8 @@ private:
SSL* m_ssl;
unsigned long m_status;
// Last exception thrown from C BIO functions
std::auto_ptr <std::exception> m_ex;
};