From 27bad0598a3ddce0417388c3945368200150d413 Mon Sep 17 00:00:00 2001
From: bunnei <bunneidev@gmail.com>
Date: Tue, 23 Jan 2018 18:03:09 -0500
Subject: [PATCH] hle: Integrate Domain handling into ServerSession.

---
 src/core/hle/ipc_helpers.h             |  6 ++--
 src/core/hle/kernel/client_session.h   |  3 +-
 src/core/hle/kernel/hle_ipc.cpp        | 10 ++----
 src/core/hle/kernel/hle_ipc.h          | 16 ++-------
 src/core/hle/kernel/server_session.cpp | 47 +++++++++++++++++++++++---
 src/core/hle/kernel/server_session.h   | 18 +++++++++-
 src/core/hle/service/sm/controller.cpp | 12 +++----
 7 files changed, 74 insertions(+), 38 deletions(-)

diff --git a/src/core/hle/ipc_helpers.h b/src/core/hle/ipc_helpers.h
index e5c26e079..d62731678 100644
--- a/src/core/hle/ipc_helpers.h
+++ b/src/core/hle/ipc_helpers.h
@@ -76,7 +76,7 @@ public:
         // The entire size of the raw data section in u32 units, including the 16 bytes of mandatory
         // padding.
         u32 raw_data_size = sizeof(IPC::DataPayloadHeader) / 4 + 4 + normal_params_size;
-        if (context.IsDomain()) {
+        if (context.Session()->IsDomain()) {
             raw_data_size += sizeof(DomainMessageHeader) / 4 + num_domain_objects;
         } else {
             // If we're not in a domain, turn the domain object parameters into move handles.
@@ -100,7 +100,7 @@ public:
 
         AlignWithPadding();
 
-        if (context.IsDomain()) {
+        if (context.Session()->IsDomain()) {
             IPC::DomainMessageHeader domain_header{};
             domain_header.num_objects = num_domain_objects;
             PushRaw(domain_header);
@@ -114,7 +114,7 @@ public:
     template <class T, class... Args>
     void PushIpcInterface(Args&&... args) {
         auto iface = std::make_shared<T>(std::forward<Args>(args)...);
-        if (context->IsDomain()) {
+        if (context->Session()->IsDomain()) {
             context->AddDomainObject(std::move(iface));
         } else {
             auto sessions = Kernel::ServerSession::CreateSessionPair(iface->GetServiceName());
diff --git a/src/core/hle/kernel/client_session.h b/src/core/hle/kernel/client_session.h
index f2765cc1e..2258f95bc 100644
--- a/src/core/hle/kernel/client_session.h
+++ b/src/core/hle/kernel/client_session.h
@@ -7,6 +7,7 @@
 #include <memory>
 #include <string>
 #include "common/common_types.h"
+#include "core/hle/kernel/kernel.h"
 #include "core/hle/result.h"
 
 namespace Kernel {
@@ -32,7 +33,7 @@ public:
         return HANDLE_TYPE;
     }
 
-    ResultCode SendSyncRequest(SharedPtr<Thread> thread) override;
+    ResultCode SendSyncRequest(SharedPtr<Thread> thread);
 
     std::string name; ///< Name of client port (optional)
 
diff --git a/src/core/hle/kernel/hle_ipc.cpp b/src/core/hle/kernel/hle_ipc.cpp
index 2cd6de12e..db104e8a2 100644
--- a/src/core/hle/kernel/hle_ipc.cpp
+++ b/src/core/hle/kernel/hle_ipc.cpp
@@ -25,10 +25,6 @@ void SessionRequestHandler::ClientDisconnected(SharedPtr<ServerSession> server_s
     boost::range::remove_erase(connected_sessions, server_session);
 }
 
-HLERequestContext::HLERequestContext(SharedPtr<Kernel::Domain> domain) : domain(std::move(domain)) {
-    cmd_buf[0] = 0;
-}
-
 HLERequestContext::HLERequestContext(SharedPtr<Kernel::ServerSession> server_session)
     : server_session(std::move(server_session)) {
     cmd_buf[0] = 0;
@@ -86,7 +82,7 @@ void HLERequestContext::ParseCommandBuffer(u32_le* src_cmdbuf, bool incoming) {
     // Padding to align to 16 bytes
     rp.AlignWithPadding();
 
-    if (IsDomain() && (command_header->type == IPC::CommandType::Request || !incoming)) {
+    if (Session()->IsDomain() && (command_header->type == IPC::CommandType::Request || !incoming)) {
         // If this is an incoming message, only CommandType "Request" has a domain header
         // All outgoing domain messages have the domain header
         domain_message_header =
@@ -199,12 +195,12 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(u32_le* dst_cmdbuf, P
 
     // TODO(Subv): Translate the X/A/B/W buffers.
 
-    if (IsDomain()) {
+    if (Session()->IsDomain()) {
         ASSERT(domain_message_header->num_objects == domain_objects.size());
         // Write the domain objects to the command buffer, these go after the raw untranslated data.
         // TODO(Subv): This completely ignores C buffers.
         size_t domain_offset = size - domain_message_header->num_objects;
-        auto& request_handlers = domain->request_handlers;
+        auto& request_handlers = server_session->domain_request_handlers;
 
         for (auto& object : domain_objects) {
             request_handlers.emplace_back(object);
diff --git a/src/core/hle/kernel/hle_ipc.h b/src/core/hle/kernel/hle_ipc.h
index 80fa48d7f..71e5609b8 100644
--- a/src/core/hle/kernel/hle_ipc.h
+++ b/src/core/hle/kernel/hle_ipc.h
@@ -86,7 +86,6 @@ protected:
  */
 class HLERequestContext {
 public:
-    HLERequestContext(SharedPtr<Kernel::Domain> domain);
     HLERequestContext(SharedPtr<Kernel::ServerSession> session);
     ~HLERequestContext();
 
@@ -95,18 +94,11 @@ public:
         return cmd_buf.data();
     }
 
-    /**
-     * Returns the domain through which this request was made.
-     */
-    const SharedPtr<Kernel::Domain>& Domain() const {
-        return domain;
-    }
-
     /**
      * Returns the session through which this request was made. This can be used as a map key to
      * access per-client data on services.
      */
-    const SharedPtr<Kernel::ServerSession>& ServerSession() const {
+    const SharedPtr<Kernel::ServerSession>& Session() const {
         return server_session;
     }
 
@@ -151,10 +143,6 @@ public:
         return domain_message_header;
     }
 
-    bool IsDomain() const {
-        return domain != nullptr;
-    }
-
     template <typename T>
     SharedPtr<T> GetCopyObject(size_t index) {
         ASSERT(index < copy_objects.size());
@@ -189,7 +177,6 @@ public:
 
 private:
     std::array<u32, IPC::COMMAND_BUFFER_LENGTH> cmd_buf;
-    SharedPtr<Kernel::Domain> domain;
     SharedPtr<Kernel::ServerSession> server_session;
     // TODO(yuriks): Check common usage of this and optimize size accordingly
     boost::container::small_vector<SharedPtr<Object>, 8> move_objects;
@@ -209,6 +196,7 @@ private:
     unsigned data_payload_offset{};
     unsigned buffer_c_offset{};
     u32_le command{};
+    bool is_domain{};
 };
 
 } // namespace Kernel
diff --git a/src/core/hle/kernel/server_session.cpp b/src/core/hle/kernel/server_session.cpp
index 09d02a691..b79bf7bab 100644
--- a/src/core/hle/kernel/server_session.cpp
+++ b/src/core/hle/kernel/server_session.cpp
@@ -4,6 +4,7 @@
 
 #include <tuple>
 
+#include "core/hle/ipc_helpers.h"
 #include "core/hle/kernel/client_port.h"
 #include "core/hle/kernel/client_session.h"
 #include "core/hle/kernel/handle_table.h"
@@ -61,6 +62,38 @@ ResultCode ServerSession::HandleSyncRequest(SharedPtr<Thread> thread) {
     // from its ClientSession, so wake up any threads that may be waiting on a svcReplyAndReceive or
     // similar.
 
+    Kernel::HLERequestContext context(this);
+    u32* cmd_buf = (u32*)Memory::GetPointer(thread->GetTLSAddress());
+    context.PopulateFromIncomingCommandBuffer(cmd_buf, *Kernel::g_current_process,
+                                              Kernel::g_handle_table);
+
+    // If the session has been converted to a domain, handle the doomain request
+    if (IsDomain()) {
+        auto& domain_message_header = context.GetDomainMessageHeader();
+        if (domain_message_header) {
+            // If there is a DomainMessageHeader, then this is CommandType "Request"
+            const u32 object_id{context.GetDomainMessageHeader()->object_id};
+            switch (domain_message_header->command) {
+            case IPC::DomainMessageHeader::CommandType::SendMessage:
+                return domain_request_handlers[object_id - 1]->HandleSyncRequest(context);
+
+            case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: {
+                LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x%08X", object_id);
+
+                domain_request_handlers[object_id - 1] = nullptr;
+
+                IPC::RequestBuilder rb{context, 2};
+                rb.Push(RESULT_SUCCESS);
+                return RESULT_SUCCESS;
+            }
+            }
+
+            LOG_CRITICAL(IPC, "Unknown domain command=%d", domain_message_header->command.Value());
+            UNIMPLEMENTED();
+        }
+        return domain_request_handlers.front()->HandleSyncRequest(context);
+    }
+
     // If this ServerSession has an associated HLE handler, forward the request to it.
     ResultCode result{RESULT_SUCCESS};
     if (hle_handler != nullptr) {
@@ -69,11 +102,6 @@ ResultCode ServerSession::HandleSyncRequest(SharedPtr<Thread> thread) {
         if (translate_result.IsError())
             return translate_result;
 
-        Kernel::HLERequestContext context(this);
-        u32* cmd_buf = (u32*)Memory::GetPointer(Kernel::GetCurrentThread()->GetTLSAddress());
-        context.PopulateFromIncomingCommandBuffer(cmd_buf, *Kernel::g_current_process,
-                                                  Kernel::g_handle_table);
-
         result = hle_handler->HandleSyncRequest(context);
     } else {
         // Add the thread to the list of threads that have issued a sync request with this
@@ -84,6 +112,15 @@ ResultCode ServerSession::HandleSyncRequest(SharedPtr<Thread> thread) {
     // If this ServerSession does not have an HLE implementation, just wake up the threads waiting
     // on it.
     WakeupAllWaitingThreads();
+
+    // Handle scenario when ConvertToDomain command was issued, as we must do the conversion at the
+    // end of the command such that only commands following this one are handled as domains
+    if (convert_to_domain) {
+        ASSERT_MSG(domain_request_handlers.empty(), "already a domain");
+        domain_request_handlers.push_back(std::move(hle_handler));
+        convert_to_domain = false;
+    }
+
     return result;
 }
 
diff --git a/src/core/hle/kernel/server_session.h b/src/core/hle/kernel/server_session.h
index 6ff4ef8c1..144692106 100644
--- a/src/core/hle/kernel/server_session.h
+++ b/src/core/hle/kernel/server_session.h
@@ -79,7 +79,10 @@ public:
     std::string name;                ///< The name of this session (optional)
     std::shared_ptr<Session> parent; ///< The parent session, which links to the client endpoint.
     std::shared_ptr<SessionRequestHandler>
-        hle_handler; ///< This session's HLE request handler (optional)
+        hle_handler; ///< This session's HLE request handler (applicable when not a domain)
+
+    /// This is the list of domain request handlers (after conversion to a domain)
+    std::vector<std::shared_ptr<SessionRequestHandler>> domain_request_handlers;
 
     /// List of threads that are pending a response after a sync request. This list is processed in
     /// a LIFO manner, thus, the last request will be dispatched first.
@@ -91,6 +94,16 @@ public:
     /// TODO(Subv): Find a better name for this.
     SharedPtr<Thread> currently_handling;
 
+    /// Returns true if the session has been converted to a domain, otherwise False
+    bool IsDomain() const {
+        return !domain_request_handlers.empty();
+    }
+
+    /// Converts the session to a domain at the end of the current command
+    void ConvertToDomain() {
+        convert_to_domain = true;
+    }
+
 private:
     ServerSession();
     ~ServerSession() override;
@@ -102,6 +115,9 @@ private:
      * @return The created server session
      */
     static ResultVal<SharedPtr<ServerSession>> Create(std::string name = "Unknown");
+
+    /// When set to True, converts the session to a domain at the end of the command
+    bool convert_to_domain{};
 };
 
 /**
diff --git a/src/core/hle/service/sm/controller.cpp b/src/core/hle/service/sm/controller.cpp
index e91d9d856..3eead315a 100644
--- a/src/core/hle/service/sm/controller.cpp
+++ b/src/core/hle/service/sm/controller.cpp
@@ -10,23 +10,21 @@ namespace Service {
 namespace SM {
 
 void Controller::ConvertSessionToDomain(Kernel::HLERequestContext& ctx) {
-    auto domain = Kernel::Domain::CreateFromSession(*ctx.ServerSession()->parent).Unwrap();
+    ASSERT_MSG(!ctx.Session()->IsDomain(), "session is alread a domain");
+    ctx.Session()->ConvertToDomain();
 
     IPC::RequestBuilder rb{ctx, 3};
     rb.Push(RESULT_SUCCESS);
-    rb.Push(static_cast<u32>(domain->request_handlers.size()));
+    rb.Push<u32>(1); // Converted sessions start with 1 request handler
 
-    LOG_DEBUG(Service, "called, domain=%d", domain->GetObjectId());
+    LOG_DEBUG(Service, "called, server_session=%d", ctx.Session()->GetObjectId());
 }
 
 void Controller::DuplicateSession(Kernel::HLERequestContext& ctx) {
     IPC::RequestBuilder rb{ctx, 2, 0, 1};
     rb.Push(RESULT_SUCCESS);
     // TODO(Subv): Check if this is correct
-    if (ctx.IsDomain())
-        rb.PushMoveObjects(ctx.Domain());
-    else
-        rb.PushMoveObjects(ctx.ServerSession());
+    rb.PushMoveObjects(ctx.Session());
 
     LOG_DEBUG(Service, "called");
 }