hle: Improve safety (#2778)
* timezone: Make timezone implementation safe * hle: Do not use TrimEnd to parse ASCII strings This adds an util that handle reading an ASCII string in a safe way. Previously it was possible to read malformed data that could cause various undefined behaviours in multiple services. * hid: Remove an useless unsafe modifier on keyboard update * Address gdkchan's comment * Address gdkchan's comment
This commit is contained in:
parent
b4dc33efc2
commit
51fa1b2cb0
9 changed files with 141 additions and 172 deletions
|
@ -1,3 +1,4 @@
|
|||
using Ryujinx.HLE.Utilities;
|
||||
using System.IO;
|
||||
using System.Text;
|
||||
|
||||
|
@ -30,10 +31,10 @@ namespace Ryujinx.HLE.FileSystem.Content
|
|||
|
||||
reader.ReadBytes(2); // Padding
|
||||
|
||||
PlatformString = Encoding.ASCII.GetString(reader.ReadBytes(0x20)).TrimEnd('\0');
|
||||
Hex = Encoding.ASCII.GetString(reader.ReadBytes(0x40)).TrimEnd('\0');
|
||||
VersionString = Encoding.ASCII.GetString(reader.ReadBytes(0x18)).TrimEnd('\0');
|
||||
VersionTitle = Encoding.ASCII.GetString(reader.ReadBytes(0x80)).TrimEnd('\0');
|
||||
PlatformString = StringUtils.ReadInlinedAsciiString(reader, 0x20);
|
||||
Hex = StringUtils.ReadInlinedAsciiString(reader, 0x40);
|
||||
VersionString = StringUtils.ReadInlinedAsciiString(reader, 0x18);
|
||||
VersionTitle = StringUtils.ReadInlinedAsciiString(reader, 0x80);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ namespace Ryujinx.HLE.HOS.Services.Hid
|
|||
{
|
||||
public KeyboardDevice(Switch device, bool active) : base(device, active) { }
|
||||
|
||||
public unsafe void Update(KeyboardInput keyState)
|
||||
public void Update(KeyboardInput keyState)
|
||||
{
|
||||
ref RingLifo<KeyboardState> lifo = ref _device.Hid.SharedMemory.Keyboard;
|
||||
|
||||
|
|
|
@ -218,11 +218,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
|
|||
|
||||
private ResultCode GetHostByNameRequestImpl(ServiceCtx context, ulong inputBufferPosition, ulong inputBufferSize, ulong outputBufferPosition, ulong outputBufferSize, ulong optionsBufferPosition, ulong optionsBufferSize)
|
||||
{
|
||||
byte[] rawName = new byte[inputBufferSize];
|
||||
|
||||
context.Memory.Read(inputBufferPosition, rawName);
|
||||
|
||||
string name = Encoding.ASCII.GetString(rawName).TrimEnd('\0');
|
||||
string name = MemoryHelper.ReadAsciiString(context.Memory, inputBufferPosition, (int)inputBufferSize);
|
||||
|
||||
// TODO: Use params.
|
||||
bool enableNsdResolve = (context.RequestData.ReadInt32() & 1) != 0;
|
||||
|
|
|
@ -116,7 +116,7 @@ namespace Ryujinx.HLE.HOS.Services.Time
|
|||
// SetupTimeZoneManager(nn::time::LocationName location_name, nn::time::SteadyClockTimePoint timezone_update_timepoint, u32 total_location_name_count, nn::time::TimeZoneRuleVersion timezone_rule_version, buffer<nn::time::TimeZoneBinary, 0x21> timezone_binary)
|
||||
public ResultCode SetupTimeZoneManager(ServiceCtx context)
|
||||
{
|
||||
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
|
||||
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
|
||||
SteadyClockTimePoint timeZoneUpdateTimePoint = context.RequestData.ReadStruct<SteadyClockTimePoint>();
|
||||
uint totalLocationNameCount = context.RequestData.ReadUInt32();
|
||||
UInt128 timeZoneRuleVersion = context.RequestData.ReadStruct<UInt128>();
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using Ryujinx.Common.Logging;
|
||||
using Ryujinx.Cpu;
|
||||
using Ryujinx.HLE.HOS.Services.Time.TimeZone;
|
||||
using Ryujinx.HLE.Utilities;
|
||||
using System;
|
||||
using System.Text;
|
||||
|
||||
|
@ -35,7 +36,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.StaticService
|
|||
return ResultCode.PermissionDenied;
|
||||
}
|
||||
|
||||
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
|
||||
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
|
||||
|
||||
return _timeZoneContentManager.SetDeviceLocationName(locationName);
|
||||
}
|
||||
|
@ -97,7 +98,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.StaticService
|
|||
throw new InvalidOperationException();
|
||||
}
|
||||
|
||||
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
|
||||
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
|
||||
|
||||
ResultCode resultCode = _timeZoneContentManager.LoadTimeZoneRule(out TimeZoneRule rules, locationName);
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.StaticService
|
|||
|
||||
(ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x21();
|
||||
|
||||
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
|
||||
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
|
||||
|
||||
ResultCode result;
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
using Ryujinx.Common.Utilities;
|
||||
using Ryujinx.HLE.Utilities;
|
||||
using System;
|
||||
using System.Buffers.Binary;
|
||||
using System.IO;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
|
@ -107,40 +108,24 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
public int TransitionTime;
|
||||
}
|
||||
|
||||
private static int Detzcode32(byte[] bytes)
|
||||
private static int Detzcode32(ReadOnlySpan<byte> bytes)
|
||||
{
|
||||
return BinaryPrimitives.ReadInt32BigEndian(bytes);
|
||||
}
|
||||
|
||||
private static int Detzcode32(int value)
|
||||
{
|
||||
if (BitConverter.IsLittleEndian)
|
||||
{
|
||||
Array.Reverse(bytes, 0, bytes.Length);
|
||||
return BinaryPrimitives.ReverseEndianness(value);
|
||||
}
|
||||
|
||||
return BitConverter.ToInt32(bytes, 0);
|
||||
return value;
|
||||
}
|
||||
|
||||
private static unsafe int Detzcode32(int* data)
|
||||
private static long Detzcode64(ReadOnlySpan<byte> bytes)
|
||||
{
|
||||
int result = *data;
|
||||
if (BitConverter.IsLittleEndian)
|
||||
{
|
||||
byte[] bytes = BitConverter.GetBytes(result);
|
||||
Array.Reverse(bytes, 0, bytes.Length);
|
||||
result = BitConverter.ToInt32(bytes, 0);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static unsafe long Detzcode64(long* data)
|
||||
{
|
||||
long result = *data;
|
||||
if (BitConverter.IsLittleEndian)
|
||||
{
|
||||
byte[] bytes = BitConverter.GetBytes(result);
|
||||
Array.Reverse(bytes, 0, bytes.Length);
|
||||
result = BitConverter.ToInt64(bytes, 0);
|
||||
}
|
||||
|
||||
return result;
|
||||
return BinaryPrimitives.ReadInt64BigEndian(bytes);
|
||||
}
|
||||
|
||||
private static bool DifferByRepeat(long t1, long t0)
|
||||
|
@ -148,7 +133,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
return (t1 - t0) == SecondsPerRepeat;
|
||||
}
|
||||
|
||||
private static unsafe bool TimeTypeEquals(TimeZoneRule outRules, byte aIndex, byte bIndex)
|
||||
private static bool TimeTypeEquals(TimeZoneRule outRules, byte aIndex, byte bIndex)
|
||||
{
|
||||
if (aIndex < 0 || aIndex >= outRules.TypeCount || bIndex < 0 || bIndex >= outRules.TypeCount)
|
||||
{
|
||||
|
@ -158,17 +143,14 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
TimeTypeInfo a = outRules.Ttis[aIndex];
|
||||
TimeTypeInfo b = outRules.Ttis[bIndex];
|
||||
|
||||
fixed (char* chars = outRules.Chars)
|
||||
{
|
||||
return a.GmtOffset == b.GmtOffset &&
|
||||
a.IsDaySavingTime == b.IsDaySavingTime &&
|
||||
a.IsStandardTimeDaylight == b.IsStandardTimeDaylight &&
|
||||
a.IsGMT == b.IsGMT &&
|
||||
StringUtils.CompareCStr(chars + a.AbbreviationListIndex, chars + b.AbbreviationListIndex) == 0;
|
||||
}
|
||||
StringUtils.CompareCStr(outRules.Chars[a.AbbreviationListIndex..], outRules.Chars[b.AbbreviationListIndex..]) == 0;
|
||||
}
|
||||
|
||||
private static int GetQZName(char[] name, int namePosition, char delimiter)
|
||||
private static int GetQZName(ReadOnlySpan<char> name, int namePosition, char delimiter)
|
||||
{
|
||||
int i = namePosition;
|
||||
|
||||
|
@ -403,7 +385,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
return 0;
|
||||
}
|
||||
|
||||
private static bool ParsePosixName(Span<char> name, out TimeZoneRule outRules, bool lastDitch)
|
||||
private static bool ParsePosixName(ReadOnlySpan<char> name, out TimeZoneRule outRules, bool lastDitch)
|
||||
{
|
||||
outRules = new TimeZoneRule
|
||||
{
|
||||
|
@ -414,7 +396,8 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
};
|
||||
|
||||
int stdLen;
|
||||
Span<char> stdName = name;
|
||||
|
||||
ReadOnlySpan<char> stdName = name;
|
||||
int namePosition = 0;
|
||||
int stdOffset = 0;
|
||||
|
||||
|
@ -433,7 +416,8 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
|
||||
int stdNamePosition = namePosition;
|
||||
|
||||
namePosition = GetQZName(name.ToArray(), namePosition, '>');
|
||||
namePosition = GetQZName(name, namePosition, '>');
|
||||
|
||||
if (name[namePosition] != '>')
|
||||
{
|
||||
return false;
|
||||
|
@ -465,7 +449,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
int destLen = 0;
|
||||
int dstOffset = 0;
|
||||
|
||||
Span<char> destName = name.Slice(namePosition);
|
||||
ReadOnlySpan<char> destName = name.Slice(namePosition);
|
||||
|
||||
if (TzCharsArraySize < charCount)
|
||||
{
|
||||
|
@ -903,7 +887,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
return ParsePosixName(name.ToCharArray(), out outRules, false);
|
||||
}
|
||||
|
||||
internal static unsafe bool ParseTimeZoneBinary(out TimeZoneRule outRules, Stream inputData)
|
||||
internal static bool ParseTimeZoneBinary(out TimeZoneRule outRules, Stream inputData)
|
||||
{
|
||||
outRules = new TimeZoneRule
|
||||
{
|
||||
|
@ -967,12 +951,11 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
|
||||
timeCount = 0;
|
||||
|
||||
fixed (byte* workBufferPtrStart = workBuffer)
|
||||
{
|
||||
byte* p = workBufferPtrStart;
|
||||
Span<byte> p = workBuffer;
|
||||
for (int i = 0; i < outRules.TimeCount; i++)
|
||||
{
|
||||
long at = Detzcode64((long*)p);
|
||||
long at = Detzcode64(p);
|
||||
outRules.Types[i] = 1;
|
||||
|
||||
if (timeCount != 0 && at <= outRules.Ats[timeCount - 1])
|
||||
|
@ -988,13 +971,15 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
|
||||
outRules.Ats[timeCount++] = at;
|
||||
|
||||
p += TimeTypeSize;
|
||||
p = p[TimeTypeSize..];
|
||||
}
|
||||
|
||||
timeCount = 0;
|
||||
for (int i = 0; i < outRules.TimeCount; i++)
|
||||
{
|
||||
byte type = *p++;
|
||||
byte type = p[0];
|
||||
p = p[1..];
|
||||
|
||||
if (outRules.TypeCount <= type)
|
||||
{
|
||||
return false;
|
||||
|
@ -1011,18 +996,20 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
for (int i = 0; i < outRules.TypeCount; i++)
|
||||
{
|
||||
TimeTypeInfo ttis = outRules.Ttis[i];
|
||||
ttis.GmtOffset = Detzcode32((int*)p);
|
||||
p += 4;
|
||||
ttis.GmtOffset = Detzcode32(p);
|
||||
p = p[sizeof(int)..];
|
||||
|
||||
if (*p >= 2)
|
||||
if (p[0] >= 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
ttis.IsDaySavingTime = *p != 0;
|
||||
p++;
|
||||
ttis.IsDaySavingTime = p[0] != 0;
|
||||
p = p[1..];
|
||||
|
||||
int abbreviationListIndex = p[0];
|
||||
p = p[1..];
|
||||
|
||||
int abbreviationListIndex = *p++;
|
||||
if (abbreviationListIndex >= outRules.CharCount)
|
||||
{
|
||||
return false;
|
||||
|
@ -1033,12 +1020,9 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
outRules.Ttis[i] = ttis;
|
||||
}
|
||||
|
||||
fixed (char* chars = outRules.Chars)
|
||||
{
|
||||
Encoding.ASCII.GetChars(p, outRules.CharCount, chars, outRules.CharCount);
|
||||
}
|
||||
Encoding.ASCII.GetChars(p[..outRules.CharCount].ToArray()).CopyTo(outRules.Chars.AsSpan());
|
||||
|
||||
p += outRules.CharCount;
|
||||
p = p[outRules.CharCount..];
|
||||
outRules.Chars[outRules.CharCount] = '\0';
|
||||
|
||||
for (int i = 0; i < outRules.TypeCount; i++)
|
||||
|
@ -1049,14 +1033,14 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
}
|
||||
else
|
||||
{
|
||||
if (*p >= 2)
|
||||
if (p[0] >= 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
outRules.Ttis[i].IsStandardTimeDaylight = *p++ != 0;
|
||||
outRules.Ttis[i].IsStandardTimeDaylight = p[0] != 0;
|
||||
p = p[1..];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int i = 0; i < outRules.TypeCount; i++)
|
||||
|
@ -1067,17 +1051,18 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
}
|
||||
else
|
||||
{
|
||||
if (*p >= 2)
|
||||
if (p[0] >= 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
outRules.Ttis[i].IsGMT = *p++ != 0;
|
||||
outRules.Ttis[i].IsGMT = p[0] != 0;
|
||||
p = p[1..];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
long position = (p - workBufferPtrStart);
|
||||
long position = (workBuffer.Length - p.Length);
|
||||
long nRead = streamLength - position;
|
||||
|
||||
if (nRead < 0)
|
||||
|
@ -1107,18 +1092,18 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
int abbreviationCount = 0;
|
||||
charCount = outRules.CharCount;
|
||||
|
||||
fixed (char* chars = outRules.Chars)
|
||||
{
|
||||
Span<char> chars = outRules.Chars;
|
||||
|
||||
for (int i = 0; i < tempRules.TypeCount; i++)
|
||||
{
|
||||
fixed (char* tempChars = tempRules.Chars)
|
||||
{
|
||||
char* tempAbbreviation = tempChars + tempRules.Ttis[i].AbbreviationListIndex;
|
||||
ReadOnlySpan<char> tempChars = tempRules.Chars;
|
||||
ReadOnlySpan<char> tempAbbreviation = tempChars[tempRules.Ttis[i].AbbreviationListIndex..];
|
||||
|
||||
int j;
|
||||
|
||||
for (j = 0; j < charCount; j++)
|
||||
{
|
||||
if (StringUtils.CompareCStr(chars + j, tempAbbreviation) == 0)
|
||||
if (StringUtils.CompareCStr(chars[j..], tempAbbreviation) == 0)
|
||||
{
|
||||
tempRules.Ttis[i].AbbreviationListIndex = j;
|
||||
abbreviationCount++;
|
||||
|
@ -1143,7 +1128,6 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (abbreviationCount == tempRules.TypeCount)
|
||||
{
|
||||
|
@ -1181,7 +1165,6 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outRules.TypeCount == 0)
|
||||
{
|
||||
|
@ -1467,17 +1450,11 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
|||
{
|
||||
calendarAdditionalInfo.IsDaySavingTime = rules.Ttis[ttiIndex].IsDaySavingTime;
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed (char* timeZoneAbbreviation = &rules.Chars[rules.Ttis[ttiIndex].AbbreviationListIndex])
|
||||
{
|
||||
ReadOnlySpan<char> timeZoneAbbreviation = rules.Chars.AsSpan()[rules.Ttis[ttiIndex].AbbreviationListIndex..];
|
||||
|
||||
int timeZoneSize = Math.Min(StringUtils.LengthCstr(timeZoneAbbreviation), 8);
|
||||
for (int i = 0; i < timeZoneSize; i++)
|
||||
{
|
||||
calendarAdditionalInfo.TimezoneName[i] = timeZoneAbbreviation[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timeZoneAbbreviation[..timeZoneSize].CopyTo(calendarAdditionalInfo.TimezoneName.AsSpan());
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
|
@ -1,34 +1,19 @@
|
|||
using System.Runtime.InteropServices;
|
||||
using Ryujinx.Common.Memory;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
|
||||
{
|
||||
[StructLayout(LayoutKind.Sequential, Pack = 0x4, Size = 0x2C)]
|
||||
struct TzifHeader
|
||||
{
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public char[] Magic;
|
||||
|
||||
public char Version;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 15)]
|
||||
public byte[] Reserved;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public byte[] TtisGMTCount;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public byte[] TtisSTDCount;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public byte[] LeapCount;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public byte[] TimeCount;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public byte[] TypeCount;
|
||||
|
||||
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
|
||||
public byte[] CharCount;
|
||||
public Array4<byte> Magic;
|
||||
public byte Version;
|
||||
private Array15<byte> _reserved;
|
||||
public int TtisGMTCount;
|
||||
public int TtisSTDCount;
|
||||
public int LeapCount;
|
||||
public int TimeCount;
|
||||
public int TypeCount;
|
||||
public int CharCount;
|
||||
}
|
||||
}
|
|
@ -36,6 +36,15 @@ namespace Ryujinx.HLE.Utilities
|
|||
return output;
|
||||
}
|
||||
|
||||
public static string ReadInlinedAsciiString(BinaryReader reader, int maxSize)
|
||||
{
|
||||
byte[] data = reader.ReadBytes(maxSize);
|
||||
|
||||
int stringSize = Array.IndexOf<byte>(data, 0);
|
||||
|
||||
return Encoding.ASCII.GetString(data, 0, stringSize < 0 ? maxSize : stringSize);
|
||||
}
|
||||
|
||||
public static byte[] HexToBytes(string hexString)
|
||||
{
|
||||
// Ignore last character if HexLength % 2 != 0.
|
||||
|
@ -107,7 +116,7 @@ namespace Ryujinx.HLE.Utilities
|
|||
}
|
||||
}
|
||||
|
||||
public static unsafe int CompareCStr(char* s1, char* s2)
|
||||
public static int CompareCStr(ReadOnlySpan<char> s1, ReadOnlySpan<char> s2)
|
||||
{
|
||||
int s1Index = 0;
|
||||
int s2Index = 0;
|
||||
|
@ -121,7 +130,7 @@ namespace Ryujinx.HLE.Utilities
|
|||
return s2[s2Index] - s1[s1Index];
|
||||
}
|
||||
|
||||
public static unsafe int LengthCstr(char* s)
|
||||
public static int LengthCstr(ReadOnlySpan<char> s)
|
||||
{
|
||||
int i = 0;
|
||||
|
||||
|
|
Loading…
Reference in a new issue