diff --git a/CMakeLists.txt b/CMakeLists.txt index 8062bcfd50..5878540b66 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1001,6 +1001,8 @@ add_library(native STATIC ext/native/net/http_client.h ext/native/net/resolve.cpp ext/native/net/resolve.h + ext/native/net/sinks.cpp + ext/native/net/sinks.h ext/native/net/url.cpp ext/native/net/url.h ext/native/profiler/profiler.cpp diff --git a/Core/FileLoaders/CachingFileLoader.cpp b/Core/FileLoaders/CachingFileLoader.cpp index f33734aa49..0b34e2c7bb 100644 --- a/Core/FileLoaders/CachingFileLoader.cpp +++ b/Core/FileLoaders/CachingFileLoader.cpp @@ -68,11 +68,22 @@ void CachingFileLoader::Seek(s64 absolutePos) { } size_t CachingFileLoader::ReadAt(s64 absolutePos, size_t bytes, void *data) { + if (absolutePos >= filesize_) { + bytes = 0; + } else if (absolutePos + (s64)bytes >= filesize_) { + bytes = filesize_ - absolutePos; + } + size_t readSize = ReadFromCache(absolutePos, bytes, data); // While in case the cache size is too small for the entire read. while (readSize < bytes) { SaveIntoCache(absolutePos + readSize, bytes - readSize); - readSize += ReadFromCache(absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); + size_t bytesFromCache = ReadFromCache(absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); + readSize += bytesFromCache; + if (bytesFromCache == 0) { + // We can't read any more. + break; + } } StartReadAhead(absolutePos + readSize); diff --git a/Core/FileLoaders/DiskCachingFileLoader.cpp b/Core/FileLoaders/DiskCachingFileLoader.cpp index e346cf35fe..27379a9bc9 100644 --- a/Core/FileLoaders/DiskCachingFileLoader.cpp +++ b/Core/FileLoaders/DiskCachingFileLoader.cpp @@ -76,13 +76,24 @@ void DiskCachingFileLoader::Seek(s64 absolutePos) { size_t DiskCachingFileLoader::ReadAt(s64 absolutePos, size_t bytes, void *data) { size_t readSize; + if (absolutePos >= filesize_) { + bytes = 0; + } else if (absolutePos + (s64)bytes >= filesize_) { + bytes = filesize_ - absolutePos; + } + if (cache_ && cache_->IsValid()) { readSize = cache_->ReadFromCache(absolutePos, bytes, data); // While in case the cache size is too small for the entire read. while (readSize < bytes) { readSize += cache_->SaveIntoCache(backend_, absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); // If there are already-cached blocks afterward, we have to read them. - readSize += cache_->ReadFromCache(absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); + size_t bytesFromCache = cache_->ReadFromCache(absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); + readSize += bytesFromCache; + if (bytesFromCache == 0) { + // We can't read any more. + break; + } } } else { readSize = backend_->ReadAt(absolutePos, bytes, data); diff --git a/Core/FileLoaders/RamCachingFileLoader.cpp b/Core/FileLoaders/RamCachingFileLoader.cpp index 28eda51cdc..20598d631c 100644 --- a/Core/FileLoaders/RamCachingFileLoader.cpp +++ b/Core/FileLoaders/RamCachingFileLoader.cpp @@ -79,7 +79,12 @@ size_t RamCachingFileLoader::ReadAt(s64 absolutePos, size_t bytes, void *data) { // While in case the cache size is too small for the entire read. while (readSize < bytes) { SaveIntoCache(absolutePos + readSize, bytes - readSize); - readSize += ReadFromCache(absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); + size_t bytesFromCache = ReadFromCache(absolutePos + readSize, bytes - readSize, (u8 *)data + readSize); + readSize += bytesFromCache; + if (bytesFromCache == 0) { + // We can't read any more. + break; + } } } @@ -134,7 +139,7 @@ size_t RamCachingFileLoader::ReadFromCache(s64 pos, size_t bytes, void *data) { u8 *p = (u8 *)data; // Clamp bytes to what's actually available. - if (pos + bytes > filesize_) { + if (pos + (s64)bytes > filesize_) { // Should've been caught above, but just in case. if (pos >= filesize_) { return 0; diff --git a/Core/Loaders.cpp b/Core/Loaders.cpp index 3630589f04..db28d4d315 100644 --- a/Core/Loaders.cpp +++ b/Core/Loaders.cpp @@ -186,7 +186,7 @@ IdentifiedFileType Identify_File(FileLoader *fileLoader) { return FILETYPE_ARCHIVE_RAR; } else if (!strcasecmp(extension.c_str(),".r01")) { return FILETYPE_ARCHIVE_RAR; - } else if (!strcasecmp(extension.substr(1).c_str(), ".7z")) { + } else if (!extension.empty() && !strcasecmp(extension.substr(1).c_str(), ".7z")) { return FILETYPE_ARCHIVE_7Z; } return FILETYPE_UNKNOWN; diff --git a/Windows/GEDebugger/TabState.cpp b/Windows/GEDebugger/TabState.cpp index 0207ffcbda..3b1bc8b444 100644 --- a/Windows/GEDebugger/TabState.cpp +++ b/Windows/GEDebugger/TabState.cpp @@ -1007,10 +1007,14 @@ TabStateTexture::TabStateTexture(HINSTANCE _hInstance, HWND _hParent) } TabStateWatch::TabStateWatch(HINSTANCE _hInstance, HWND _hParent) - : TabStateValues(&watchList[0], 0, (LPCSTR)IDD_GEDBG_TAB_VALUES, _hInstance, _hParent) { + : TabStateValues(nullptr, 0, (LPCSTR)IDD_GEDBG_TAB_VALUES, _hInstance, _hParent) { } void TabStateWatch::Update() { - values->UpdateRows(&watchList[0], (int)watchList.size()); + if (watchList.empty()) { + values->UpdateRows(nullptr, 0); + } else { + values->UpdateRows(&watchList[0], (int)watchList.size()); + } TabStateValues::Update(); } diff --git a/ext/native/Android.mk b/ext/native/Android.mk index 1f3eea6530..87e0004273 100644 --- a/ext/native/Android.mk +++ b/ext/native/Android.mk @@ -62,6 +62,7 @@ LOCAL_SRC_FILES :=\ net/http_server.cpp \ net/http_headers.cpp \ net/resolve.cpp \ + net/sinks.cpp \ net/url.cpp \ profiler/profiler.cpp \ thread/executor.cpp \ diff --git a/ext/native/file/fd_util.cpp b/ext/native/file/fd_util.cpp index f8f9fa7e86..ecb0027e98 100644 --- a/ext/native/file/fd_util.cpp +++ b/ext/native/file/fd_util.cpp @@ -75,7 +75,7 @@ ssize_t Write(int fd, const std::string &str) { return WriteLine(fd, str.c_str(), str.size()); } -bool WaitUntilReady(int fd, double timeout) { +bool WaitUntilReady(int fd, double timeout, bool for_write) { struct timeval tv; tv.tv_sec = floor(timeout); tv.tv_usec = (timeout - floor(timeout)) * 1000000.0; @@ -84,7 +84,12 @@ bool WaitUntilReady(int fd, double timeout) { FD_ZERO(&fds); FD_SET(fd, &fds); // First argument to select is the highest socket in the set + 1. - int rval = select(fd + 1, &fds, NULL, NULL, &tv); + int rval; + if (for_write) { + rval = select(fd + 1, NULL, &fds, NULL, &tv); + } else { + rval = select(fd + 1, &fds, NULL, NULL, &tv); + } if (rval < 0) { // Error calling select. return false; @@ -115,7 +120,10 @@ void SetNonBlocking(int sock, bool non_blocking) { ELOG("Error setting socket nonblocking status"); } #else - WLOG("NonBlocking mode not supported on Win32"); + u_long val = non_blocking ? 1 : 0; + if (ioctlsocket(sock, FIONBIO, &val) != 0) { + ELOG("Error setting socket nonblocking status"); + } #endif } diff --git a/ext/native/file/fd_util.h b/ext/native/file/fd_util.h index a38e59b357..6d64e018ab 100644 --- a/ext/native/file/fd_util.h +++ b/ext/native/file/fd_util.h @@ -18,7 +18,7 @@ ssize_t Write(int fd, const std::string &str); // Returns true if the fd became ready, false if it didn't or // if there was another error. -bool WaitUntilReady(int fd, double timeout); +bool WaitUntilReady(int fd, double timeout, bool for_write = false); void SetNonBlocking(int fd, bool non_blocking); diff --git a/ext/native/native.vcxproj b/ext/native/native.vcxproj index f750063f7c..ba4fccf8a6 100644 --- a/ext/native/native.vcxproj +++ b/ext/native/native.vcxproj @@ -264,6 +264,7 @@ + @@ -723,6 +724,7 @@ + diff --git a/ext/native/native.vcxproj.filters b/ext/native/native.vcxproj.filters index 269ac5835d..c3226a5147 100644 --- a/ext/native/native.vcxproj.filters +++ b/ext/native/native.vcxproj.filters @@ -308,6 +308,9 @@ math + + net + @@ -745,6 +748,9 @@ thin3d + + net + diff --git a/ext/native/net/http_headers.cpp b/ext/native/net/http_headers.cpp index 0fdf628d3f..fdfaa2c1ea 100644 --- a/ext/native/net/http_headers.cpp +++ b/ext/native/net/http_headers.cpp @@ -1,11 +1,13 @@ #include "net/http_headers.h" +#include #include #include #include "base/logging.h" #include "base/stringutil.h" #include "file/fd_util.h" +#include "net/sinks.h" namespace http { @@ -39,6 +41,15 @@ bool RequestHeader::GetParamValue(const char *param_name, std::string *value) co return false; } +bool RequestHeader::GetOther(const char *name, std::string *value) const { + auto it = other.find(name); + if (it != other.end()) { + *value = it->second; + return true; + } + return false; +} + // Intended to be a mad fast parser. It's not THAT fast currently, there's still // things to optimize, but meh. int RequestHeader::ParseHttpHeader(const char *buffer) { @@ -56,7 +67,7 @@ int RequestHeader::ParseHttpHeader(const char *buffer) { buffer += 5; } else { method = UNSUPPORTED; - status = 501; + status = 405; return -1; } SkipSpace(&buffer); @@ -99,56 +110,52 @@ int RequestHeader::ParseHttpHeader(const char *buffer) { // The header is formatted as key: value. int key_len = colon - buffer; - char *key = new char[key_len + 1]; - strncpy(key, buffer, key_len); - key[key_len] = 0; - StringUpper(key, key_len); + const char *key = buffer; // Go to after the colon to get the value. buffer = colon + 1; SkipSpace(&buffer); int value_len = (int)strlen(buffer); - if (!strcmp(key, "USER-AGENT")) { + if (!strncasecmp(key, "User-Agent", key_len)) { user_agent = new char[value_len + 1]; memcpy(user_agent, buffer, value_len + 1); ILOG("user-agent: %s", user_agent); - } else if (!strcmp(key, "REFERER")) { + } else if (!strncasecmp(key, "Referer", key_len)) { referer = new char[value_len + 1]; memcpy(referer, buffer, value_len + 1); - } else if (!strcmp(key, "CONTENT-LENGTH")) { + } else if (!strncasecmp(key, "Content-Length", key_len)) { content_length = atoi(buffer); ILOG("Content-Length: %i", (int)content_length); + } else { + std::string key_str(key, key_len); + std::transform(key_str.begin(), key_str.end(), key_str.begin(), tolower); + other[key_str] = buffer; } - delete [] key; return 0; } -void RequestHeader::ParseHeaders(int fd) { - int line_count = 0; - // Loop through request headers. - while (true) { - if (!fd_util::WaitUntilReady(fd, 5.0)) { // Wait max 5 secs. - // Timed out or error. - ok = false; - return; - } - char buffer[1024]; - fd_util::ReadLine(fd, buffer, 1023); - StringTrimEndNonAlphaNum(buffer); - if (buffer[0] == '\0') - break; - ParseHttpHeader(buffer); - line_count++; - if (type == SIMPLE) { - // Done! - ILOG("Simple: Done parsing http request."); - break; - } - } - ILOG("finished parsing request."); - ok = line_count > 1; +void RequestHeader::ParseHeaders(net::InputSink *sink) { + int line_count = 0; + std::string line; + while (sink->ReadLine(line)) { + if (line.length() == 0) { + // Blank line, this means end of headers. + break; + } + + ParseHttpHeader(line.c_str()); + line_count++; + if (type == SIMPLE) { + // Done! + ILOG("Simple: Done parsing http request."); + break; + } + } + + ILOG("finished parsing request."); + ok = line_count > 1; } } // namespace http diff --git a/ext/native/net/http_headers.h b/ext/native/net/http_headers.h index d333e1f91a..470f589972 100644 --- a/ext/native/net/http_headers.h +++ b/ext/native/net/http_headers.h @@ -1,8 +1,14 @@ #ifndef _NET_HTTP_HTTP_HEADERS #define _NET_HTTP_HTTP_HEADERS +#include +#include #include "base/buffer.h" +namespace net { +class InputSink; +}; + namespace http { class RequestHeader { @@ -12,11 +18,13 @@ class RequestHeader { // Public variables since it doesn't make sense // to bother with accessors for all these. int status; + // Intentional misspelling. char *referer; char *user_agent; char *resource; char *params; int content_length; + std::unordered_map other; enum RequestType { SIMPLE, FULL, }; @@ -29,8 +37,9 @@ class RequestHeader { }; Method method; bool ok; - void ParseHeaders(int fd); + void ParseHeaders(net::InputSink *sink); bool GetParamValue(const char *param_name, std::string *value) const; + bool GetOther(const char *name, std::string *value) const; private: int ParseHttpHeader(const char *buffer); bool first_header_; diff --git a/ext/native/net/http_server.cpp b/ext/native/net/http_server.cpp index 68170251ea..c5a6ccd065 100644 --- a/ext/native/net/http_server.cpp +++ b/ext/native/net/http_server.cpp @@ -2,6 +2,7 @@ #ifdef _WIN32 +#define NOMINMAX #include #include #include @@ -15,8 +16,11 @@ #include /* inet (3) funtions */ #include /* misc. UNIX functions */ +#define closesocket close + #endif +#include #include #include @@ -25,50 +29,74 @@ #include "base/buffer.h" #include "file/fd_util.h" #include "net/http_server.h" +#include "net/sinks.h" #include "thread/executor.h" namespace http { +// Note: charset here helps prevent XSS. +const char *const DEFAULT_MIME_TYPE = "text/html; charset=utf-8"; + Request::Request(int fd) : fd_(fd) { - in_buffer_ = new Buffer; - out_buffer_ = new Buffer; - header_.ParseHeaders(fd_); + in_ = new net::InputSink(fd); + out_ = new net::OutputSink(fd); + header_.ParseHeaders(in_); - if (header_.ok) { - // Read the rest, too. - if (header_.content_length >= 0) { - in_buffer_->Read(fd_, header_.content_length); - } - ILOG("The request carried with it %i bytes", (int)in_buffer_->size()); - } else { - Close(); - } + if (header_.ok) { + ILOG("The request carried with it %i bytes", (int)header_.content_length); + } else { + Close(); + } } Request::~Request() { - Close(); + Close(); - CHECK(in_buffer_->empty()); - delete in_buffer_; - CHECK(out_buffer_->empty()); - delete out_buffer_; + CHECK(in_->Empty()); + delete in_; + CHECK(out_->Empty()); + delete out_; } -void Request::WriteHttpResponseHeader(int status, int size) const { - Buffer *buffer = out_buffer_; - buffer->Printf("HTTP/1.0 %d OK\r\n", status); - buffer->Append("Server: SuperDuperServer v0.1\r\n"); - buffer->Append("Content-Type: text/html\r\n"); - if (size >= 0) { - buffer->Printf("Content-Length: %i\r\n", size); - } - buffer->Append("\r\n"); +void Request::WriteHttpResponseHeader(int status, int64_t size, const char *mimeType, const char *otherHeaders) const { + const char *statusStr; + switch (status) { + case 200: statusStr = "OK"; break; + case 206: statusStr = "Partial Content"; break; + case 301: statusStr = "Moved Permanently"; break; + case 302: statusStr = "Found"; break; + case 304: statusStr = "Not Modified"; break; + case 400: statusStr = "Bad Request"; break; + case 403: statusStr = "Forbidden"; break; + case 404: statusStr = "Not Found"; break; + case 405: statusStr = "Method Not Allowed"; break; + case 406: statusStr = "Not Acceptable"; break; + case 410: statusStr = "Gone"; break; + case 416: statusStr = "Range Not Satisfiable"; break; + case 418: statusStr = "I'm a teapot"; break; + case 500: statusStr = "Internal Server Error"; break; + case 503: statusStr = "Service Unavailable"; break; + default: statusStr = "OK"; break; + } + + net::OutputSink *buffer = Out(); + buffer->Printf("HTTP/1.0 %03d %s\r\n", status, statusStr); + buffer->Push("Server: SuperDuperServer v0.1\r\n"); + buffer->Printf("Content-Type: %s\r\n", mimeType ? mimeType : DEFAULT_MIME_TYPE); + buffer->Push("Connection: close\r\n"); + if (size >= 0) { + buffer->Printf("Content-Length: %llu\r\n", size); + } + if (otherHeaders) { + buffer->Push(otherHeaders, (int)strlen(otherHeaders)); + } + buffer->Push("\r\n"); } void Request::WritePartial() const { CHECK(fd_); - out_buffer_->Flush(fd_); + out_->Flush(); } void Request::Write() { @@ -79,7 +107,7 @@ void Request::Write() { void Request::Close() { if (fd_) { - close(fd_); + closesocket(fd_); fd_ = 0; } } @@ -87,12 +115,17 @@ void Request::Close() { Server::Server(threading::Executor *executor) : port_(0), executor_(executor) { RegisterHandler("/", std::bind(&Server::HandleListing, this, placeholder::_1)); + SetFallbackHandler(std::bind(&Server::Handle404, this, placeholder::_1)); } void Server::RegisterHandler(const char *url_path, UrlHandlerFunc handler) { handlers_[std::string(url_path)] = handler; } +void Server::SetFallbackHandler(UrlHandlerFunc handler) { + fallback_ = handler; +} + bool Server::Run(int port) { ILOG("HTTP server started on port %i", port); port_ = port; @@ -139,31 +172,41 @@ void Server::HandleConnection(int conn_fd) { return; } HandleRequestDefault(request); - request.WritePartial(); + + // TODO: Way to mark the content body as read, read it here if never read. + // This allows the handler to stream if need be. + + // TODO: Could handle keep alive here. + request.Write(); } void Server::HandleRequest(const Request &request) { - HandleRequestDefault(request); + HandleRequestDefault(request); } void Server::HandleRequestDefault(const Request &request) { - // First, look through all handlers. If we got one, use it. - for (auto iter = handlers_.begin(); iter != handlers_.end(); ++iter) { - if (iter->first == request.resource()) { - (iter->second)(request); - return; - } - } - ILOG("No handler for '%s', falling back to 404.", request.resource()); - const char *payload = "404 not found\r\n"; - request.WriteHttpResponseHeader(404, (int)strlen(payload)); - request.out_buffer()->Append(payload); + // First, look through all handlers. If we got one, use it. + auto handler = handlers_.find(request.resource()); + if (handler != handlers_.end()) { + (handler->second)(request); + } else { + // Let's hit the 404 handler instead. + fallback_(request); + } +} + +void Server::Handle404(const Request &request) { + ILOG("No handler for '%s', falling back to 404.", request.resource()); + const char *payload = "404 not found\r\n"; + request.WriteHttpResponseHeader(404, (int)strlen(payload)); + request.Out()->Push(payload); } void Server::HandleListing(const Request &request) { - for (auto iter = handlers_.begin(); iter != handlers_.end(); ++iter) { - request.out_buffer()->Printf("%s", iter->first.c_str()); - } + request.WriteHttpResponseHeader(200, -1); + for (auto iter = handlers_.begin(); iter != handlers_.end(); ++iter) { + request.Out()->Printf("%s\n", iter->first.c_str()); + } } } // namespace http diff --git a/ext/native/net/http_server.h b/ext/native/net/http_server.h index 5d45722671..6ff92dc0f0 100644 --- a/ext/native/net/http_server.h +++ b/ext/native/net/http_server.h @@ -8,6 +8,11 @@ #include "net/http_headers.h" #include "thread/executor.h" +namespace net { +class InputSink; +class OutputSink; +}; + namespace http { class Request { @@ -19,12 +24,20 @@ class Request { return header_.resource; } + RequestHeader::Method Method() const { + return header_.method; + } + bool GetParamValue(const char *param_name, std::string *value) const { return header_.GetParamValue(param_name, value); } + // Use lowercase. + bool GetHeader(const char *name, std::string *value) const { + return header_.GetOther(name, value); + } - Buffer *in_buffer() const { return in_buffer_; } - Buffer *out_buffer() const { return out_buffer_; } + net::InputSink *In() const { return in_; } + net::OutputSink *Out() const { return out_; } // TODO: Remove, in favor of PartialWrite and friends. int fd() const { return fd_; } @@ -36,13 +49,13 @@ class Request { bool IsOK() const { return fd_ > 0; } // If size is negative, no Content-Length: line is written. - void WriteHttpResponseHeader(int status, int size = -1) const; + void WriteHttpResponseHeader(int status, int64_t size = -1, const char *mimeType = nullptr, const char *otherHeaders = nullptr) const; - private: - Buffer *in_buffer_; - Buffer *out_buffer_; - RequestHeader header_; - int fd_; +private: + net::InputSink *in_; + net::OutputSink *out_; + RequestHeader header_; + int fd_; }; // Register handlers on this class to serve stuff. @@ -59,6 +72,7 @@ class Server { bool Run(int port); void RegisterHandler(const char *url_path, UrlHandlerFunc handler); + void SetFallbackHandler(UrlHandlerFunc handler); // If you want to customize things at a lower level than just a simple path handler, // then inherit and override this. Implementations should forward to HandleRequestDefault @@ -75,10 +89,12 @@ class Server { // Neat built-in handlers that are tied to the server. void HandleListing(const Request &request); + void Handle404(const Request &request); int port_; UrlHandlerMap handlers_; + UrlHandlerFunc fallback_; threading::Executor *executor_; }; diff --git a/ext/native/net/sinks.cpp b/ext/native/net/sinks.cpp new file mode 100644 index 0000000000..397ff76ca4 --- /dev/null +++ b/ext/native/net/sinks.cpp @@ -0,0 +1,373 @@ +#pragma optimize("", off) + +#ifdef _WIN32 + +#define NOMINMAX +#include +#include +#include + +#else + +#include /* socket definitions */ +#include /* socket types */ +#include /* for waitpid() */ +#include /* struct sockaddr_in */ +#include /* inet (3) funtions */ +#include /* misc. UNIX functions */ + +#endif + +#include +#include + +#include "base/logging.h" +#include "net/sinks.h" +#include "file/fd_util.h" + +namespace net { + +InputSink::InputSink(size_t fd) : fd_(fd), read_(0), write_(0), valid_(0) { + fd_util::SetNonBlocking((int)fd_, true); +} + +bool InputSink::ReadLineWithEnding(std::string &s) { + size_t newline = FindNewline(); + if (newline == BUFFER_SIZE) { + Block(); + newline = FindNewline(); + } + if (newline == BUFFER_SIZE) { + // Timed out. + return false; + } + + s.resize(newline + 1); + if (read_ + newline + 1 > BUFFER_SIZE) { + // Need to do two reads. + size_t chunk1 = BUFFER_SIZE - read_; + size_t chunk2 = read_ + newline + 1 - BUFFER_SIZE; + memcpy(&s[0], buf_ + read_, chunk1); + memcpy(&s[chunk1], buf_, chunk2); + } else { + memcpy(&s[0], buf_ + read_, newline + 1); + } + AccountDrain(newline + 1); + + return true; +} + +std::string InputSink::ReadLineWithEnding() { + std::string s; + ReadLineWithEnding(s); + return s; +} + +bool InputSink::ReadLine(std::string &s) { + bool result = ReadLineWithEnding(s); + if (result) { + size_t l = s.length(); + if (l >= 2 && s[l - 2] == '\r' && s[l - 1] == '\n') { + s.resize(l - 2); + } else if (l >= 1 && s[l - 1] == '\n') { + s.resize(l - 1); + } + } + return result; +} + +std::string InputSink::ReadLine() { + std::string s; + ReadLine(s); + return s; +} + +size_t InputSink::FindNewline() const { + // Technically, \r\n, but most parsers are lax... let's follow suit. + size_t until_end = std::min(valid_, BUFFER_SIZE - read_); + for (const char *p = buf_ + read_, *end = buf_ + read_ + until_end; p < end; ++p) { + if (*p == '\n') { + return p - (buf_ + read_); + } + } + + // Were there more bytes wrapped around? + if (read_ + valid_ > BUFFER_SIZE) { + size_t wrapped = read_ + valid_ - BUFFER_SIZE; + for (const char *p = buf_, *end = buf_ + wrapped; p < end; ++p) { + if (*p == '\n') { + // Offset by the skipped portion before wrapping. + return (p - buf_) + until_end; + } + } + } + + // Never found, return an invalid position to indicate. + return BUFFER_SIZE; +} + +bool InputSink::TakeExact(char *buf, size_t bytes) { + while (bytes > 0) { + size_t drained = TakeAtMost(buf, bytes); + buf += drained; + bytes -= drained; + + if (drained == 0) { + if (!Block()) { + // Timed out reading more bytes. + return false; + } + } + } + + return true; +} + +size_t InputSink::TakeAtMost(char *buf, size_t bytes) { + Fill(); + + // The least of: contiguous to read, actually populated in buffer, and wanted. + size_t avail = std::min(std::min(BUFFER_SIZE - read_, valid_), bytes); + + if (avail != 0) { + memcpy(buf, buf_ + read_, avail); + AccountDrain(avail); + } + + return avail; +} + +bool InputSink::Skip(size_t bytes) { + while (bytes > 0) { + size_t drained = std::min(valid_, bytes); + AccountDrain(drained); + bytes -= drained; + + // Nothing left to read? Get more. + if (drained == 0) { + if (!Block()) { + // Timed out reading more bytes. + return false; + } + } + } + + return true; +} + +void InputSink::Fill() { + // Avoid small reads if possible. + if (BUFFER_SIZE - valid_ > PRESSURE) { + // Whatever isn't valid and follows write_ is what's available. + size_t avail = BUFFER_SIZE - std::max(write_, valid_); + + int bytes = recv(fd_, buf_ + write_, (int)avail, 0); + AccountFill(bytes); + } +} + +bool InputSink::Block() { + if (!fd_util::WaitUntilReady((int)fd_, 5.0)) { + return false; + } + + Fill(); + return true; +} + +void InputSink::AccountFill(int bytes) { + if (bytes < 0) { + ELOG("Error reading from socket"); + return; + } + + // Okay, move forward (might be by zero.) + valid_ += bytes; + write_ += bytes; + if (write_ >= BUFFER_SIZE) { + write_ -= BUFFER_SIZE; + } +} + +void InputSink::AccountDrain(size_t bytes) { + valid_ -= bytes; + read_ += bytes; + if (read_ >= BUFFER_SIZE) { + read_ -= BUFFER_SIZE; + } +} + +bool InputSink::Empty() { + return valid_ == 0; +} + +OutputSink::OutputSink(size_t fd) : fd_(fd), read_(0), write_(0), valid_(0) { + fd_util::SetNonBlocking((int)fd_, true); +} + +bool OutputSink::Push(const std::string &s) { + return Push(&s[0], s.length()); +} + +bool OutputSink::Push(const char *buf, size_t bytes) { + while (bytes > 0) { + size_t pushed = PushAtMost(buf, bytes); + buf += pushed; + bytes -= pushed; + + if (pushed == 0) { + if (!Block()) { + // We couldn't write all the bytes. + return false; + } + } + } + + return true; +} + +bool OutputSink::PushCRLF(const std::string &s) { + if (Push(s)) { + return Push("r\n", 2); + } + return false; +} + +size_t OutputSink::PushAtMost(const char *buf, size_t bytes) { + Drain(); + + if (valid_ == 0 && bytes > PRESSURE) { + // Special case for pushing larger buffers: let's try to send directly. + int sentBytes = send(fd_, buf, (int)bytes, 0); + // If it was 0 or EWOULDBLOCK, that's fine, we'll enqueue as we can. + if (sentBytes > 0) { + return sentBytes; + } + } + + // Look for contiguous free space after write_ that's valid. + size_t avail = std::min(BUFFER_SIZE - std::max(write_, valid_), bytes); + + if (avail != 0) { + memcpy(buf_ + write_, buf, avail); + AccountPush(avail); + } + + return avail; +} + + +bool OutputSink::Printf(const char *fmt, ...) { + // Let's start by checking how much space we have. + size_t avail = BUFFER_SIZE - std::max(write_, valid_); + + va_list args; + va_start(args, fmt); + // Make a backup in case we don't have sufficient space. + va_list backup; + va_copy(backup, args); + + bool success = true; + + int result = vsnprintf(buf_ + write_, avail, fmt, args); + if (result >= avail) { + // There wasn't enough space. Let's use a buffer instead. + // This could be caused by wraparound. + char temp[BUFFER_SIZE]; + result = vsnprintf(temp, BUFFER_SIZE, fmt, args); + + if (result < BUFFER_SIZE && result > 0) { + // In case it did return the null terminator. + if (temp[result - 1] == '\0') { + result--; + } + + success = Push(temp, result); + // We've written so there's nothing more. + result = 0; + } + } + va_end(args); + va_end(backup); + + // Okay, did we actually write? + if (result >= avail) { + // This means the result string was too big for the buffer. + ELOG("Not enough space to format output."); + return false; + } else if (result < 0) { + ELOG("vsnprintf failed."); + return false; + } + + if (result > 0) { + AccountPush(result); + } + + return success; +} + +bool OutputSink::Block() { + if (!fd_util::WaitUntilReady((int)fd_, 5.0, true)) { + return false; + } + + Drain(); + return true; +} + +bool OutputSink::Flush() { + while (valid_ > 0) { + size_t avail = std::min(BUFFER_SIZE - read_, valid_); + + int bytes = send(fd_, buf_ + read_, (int)avail, 0); + AccountDrain(bytes); + + if (bytes == 0) { + // This may also drain. Either way, keep looping. + if (!Block()) { + return false; + } + } + } + + return true; +} + +void OutputSink::Drain() { + // Avoid small reads if possible. + if (valid_ > PRESSURE) { + // Let's just do contiguous valid. + size_t avail = std::min(BUFFER_SIZE - read_, valid_); + + int bytes = send(fd_, buf_ + read_, (int)avail, 0); + AccountDrain(bytes); + } +} + +void OutputSink::AccountPush(size_t bytes) { + valid_ += bytes; + write_ += bytes; + if (write_ >= BUFFER_SIZE) { + write_ -= BUFFER_SIZE; + } +} + +void OutputSink::AccountDrain(int bytes) { + if (bytes < 0) { + ELOG("Error writing to socket"); + return; + } + + valid_ -= bytes; + read_ += bytes; + if (read_ >= BUFFER_SIZE) { + read_ -= BUFFER_SIZE; + } +} + +bool OutputSink::Empty() { + return valid_ == 0; +} + +}; diff --git a/ext/native/net/sinks.h b/ext/native/net/sinks.h new file mode 100644 index 0000000000..0138d74a6c --- /dev/null +++ b/ext/native/net/sinks.h @@ -0,0 +1,72 @@ +#pragma once + +#include + +namespace net { + +class InputSink { +public: + InputSink(size_t fd); + + bool ReadLine(std::string &s); + std::string ReadLine(); + bool ReadLineWithEnding(std::string &s); + std::string ReadLineWithEnding(); + + // Read exactly this number of bytes, or fail. + bool TakeExact(char *buf, size_t bytes); + // Read whatever is convenient (may even return 0 bytes when there's more coming eventually.) + size_t TakeAtMost(char *buf, size_t bytes); + // Skip exactly this number of bytes, or fail. + bool Skip(size_t bytes); + + bool Empty(); + +private: + void Fill(); + bool Block(); + void AccountFill(int bytes); + void AccountDrain(size_t bytes); + size_t FindNewline() const; + + static const size_t BUFFER_SIZE = 32 * 1024; + static const size_t PRESSURE = 8 * 1024; + + size_t fd_; + char buf_[BUFFER_SIZE]; + size_t read_; + size_t write_; + size_t valid_; +}; + +class OutputSink { +public: + OutputSink(size_t fd); + + bool Push(const std::string &s); + bool Push(const char *buf, size_t bytes); + size_t PushAtMost(const char *buf, size_t bytes); + bool PushCRLF(const std::string &s); + bool Printf(const char *fmt, ...); + + bool Flush(); + + bool Empty(); + +private: + void Drain(); + bool Block(); + void AccountPush(size_t bytes); + void AccountDrain(int bytes); + + static const size_t BUFFER_SIZE = 32 * 1024; + static const size_t PRESSURE = 8 * 1024; + + size_t fd_; + char buf_[BUFFER_SIZE]; + size_t read_; + size_t write_; + size_t valid_; +}; + +};