diff --git a/src/video_core/rasterizer_cache.h b/src/video_core/rasterizer_cache.h index 599d39f62..de1eab86b 100644 --- a/src/video_core/rasterizer_cache.h +++ b/src/video_core/rasterizer_cache.h @@ -4,7 +4,9 @@ #pragma once -#include +#include + +#include #include "common/common_types.h" #include "core/core.h" @@ -17,62 +19,72 @@ template class RasterizerCache : NonCopyable { public: /// Mark the specified region as being invalidated - void InvalidateRegion(VAddr region_addr, size_t region_size) { - for (auto iter = cached_objects.cbegin(); iter != cached_objects.cend();) { - const auto& object{iter->second}; + void InvalidateRegion(VAddr addr, u64 size) { + if (size == 0) + return; - ++iter; + const ObjectInterval interval{addr, addr + size}; + for (auto& pair : boost::make_iterator_range(object_cache.equal_range(interval))) { + for (auto& cached_object : pair.second) { + if (!cached_object) + continue; - if (object->GetAddr() <= (region_addr + region_size) && - region_addr <= (object->GetAddr() + object->GetSizeInBytes())) { - // Regions overlap, so invalidate - Unregister(object); + remove_objects.emplace(cached_object); } } + + for (auto& remove_object : remove_objects) { + Unregister(remove_object); + } + + remove_objects.clear(); + } + + /// Invalidates everything in the cache + void InvalidateAll() { + while (object_cache.begin() != object_cache.end()) { + Unregister(*object_cache.begin()->second.begin()); + } } protected: /// Tries to get an object from the cache with the specified address T TryGet(VAddr addr) const { - const auto& search{cached_objects.find(addr)}; - if (search != cached_objects.end()) { - return search->second; + const ObjectInterval interval{addr}; + for (auto& pair : boost::make_iterator_range(object_cache.equal_range(interval))) { + for (auto& cached_object : pair.second) { + if (cached_object->GetAddr() == addr) { + return cached_object; + } + } } - return nullptr; } - /// Gets a reference to the cache - const std::unordered_map& GetCache() const { - return cached_objects; - } - /// Register an object into the cache void Register(const T& object) { - const auto& search{cached_objects.find(object->GetAddr())}; - if (search != cached_objects.end()) { - // Registered already - return; - } - + object_cache.add({GetInterval(object), ObjectSet{object}}); auto& rasterizer = Core::System::GetInstance().Renderer().Rasterizer(); rasterizer.UpdatePagesCachedCount(object->GetAddr(), object->GetSizeInBytes(), 1); - cached_objects[object->GetAddr()] = std::move(object); } /// Unregisters an object from the cache void Unregister(const T& object) { - const auto& search{cached_objects.find(object->GetAddr())}; - if (search == cached_objects.end()) { - // Unregistered already - return; - } - auto& rasterizer = Core::System::GetInstance().Renderer().Rasterizer(); rasterizer.UpdatePagesCachedCount(object->GetAddr(), object->GetSizeInBytes(), -1); - cached_objects.erase(search); + object_cache.subtract({GetInterval(object), ObjectSet{object}}); } private: - std::unordered_map cached_objects; + using ObjectSet = std::set; + using ObjectCache = boost::icl::interval_map; + using ObjectInterval = typename ObjectCache::interval_type; + + static auto GetInterval(const T& object) { + return ObjectInterval::right_open(object->GetAddr(), + object->GetAddr() + object->GetSizeInBytes()); + } + + ObjectCache object_cache; + ObjectSet remove_objects; };