From 532b8cad139ef611d7583788d05f3b35e45ee5a6 Mon Sep 17 00:00:00 2001 From: gdkchan Date: Fri, 1 Jan 2021 14:59:26 -0300 Subject: [PATCH] Update KAddressArbiter implementation to 11.x kernel (#1851) * Update KAddressArbiter implementation to 11.x kernel * InsertSortedByPriority is no longer needed --- .../HOS/Kernel/Threading/KAddressArbiter.cs | 152 +++++++----------- 1 file changed, 60 insertions(+), 92 deletions(-) diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs index 3ddcffc1b..3fd07f90d 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs @@ -1,5 +1,6 @@ using Ryujinx.HLE.HOS.Kernel.Common; using Ryujinx.HLE.HOS.Kernel.Process; +using System; using System.Collections.Generic; using System.Linq; using System.Threading; @@ -83,7 +84,14 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading KThread currentThread = KernelStatic.GetCurrentThread(); - (KernelResult result, KThread newOwnerThread) = MutexUnlock(currentThread, mutexAddress); + (int mutexValue, KThread newOwnerThread) = MutexUnlock(currentThread, mutexAddress); + + KernelResult result = KernelResult.Success; + + if (!KernelTransfer.KernelToUserInt32(_context, mutexAddress, mutexValue)) + { + result = KernelResult.InvalidMemState; + } if (result != KernelResult.Success && newOwnerThread != null) { @@ -96,11 +104,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return result; } - public KernelResult WaitProcessWideKeyAtomic( - ulong mutexAddress, - ulong condVarAddress, - int threadHandle, - long timeout) + public KernelResult WaitProcessWideKeyAtomic(ulong mutexAddress, ulong condVarAddress, int threadHandle, long timeout) { _context.CriticalSection.Enter(); @@ -117,13 +121,15 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return KernelResult.ThreadTerminating; } - (KernelResult result, _) = MutexUnlock(currentThread, mutexAddress); + (int mutexValue, _) = MutexUnlock(currentThread, mutexAddress); - if (result != KernelResult.Success) + KernelTransfer.KernelToUserInt32(_context, condVarAddress, 1); + + if (!KernelTransfer.KernelToUserInt32(_context, mutexAddress, mutexValue)) { _context.CriticalSection.Leave(); - return result; + return KernelResult.InvalidMemState; } currentThread.MutexAddress = mutexAddress; @@ -163,7 +169,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return currentThread.ObjSyncResult; } - private (KernelResult, KThread) MutexUnlock(KThread currentThread, ulong mutexAddress) + private (int, KThread) MutexUnlock(KThread currentThread, ulong mutexAddress) { KThread newOwnerThread = currentThread.RelinquishMutex(mutexAddress, out int count); @@ -184,46 +190,24 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading newOwnerThread.ReleaseAndResume(); } - KernelResult result = KernelResult.Success; - - if (!KernelTransfer.KernelToUserInt32(_context, mutexAddress, mutexValue)) - { - result = KernelResult.InvalidMemState; - } - - return (result, newOwnerThread); + return (mutexValue, newOwnerThread); } public void SignalProcessWideKey(ulong address, int count) { - Queue signaledThreads = new Queue(); - _context.CriticalSection.Enter(); - IOrderedEnumerable sortedThreads = _condVarThreads.OrderBy(x => x.DynamicPriority); + WakeThreads(_condVarThreads, count, TryAcquireMutex, x => x.CondVarAddress == address); - foreach (KThread thread in sortedThreads.Where(x => x.CondVarAddress == address)) + if (!_condVarThreads.Any(x => x.CondVarAddress == address)) { - TryAcquireMutex(thread); - - signaledThreads.Enqueue(thread); - - // If the count is <= 0, we should signal all threads waiting. - if (count >= 1 && --count == 0) - { - break; - } - } - - while (signaledThreads.TryDequeue(out KThread thread)) - { - _condVarThreads.Remove(thread); + KernelTransfer.KernelToUserInt32(_context, address, 0); } _context.CriticalSection.Leave(); } - private KThread TryAcquireMutex(KThread requester) + private static void TryAcquireMutex(KThread requester) { ulong address = requester.MutexAddress; @@ -235,7 +219,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading requester.SignaledObj = null; requester.ObjSyncResult = KernelResult.InvalidMemState; - return null; + return; } ref int mutexRef = ref currentProcess.CpuMemory.GetRef(address); @@ -267,7 +251,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading requester.ReleaseAndResume(); - return null; + return; } mutexValue &= ~HasListenersMask; @@ -287,8 +271,6 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading requester.ReleaseAndResume(); } - - return mutexOwner; } public KernelResult WaitForAddressIfEqual(ulong address, int value, long timeout) @@ -327,7 +309,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading currentThread.MutexAddress = address; currentThread.WaitingInArbitration = true; - InsertSortedByPriority(_arbiterThreads, currentThread); + _arbiterThreads.Add(currentThread); currentThread.Reschedule(ThreadSchedState.Paused); @@ -362,11 +344,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return KernelResult.InvalidState; } - public KernelResult WaitForAddressIfLessThan( - ulong address, - int value, - bool shouldDecrement, - long timeout) + public KernelResult WaitForAddressIfLessThan(ulong address, int value, bool shouldDecrement, long timeout) { KThread currentThread = KernelStatic.GetCurrentThread(); @@ -409,7 +387,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading currentThread.MutexAddress = address; currentThread.WaitingInArbitration = true; - InsertSortedByPriority(_arbiterThreads, currentThread); + _arbiterThreads.Add(currentThread); currentThread.Reschedule(ThreadSchedState.Paused); @@ -444,30 +422,6 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return KernelResult.InvalidState; } - private void InsertSortedByPriority(List threads, KThread thread) - { - int nextIndex = -1; - - for (int index = 0; index < threads.Count; index++) - { - if (threads[index].DynamicPriority > thread.DynamicPriority) - { - nextIndex = index; - - break; - } - } - - if (nextIndex != -1) - { - threads.Insert(nextIndex, thread); - } - else - { - threads.Add(thread); - } - } - public KernelResult Signal(ulong address, int count) { _context.CriticalSection.Enter(); @@ -520,7 +474,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading { _context.CriticalSection.Enter(); - int offset; + int addend; // The value is decremented if the number of threads waiting is less // or equal to the Count of threads to be signaled, or Count is zero @@ -529,7 +483,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading foreach (KThread thread in _arbiterThreads.Where(x => x.MutexAddress == address)) { - if (++waitingCount > count) + if (++waitingCount >= count) { break; } @@ -537,11 +491,22 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading if (waitingCount > 0) { - offset = waitingCount <= count || count <= 0 ? -1 : 0; + if (count <= 0) + { + addend = -2; + } + else if (waitingCount < count) + { + addend = -1; + } + else + { + addend = 0; + } } else { - offset = 1; + addend = 1; } KProcess currentProcess = KernelStatic.GetCurrentProcess(); @@ -568,7 +533,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return KernelResult.InvalidState; } } - while (Interlocked.CompareExchange(ref valueRef, currentValue + offset, currentValue) != currentValue); + while (Interlocked.CompareExchange(ref valueRef, currentValue + addend, currentValue) != currentValue); WakeArbiterThreads(address, count); @@ -579,20 +544,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading private void WakeArbiterThreads(ulong address, int count) { - Queue signaledThreads = new Queue(); - - foreach (KThread thread in _arbiterThreads.Where(x => x.MutexAddress == address)) - { - signaledThreads.Enqueue(thread); - - // If the count is <= 0, we should signal all threads waiting. - if (count >= 1 && --count == 0) - { - break; - } - } - - while (signaledThreads.TryDequeue(out KThread thread)) + static void RemoveArbiterThread(KThread thread) { thread.SignaledObj = null; thread.ObjSyncResult = KernelResult.Success; @@ -600,8 +552,24 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading thread.ReleaseAndResume(); thread.WaitingInArbitration = false; + } - _arbiterThreads.Remove(thread); + WakeThreads(_arbiterThreads, count, RemoveArbiterThread, x => x.MutexAddress == address); + } + + private static void WakeThreads( + List threads, + int count, + Action removeCallback, + Func predicate) + { + var candidates = threads.Where(predicate).OrderBy(x => x.DynamicPriority); + var toSignal = (count > 0 ? candidates.Take(count) : candidates).ToArray(); + + foreach (KThread thread in toSignal) + { + removeCallback(thread); + threads.Remove(thread); } } }