audio_core\hle\adts_reader.cpp: Use BitField to parse ADTS header (#6719)

This commit is contained in:
SachinVin 2023-07-29 00:45:58 +05:30 committed by GitHub
parent 539a1a0b6e
commit 51996c54f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 161 additions and 61 deletions

View file

@ -5,20 +5,24 @@
#include "common/common_types.h" #include "common/common_types.h"
namespace AudioCore {
struct ADTSData { struct ADTSData {
u8 header_length; u8 header_length = 0;
bool MPEG2; bool mpeg2 = false;
u8 profile; u8 profile = 0;
u8 channels; u8 channels = 0;
u8 channel_idx; u8 channel_idx = 0;
u8 framecount; u8 framecount = 0;
u8 samplerate_idx; u8 samplerate_idx = 0;
u32 length; u32 length = 0;
u32 samplerate; u32 samplerate = 0;
}; };
ADTSData ParseADTS(const char* buffer); ADTSData ParseADTS(const u8* buffer);
// last two bytes of MF AAC decoder user data // last two bytes of MF AAC decoder user data
// see https://docs.microsoft.com/en-us/windows/desktop/medfound/aac-decoder#example-media-types // see https://docs.microsoft.com/en-us/windows/desktop/medfound/aac-decoder#example-media-types
u16 MFGetAACTag(const ADTSData& input); u16 MFGetAACTag(const ADTSData& input);
} // namespace AudioCore

View file

@ -3,44 +3,59 @@
// Refer to the license.txt file included. // Refer to the license.txt file included.
#include <array> #include <array>
#include "adts.h" #include "adts.h"
#include "common/bit_field.h"
namespace AudioCore {
constexpr std::array<u32, 16> freq_table = {96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, constexpr std::array<u32, 16> freq_table = {96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050,
16000, 12000, 11025, 8000, 7350, 0, 0, 0}; 16000, 12000, 11025, 8000, 7350, 0, 0, 0};
constexpr std::array<u8, 8> channel_table = {0, 1, 2, 3, 4, 5, 6, 8}; constexpr std::array<u8, 8> channel_table = {0, 1, 2, 3, 4, 5, 6, 8};
ADTSData ParseADTS(const char* buffer) { struct ADTSHeader {
u32 tmp = 0; union {
ADTSData out; std::array<u8, 7> raw{};
BitFieldBE<52, 12, u64> sync_word;
BitFieldBE<51, 1, u64> mpeg2;
BitFieldBE<49, 2, u64> layer;
BitFieldBE<48, 1, u64> protection_absent;
BitFieldBE<46, 2, u64> profile;
BitFieldBE<42, 4, u64> samplerate_idx;
BitFieldBE<41, 1, u64> private_bit;
BitFieldBE<38, 3, u64> channel_idx;
BitFieldBE<37, 1, u64> originality;
BitFieldBE<36, 1, u64> home;
BitFieldBE<35, 1, u64> copyright_id;
BitFieldBE<34, 1, u64> copyright_id_start;
BitFieldBE<21, 13, u64> frame_length;
BitFieldBE<10, 11, u64> buffer_fullness;
BitFieldBE<8, 2, u64> frame_count;
};
};
ADTSData ParseADTS(const u8* buffer) {
ADTSHeader header;
memcpy(header.raw.data(), buffer, sizeof(header.raw));
// sync word 0xfff // sync word 0xfff
tmp = (buffer[0] << 8) | (buffer[1] & 0xf0); if (header.sync_word != 0xfff) {
if ((tmp & 0xffff) != 0xfff0) { return {};
out.length = 0;
return out;
} }
ADTSData out{};
// bit 16 = no CRC // bit 16 = no CRC
out.header_length = (buffer[1] & 0x1) ? 7 : 9; out.header_length = header.protection_absent ? 7 : 9;
out.MPEG2 = (buffer[1] >> 3) & 0x1; out.mpeg2 = static_cast<bool>(header.mpeg2);
// bit 17 to 18 // bit 17 to 18
out.profile = (buffer[2] >> 6) + 1; out.profile = static_cast<u8>(header.profile) + 1;
// bit 19 to 22 // bit 19 to 22
tmp = (buffer[2] >> 2) & 0xf; out.samplerate_idx = static_cast<u8>(header.samplerate_idx);
out.samplerate_idx = tmp; out.samplerate = header.samplerate_idx > 15 ? 0 : freq_table[header.samplerate_idx];
out.samplerate = (tmp > 15) ? 0 : freq_table[tmp];
// bit 24 to 26 // bit 24 to 26
tmp = ((buffer[2] & 0x1) << 2) | ((buffer[3] >> 6) & 0x3); out.channel_idx = static_cast<u8>(header.channel_idx);
out.channel_idx = tmp; out.channels = (header.channel_idx > 7) ? 0 : channel_table[header.channel_idx];
out.channels = (tmp > 7) ? 0 : channel_table[tmp];
// bit 55 to 56 // bit 55 to 56
out.framecount = (buffer[6] & 0x3) + 1; out.framecount = static_cast<u8>(header.frame_count + 1);
// bit 31 to 43 // bit 31 to 43
tmp = (buffer[3] & 0x3) << 11; out.length = static_cast<u32>(header.frame_length);
tmp |= (buffer[4] << 3) & 0x7f8;
tmp |= (buffer[5] >> 5) & 0x7;
out.length = tmp;
return out; return out;
} }
@ -61,3 +76,4 @@ u16 MFGetAACTag(const ADTSData& input) {
return tag; return tag;
} }
} // namespace AudioCore

View file

@ -24,7 +24,7 @@ private:
std::optional<BinaryMessage> Decode(const BinaryMessage& request); std::optional<BinaryMessage> Decode(const BinaryMessage& request);
void Clear(); void Clear();
bool InitializeDecoder(ADTSData& adts_header); bool InitializeDecoder(AudioCore::ADTSData& adts_header);
static OSStatus DataFunc(AudioConverterRef in_audio_converter, u32* io_number_data_packets, static OSStatus DataFunc(AudioConverterRef in_audio_converter, u32* io_number_data_packets,
AudioBufferList* io_data, AudioBufferList* io_data,
@ -33,7 +33,7 @@ private:
Memory::MemorySystem& memory; Memory::MemorySystem& memory;
ADTSData adts_config; AudioCore::ADTSData adts_config;
AudioStreamBasicDescription output_format = {}; AudioStreamBasicDescription output_format = {};
AudioConverterRef converter = nullptr; AudioConverterRef converter = nullptr;
@ -101,7 +101,7 @@ std::optional<BinaryMessage> AudioToolboxDecoder::Impl::ProcessRequest(
} }
} }
bool AudioToolboxDecoder::Impl::InitializeDecoder(ADTSData& adts_header) { bool AudioToolboxDecoder::Impl::InitializeDecoder(AudioCore::ADTSData& adts_header) {
if (converter) { if (converter) {
if (adts_config.channels == adts_header.channels && if (adts_config.channels == adts_header.channels &&
adts_config.samplerate == adts_header.samplerate) { adts_config.samplerate == adts_header.samplerate) {
@ -183,8 +183,9 @@ std::optional<BinaryMessage> AudioToolboxDecoder::Impl::Decode(const BinaryMessa
return {}; return {};
} }
auto data = memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); const auto data =
auto adts_header = ParseADTS(reinterpret_cast<const char*>(data)); memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR);
auto adts_header = AudioCore::ParseADTS(data);
curr_data = data + adts_header.header_length; curr_data = data + adts_header.header_length;
curr_data_len = request.decode_aac_request.size - adts_header.header_length; curr_data_len = request.decode_aac_request.size - adts_header.header_length;

View file

@ -27,7 +27,7 @@ public:
~Impl(); ~Impl();
std::optional<BinaryMessage> ProcessRequest(const BinaryMessage& request); std::optional<BinaryMessage> ProcessRequest(const BinaryMessage& request);
bool SetMediaType(const ADTSData& adts_data); bool SetMediaType(const AudioCore::ADTSData& adts_data);
private: private:
std::optional<BinaryMessage> Initalize(const BinaryMessage& request); std::optional<BinaryMessage> Initalize(const BinaryMessage& request);
@ -36,8 +36,8 @@ private:
Memory::MemorySystem& memory; Memory::MemorySystem& memory;
std::unique_ptr<AMediaCodec, AMediaCodecRelease> decoder; std::unique_ptr<AMediaCodec, AMediaCodecRelease> decoder;
// default: 2 channles, 48000 samplerate // default: 2 channles, 48000 samplerate
ADTSData mADTSData{ AudioCore::ADTSData mADTSData{
/*header_length*/ 7, /*MPEG2*/ false, /*profile*/ 2, /*header_length*/ 7, /*mpeg2*/ false, /*profile*/ 2,
/*channels*/ 2, /*channel_idx*/ 2, /*framecount*/ 0, /*channels*/ 2, /*channel_idx*/ 2, /*framecount*/ 0,
/*samplerate_idx*/ 3, /*length*/ 0, /*samplerate*/ 48000}; /*samplerate_idx*/ 3, /*length*/ 0, /*samplerate*/ 48000};
}; };
@ -54,7 +54,7 @@ std::optional<BinaryMessage> MediaNDKDecoder::Impl::Initalize(const BinaryMessag
return response; return response;
} }
bool MediaNDKDecoder::Impl::SetMediaType(const ADTSData& adts_data) { bool MediaNDKDecoder::Impl::SetMediaType(const AudioCore::ADTSData& adts_data) {
const char* mime = "audio/mp4a-latm"; const char* mime = "audio/mp4a-latm";
if (decoder && mADTSData.profile == adts_data.profile && if (decoder && mADTSData.profile == adts_data.profile &&
mADTSData.channel_idx == adts_data.channel_idx && mADTSData.channel_idx == adts_data.channel_idx &&
@ -141,8 +141,9 @@ std::optional<BinaryMessage> MediaNDKDecoder::Impl::Decode(const BinaryMessage&
return response; return response;
} }
u8* data = memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); const u8* data =
ADTSData adts_data = ParseADTS(reinterpret_cast<const char*>(data)); memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR);
ADTSData adts_data = AudioCore::ParseADTS(data);
SetMediaType(adts_data); SetMediaType(adts_data);
response.decode_aac_response.sample_rate = GetSampleRateEnum(adts_data.samplerate); response.decode_aac_response.sample_rate = GetSampleRateEnum(adts_data.samplerate);
response.decode_aac_response.num_channels = adts_data.channels; response.decode_aac_response.num_channels = adts_data.channels;

View file

@ -23,7 +23,8 @@ private:
std::optional<BinaryMessage> Decode(const BinaryMessage& request); std::optional<BinaryMessage> Decode(const BinaryMessage& request);
MFOutputState DecodingLoop(ADTSData adts_header, std::array<std::vector<u8>, 2>& out_streams); MFOutputState DecodingLoop(AudioCore::ADTSData adts_header,
std::array<std::vector<u8>, 2>& out_streams);
bool transform_initialized = false; bool transform_initialized = false;
bool format_selected = false; bool format_selected = false;
@ -139,7 +140,7 @@ std::optional<BinaryMessage> WMFDecoder::Impl::Initalize(const BinaryMessage& re
return response; return response;
} }
MFOutputState WMFDecoder::Impl::DecodingLoop(ADTSData adts_header, MFOutputState WMFDecoder::Impl::DecodingLoop(AudioCore::ADTSData adts_header,
std::array<std::vector<u8>, 2>& out_streams) { std::array<std::vector<u8>, 2>& out_streams) {
std::optional<std::vector<f32>> output_buffer; std::optional<std::vector<f32>> output_buffer;
@ -210,14 +211,14 @@ std::optional<BinaryMessage> WMFDecoder::Impl::Decode(const BinaryMessage& reque
request.decode_aac_request.src_addr); request.decode_aac_request.src_addr);
return std::nullopt; return std::nullopt;
} }
u8* data = memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); const u8* data =
memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR);
std::array<std::vector<u8>, 2> out_streams; std::array<std::vector<u8>, 2> out_streams;
unique_mfptr<IMFSample> sample; unique_mfptr<IMFSample> sample;
MFInputState input_status = MFInputState::OK; MFInputState input_status = MFInputState::OK;
MFOutputState output_status = MFOutputState::OK; MFOutputState output_status = MFOutputState::OK;
std::optional<ADTSMeta> adts_meta = std::optional<ADTSMeta> adts_meta = DetectMediaType(data, request.decode_aac_request.size);
DetectMediaType((char*)data, request.decode_aac_request.size);
if (!adts_meta) { if (!adts_meta) {
LOG_ERROR(Audio_DSP, "Unable to deduce decoding parameters from ADTS stream"); LOG_ERROR(Audio_DSP, "Unable to deduce decoding parameters from ADTS stream");

View file

@ -110,8 +110,9 @@ unique_mfptr<IMFSample> CreateSample(const void* data, DWORD len, DWORD alignmen
return sample; return sample;
} }
bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSData& adts, bool SelectInputMediaType(IMFTransform* transform, int in_stream_id,
const UINT8* user_data, UINT32 user_data_len, GUID audio_format) { const AudioCore::ADTSData& adts, const UINT8* user_data,
UINT32 user_data_len, GUID audio_format) {
HRESULT hr = S_OK; HRESULT hr = S_OK;
unique_mfptr<IMFMediaType> t; unique_mfptr<IMFMediaType> t;
@ -190,12 +191,12 @@ bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id, GUID audi
return false; return false;
} }
std::optional<ADTSMeta> DetectMediaType(char* buffer, std::size_t len) { std::optional<ADTSMeta> DetectMediaType(const u8* buffer, std::size_t len) {
if (len < 7) { if (len < 7) {
return std::nullopt; return std::nullopt;
} }
ADTSData tmp; AudioCore::ADTSData tmp;
ADTSMeta result; ADTSMeta result;
// see https://docs.microsoft.com/en-us/windows/desktop/api/mmreg/ns-mmreg-heaacwaveinfo_tag // see https://docs.microsoft.com/en-us/windows/desktop/api/mmreg/ns-mmreg-heaacwaveinfo_tag
// for the meaning of the byte array below // for the meaning of the byte array below
@ -207,7 +208,7 @@ std::optional<ADTSMeta> DetectMediaType(char* buffer, std::size_t len) {
UINT8 aac_tmp[] = {0x01, 0x00, 0xfe, 00, 00, 00, 00, 00, 00, 00, 00, 00, 0x00, 0x00}; UINT8 aac_tmp[] = {0x01, 0x00, 0xfe, 00, 00, 00, 00, 00, 00, 00, 00, 00, 0x00, 0x00};
uint16_t tag = 0; uint16_t tag = 0;
tmp = ParseADTS(buffer); tmp = AudioCore::ParseADTS(buffer);
if (tmp.length == 0) { if (tmp.length == 0) {
return std::nullopt; return std::nullopt;
} }
@ -215,7 +216,7 @@ std::optional<ADTSMeta> DetectMediaType(char* buffer, std::size_t len) {
tag = MFGetAACTag(tmp); tag = MFGetAACTag(tmp);
aac_tmp[12] |= (tag & 0xff00) >> 8; aac_tmp[12] |= (tag & 0xff00) >> 8;
aac_tmp[13] |= (tag & 0x00ff); aac_tmp[13] |= (tag & 0x00ff);
std::memcpy(&(result.ADTSHeader), &tmp, sizeof(ADTSData)); std::memcpy(&(result.ADTSHeader), &tmp, sizeof(AudioCore::ADTSData));
std::memcpy(&(result.AACTag), aac_tmp, 14); std::memcpy(&(result.AACTag), aac_tmp, 14);
return result; return result;
} }

View file

@ -99,7 +99,7 @@ void ReportError(std::string msg, HRESULT hr);
// data type for transferring ADTS metadata between functions // data type for transferring ADTS metadata between functions
struct ADTSMeta { struct ADTSMeta {
ADTSData ADTSHeader; AudioCore::ADTSData ADTSHeader;
u8 AACTag[14]; u8 AACTag[14];
}; };
@ -110,10 +110,10 @@ bool InitMFDLL();
unique_mfptr<IMFTransform> MFDecoderInit(GUID audio_format = MFAudioFormat_AAC); unique_mfptr<IMFTransform> MFDecoderInit(GUID audio_format = MFAudioFormat_AAC);
unique_mfptr<IMFSample> CreateSample(const void* data, DWORD len, DWORD alignment = 1, unique_mfptr<IMFSample> CreateSample(const void* data, DWORD len, DWORD alignment = 1,
LONGLONG duration = 0); LONGLONG duration = 0);
bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSData& adts, bool SelectInputMediaType(IMFTransform* transform, int in_stream_id,
const UINT8* user_data, UINT32 user_data_len, const AudioCore::ADTSData& adts, const UINT8* user_data,
GUID audio_format = MFAudioFormat_AAC); UINT32 user_data_len, GUID audio_format = MFAudioFormat_AAC);
std::optional<ADTSMeta> DetectMediaType(char* buffer, std::size_t len); std::optional<ADTSMeta> DetectMediaType(const u8* buffer, std::size_t len);
bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id, bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id,
GUID audio_format = MFAudioFormat_PCM); GUID audio_format = MFAudioFormat_PCM);
void MFFlush(IMFTransform* transform); void MFFlush(IMFTransform* transform);

View file

@ -12,6 +12,7 @@ add_executable(tests
core/memory/vm_manager.cpp core/memory/vm_manager.cpp
precompiled_headers.h precompiled_headers.h
audio_core/hle/hle.cpp audio_core/hle/hle.cpp
audio_core/hle/adts_reader.cpp
audio_core/lle/lle.cpp audio_core/lle/lle.cpp
audio_core/audio_fixures.h audio_core/audio_fixures.h
audio_core/decoder_tests.cpp audio_core/decoder_tests.cpp

View file

@ -0,0 +1,75 @@
// Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <catch2/catch_test_macros.hpp>
#include "audio_core/hle/adts.h"
namespace {
constexpr std::array<u32, 16> freq_table = {96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050,
16000, 12000, 11025, 8000, 7350, 0, 0, 0};
constexpr std::array<u8, 8> channel_table = {0, 1, 2, 3, 4, 5, 6, 8};
AudioCore::ADTSData ParseADTS_Old(const unsigned char* buffer) {
u32 tmp = 0;
AudioCore::ADTSData out{};
// sync word 0xfff
tmp = (buffer[0] << 8) | (buffer[1] & 0xf0);
if ((tmp & 0xffff) != 0xfff0) {
out.length = 0;
return out;
}
// bit 16 = no CRC
out.header_length = (buffer[1] & 0x1) ? 7 : 9;
out.mpeg2 = (buffer[1] >> 3) & 0x1;
// bit 17 to 18
out.profile = (buffer[2] >> 6) + 1;
// bit 19 to 22
tmp = (buffer[2] >> 2) & 0xf;
out.samplerate_idx = tmp;
out.samplerate = (tmp > 15) ? 0 : freq_table[tmp];
// bit 24 to 26
tmp = ((buffer[2] & 0x1) << 2) | ((buffer[3] >> 6) & 0x3);
out.channel_idx = tmp;
out.channels = (tmp > 7) ? 0 : channel_table[tmp];
// bit 55 to 56
out.framecount = (buffer[6] & 0x3) + 1;
// bit 31 to 43
tmp = (buffer[3] & 0x3) << 11;
tmp |= (buffer[4] << 3) & 0x7f8;
tmp |= (buffer[5] >> 5) & 0x7;
out.length = tmp;
return out;
}
} // namespace
TEST_CASE("ParseADTS fuzz", "[audio_core][hle]") {
for (u32 i = 0; i < 0x10000; i++) {
std::array<u8, 7> adts_header;
std::string adts_header_string = "ADTS Header: ";
for (auto& it : adts_header) {
it = static_cast<u8>(rand());
adts_header_string.append(fmt::format("{:2X} ", it));
}
INFO(adts_header_string);
AudioCore::ADTSData out_old_impl =
ParseADTS_Old(reinterpret_cast<const unsigned char*>(adts_header.data()));
AudioCore::ADTSData out = AudioCore::ParseADTS(adts_header.data());
REQUIRE(out_old_impl.length == out.length);
REQUIRE(out_old_impl.channels == out.channels);
REQUIRE(out_old_impl.channel_idx == out.channel_idx);
REQUIRE(out_old_impl.framecount == out.framecount);
REQUIRE(out_old_impl.header_length == out.header_length);
REQUIRE(out_old_impl.mpeg2 == out.mpeg2);
REQUIRE(out_old_impl.profile == out.profile);
REQUIRE(out_old_impl.samplerate == out.samplerate);
REQUIRE(out_old_impl.samplerate_idx == out.samplerate_idx);
}
}