WebSocketHandshake.cpp [plain text]
#include "config.h"
#if ENABLE(WEB_SOCKETS)
#include "WebSocketHandshake.h"
#include "AtomicString.h"
#include "Cookie.h"
#include "CookieJar.h"
#include "Document.h"
#include "HTTPHeaderMap.h"
#include "KURL.h"
#include "Logging.h"
#include "ScriptExecutionContext.h"
#include "SecurityOrigin.h"
#include "StringBuilder.h"
#include <wtf/MD5.h>
#include <wtf/RandomNumber.h>
#include <wtf/StdLibExtras.h>
#include <wtf/StringExtras.h>
#include <wtf/Vector.h>
#include <wtf/text/CString.h>
namespace WebCore {
static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
static String extractResponseCode(const char* header, int len, size_t& lineLength)
{
const char* space1 = 0;
const char* space2 = 0;
const char* p;
lineLength = 0;
for (p = header; p - header < len; p++, lineLength++) {
if (*p == ' ') {
if (!space1)
space1 = p;
else if (!space2)
space2 = p;
} else if (*p == '\n')
break;
}
if (p - header == len)
return String();
if (!space1 || !space2)
return "";
return String(space1 + 1, space2 - space1 - 1);
}
static String resourceName(const KURL& url)
{
String name = url.path();
if (name.isEmpty())
name = "/";
if (!url.query().isNull())
name += "?" + url.query();
ASSERT(!name.isEmpty());
ASSERT(!name.contains(' '));
return name;
}
static String hostName(const KURL& url, bool secure)
{
ASSERT(url.protocolIs("wss") == secure);
StringBuilder builder;
builder.append(url.host().lower());
if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
builder.append(":");
builder.append(String::number(url.port()));
}
return builder.toString();
}
static String trimConsoleMessage(const char* p, size_t len)
{
String s = String(p, std::min<size_t>(len, 128));
if (len > 128)
s += "...";
return s;
}
static void generateSecWebSocketKey(uint32_t& number, String& key)
{
uint32_t space = static_cast<uint32_t>(WTF::randomNumber() * 12) + 1;
uint32_t max = 4294967295U / space;
number = static_cast<uint32_t>(WTF::randomNumber() * max);
uint32_t product = number * space;
String s = String::number(product);
int n = static_cast<int>(WTF::randomNumber() * 12) + 1;
DEFINE_STATIC_LOCAL(String, randomChars, (randomCharacterInSecWebSocketKey));
for (int i = 0; i < n; i++) {
int pos = static_cast<int>(WTF::randomNumber() * (s.length() + 1));
int chpos = static_cast<int>(WTF::randomNumber() * randomChars.length());
s.insert(randomChars.substring(chpos, 1), pos);
}
DEFINE_STATIC_LOCAL(String, spaceChar, (" "));
for (uint32_t i = 0; i < space; i++) {
int pos = static_cast<int>(WTF::randomNumber() * s.length() - 1) + 1;
s.insert(spaceChar, pos);
}
key = s;
}
static void generateKey3(unsigned char key3[8])
{
for (int i = 0; i < 8; i++)
key3[i] = WTF::randomNumber() * 256;
}
static void setChallengeNumber(unsigned char* buf, uint32_t number)
{
unsigned char* p = buf + 3;
for (int i = 0; i < 4; i++) {
*p = number & 0xFF;
--p;
number >>= 8;
}
}
static void generateExpectedChallengeResponse(uint32_t number1, uint32_t number2, unsigned char key3[8], unsigned char expectedChallenge[16])
{
unsigned char challenge[16];
setChallengeNumber(&challenge[0], number1);
setChallengeNumber(&challenge[4], number2);
memcpy(&challenge[8], key3, 8);
MD5 md5;
md5.addBytes(challenge, sizeof(challenge));
Vector<uint8_t, 16> digest = md5.checksum();
memcpy(expectedChallenge, digest.data(), 16);
}
WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
: m_url(url)
, m_clientProtocol(protocol)
, m_secure(m_url.protocolIs("wss"))
, m_context(context)
, m_mode(Incomplete)
{
uint32_t number1;
uint32_t number2;
generateSecWebSocketKey(number1, m_secWebSocketKey1);
generateSecWebSocketKey(number2, m_secWebSocketKey2);
generateKey3(m_key3);
generateExpectedChallengeResponse(number1, number2, m_key3, m_expectedChallengeResponse);
}
WebSocketHandshake::~WebSocketHandshake()
{
}
const KURL& WebSocketHandshake::url() const
{
return m_url;
}
void WebSocketHandshake::setURL(const KURL& url)
{
m_url = url.copy();
}
const String WebSocketHandshake::host() const
{
return m_url.host().lower();
}
const String& WebSocketHandshake::clientProtocol() const
{
return m_clientProtocol;
}
void WebSocketHandshake::setClientProtocol(const String& protocol)
{
m_clientProtocol = protocol;
}
bool WebSocketHandshake::secure() const
{
return m_secure;
}
String WebSocketHandshake::clientOrigin() const
{
return m_context->securityOrigin()->toString();
}
String WebSocketHandshake::clientLocation() const
{
StringBuilder builder;
builder.append(m_secure ? "wss" : "ws");
builder.append("://");
builder.append(hostName(m_url, m_secure));
builder.append(resourceName(m_url));
return builder.toString();
}
CString WebSocketHandshake::clientHandshakeMessage() const
{
StringBuilder builder;
builder.append("GET ");
builder.append(resourceName(m_url));
builder.append(" HTTP/1.1\r\n");
Vector<String> fields;
fields.append("Upgrade: WebSocket");
fields.append("Connection: Upgrade");
fields.append("Host: " + hostName(m_url, m_secure));
fields.append("Origin: " + clientOrigin());
if (!m_clientProtocol.isEmpty())
fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
KURL url = httpURLForAuthenticationAndCookies();
if (m_context->isDocument()) {
Document* document = static_cast<Document*>(m_context);
String cookie = cookieRequestHeaderFieldValue(document, url);
if (!cookie.isEmpty())
fields.append("Cookie: " + cookie);
}
fields.append("Sec-WebSocket-Key1: " + m_secWebSocketKey1);
fields.append("Sec-WebSocket-Key2: " + m_secWebSocketKey2);
for (size_t i = 0; i < fields.size(); i++) {
builder.append(fields[i]);
builder.append("\r\n");
}
builder.append("\r\n");
CString handshakeHeader = builder.toString().utf8();
char* characterBuffer = 0;
CString msg = CString::newUninitialized(handshakeHeader.length() + sizeof(m_key3), characterBuffer);
memcpy(characterBuffer, handshakeHeader.data(), handshakeHeader.length());
memcpy(characterBuffer + handshakeHeader.length(), m_key3, sizeof(m_key3));
return msg;
}
WebSocketHandshakeRequest WebSocketHandshake::clientHandshakeRequest() const
{
WebSocketHandshakeRequest request(m_url, clientOrigin(), m_clientProtocol);
KURL url = httpURLForAuthenticationAndCookies();
if (m_context->isDocument()) {
Document* document = static_cast<Document*>(m_context);
String cookie = cookieRequestHeaderFieldValue(document, url);
if (!cookie.isEmpty())
request.addExtraHeaderField("Cookie", cookie);
}
return request;
}
void WebSocketHandshake::reset()
{
m_mode = Incomplete;
m_wsOrigin = String();
m_wsLocation = String();
m_wsProtocol = String();
m_setCookie = String();
m_setCookie2 = String();
}
void WebSocketHandshake::clearScriptExecutionContext()
{
m_context = 0;
}
int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
{
m_mode = Incomplete;
size_t lineLength;
const String& code = extractResponseCode(header, len, lineLength);
if (code.isNull()) {
return -1;
}
if (code.isEmpty()) {
m_mode = Failed;
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, lineLength), 0, clientOrigin());
return len;
}
LOG(Network, "response code: %s", code.utf8().data());
if (code != "101") {
m_mode = Failed;
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected response code:" + code, 0, clientOrigin());
return len;
}
m_mode = Normal;
if (!strnstr(header, "\r\n\r\n", len)) {
m_mode = Incomplete;
return -1;
}
HTTPHeaderMap headers;
const char* headerFields = strnstr(header, "\r\n", len); ASSERT(headerFields);
headerFields += 2; const char* p = readHTTPHeaders(headerFields, header + len, &headers);
if (!p) {
LOG(Network, "readHTTPHeaders failed");
m_mode = Failed;
return len;
}
if (!processHeaders(headers) || !checkResponseHeaders()) {
LOG(Network, "header process failed");
m_mode = Failed;
return p - header;
}
if (len < static_cast<size_t>(p - header + sizeof(m_expectedChallengeResponse))) {
m_mode = Incomplete;
return -1;
}
if (memcmp(p, m_expectedChallengeResponse, sizeof(m_expectedChallengeResponse))) {
m_mode = Failed;
return (p - header) + sizeof(m_expectedChallengeResponse);
}
m_mode = Connected;
return (p - header) + sizeof(m_expectedChallengeResponse);
}
WebSocketHandshake::Mode WebSocketHandshake::mode() const
{
return m_mode;
}
const String& WebSocketHandshake::serverWebSocketOrigin() const
{
return m_wsOrigin;
}
void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin)
{
m_wsOrigin = webSocketOrigin;
}
const String& WebSocketHandshake::serverWebSocketLocation() const
{
return m_wsLocation;
}
void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation)
{
m_wsLocation = webSocketLocation;
}
const String& WebSocketHandshake::serverWebSocketProtocol() const
{
return m_wsProtocol;
}
void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol)
{
m_wsProtocol = webSocketProtocol;
}
const String& WebSocketHandshake::serverSetCookie() const
{
return m_setCookie;
}
void WebSocketHandshake::setServerSetCookie(const String& setCookie)
{
m_setCookie = setCookie;
}
const String& WebSocketHandshake::serverSetCookie2() const
{
return m_setCookie2;
}
void WebSocketHandshake::setServerSetCookie2(const String& setCookie2)
{
m_setCookie2 = setCookie2;
}
KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
{
KURL url = m_url.copy();
bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
return url;
}
const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end, HTTPHeaderMap* headers)
{
Vector<char> name;
Vector<char> value;
for (const char* p = start; p < end; p++) {
name.clear();
value.clear();
for (; p < end; p++) {
switch (*p) {
case '\r':
if (name.isEmpty()) {
if (p + 1 < end && *(p + 1) == '\n')
return p + 2;
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + trimConsoleMessage(p, end - p), 0, clientOrigin());
return 0;
}
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin());
return 0;
case '\n':
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin());
return 0;
case ':':
break;
default:
if (*p >= 0x41 && *p <= 0x5a)
name.append(*p + 0x20);
else
name.append(*p);
continue;
}
if (*p == ':') {
++p;
break;
}
}
for (; p < end && *p == 0x20; p++) { }
for (; p < end; p++) {
switch (*p) {
case '\r':
break;
case '\n':
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(value.data(), value.size()), 0, clientOrigin());
return 0;
default:
value.append(*p);
}
if (*p == '\r') {
++p;
break;
}
}
if (p >= end || *p != '\n') {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + trimConsoleMessage(p, end - p), 0, clientOrigin());
return 0;
}
AtomicString nameStr(String::fromUTF8(name.data(), name.size()));
String valueStr = String::fromUTF8(value.data(), value.size());
if (nameStr.isNull()) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header name", 0, clientOrigin());
return 0;
}
if (valueStr.isNull()) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header value", 0, clientOrigin());
return 0;
}
LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
headers->add(nameStr, valueStr);
}
ASSERT_NOT_REACHED();
return 0;
}
bool WebSocketHandshake::processHeaders(const HTTPHeaderMap& headers)
{
for (HTTPHeaderMap::const_iterator it = headers.begin(); it != headers.end(); ++it) {
switch (m_mode) {
case Normal:
if (it->first == "sec-websocket-origin")
m_wsOrigin = it->second;
else if (it->first == "sec-websocket-location")
m_wsLocation = it->second;
else if (it->first == "sec-websocket-protocol")
m_wsProtocol = it->second;
else if (it->first == "set-cookie")
m_setCookie = it->second;
else if (it->first == "set-cookie2")
m_setCookie2 = it->second;
continue;
case Incomplete:
case Failed:
case Connected:
ASSERT_NOT_REACHED();
}
ASSERT_NOT_REACHED();
}
return true;
}
bool WebSocketHandshake::checkResponseHeaders()
{
if (m_wsOrigin.isNull()) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-origin' header is missing", 0, clientOrigin());
return false;
}
if (m_wsLocation.isNull()) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-location' header is missing", 0, clientOrigin());
return false;
}
if (clientOrigin() != m_wsOrigin) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin());
return false;
}
if (clientLocation() != m_wsLocation) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin());
return false;
}
if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) {
m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin());
return false;
}
return true;
}
}
#endif // ENABLE(WEB_SOCKETS)