bsd: implement SendMMsg and RecvMMsg (#3660)

* bsd: implement sendmmsg and recvmmsg

* Fix wrong increment of vlen
This commit is contained in:
Mary-nyan 2022-09-07 22:37:15 +02:00 committed by GitHub
parent 51bb8707ef
commit f3835dc78b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 528 additions and 0 deletions

View file

@ -886,6 +886,91 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return WriteBsdResult(context, newSockFd, errno);
}
[CommandHipc(29)] // 7.0.0+
// RecvMMsg(u32 fd, u32 vlen, u32 flags, u32 reserved, nn::socket::TimeVal timeout) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
public ResultCode RecvMMsg(ServiceCtx context)
{
int socketFd = context.RequestData.ReadInt32();
int vlen = context.RequestData.ReadInt32();
BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
uint reserved = context.RequestData.ReadUInt32();
TimeVal timeout = context.RequestData.ReadStruct<TimeVal>();
ulong receivePosition = context.Request.ReceiveBuff[0].Position;
ulong receiveLength = context.Request.ReceiveBuff[0].Size;
WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
LinuxError errno = LinuxError.EBADF;
ISocket socket = _context.RetrieveSocket(socketFd);
int result = -1;
if (socket != null)
{
errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
if (errno == LinuxError.SUCCESS)
{
errno = socket.RecvMMsg(out result, message, socketFlags, timeout);
if (errno == LinuxError.SUCCESS)
{
errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
}
}
}
if (errno == LinuxError.SUCCESS)
{
SetResultErrno(socket, result);
receiveRegion.Dispose();
}
return WriteBsdResult(context, result, errno);
}
[CommandHipc(30)] // 7.0.0+
// SendMMsg(u32 fd, u32 vlen, u32 flags) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
public ResultCode SendMMsg(ServiceCtx context)
{
int socketFd = context.RequestData.ReadInt32();
int vlen = context.RequestData.ReadInt32();
BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
ulong receivePosition = context.Request.ReceiveBuff[0].Position;
ulong receiveLength = context.Request.ReceiveBuff[0].Size;
WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
LinuxError errno = LinuxError.EBADF;
ISocket socket = _context.RetrieveSocket(socketFd);
int result = -1;
if (socket != null)
{
errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
if (errno == LinuxError.SUCCESS)
{
errno = socket.SendMMsg(out result, message, socketFlags);
if (errno == LinuxError.SUCCESS)
{
errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
}
}
}
if (errno == LinuxError.SUCCESS)
{
SetResultErrno(socket, result);
receiveRegion.Dispose();
}
return WriteBsdResult(context, result, errno);
}
[CommandHipc(31)] // 7.0.0+
// EventFd(u64 initval, nn::socket::EventFdFlags flags) -> (i32 ret, u32 bsd_errno)
public ResultCode EventFd(ServiceCtx context)

View file

@ -25,7 +25,12 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
LinuxError SendTo(out int sendSize, ReadOnlySpan<byte> buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint);
LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout);
LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags);
LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span<byte> optionValue);
LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan<byte> optionValue);
bool Poll(int microSeconds, SelectMode mode);

View file

@ -1,5 +1,7 @@
using Ryujinx.Common.Logging;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
@ -356,5 +358,165 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
return Send(out writeSize, buffer, BsdSocketFlags.None);
}
private bool CanSupportMMsgHdr(BsdMMsgHdr message)
{
for (int i = 0; i < message.Messages.Length; i++)
{
if (message.Messages[i].Name != null ||
message.Messages[i].Control != null)
{
return false;
}
}
return true;
}
private static IList<ArraySegment<byte>> ConvertMessagesToBuffer(BsdMMsgHdr message)
{
int segmentCount = 0;
int index = 0;
foreach (BsdMsgHdr msgHeader in message.Messages)
{
segmentCount += msgHeader.Iov.Length;
}
ArraySegment<byte>[] buffers = new ArraySegment<byte>[segmentCount];
foreach (BsdMsgHdr msgHeader in message.Messages)
{
foreach (byte[] iov in msgHeader.Iov)
{
buffers[index++] = new ArraySegment<byte>(iov);
}
// Clear the length
msgHeader.Length = 0;
}
return buffers;
}
private static void UpdateMessages(out int vlen, BsdMMsgHdr message, int transferedSize)
{
int bytesLeft = transferedSize;
int index = 0;
while (bytesLeft > 0)
{
// First ensure we haven't finished all buffers
if (index >= message.Messages.Length)
{
break;
}
BsdMsgHdr msgHeader = message.Messages[index];
int possiblyTransferedBytes = 0;
foreach (byte[] iov in msgHeader.Iov)
{
possiblyTransferedBytes += iov.Length;
}
int storedBytes;
if (bytesLeft > possiblyTransferedBytes)
{
storedBytes = possiblyTransferedBytes;
index++;
}
else
{
storedBytes = bytesLeft;
}
msgHeader.Length = (uint)storedBytes;
bytesLeft -= storedBytes;
}
Debug.Assert(bytesLeft == 0);
vlen = index + 1;
}
// TODO: Find a way to support passing the timeout somehow without changing the socket ReceiveTimeout.
public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout)
{
vlen = 0;
if (message.Messages.Length == 0)
{
return LinuxError.SUCCESS;
}
if (!CanSupportMMsgHdr(message))
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
return LinuxError.EOPNOTSUPP;
}
if (message.Messages.Length == 0)
{
return LinuxError.SUCCESS;
}
try
{
int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
if (receiveSize > 0)
{
UpdateMessages(out vlen, message, receiveSize);
}
return WinSockHelper.ConvertError((WsaError)socketError);
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags)
{
vlen = 0;
if (message.Messages.Length == 0)
{
return LinuxError.SUCCESS;
}
if (!CanSupportMMsgHdr(message))
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
return LinuxError.EOPNOTSUPP;
}
if (message.Messages.Length == 0)
{
return LinuxError.SUCCESS;
}
try
{
int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
if (sendSize > 0)
{
UpdateMessages(out vlen, message, sendSize);
}
return WinSockHelper.ConvertError((WsaError)socketError);
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
}
}

View file

@ -0,0 +1,56 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
class BsdMMsgHdr
{
public BsdMsgHdr[] Messages { get; }
private BsdMMsgHdr(BsdMsgHdr[] messages)
{
Messages = messages;
}
public static LinuxError Serialize(Span<byte> rawData, BsdMMsgHdr message)
{
rawData[0] = 0x8;
rawData = rawData[1..];
for (int index = 0; index < message.Messages.Length; index++)
{
LinuxError res = BsdMsgHdr.Serialize(ref rawData, message.Messages[index]);
if (res != LinuxError.SUCCESS)
{
return res;
}
}
return LinuxError.SUCCESS;
}
public static LinuxError Deserialize(out BsdMMsgHdr message, ReadOnlySpan<byte> rawData, int vlen)
{
message = null;
BsdMsgHdr[] messages = new BsdMsgHdr[vlen];
// Skip "header" byte (Nintendo also ignore it)
rawData = rawData[1..];
for (int index = 0; index < messages.Length; index++)
{
LinuxError res = BsdMsgHdr.Deserialize(out messages[index], ref rawData);
if (res != LinuxError.SUCCESS)
{
return res;
}
}
message = new BsdMMsgHdr(messages);
return LinuxError.SUCCESS;
}
}
}

View file

@ -0,0 +1,212 @@
using System;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
class BsdMsgHdr
{
public byte[] Name { get; }
public byte[][] Iov { get; }
public byte[] Control { get; }
public BsdSocketFlags Flags { get; }
public uint Length;
private BsdMsgHdr(byte[] name, byte[][] iov, byte[] control, BsdSocketFlags flags, uint length)
{
Name = name;
Iov = iov;
Control = control;
Flags = flags;
Length = length;
}
public static LinuxError Serialize(ref Span<byte> rawData, BsdMsgHdr message)
{
int msgNameLength = message.Name == null ? 0 : message.Name.Length;
int iovCount = message.Iov == null ? 0 : message.Iov.Length;
int controlLength = message.Control == null ? 0 : message.Control.Length;
BsdSocketFlags flags = message.Flags;
if (!MemoryMarshal.TryWrite(rawData, ref msgNameLength))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
if (msgNameLength > 0)
{
if (rawData.Length < msgNameLength)
{
return LinuxError.EFAULT;
}
message.Name.CopyTo(rawData);
rawData = rawData[msgNameLength..];
}
if (!MemoryMarshal.TryWrite(rawData, ref iovCount))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
if (iovCount > 0)
{
for (int index = 0; index < iovCount; index++)
{
ulong iovLength = (ulong)message.Iov[index].Length;
if (!MemoryMarshal.TryWrite(rawData, ref iovLength))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(ulong)..];
if (iovLength > 0)
{
if ((ulong)rawData.Length < iovLength)
{
return LinuxError.EFAULT;
}
message.Iov[index].CopyTo(rawData);
rawData = rawData[(int)iovLength..];
}
}
}
if (!MemoryMarshal.TryWrite(rawData, ref controlLength))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
if (controlLength > 0)
{
if (rawData.Length < controlLength)
{
return LinuxError.EFAULT;
}
message.Control.CopyTo(rawData);
rawData = rawData[controlLength..];
}
if (!MemoryMarshal.TryWrite(rawData, ref flags))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(BsdSocketFlags)..];
if (!MemoryMarshal.TryWrite(rawData, ref message.Length))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
return LinuxError.SUCCESS;
}
public static LinuxError Deserialize(out BsdMsgHdr message, ref ReadOnlySpan<byte> rawData)
{
byte[] name = null;
byte[][] iov = null;
byte[] control = null;
message = null;
if (!MemoryMarshal.TryRead(rawData, out uint msgNameLength))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
if (msgNameLength > 0)
{
if (rawData.Length < msgNameLength)
{
return LinuxError.EFAULT;
}
name = rawData[..(int)msgNameLength].ToArray();
rawData = rawData[(int)msgNameLength..];
}
if (!MemoryMarshal.TryRead(rawData, out uint iovCount))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
if (iovCount > 0)
{
iov = new byte[iovCount][];
for (int index = 0; index < iov.Length; index++)
{
if (!MemoryMarshal.TryRead(rawData, out ulong iovLength))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(ulong)..];
if (iovLength > 0)
{
if ((ulong)rawData.Length < iovLength)
{
return LinuxError.EFAULT;
}
iov[index] = rawData[..(int)iovLength].ToArray();
rawData = rawData[(int)iovLength..];
}
}
}
if (!MemoryMarshal.TryRead(rawData, out uint controlLength))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
if (controlLength > 0)
{
if (rawData.Length < controlLength)
{
return LinuxError.EFAULT;
}
control = rawData[..(int)controlLength].ToArray();
rawData = rawData[(int)controlLength..];
}
if (!MemoryMarshal.TryRead(rawData, out BsdSocketFlags flags))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(BsdSocketFlags)..];
if (!MemoryMarshal.TryRead(rawData, out uint length))
{
return LinuxError.EFAULT;
}
rawData = rawData[sizeof(uint)..];
message = new BsdMsgHdr(name, iov, control, flags, length);
return LinuxError.SUCCESS;
}
}
}

View file

@ -0,0 +1,8 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
public struct TimeVal
{
public ulong TvSec;
public ulong TvUsec;
}
}