core: hle: kernel: Use weak_ptr where possible for SessionRequestHandler and SessionRequestManager.

This commit is contained in:
bunnei 2022-03-10 23:45:54 -08:00
parent ce33503adf
commit 07c9d9bdbd
7 changed files with 25 additions and 14 deletions

View file

@ -385,7 +385,7 @@ public:
T PopRaw(); T PopRaw();
template <class T> template <class T>
std::shared_ptr<T> PopIpcInterface() { std::weak_ptr<T> PopIpcInterface() {
ASSERT(context->Session()->IsDomain()); ASSERT(context->Session()->IsDomain());
ASSERT(context->GetDomainMessageHeader().input_object_count > 0); ASSERT(context->GetDomainMessageHeader().input_object_count > 0);
return context->GetDomainHandler<T>(Pop<u32>() - 1); return context->GetDomainHandler<T>(Pop<u32>() - 1);

View file

@ -45,7 +45,7 @@ bool SessionRequestManager::HasSessionRequestHandler(const HLERequestContext& co
LOG_CRITICAL(IPC, "object_id {} is too big!", object_id); LOG_CRITICAL(IPC, "object_id {} is too big!", object_id);
return false; return false;
} }
return DomainHandler(object_id - 1) != nullptr; return DomainHandler(object_id - 1).lock() != nullptr;
} else { } else {
return session_handler != nullptr; return session_handler != nullptr;
} }

View file

@ -94,6 +94,7 @@ protected:
std::weak_ptr<ServiceThread> service_thread; std::weak_ptr<ServiceThread> service_thread;
}; };
using SessionRequestHandlerWeakPtr = std::weak_ptr<SessionRequestHandler>;
using SessionRequestHandlerPtr = std::shared_ptr<SessionRequestHandler>; using SessionRequestHandlerPtr = std::shared_ptr<SessionRequestHandler>;
/** /**
@ -139,7 +140,7 @@ public:
} }
} }
SessionRequestHandlerPtr DomainHandler(std::size_t index) const { SessionRequestHandlerWeakPtr DomainHandler(std::size_t index) const {
ASSERT_MSG(index < DomainHandlerCount(), "Unexpected handler index {}", index); ASSERT_MSG(index < DomainHandlerCount(), "Unexpected handler index {}", index);
return domain_handlers.at(index); return domain_handlers.at(index);
} }
@ -328,10 +329,10 @@ public:
template <typename T> template <typename T>
std::shared_ptr<T> GetDomainHandler(std::size_t index) const { std::shared_ptr<T> GetDomainHandler(std::size_t index) const {
return std::static_pointer_cast<T>(manager->DomainHandler(index)); return std::static_pointer_cast<T>(manager.lock()->DomainHandler(index).lock());
} }
void SetSessionRequestManager(std::shared_ptr<SessionRequestManager> manager_) { void SetSessionRequestManager(std::weak_ptr<SessionRequestManager> manager_) {
manager = std::move(manager_); manager = std::move(manager_);
} }
@ -374,7 +375,7 @@ private:
u32 handles_offset{}; u32 handles_offset{};
u32 domain_offset{}; u32 domain_offset{};
std::shared_ptr<SessionRequestManager> manager; std::weak_ptr<SessionRequestManager> manager;
KernelCore& kernel; KernelCore& kernel;
Core::Memory::Memory& memory; Core::Memory::Memory& memory;

View file

@ -57,7 +57,12 @@ ResultCode KPort::EnqueueSession(KServerSession* session) {
R_UNLESS(state == State::Normal, ResultPortClosed); R_UNLESS(state == State::Normal, ResultPortClosed);
server.EnqueueSession(session); server.EnqueueSession(session);
server.GetSessionRequestHandler()->ClientConnected(server.AcceptSession());
if (auto session_ptr = server.GetSessionRequestHandler().lock()) {
session_ptr->ClientConnected(server.AcceptSession());
} else {
UNREACHABLE();
}
return ResultSuccess; return ResultSuccess;
} }

View file

@ -30,11 +30,11 @@ public:
/// Whether or not this server port has an HLE handler available. /// Whether or not this server port has an HLE handler available.
bool HasSessionRequestHandler() const { bool HasSessionRequestHandler() const {
return session_handler != nullptr; return !session_handler.expired();
} }
/// Gets the HLE handler for this port. /// Gets the HLE handler for this port.
SessionRequestHandlerPtr GetSessionRequestHandler() const { SessionRequestHandlerWeakPtr GetSessionRequestHandler() const {
return session_handler; return session_handler;
} }
@ -42,7 +42,7 @@ public:
* Sets the HLE handler template for the port. ServerSessions crated by connecting to this port * Sets the HLE handler template for the port. ServerSessions crated by connecting to this port
* will inherit a reference to this handler. * will inherit a reference to this handler.
*/ */
void SetSessionHandler(SessionRequestHandlerPtr&& handler) { void SetSessionHandler(SessionRequestHandlerWeakPtr&& handler) {
session_handler = std::move(handler); session_handler = std::move(handler);
} }
@ -66,7 +66,7 @@ private:
void CleanupSessions(); void CleanupSessions();
SessionList session_list; SessionList session_list;
SessionRequestHandlerPtr session_handler; SessionRequestHandlerWeakPtr session_handler;
KPort* parent{}; KPort* parent{};
}; };

View file

@ -98,7 +98,12 @@ ResultCode KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& co
UNREACHABLE(); UNREACHABLE();
return ResultSuccess; // Ignore error if asserts are off return ResultSuccess; // Ignore error if asserts are off
} }
return manager->DomainHandler(object_id - 1)->HandleSyncRequest(*this, context); if (auto strong_ptr = manager->DomainHandler(object_id - 1).lock()) {
return strong_ptr->HandleSyncRequest(*this, context);
} else {
UNREACHABLE();
return ResultSuccess;
}
case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: { case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: {
LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id); LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id);

View file

@ -980,7 +980,7 @@ private:
LOG_DEBUG(Service_AM, "called"); LOG_DEBUG(Service_AM, "called");
IPC::RequestParser rp{ctx}; IPC::RequestParser rp{ctx};
applet->GetBroker().PushNormalDataFromGame(rp.PopIpcInterface<IStorage>()); applet->GetBroker().PushNormalDataFromGame(rp.PopIpcInterface<IStorage>().lock());
IPC::ResponseBuilder rb{ctx, 2}; IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ResultSuccess); rb.Push(ResultSuccess);
@ -1007,7 +1007,7 @@ private:
LOG_DEBUG(Service_AM, "called"); LOG_DEBUG(Service_AM, "called");
IPC::RequestParser rp{ctx}; IPC::RequestParser rp{ctx};
applet->GetBroker().PushInteractiveDataFromGame(rp.PopIpcInterface<IStorage>()); applet->GetBroker().PushInteractiveDataFromGame(rp.PopIpcInterface<IStorage>().lock());
ASSERT(applet->IsInitialized()); ASSERT(applet->IsInitialized());
applet->ExecuteInteractive(); applet->ExecuteInteractive();