ws2_32: Avoid overflows in get_rcvsnd_timeo.

This commit is contained in:
Jacek Caban 2015-04-27 16:00:37 +02:00 committed by Alexandre Julliard
parent d15ca4edb9
commit 9e3a3f46fa
1 changed files with 21 additions and 18 deletions

View File

@ -1251,30 +1251,28 @@ static char *strdup_lower(const char *str)
/* Utility: get the SO_RCVTIMEO or SO_SNDTIMEO socket option /* Utility: get the SO_RCVTIMEO or SO_SNDTIMEO socket option
* from an fd and return the value converted to milli seconds * from an fd and return the value converted to milli seconds
* or -1 if there is an infinite time out */ * or 0 if there is an infinite time out */
static inline int get_rcvsnd_timeo( int fd, int optname) static inline INT64 get_rcvsnd_timeo( int fd, int optname)
{ {
struct timeval tv; struct timeval tv;
socklen_t len = sizeof(tv); socklen_t len = sizeof(tv);
int ret = getsockopt(fd, SOL_SOCKET, optname, &tv, &len); int res = getsockopt(fd, SOL_SOCKET, optname, &tv, &len);
if( ret >= 0) if (res < 0)
ret = tv.tv_sec * 1000 + tv.tv_usec / 1000; return 0;
if( ret <= 0 ) /* tv == {0,0} means infinite time out */ return (UINT64)tv.tv_sec * 1000 + tv.tv_usec / 1000;
return -1;
return ret;
} }
/* macro wrappers for portability */ /* macro wrappers for portability */
#ifdef SO_RCVTIMEO #ifdef SO_RCVTIMEO
#define GET_RCVTIMEO(fd) get_rcvsnd_timeo( (fd), SO_RCVTIMEO) #define GET_RCVTIMEO(fd) get_rcvsnd_timeo( (fd), SO_RCVTIMEO)
#else #else
#define GET_RCVTIMEO(fd) (-1) #define GET_RCVTIMEO(fd) (0)
#endif #endif
#ifdef SO_SNDTIMEO #ifdef SO_SNDTIMEO
#define GET_SNDTIMEO(fd) get_rcvsnd_timeo( (fd), SO_SNDTIMEO) #define GET_SNDTIMEO(fd) get_rcvsnd_timeo( (fd), SO_SNDTIMEO)
#else #else
#define GET_SNDTIMEO(fd) (-1) #define GET_SNDTIMEO(fd) (0)
#endif #endif
/* utility: given an fd, will block until one of the events occurs */ /* utility: given an fd, will block until one of the events occurs */
@ -5076,18 +5074,20 @@ static int WS2_sendto( SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount,
while (wsa->first_iovec < wsa->n_iovecs) while (wsa->first_iovec < wsa->n_iovecs)
{ {
struct pollfd pfd; struct pollfd pfd;
int timeout = GET_SNDTIMEO(fd); int poll_timeout = -1;
INT64 timeout = GET_SNDTIMEO(fd);
if (timeout != -1) if (timeout)
{ {
timeout -= GetTickCount() - timeout_start; timeout -= GetTickCount() - timeout_start;
if (timeout < 0) timeout = 0; if (timeout < 0) poll_timeout = 0;
else poll_timeout = timeout <= INT_MAX ? timeout : INT_MAX;
} }
pfd.fd = fd; pfd.fd = fd;
pfd.events = POLLOUT; pfd.events = POLLOUT;
if (!timeout || !poll( &pfd, 1, timeout )) if (!poll_timeout || !poll( &pfd, 1, poll_timeout ))
{ {
err = WSAETIMEDOUT; err = WSAETIMEDOUT;
goto error; /* msdn says a timeout in send is fatal */ goto error; /* msdn says a timeout in send is fatal */
@ -7130,18 +7130,21 @@ static int WS2_recv_base( SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount,
if ( is_blocking ) if ( is_blocking )
{ {
struct pollfd pfd; struct pollfd pfd;
int timeout = GET_RCVTIMEO(fd); int poll_timeout = -1;
if (timeout != -1) INT64 timeout = GET_RCVTIMEO(fd);
if (timeout)
{ {
timeout -= GetTickCount() - timeout_start; timeout -= GetTickCount() - timeout_start;
if (timeout < 0) timeout = 0; if (timeout < 0) poll_timeout = 0;
else poll_timeout = timeout <= INT_MAX ? timeout : INT_MAX;
} }
pfd.fd = fd; pfd.fd = fd;
pfd.events = POLLIN; pfd.events = POLLIN;
if (*lpFlags & WS_MSG_OOB) pfd.events |= POLLPRI; if (*lpFlags & WS_MSG_OOB) pfd.events |= POLLPRI;
if (!timeout || !poll( &pfd, 1, timeout )) if (!poll_timeout || !poll( &pfd, 1, poll_timeout ))
{ {
err = WSAETIMEDOUT; err = WSAETIMEDOUT;
/* a timeout is not fatal */ /* a timeout is not fatal */