Updates:
- Address PR feedback. - Add SecureTransport backend for macOS.
This commit is contained in:
parent
98685d48e3
commit
0e191c2711
|
@ -64,8 +64,9 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF)
|
||||||
CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF)
|
CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF)
|
||||||
|
|
||||||
set(DEFAULT_ENABLE_OPENSSL ON)
|
set(DEFAULT_ENABLE_OPENSSL ON)
|
||||||
if (ANDROID OR WIN32)
|
if (ANDROID OR WIN32 OR APPLE)
|
||||||
# - Windows defaults to the Schannel backend.
|
# - Windows defaults to the Schannel backend.
|
||||||
|
# - macOS defaults to the SecureTransport backend.
|
||||||
# - Android currently has no SSL backend as the NDK doesn't include any SSL
|
# - Android currently has no SSL backend as the NDK doesn't include any SSL
|
||||||
# library; a proper 'native' backend would have to go through Java.
|
# library; a proper 'native' backend would have to go through Java.
|
||||||
# But you can force builds for those platforms to use OpenSSL if you have
|
# But you can force builds for those platforms to use OpenSSL if you have
|
||||||
|
|
|
@ -868,6 +868,10 @@ if(ENABLE_OPENSSL)
|
||||||
target_sources(core PRIVATE
|
target_sources(core PRIVATE
|
||||||
hle/service/ssl/ssl_backend_openssl.cpp)
|
hle/service/ssl/ssl_backend_openssl.cpp)
|
||||||
target_link_libraries(core PRIVATE OpenSSL::SSL)
|
target_link_libraries(core PRIVATE OpenSSL::SSL)
|
||||||
|
elseif (APPLE)
|
||||||
|
target_sources(core PRIVATE
|
||||||
|
hle/service/ssl/ssl_backend_securetransport.cpp)
|
||||||
|
target_link_libraries(core PRIVATE "-framework Security")
|
||||||
elseif (WIN32)
|
elseif (WIN32)
|
||||||
target_sources(core PRIVATE
|
target_sources(core PRIVATE
|
||||||
hle/service/ssl/ssl_backend_schannel.cpp)
|
hle/service/ssl/ssl_backend_schannel.cpp)
|
||||||
|
|
|
@ -443,15 +443,28 @@ void BSD::Close(HLERequestContext& ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BSD::DuplicateSocket(HLERequestContext& ctx) {
|
void BSD::DuplicateSocket(HLERequestContext& ctx) {
|
||||||
IPC::RequestParser rp{ctx};
|
struct InputParameters {
|
||||||
const s32 fd = rp.Pop<s32>();
|
s32 fd;
|
||||||
[[maybe_unused]] const u64 unused = rp.Pop<u64>();
|
u64 reserved;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(InputParameters) == 0x10);
|
||||||
|
|
||||||
Expected<s32, Errno> res = DuplicateSocketImpl(fd);
|
struct OutputParameters {
|
||||||
|
s32 ret;
|
||||||
|
Errno bsd_errno;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(OutputParameters) == 0x8);
|
||||||
|
|
||||||
|
IPC::RequestParser rp{ctx};
|
||||||
|
auto input = rp.PopRaw<InputParameters>();
|
||||||
|
|
||||||
|
Expected<s32, Errno> res = DuplicateSocketImpl(input.fd);
|
||||||
IPC::ResponseBuilder rb{ctx, 4};
|
IPC::ResponseBuilder rb{ctx, 4};
|
||||||
rb.Push(ResultSuccess);
|
rb.Push(ResultSuccess);
|
||||||
rb.Push(res.value_or(0)); // ret
|
rb.PushRaw(OutputParameters{
|
||||||
rb.Push(res ? 0 : static_cast<s32>(res.error())); // bsd errno
|
.ret = res.value_or(0),
|
||||||
|
.bsd_errno = res ? Errno::SUCCESS : res.error(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void BSD::EventFd(HLERequestContext& ctx) {
|
void BSD::EventFd(HLERequestContext& ctx) {
|
||||||
|
|
|
@ -131,14 +131,15 @@ static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::Add
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
|
static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
|
||||||
struct Parameters {
|
struct InputParameters {
|
||||||
u8 use_nsd_resolve;
|
u8 use_nsd_resolve;
|
||||||
u32 cancel_handle;
|
u32 cancel_handle;
|
||||||
u64 process_id;
|
u64 process_id;
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(InputParameters) == 0x10);
|
||||||
|
|
||||||
IPC::RequestParser rp{ctx};
|
IPC::RequestParser rp{ctx};
|
||||||
const auto parameters = rp.PopRaw<Parameters>();
|
const auto parameters = rp.PopRaw<InputParameters>();
|
||||||
|
|
||||||
LOG_WARNING(
|
LOG_WARNING(
|
||||||
Service,
|
Service,
|
||||||
|
@ -164,21 +165,39 @@ static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestConte
|
||||||
void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
|
void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
|
||||||
auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
|
auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
|
||||||
|
|
||||||
|
struct OutputParameters {
|
||||||
|
NetDbError netdb_error;
|
||||||
|
Errno bsd_errno;
|
||||||
|
u32 data_size;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(OutputParameters) == 0xc);
|
||||||
|
|
||||||
IPC::ResponseBuilder rb{ctx, 5};
|
IPC::ResponseBuilder rb{ctx, 5};
|
||||||
rb.Push(ResultSuccess);
|
rb.Push(ResultSuccess);
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code
|
rb.PushRaw(OutputParameters{
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
|
.netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
|
||||||
rb.Push(data_size); // serialized size
|
.bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
|
||||||
|
.data_size = data_size,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
|
void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
|
||||||
auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
|
auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
|
||||||
|
|
||||||
|
struct OutputParameters {
|
||||||
|
u32 data_size;
|
||||||
|
NetDbError netdb_error;
|
||||||
|
Errno bsd_errno;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(OutputParameters) == 0xc);
|
||||||
|
|
||||||
IPC::ResponseBuilder rb{ctx, 5};
|
IPC::ResponseBuilder rb{ctx, 5};
|
||||||
rb.Push(ResultSuccess);
|
rb.Push(ResultSuccess);
|
||||||
rb.Push(data_size); // serialized size
|
rb.PushRaw(OutputParameters{
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code
|
.data_size = data_size,
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
|
.netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
|
||||||
|
.bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
|
static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
|
||||||
|
@ -221,14 +240,15 @@ static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& v
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
|
static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
|
||||||
struct Parameters {
|
struct InputParameters {
|
||||||
u8 use_nsd_resolve;
|
u8 use_nsd_resolve;
|
||||||
u32 cancel_handle;
|
u32 cancel_handle;
|
||||||
u64 process_id;
|
u64 process_id;
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(InputParameters) == 0x10);
|
||||||
|
|
||||||
IPC::RequestParser rp{ctx};
|
IPC::RequestParser rp{ctx};
|
||||||
const auto parameters = rp.PopRaw<Parameters>();
|
const auto parameters = rp.PopRaw<InputParameters>();
|
||||||
|
|
||||||
LOG_WARNING(
|
LOG_WARNING(
|
||||||
Service,
|
Service,
|
||||||
|
@ -264,23 +284,42 @@ static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext
|
||||||
void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
|
void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
|
||||||
auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
|
auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
|
||||||
|
|
||||||
|
struct OutputParameters {
|
||||||
|
Errno bsd_errno;
|
||||||
|
GetAddrInfoError gai_error;
|
||||||
|
u32 data_size;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(OutputParameters) == 0xc);
|
||||||
|
|
||||||
IPC::ResponseBuilder rb{ctx, 5};
|
IPC::ResponseBuilder rb{ctx, 5};
|
||||||
rb.Push(ResultSuccess);
|
rb.Push(ResultSuccess);
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
|
rb.PushRaw(OutputParameters{
|
||||||
rb.Push(static_cast<s32>(emu_gai_err)); // getaddrinfo error code
|
.bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
|
||||||
rb.Push(data_size); // serialized size
|
.gai_error = emu_gai_err,
|
||||||
|
.data_size = data_size,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
|
void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
|
||||||
// Additional options are ignored
|
// Additional options are ignored
|
||||||
auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
|
auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
|
||||||
|
|
||||||
|
struct OutputParameters {
|
||||||
|
u32 data_size;
|
||||||
|
GetAddrInfoError gai_error;
|
||||||
|
NetDbError netdb_error;
|
||||||
|
Errno bsd_errno;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(OutputParameters) == 0x10);
|
||||||
|
|
||||||
IPC::ResponseBuilder rb{ctx, 6};
|
IPC::ResponseBuilder rb{ctx, 6};
|
||||||
rb.Push(ResultSuccess);
|
rb.Push(ResultSuccess);
|
||||||
rb.Push(data_size); // serialized size
|
rb.PushRaw(OutputParameters{
|
||||||
rb.Push(static_cast<s32>(emu_gai_err)); // getaddrinfo error code
|
.data_size = data_size,
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code
|
.gai_error = emu_gai_err,
|
||||||
rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno
|
.netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
|
||||||
|
.bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
|
void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ public:
|
||||||
std::shared_ptr<SslContextSharedData>& shared_data,
|
std::shared_ptr<SslContextSharedData>& shared_data,
|
||||||
std::unique_ptr<SSLConnectionBackend>&& backend)
|
std::unique_ptr<SSLConnectionBackend>&& backend)
|
||||||
: ServiceFramework{system_, "ISslConnection"}, ssl_version{version},
|
: ServiceFramework{system_, "ISslConnection"}, ssl_version{version},
|
||||||
shared_data_{shared_data}, backend_{std::move(backend)} {
|
shared_data{shared_data}, backend{std::move(backend)} {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
static const FunctionInfo functions[] = {
|
static const FunctionInfo functions[] = {
|
||||||
{0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
|
{0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
|
||||||
|
@ -112,10 +112,10 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
~ISslConnection() {
|
~ISslConnection() {
|
||||||
shared_data_->connection_count--;
|
shared_data->connection_count--;
|
||||||
if (fd_to_close_.has_value()) {
|
if (fd_to_close.has_value()) {
|
||||||
const s32 fd = *fd_to_close_;
|
const s32 fd = *fd_to_close;
|
||||||
if (!do_not_close_socket_) {
|
if (!do_not_close_socket) {
|
||||||
LOG_ERROR(Service_SSL,
|
LOG_ERROR(Service_SSL,
|
||||||
"do_not_close_socket was changed after setting socket; is this right?");
|
"do_not_close_socket was changed after setting socket; is this right?");
|
||||||
} else {
|
} else {
|
||||||
|
@ -132,30 +132,30 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SslVersion ssl_version;
|
SslVersion ssl_version;
|
||||||
std::shared_ptr<SslContextSharedData> shared_data_;
|
std::shared_ptr<SslContextSharedData> shared_data;
|
||||||
std::unique_ptr<SSLConnectionBackend> backend_;
|
std::unique_ptr<SSLConnectionBackend> backend;
|
||||||
std::optional<int> fd_to_close_;
|
std::optional<int> fd_to_close;
|
||||||
bool do_not_close_socket_ = false;
|
bool do_not_close_socket = false;
|
||||||
bool get_server_cert_chain_ = false;
|
bool get_server_cert_chain = false;
|
||||||
std::shared_ptr<Network::SocketBase> socket_;
|
std::shared_ptr<Network::SocketBase> socket;
|
||||||
bool did_set_host_name_ = false;
|
bool did_set_host_name = false;
|
||||||
bool did_handshake_ = false;
|
bool did_handshake = false;
|
||||||
|
|
||||||
ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
|
ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
|
||||||
LOG_DEBUG(Service_SSL, "called, fd={}", fd);
|
LOG_DEBUG(Service_SSL, "called, fd={}", fd);
|
||||||
ASSERT(!did_handshake_);
|
ASSERT(!did_handshake);
|
||||||
auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
|
auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
|
||||||
ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
|
||||||
s32 ret_fd;
|
s32 ret_fd;
|
||||||
// Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
|
// Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
|
||||||
if (do_not_close_socket_) {
|
if (do_not_close_socket) {
|
||||||
auto res = bsd->DuplicateSocketImpl(fd);
|
auto res = bsd->DuplicateSocketImpl(fd);
|
||||||
if (!res.has_value()) {
|
if (!res.has_value()) {
|
||||||
LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
|
LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
|
||||||
return ResultInvalidSocket;
|
return ResultInvalidSocket;
|
||||||
}
|
}
|
||||||
fd = *res;
|
fd = *res;
|
||||||
fd_to_close_ = fd;
|
fd_to_close = fd;
|
||||||
ret_fd = fd;
|
ret_fd = fd;
|
||||||
} else {
|
} else {
|
||||||
ret_fd = -1;
|
ret_fd = -1;
|
||||||
|
@ -165,34 +165,34 @@ private:
|
||||||
LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
|
LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
|
||||||
return ResultInvalidSocket;
|
return ResultInvalidSocket;
|
||||||
}
|
}
|
||||||
socket_ = std::move(*sock);
|
socket = std::move(*sock);
|
||||||
backend_->SetSocket(socket_);
|
backend->SetSocket(socket);
|
||||||
return ret_fd;
|
return ret_fd;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetHostNameImpl(const std::string& hostname) {
|
Result SetHostNameImpl(const std::string& hostname) {
|
||||||
LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
|
LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
|
||||||
ASSERT(!did_handshake_);
|
ASSERT(!did_handshake);
|
||||||
Result res = backend_->SetHostName(hostname);
|
Result res = backend->SetHostName(hostname);
|
||||||
if (res == ResultSuccess) {
|
if (res == ResultSuccess) {
|
||||||
did_set_host_name_ = true;
|
did_set_host_name = true;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetVerifyOptionImpl(u32 option) {
|
Result SetVerifyOptionImpl(u32 option) {
|
||||||
ASSERT(!did_handshake_);
|
ASSERT(!did_handshake);
|
||||||
LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
|
LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetIOModeImpl(u32 _mode) {
|
Result SetIoModeImpl(u32 input_mode) {
|
||||||
auto mode = static_cast<IoMode>(_mode);
|
auto mode = static_cast<IoMode>(input_mode);
|
||||||
ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
|
ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
|
||||||
ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; });
|
ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
|
||||||
|
|
||||||
const bool non_block = mode == IoMode::NonBlocking;
|
const bool non_block = mode == IoMode::NonBlocking;
|
||||||
const Network::Errno error = socket_->SetNonBlock(non_block);
|
const Network::Errno error = socket->SetNonBlock(non_block);
|
||||||
if (error != Network::Errno::SUCCESS) {
|
if (error != Network::Errno::SUCCESS) {
|
||||||
LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
|
LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
|
||||||
}
|
}
|
||||||
|
@ -200,18 +200,18 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetSessionCacheModeImpl(u32 mode) {
|
Result SetSessionCacheModeImpl(u32 mode) {
|
||||||
ASSERT(!did_handshake_);
|
ASSERT(!did_handshake);
|
||||||
LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
|
LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result DoHandshakeImpl() {
|
Result DoHandshakeImpl() {
|
||||||
ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; });
|
ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
|
||||||
ASSERT_OR_EXECUTE_MSG(
|
ASSERT_OR_EXECUTE_MSG(
|
||||||
did_set_host_name_, { return ResultInternalError; },
|
did_set_host_name, { return ResultInternalError; },
|
||||||
"Expected SetHostName before DoHandshake");
|
"Expected SetHostName before DoHandshake");
|
||||||
Result res = backend_->DoHandshake();
|
Result res = backend->DoHandshake();
|
||||||
did_handshake_ = res.IsSuccess();
|
did_handshake = res.IsSuccess();
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ private:
|
||||||
u32 size;
|
u32 size;
|
||||||
u32 offset;
|
u32 offset;
|
||||||
};
|
};
|
||||||
if (!get_server_cert_chain_) {
|
if (!get_server_cert_chain) {
|
||||||
// Just return the first one, unencoded.
|
// Just return the first one, unencoded.
|
||||||
ASSERT_OR_EXECUTE_MSG(
|
ASSERT_OR_EXECUTE_MSG(
|
||||||
!certs.empty(), { return {}; }, "Should be at least one server cert");
|
!certs.empty(), { return {}; }, "Should be at least one server cert");
|
||||||
|
@ -248,9 +248,9 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<u8>> ReadImpl(size_t size) {
|
ResultVal<std::vector<u8>> ReadImpl(size_t size) {
|
||||||
ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
|
||||||
std::vector<u8> res(size);
|
std::vector<u8> res(size);
|
||||||
ResultVal<size_t> actual = backend_->Read(res);
|
ResultVal<size_t> actual = backend->Read(res);
|
||||||
if (actual.Failed()) {
|
if (actual.Failed()) {
|
||||||
return actual.Code();
|
return actual.Code();
|
||||||
}
|
}
|
||||||
|
@ -259,8 +259,8 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> WriteImpl(std::span<const u8> data) {
|
ResultVal<size_t> WriteImpl(std::span<const u8> data) {
|
||||||
ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
|
||||||
return backend_->Write(data);
|
return backend->Write(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<s32> PendingImpl() {
|
ResultVal<s32> PendingImpl() {
|
||||||
|
@ -295,7 +295,7 @@ private:
|
||||||
void SetIoMode(HLERequestContext& ctx) {
|
void SetIoMode(HLERequestContext& ctx) {
|
||||||
IPC::RequestParser rp{ctx};
|
IPC::RequestParser rp{ctx};
|
||||||
const u32 mode = rp.Pop<u32>();
|
const u32 mode = rp.Pop<u32>();
|
||||||
const Result res = SetIOModeImpl(mode);
|
const Result res = SetIoModeImpl(mode);
|
||||||
IPC::ResponseBuilder rb{ctx, 2};
|
IPC::ResponseBuilder rb{ctx, 2};
|
||||||
rb.Push(res);
|
rb.Push(res);
|
||||||
}
|
}
|
||||||
|
@ -307,22 +307,26 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
void DoHandshakeGetServerCert(HLERequestContext& ctx) {
|
void DoHandshakeGetServerCert(HLERequestContext& ctx) {
|
||||||
|
struct OutputParameters {
|
||||||
|
u32 certs_size;
|
||||||
|
u32 certs_count;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(OutputParameters) == 0x8);
|
||||||
|
|
||||||
const Result res = DoHandshakeImpl();
|
const Result res = DoHandshakeImpl();
|
||||||
u32 certs_count = 0;
|
OutputParameters out{};
|
||||||
u32 certs_size = 0;
|
|
||||||
if (res == ResultSuccess) {
|
if (res == ResultSuccess) {
|
||||||
auto certs = backend_->GetServerCerts();
|
auto certs = backend->GetServerCerts();
|
||||||
if (certs.Succeeded()) {
|
if (certs.Succeeded()) {
|
||||||
const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
|
const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
|
||||||
ctx.WriteBuffer(certs_buf);
|
ctx.WriteBuffer(certs_buf);
|
||||||
certs_count = static_cast<u32>(certs->size());
|
out.certs_count = static_cast<u32>(certs->size());
|
||||||
certs_size = static_cast<u32>(certs_buf.size());
|
out.certs_size = static_cast<u32>(certs_buf.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IPC::ResponseBuilder rb{ctx, 4};
|
IPC::ResponseBuilder rb{ctx, 4};
|
||||||
rb.Push(res);
|
rb.Push(res);
|
||||||
rb.Push(certs_size);
|
rb.PushRaw(out);
|
||||||
rb.Push(certs_count);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Read(HLERequestContext& ctx) {
|
void Read(HLERequestContext& ctx) {
|
||||||
|
@ -371,10 +375,10 @@ private:
|
||||||
|
|
||||||
switch (parameters.option) {
|
switch (parameters.option) {
|
||||||
case OptionType::DoNotCloseSocket:
|
case OptionType::DoNotCloseSocket:
|
||||||
do_not_close_socket_ = static_cast<bool>(parameters.value);
|
do_not_close_socket = static_cast<bool>(parameters.value);
|
||||||
break;
|
break;
|
||||||
case OptionType::GetServerCertChain:
|
case OptionType::GetServerCertChain:
|
||||||
get_server_cert_chain_ = static_cast<bool>(parameters.value);
|
get_server_cert_chain = static_cast<bool>(parameters.value);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
|
LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
|
||||||
|
@ -390,7 +394,7 @@ class ISslContext final : public ServiceFramework<ISslContext> {
|
||||||
public:
|
public:
|
||||||
explicit ISslContext(Core::System& system_, SslVersion version)
|
explicit ISslContext(Core::System& system_, SslVersion version)
|
||||||
: ServiceFramework{system_, "ISslContext"}, ssl_version{version},
|
: ServiceFramework{system_, "ISslContext"}, ssl_version{version},
|
||||||
shared_data_{std::make_shared<SslContextSharedData>()} {
|
shared_data{std::make_shared<SslContextSharedData>()} {
|
||||||
static const FunctionInfo functions[] = {
|
static const FunctionInfo functions[] = {
|
||||||
{0, &ISslContext::SetOption, "SetOption"},
|
{0, &ISslContext::SetOption, "SetOption"},
|
||||||
{1, nullptr, "GetOption"},
|
{1, nullptr, "GetOption"},
|
||||||
|
@ -412,7 +416,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SslVersion ssl_version;
|
SslVersion ssl_version;
|
||||||
std::shared_ptr<SslContextSharedData> shared_data_;
|
std::shared_ptr<SslContextSharedData> shared_data;
|
||||||
|
|
||||||
void SetOption(HLERequestContext& ctx) {
|
void SetOption(HLERequestContext& ctx) {
|
||||||
struct Parameters {
|
struct Parameters {
|
||||||
|
@ -439,17 +443,17 @@ private:
|
||||||
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
|
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
|
||||||
rb.Push(backend_res.Code());
|
rb.Push(backend_res.Code());
|
||||||
if (backend_res.Succeeded()) {
|
if (backend_res.Succeeded()) {
|
||||||
rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_,
|
rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
|
||||||
std::move(*backend_res));
|
std::move(*backend_res));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetConnectionCount(HLERequestContext& ctx) {
|
void GetConnectionCount(HLERequestContext& ctx) {
|
||||||
LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count);
|
LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
|
||||||
|
|
||||||
IPC::ResponseBuilder rb{ctx, 3};
|
IPC::ResponseBuilder rb{ctx, 3};
|
||||||
rb.Push(ResultSuccess);
|
rb.Push(ResultSuccess);
|
||||||
rb.Push(shared_data_->connection_count);
|
rb.Push(shared_data->connection_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportServerPki(HLERequestContext& ctx) {
|
void ImportServerPki(HLERequestContext& ctx) {
|
||||||
|
|
|
@ -51,37 +51,37 @@ public:
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
|
|
||||||
ssl_ = SSL_new(ssl_ctx);
|
ssl = SSL_new(ssl_ctx);
|
||||||
if (!ssl_) {
|
if (!ssl) {
|
||||||
LOG_ERROR(Service_SSL, "SSL_new failed");
|
LOG_ERROR(Service_SSL, "SSL_new failed");
|
||||||
return CheckOpenSSLErrors();
|
return CheckOpenSSLErrors();
|
||||||
}
|
}
|
||||||
|
|
||||||
SSL_set_connect_state(ssl_);
|
SSL_set_connect_state(ssl);
|
||||||
|
|
||||||
bio_ = BIO_new(bio_meth);
|
bio = BIO_new(bio_meth);
|
||||||
if (!bio_) {
|
if (!bio) {
|
||||||
LOG_ERROR(Service_SSL, "BIO_new failed");
|
LOG_ERROR(Service_SSL, "BIO_new failed");
|
||||||
return CheckOpenSSLErrors();
|
return CheckOpenSSLErrors();
|
||||||
}
|
}
|
||||||
|
|
||||||
BIO_set_data(bio_, this);
|
BIO_set_data(bio, this);
|
||||||
BIO_set_init(bio_, 1);
|
BIO_set_init(bio, 1);
|
||||||
SSL_set_bio(ssl_, bio_, bio_);
|
SSL_set_bio(ssl, bio, bio);
|
||||||
|
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
|
void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
|
||||||
socket_ = socket;
|
socket = std::move(socket_in);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetHostName(const std::string& hostname) override {
|
Result SetHostName(const std::string& hostname) override {
|
||||||
if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification
|
if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
|
||||||
LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
|
LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
|
||||||
return CheckOpenSSLErrors();
|
return CheckOpenSSLErrors();
|
||||||
}
|
}
|
||||||
if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI
|
if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
|
||||||
LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
|
LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
|
||||||
return CheckOpenSSLErrors();
|
return CheckOpenSSLErrors();
|
||||||
}
|
}
|
||||||
|
@ -89,18 +89,18 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Result DoHandshake() override {
|
Result DoHandshake() override {
|
||||||
SSL_set_verify_result(ssl_, X509_V_OK);
|
SSL_set_verify_result(ssl, X509_V_OK);
|
||||||
const int ret = SSL_do_handshake(ssl_);
|
const int ret = SSL_do_handshake(ssl);
|
||||||
const long verify_result = SSL_get_verify_result(ssl_);
|
const long verify_result = SSL_get_verify_result(ssl);
|
||||||
if (verify_result != X509_V_OK) {
|
if (verify_result != X509_V_OK) {
|
||||||
LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
|
LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
|
||||||
X509_verify_cert_error_string(verify_result));
|
X509_verify_cert_error_string(verify_result));
|
||||||
return CheckOpenSSLErrors();
|
return CheckOpenSSLErrors();
|
||||||
}
|
}
|
||||||
if (ret <= 0) {
|
if (ret <= 0) {
|
||||||
const int ssl_err = SSL_get_error(ssl_, ret);
|
const int ssl_err = SSL_get_error(ssl, ret);
|
||||||
if (ssl_err == SSL_ERROR_ZERO_RETURN ||
|
if (ssl_err == SSL_ERROR_ZERO_RETURN ||
|
||||||
(ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) {
|
(ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
|
||||||
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
|
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
|
@ -110,18 +110,18 @@ public:
|
||||||
|
|
||||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
ResultVal<size_t> Read(std::span<u8> data) override {
|
||||||
size_t actual;
|
size_t actual;
|
||||||
const int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual);
|
const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
|
||||||
return HandleReturn("SSL_read_ex", actual, ret);
|
return HandleReturn("SSL_read_ex", actual, ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
ResultVal<size_t> Write(std::span<const u8> data) override {
|
||||||
size_t actual;
|
size_t actual;
|
||||||
const int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual);
|
const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
|
||||||
return HandleReturn("SSL_write_ex", actual, ret);
|
return HandleReturn("SSL_write_ex", actual, ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
|
ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
|
||||||
const int ssl_err = SSL_get_error(ssl_, ret);
|
const int ssl_err = SSL_get_error(ssl, ret);
|
||||||
CheckOpenSSLErrors();
|
CheckOpenSSLErrors();
|
||||||
switch (ssl_err) {
|
switch (ssl_err) {
|
||||||
case SSL_ERROR_NONE:
|
case SSL_ERROR_NONE:
|
||||||
|
@ -137,7 +137,7 @@ public:
|
||||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
|
||||||
return ResultWouldBlock;
|
return ResultWouldBlock;
|
||||||
default:
|
default:
|
||||||
if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) {
|
if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
|
||||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
|
||||||
return size_t(0);
|
return size_t(0);
|
||||||
}
|
}
|
||||||
|
@ -147,7 +147,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
||||||
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_);
|
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
|
||||||
if (!chain) {
|
if (!chain) {
|
||||||
LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
|
LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
|
@ -169,8 +169,8 @@ public:
|
||||||
|
|
||||||
~SSLConnectionBackendOpenSSL() {
|
~SSLConnectionBackendOpenSSL() {
|
||||||
// these are null-tolerant:
|
// these are null-tolerant:
|
||||||
SSL_free(ssl_);
|
SSL_free(ssl);
|
||||||
BIO_free(bio_);
|
BIO_free(bio);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void KeyLogCallback(const SSL* ssl, const char* line) {
|
static void KeyLogCallback(const SSL* ssl, const char* line) {
|
||||||
|
@ -188,9 +188,9 @@ public:
|
||||||
static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
|
static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
|
||||||
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
|
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
|
||||||
ASSERT_OR_EXECUTE_MSG(
|
ASSERT_OR_EXECUTE_MSG(
|
||||||
self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket");
|
self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
|
||||||
BIO_clear_retry_flags(bio);
|
BIO_clear_retry_flags(bio);
|
||||||
auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0);
|
auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
|
||||||
switch (err) {
|
switch (err) {
|
||||||
case Network::Errno::SUCCESS:
|
case Network::Errno::SUCCESS:
|
||||||
*actual_p = actual;
|
*actual_p = actual;
|
||||||
|
@ -207,14 +207,14 @@ public:
|
||||||
static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
|
static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
|
||||||
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
|
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
|
||||||
ASSERT_OR_EXECUTE_MSG(
|
ASSERT_OR_EXECUTE_MSG(
|
||||||
self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket");
|
self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
|
||||||
BIO_clear_retry_flags(bio);
|
BIO_clear_retry_flags(bio);
|
||||||
auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len});
|
auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
|
||||||
switch (err) {
|
switch (err) {
|
||||||
case Network::Errno::SUCCESS:
|
case Network::Errno::SUCCESS:
|
||||||
*actual_p = actual;
|
*actual_p = actual;
|
||||||
if (actual == 0) {
|
if (actual == 0) {
|
||||||
self->got_read_eof_ = true;
|
self->got_read_eof = true;
|
||||||
}
|
}
|
||||||
return actual ? 1 : 0;
|
return actual ? 1 : 0;
|
||||||
case Network::Errno::AGAIN:
|
case Network::Errno::AGAIN:
|
||||||
|
@ -246,11 +246,11 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SSL* ssl_ = nullptr;
|
SSL* ssl = nullptr;
|
||||||
BIO* bio_ = nullptr;
|
BIO* bio = nullptr;
|
||||||
bool got_read_eof_ = false;
|
bool got_read_eof = false;
|
||||||
|
|
||||||
std::shared_ptr<Network::SocketBase> socket_;
|
std::shared_ptr<Network::SocketBase> socket;
|
||||||
};
|
};
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
||||||
|
|
|
@ -48,6 +48,12 @@ static void OneTimeInit() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (getenv("SSLKEYLOGFILE")) {
|
||||||
|
LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
|
||||||
|
"keys; not logging keys!");
|
||||||
|
// Not fatal.
|
||||||
|
}
|
||||||
|
|
||||||
one_time_init_success = true;
|
one_time_init_success = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,25 +76,25 @@ public:
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
|
void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
|
||||||
socket_ = socket;
|
socket = std::move(socket_in);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetHostName(const std::string& hostname) override {
|
Result SetHostName(const std::string& hostname_in) override {
|
||||||
hostname_ = hostname;
|
hostname = hostname_in;
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result DoHandshake() override {
|
Result DoHandshake() override {
|
||||||
while (1) {
|
while (1) {
|
||||||
Result r;
|
Result r;
|
||||||
switch (handshake_state_) {
|
switch (handshake_state) {
|
||||||
case HandshakeState::Initial:
|
case HandshakeState::Initial:
|
||||||
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
|
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
|
||||||
(r = CallInitializeSecurityContext()) != ResultSuccess) {
|
(r = CallInitializeSecurityContext()) != ResultSuccess) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
// CallInitializeSecurityContext updated `handshake_state_`.
|
// CallInitializeSecurityContext updated `handshake_state`.
|
||||||
continue;
|
continue;
|
||||||
case HandshakeState::ContinueNeeded:
|
case HandshakeState::ContinueNeeded:
|
||||||
case HandshakeState::IncompleteMessage:
|
case HandshakeState::IncompleteMessage:
|
||||||
|
@ -96,20 +102,20 @@ public:
|
||||||
(r = FillCiphertextReadBuf()) != ResultSuccess) {
|
(r = FillCiphertextReadBuf()) != ResultSuccess) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
if (ciphertext_read_buf_.empty()) {
|
if (ciphertext_read_buf.empty()) {
|
||||||
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
|
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
|
if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
// CallInitializeSecurityContext updated `handshake_state_`.
|
// CallInitializeSecurityContext updated `handshake_state`.
|
||||||
continue;
|
continue;
|
||||||
case HandshakeState::DoneAfterFlush:
|
case HandshakeState::DoneAfterFlush:
|
||||||
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
|
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
handshake_state_ = HandshakeState::Connected;
|
handshake_state = HandshakeState::Connected;
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
case HandshakeState::Connected:
|
case HandshakeState::Connected:
|
||||||
LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
|
LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
|
||||||
|
@ -121,24 +127,24 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Result FillCiphertextReadBuf() {
|
Result FillCiphertextReadBuf() {
|
||||||
const size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
|
const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
|
||||||
read_buf_fill_size_ = 0;
|
read_buf_fill_size = 0;
|
||||||
// This unnecessarily zeroes the buffer; oh well.
|
// This unnecessarily zeroes the buffer; oh well.
|
||||||
const size_t offset = ciphertext_read_buf_.size();
|
const size_t offset = ciphertext_read_buf.size();
|
||||||
ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
|
||||||
ciphertext_read_buf_.resize(offset + fill_size, 0);
|
ciphertext_read_buf.resize(offset + fill_size, 0);
|
||||||
const auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
|
const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
|
||||||
const auto [actual, err] = socket_->Recv(0, read_span);
|
const auto [actual, err] = socket->Recv(0, read_span);
|
||||||
switch (err) {
|
switch (err) {
|
||||||
case Network::Errno::SUCCESS:
|
case Network::Errno::SUCCESS:
|
||||||
ASSERT(static_cast<size_t>(actual) <= fill_size);
|
ASSERT(static_cast<size_t>(actual) <= fill_size);
|
||||||
ciphertext_read_buf_.resize(offset + actual);
|
ciphertext_read_buf.resize(offset + actual);
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
case Network::Errno::AGAIN:
|
case Network::Errno::AGAIN:
|
||||||
ciphertext_read_buf_.resize(offset);
|
ciphertext_read_buf.resize(offset);
|
||||||
return ResultWouldBlock;
|
return ResultWouldBlock;
|
||||||
default:
|
default:
|
||||||
ciphertext_read_buf_.resize(offset);
|
ciphertext_read_buf.resize(offset);
|
||||||
LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
|
LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
|
@ -146,13 +152,13 @@ public:
|
||||||
|
|
||||||
// Returns success if the write buffer has been completely emptied.
|
// Returns success if the write buffer has been completely emptied.
|
||||||
Result FlushCiphertextWriteBuf() {
|
Result FlushCiphertextWriteBuf() {
|
||||||
while (!ciphertext_write_buf_.empty()) {
|
while (!ciphertext_write_buf.empty()) {
|
||||||
const auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
|
const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
|
||||||
switch (err) {
|
switch (err) {
|
||||||
case Network::Errno::SUCCESS:
|
case Network::Errno::SUCCESS:
|
||||||
ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
|
ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
|
||||||
ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(),
|
ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
|
||||||
ciphertext_write_buf_.begin() + actual);
|
ciphertext_write_buf.begin() + actual);
|
||||||
break;
|
break;
|
||||||
case Network::Errno::AGAIN:
|
case Network::Errno::AGAIN:
|
||||||
return ResultWouldBlock;
|
return ResultWouldBlock;
|
||||||
|
@ -175,9 +181,9 @@ public:
|
||||||
// only used if `initial_call_done`
|
// only used if `initial_call_done`
|
||||||
{
|
{
|
||||||
// [0]
|
// [0]
|
||||||
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
|
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
|
||||||
.BufferType = SECBUFFER_TOKEN,
|
.BufferType = SECBUFFER_TOKEN,
|
||||||
.pvBuffer = ciphertext_read_buf_.data(),
|
.pvBuffer = ciphertext_read_buf.data(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
|
// [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
|
||||||
|
@ -211,30 +217,30 @@ public:
|
||||||
.pBuffers = output_buffers.data(),
|
.pBuffers = output_buffers.data(),
|
||||||
};
|
};
|
||||||
ASSERT_OR_EXECUTE_MSG(
|
ASSERT_OR_EXECUTE_MSG(
|
||||||
input_buffers[0].cbBuffer == ciphertext_read_buf_.size(),
|
input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
|
||||||
{ return ResultInternalError; }, "read buffer too large");
|
{ return ResultInternalError; }, "read buffer too large");
|
||||||
|
|
||||||
bool initial_call_done = handshake_state_ != HandshakeState::Initial;
|
bool initial_call_done = handshake_state != HandshakeState::Initial;
|
||||||
if (initial_call_done) {
|
if (initial_call_done) {
|
||||||
LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
|
LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
|
||||||
ciphertext_read_buf_.size());
|
ciphertext_read_buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
const SECURITY_STATUS ret =
|
const SECURITY_STATUS ret =
|
||||||
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
|
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
|
||||||
// Caller ensured we have set a hostname:
|
// Caller ensured we have set a hostname:
|
||||||
const_cast<char*>(hostname_.value().c_str()), req,
|
const_cast<char*>(hostname.value().c_str()), req,
|
||||||
0, // Reserved1
|
0, // Reserved1
|
||||||
0, // TargetDataRep not used with Schannel
|
0, // TargetDataRep not used with Schannel
|
||||||
initial_call_done ? &input_desc : nullptr,
|
initial_call_done ? &input_desc : nullptr,
|
||||||
0, // Reserved2
|
0, // Reserved2
|
||||||
initial_call_done ? nullptr : &ctxt_, &output_desc, &attr,
|
initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
|
||||||
nullptr); // ptsExpiry
|
nullptr); // ptsExpiry
|
||||||
|
|
||||||
if (output_buffers[0].pvBuffer) {
|
if (output_buffers[0].pvBuffer) {
|
||||||
const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
|
const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
|
||||||
output_buffers[0].cbBuffer);
|
output_buffers[0].cbBuffer);
|
||||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
|
||||||
FreeContextBuffer(output_buffers[0].pvBuffer);
|
FreeContextBuffer(output_buffers[0].pvBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,64 +257,64 @@ public:
|
||||||
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
|
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
|
||||||
if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
|
if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
|
||||||
LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
|
LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
|
||||||
ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size());
|
ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
|
||||||
ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
|
ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
|
||||||
ciphertext_read_buf_.end() - input_buffers[1].cbBuffer);
|
ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
|
||||||
} else {
|
} else {
|
||||||
ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
|
ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
|
||||||
ciphertext_read_buf_.clear();
|
ciphertext_read_buf.clear();
|
||||||
}
|
}
|
||||||
handshake_state_ = HandshakeState::ContinueNeeded;
|
handshake_state = HandshakeState::ContinueNeeded;
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
case SEC_E_INCOMPLETE_MESSAGE:
|
case SEC_E_INCOMPLETE_MESSAGE:
|
||||||
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
|
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
|
||||||
ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
|
ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
|
||||||
read_buf_fill_size_ = input_buffers[1].cbBuffer;
|
read_buf_fill_size = input_buffers[1].cbBuffer;
|
||||||
handshake_state_ = HandshakeState::IncompleteMessage;
|
handshake_state = HandshakeState::IncompleteMessage;
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
case SEC_E_OK:
|
case SEC_E_OK:
|
||||||
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
|
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
|
||||||
ciphertext_read_buf_.clear();
|
ciphertext_read_buf.clear();
|
||||||
handshake_state_ = HandshakeState::DoneAfterFlush;
|
handshake_state = HandshakeState::DoneAfterFlush;
|
||||||
return GrabStreamSizes();
|
return GrabStreamSizes();
|
||||||
default:
|
default:
|
||||||
LOG_ERROR(Service_SSL,
|
LOG_ERROR(Service_SSL,
|
||||||
"InitializeSecurityContext failed (probably certificate/protocol issue): {}",
|
"InitializeSecurityContext failed (probably certificate/protocol issue): {}",
|
||||||
Common::NativeErrorToString(ret));
|
Common::NativeErrorToString(ret));
|
||||||
handshake_state_ = HandshakeState::Error;
|
handshake_state = HandshakeState::Error;
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Result GrabStreamSizes() {
|
Result GrabStreamSizes() {
|
||||||
const SECURITY_STATUS ret =
|
const SECURITY_STATUS ret =
|
||||||
QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
|
QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
|
||||||
if (ret != SEC_E_OK) {
|
if (ret != SEC_E_OK) {
|
||||||
LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
|
LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
|
||||||
Common::NativeErrorToString(ret));
|
Common::NativeErrorToString(ret));
|
||||||
handshake_state_ = HandshakeState::Error;
|
handshake_state = HandshakeState::Error;
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
ResultVal<size_t> Read(std::span<u8> data) override {
|
||||||
if (handshake_state_ != HandshakeState::Connected) {
|
if (handshake_state != HandshakeState::Connected) {
|
||||||
LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
|
LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
if (data.size() == 0 || got_read_eof_) {
|
if (data.size() == 0 || got_read_eof) {
|
||||||
return size_t(0);
|
return size_t(0);
|
||||||
}
|
}
|
||||||
while (1) {
|
while (1) {
|
||||||
if (!cleartext_read_buf_.empty()) {
|
if (!cleartext_read_buf.empty()) {
|
||||||
const size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
|
const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
|
||||||
std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
|
std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
|
||||||
cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
|
cleartext_read_buf.erase(cleartext_read_buf.begin(),
|
||||||
cleartext_read_buf_.begin() + read_size);
|
cleartext_read_buf.begin() + read_size);
|
||||||
return read_size;
|
return read_size;
|
||||||
}
|
}
|
||||||
if (!ciphertext_read_buf_.empty()) {
|
if (!ciphertext_read_buf.empty()) {
|
||||||
SecBuffer empty{
|
SecBuffer empty{
|
||||||
.cbBuffer = 0,
|
.cbBuffer = 0,
|
||||||
.BufferType = SECBUFFER_EMPTY,
|
.BufferType = SECBUFFER_EMPTY,
|
||||||
|
@ -316,16 +322,16 @@ public:
|
||||||
};
|
};
|
||||||
std::array<SecBuffer, 5> buffers{{
|
std::array<SecBuffer, 5> buffers{{
|
||||||
{
|
{
|
||||||
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
|
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
|
||||||
.BufferType = SECBUFFER_DATA,
|
.BufferType = SECBUFFER_DATA,
|
||||||
.pvBuffer = ciphertext_read_buf_.data(),
|
.pvBuffer = ciphertext_read_buf.data(),
|
||||||
},
|
},
|
||||||
empty,
|
empty,
|
||||||
empty,
|
empty,
|
||||||
empty,
|
empty,
|
||||||
}};
|
}};
|
||||||
ASSERT_OR_EXECUTE_MSG(
|
ASSERT_OR_EXECUTE_MSG(
|
||||||
buffers[0].cbBuffer == ciphertext_read_buf_.size(),
|
buffers[0].cbBuffer == ciphertext_read_buf.size(),
|
||||||
{ return ResultInternalError; }, "read buffer too large");
|
{ return ResultInternalError; }, "read buffer too large");
|
||||||
SecBufferDesc desc{
|
SecBufferDesc desc{
|
||||||
.ulVersion = SECBUFFER_VERSION,
|
.ulVersion = SECBUFFER_VERSION,
|
||||||
|
@ -333,7 +339,7 @@ public:
|
||||||
.pBuffers = buffers.data(),
|
.pBuffers = buffers.data(),
|
||||||
};
|
};
|
||||||
SECURITY_STATUS ret =
|
SECURITY_STATUS ret =
|
||||||
DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
|
DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
|
||||||
switch (ret) {
|
switch (ret) {
|
||||||
case SEC_E_OK:
|
case SEC_E_OK:
|
||||||
ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
|
ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
|
||||||
|
@ -342,24 +348,23 @@ public:
|
||||||
{ return ResultInternalError; });
|
{ return ResultInternalError; });
|
||||||
ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
|
ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
|
||||||
{ return ResultInternalError; });
|
{ return ResultInternalError; });
|
||||||
cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer),
|
cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
|
||||||
static_cast<u8*>(buffers[1].pvBuffer) +
|
static_cast<u8*>(buffers[1].pvBuffer) +
|
||||||
buffers[1].cbBuffer);
|
buffers[1].cbBuffer);
|
||||||
if (buffers[3].BufferType == SECBUFFER_EXTRA) {
|
if (buffers[3].BufferType == SECBUFFER_EXTRA) {
|
||||||
ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size());
|
ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
|
||||||
ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
|
ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
|
||||||
ciphertext_read_buf_.end() -
|
ciphertext_read_buf.end() - buffers[3].cbBuffer);
|
||||||
buffers[3].cbBuffer);
|
|
||||||
} else {
|
} else {
|
||||||
ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
|
ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
|
||||||
ciphertext_read_buf_.clear();
|
ciphertext_read_buf.clear();
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
case SEC_E_INCOMPLETE_MESSAGE:
|
case SEC_E_INCOMPLETE_MESSAGE:
|
||||||
break;
|
break;
|
||||||
case SEC_I_CONTEXT_EXPIRED:
|
case SEC_I_CONTEXT_EXPIRED:
|
||||||
// Server hung up by sending close_notify.
|
// Server hung up by sending close_notify.
|
||||||
got_read_eof_ = true;
|
got_read_eof = true;
|
||||||
return size_t(0);
|
return size_t(0);
|
||||||
default:
|
default:
|
||||||
LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
|
LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
|
||||||
|
@ -371,43 +376,43 @@ public:
|
||||||
if (r != ResultSuccess) {
|
if (r != ResultSuccess) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
if (ciphertext_read_buf_.empty()) {
|
if (ciphertext_read_buf.empty()) {
|
||||||
got_read_eof_ = true;
|
got_read_eof = true;
|
||||||
return size_t(0);
|
return size_t(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
ResultVal<size_t> Write(std::span<const u8> data) override {
|
||||||
if (handshake_state_ != HandshakeState::Connected) {
|
if (handshake_state != HandshakeState::Connected) {
|
||||||
LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
|
LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
if (data.size() == 0) {
|
if (data.size() == 0) {
|
||||||
return size_t(0);
|
return size_t(0);
|
||||||
}
|
}
|
||||||
data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage));
|
data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
|
||||||
if (!cleartext_write_buf_.empty()) {
|
if (!cleartext_write_buf.empty()) {
|
||||||
// Already in the middle of a write. It wouldn't make sense to not
|
// Already in the middle of a write. It wouldn't make sense to not
|
||||||
// finish sending the entire buffer since TLS has
|
// finish sending the entire buffer since TLS has
|
||||||
// header/MAC/padding/etc.
|
// header/MAC/padding/etc.
|
||||||
if (data.size() != cleartext_write_buf_.size() ||
|
if (data.size() != cleartext_write_buf.size() ||
|
||||||
std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) {
|
std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
|
||||||
LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
|
LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
return WriteAlreadyEncryptedData();
|
return WriteAlreadyEncryptedData();
|
||||||
} else {
|
} else {
|
||||||
cleartext_write_buf_.assign(data.begin(), data.end());
|
cleartext_write_buf.assign(data.begin(), data.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<u8> header_buf(stream_sizes_.cbHeader, 0);
|
std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
|
||||||
std::vector<u8> tmp_data_buf = cleartext_write_buf_;
|
std::vector<u8> tmp_data_buf = cleartext_write_buf;
|
||||||
std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0);
|
std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
|
||||||
|
|
||||||
std::array<SecBuffer, 3> buffers{{
|
std::array<SecBuffer, 3> buffers{{
|
||||||
{
|
{
|
||||||
.cbBuffer = stream_sizes_.cbHeader,
|
.cbBuffer = stream_sizes.cbHeader,
|
||||||
.BufferType = SECBUFFER_STREAM_HEADER,
|
.BufferType = SECBUFFER_STREAM_HEADER,
|
||||||
.pvBuffer = header_buf.data(),
|
.pvBuffer = header_buf.data(),
|
||||||
},
|
},
|
||||||
|
@ -417,7 +422,7 @@ public:
|
||||||
.pvBuffer = tmp_data_buf.data(),
|
.pvBuffer = tmp_data_buf.data(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
.cbBuffer = stream_sizes_.cbTrailer,
|
.cbBuffer = stream_sizes.cbTrailer,
|
||||||
.BufferType = SECBUFFER_STREAM_TRAILER,
|
.BufferType = SECBUFFER_STREAM_TRAILER,
|
||||||
.pvBuffer = trailer_buf.data(),
|
.pvBuffer = trailer_buf.data(),
|
||||||
},
|
},
|
||||||
|
@ -431,16 +436,16 @@ public:
|
||||||
.pBuffers = buffers.data(),
|
.pBuffers = buffers.data(),
|
||||||
};
|
};
|
||||||
|
|
||||||
const SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
|
const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
|
||||||
if (ret != SEC_E_OK) {
|
if (ret != SEC_E_OK) {
|
||||||
LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
|
LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(),
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
|
||||||
header_buf.end());
|
header_buf.end());
|
||||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(),
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
|
||||||
tmp_data_buf.end());
|
tmp_data_buf.end());
|
||||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(),
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
|
||||||
trailer_buf.end());
|
trailer_buf.end());
|
||||||
return WriteAlreadyEncryptedData();
|
return WriteAlreadyEncryptedData();
|
||||||
}
|
}
|
||||||
|
@ -451,15 +456,15 @@ public:
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
// write buf is empty
|
// write buf is empty
|
||||||
const size_t cleartext_bytes_written = cleartext_write_buf_.size();
|
const size_t cleartext_bytes_written = cleartext_write_buf.size();
|
||||||
cleartext_write_buf_.clear();
|
cleartext_write_buf.clear();
|
||||||
return cleartext_bytes_written;
|
return cleartext_bytes_written;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
||||||
PCCERT_CONTEXT returned_cert = nullptr;
|
PCCERT_CONTEXT returned_cert = nullptr;
|
||||||
const SECURITY_STATUS ret =
|
const SECURITY_STATUS ret =
|
||||||
QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
||||||
if (ret != SEC_E_OK) {
|
if (ret != SEC_E_OK) {
|
||||||
LOG_ERROR(Service_SSL,
|
LOG_ERROR(Service_SSL,
|
||||||
"QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
|
"QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
|
||||||
|
@ -480,8 +485,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
~SSLConnectionBackendSchannel() {
|
~SSLConnectionBackendSchannel() {
|
||||||
if (handshake_state_ != HandshakeState::Initial) {
|
if (handshake_state != HandshakeState::Initial) {
|
||||||
DeleteSecurityContext(&ctxt_);
|
DeleteSecurityContext(&ctxt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -509,21 +514,21 @@ public:
|
||||||
// Another error was returned and we shouldn't allow initialization
|
// Another error was returned and we shouldn't allow initialization
|
||||||
// to continue.
|
// to continue.
|
||||||
Error,
|
Error,
|
||||||
} handshake_state_ = HandshakeState::Initial;
|
} handshake_state = HandshakeState::Initial;
|
||||||
|
|
||||||
CtxtHandle ctxt_;
|
CtxtHandle ctxt;
|
||||||
SecPkgContext_StreamSizes stream_sizes_;
|
SecPkgContext_StreamSizes stream_sizes;
|
||||||
|
|
||||||
std::shared_ptr<Network::SocketBase> socket_;
|
std::shared_ptr<Network::SocketBase> socket;
|
||||||
std::optional<std::string> hostname_;
|
std::optional<std::string> hostname;
|
||||||
|
|
||||||
std::vector<u8> ciphertext_read_buf_;
|
std::vector<u8> ciphertext_read_buf;
|
||||||
std::vector<u8> ciphertext_write_buf_;
|
std::vector<u8> ciphertext_write_buf;
|
||||||
std::vector<u8> cleartext_read_buf_;
|
std::vector<u8> cleartext_read_buf;
|
||||||
std::vector<u8> cleartext_write_buf_;
|
std::vector<u8> cleartext_write_buf;
|
||||||
|
|
||||||
bool got_read_eof_ = false;
|
bool got_read_eof = false;
|
||||||
size_t read_buf_fill_size_ = 0;
|
size_t read_buf_fill_size = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
||||||
|
|
|
@ -570,10 +570,10 @@ Socket::Socket(Socket&& rhs) noexcept {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_, int option) {
|
std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) {
|
||||||
T value{};
|
T value{};
|
||||||
socklen_t len = sizeof(value);
|
socklen_t len = sizeof(value);
|
||||||
const int result = getsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
|
const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
|
||||||
if (result != SOCKET_ERROR) {
|
if (result != SOCKET_ERROR) {
|
||||||
ASSERT(len == sizeof(value));
|
ASSERT(len == sizeof(value));
|
||||||
return {value, Errno::SUCCESS};
|
return {value, Errno::SUCCESS};
|
||||||
|
@ -582,9 +582,9 @@ std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_, int option) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
|
Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) {
|
||||||
const int result =
|
const int result =
|
||||||
setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
|
setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
|
||||||
if (result != SOCKET_ERROR) {
|
if (result != SOCKET_ERROR) {
|
||||||
return Errno::SUCCESS;
|
return Errno::SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue