diff --git a/include/modules/hyprland/backend.hpp b/include/modules/hyprland/backend.hpp index 2e0ef657..a6ebd191 100644 --- a/include/modules/hyprland/backend.hpp +++ b/include/modules/hyprland/backend.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -43,10 +44,11 @@ class IPC { std::thread ipcThread_; std::mutex callbackMutex_; + std::mutex socketMutex_; util::JsonParser parser_; std::list> callbacks_; - int socketfd_; // the hyprland socket file descriptor - pid_t socketOwnerPid_; - bool running_ = true; // the ipcThread will stop running when this is false + int socketfd_ = -1; // the hyprland socket file descriptor + pid_t socketOwnerPid_ = -1; + std::atomic running_ = true; // the ipcThread will stop running when this is false }; }; // namespace waybar::modules::hyprland diff --git a/include/modules/sway/ipc/client.hpp b/include/modules/sway/ipc/client.hpp index 281df7ab..eb0f32f9 100644 --- a/include/modules/sway/ipc/client.hpp +++ b/include/modules/sway/ipc/client.hpp @@ -14,6 +14,7 @@ #include "ipc.hpp" #include "util/SafeSignal.hpp" #include "util/sleeper_thread.hpp" +#include "util/scoped_fd.hpp" namespace waybar::modules::sway { @@ -45,8 +46,8 @@ class Ipc { struct ipc_response send(int fd, uint32_t type, const std::string& payload = ""); struct ipc_response recv(int fd); - int fd_; - int fd_event_; + util::ScopedFd fd_; + util::ScopedFd fd_event_; std::mutex mutex_; util::SleeperThread thread_; }; diff --git a/include/modules/wayfire/backend.hpp b/include/modules/wayfire/backend.hpp index d3173269..e1f259cc 100644 --- a/include/modules/wayfire/backend.hpp +++ b/include/modules/wayfire/backend.hpp @@ -12,6 +12,8 @@ #include #include +#include "util/scoped_fd.hpp" + namespace waybar::modules::wayfire { using EventHandler = std::function; @@ -71,23 +73,7 @@ struct State { auto update_view(const Json::Value& view) -> void; }; -struct Sock { - int fd; - - Sock(int fd) : fd{fd} {} - ~Sock() { close(fd); } - Sock(const Sock&) = delete; - auto operator=(const Sock&) = delete; - Sock(Sock&& rhs) noexcept { - fd = rhs.fd; - rhs.fd = -1; - } - auto& operator=(Sock&& rhs) noexcept { - fd = rhs.fd; - rhs.fd = -1; - return *this; - } -}; +using Sock = util::ScopedFd; class IPC : public std::enable_shared_from_this { static std::weak_ptr instance; diff --git a/include/util/scoped_fd.hpp b/include/util/scoped_fd.hpp new file mode 100644 index 00000000..e970109e --- /dev/null +++ b/include/util/scoped_fd.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include + +namespace waybar::util { + +class ScopedFd { + public: + explicit ScopedFd(int fd = -1) : fd_(fd) {} + ~ScopedFd() { + if (fd_ != -1) { + close(fd_); + } + } + + // ScopedFd is non-copyable + ScopedFd(const ScopedFd&) = delete; + ScopedFd& operator=(const ScopedFd&) = delete; + + // ScopedFd is moveable + ScopedFd(ScopedFd&& other) noexcept : fd_(other.fd_) { other.fd_ = -1; } + ScopedFd& operator=(ScopedFd&& other) noexcept { + if (this != &other) { + if (fd_ != -1) { + close(fd_); + } + fd_ = other.fd_; + other.fd_ = -1; + } + return *this; + } + + int get() const { return fd_; } + + operator int() const { return fd_; } + + void reset(int fd = -1) { + if (fd_ != -1) { + close(fd_); + } + fd_ = fd; + } + + int release() { + int fd = fd_; + fd_ = -1; + return fd; + } + + private: + int fd_; +}; + +} // namespace waybar::util diff --git a/src/modules/hyprland/backend.cpp b/src/modules/hyprland/backend.cpp index 7060d304..0f02b919 100644 --- a/src/modules/hyprland/backend.cpp +++ b/src/modules/hyprland/backend.cpp @@ -9,9 +9,14 @@ #include #include +#include +#include +#include #include #include +#include "util/scoped_fd.hpp" + namespace waybar::modules::hyprland { std::filesystem::path IPC::socketFolder_; @@ -45,8 +50,8 @@ std::filesystem::path IPC::getSocketFolder(const char* instanceSig) { IPC::IPC() { // will start IPC and relay events to parseIPC - ipcThread_ = std::thread([this]() { socketListener(); }); socketOwnerPid_ = getpid(); + ipcThread_ = std::thread([this]() { socketListener(); }); } IPC::~IPC() { @@ -54,19 +59,20 @@ IPC::~IPC() { // failed exec()) exits. if (getpid() != socketOwnerPid_) return; - running_ = false; + running_.store(false, std::memory_order_relaxed); spdlog::info("Hyprland IPC stopping..."); - if (socketfd_ != -1) { - spdlog::trace("Shutting down socket"); - if (shutdown(socketfd_, SHUT_RDWR) == -1) { - spdlog::error("Hyprland IPC: Couldn't shutdown socket"); - } - spdlog::trace("Closing socket"); - if (close(socketfd_) == -1) { - spdlog::error("Hyprland IPC: Couldn't close socket"); + { + std::lock_guard lock(socketMutex_); + if (socketfd_ != -1) { + spdlog::trace("Shutting down socket"); + if (shutdown(socketfd_, SHUT_RDWR) == -1 && errno != ENOTCONN) { + spdlog::error("Hyprland IPC: Couldn't shutdown socket"); + } } } - ipcThread_.join(); + if (ipcThread_.joinable()) { + ipcThread_.join(); + } } IPC& IPC::inst() { @@ -86,9 +92,9 @@ void IPC::socketListener() { spdlog::info("Hyprland IPC starting"); struct sockaddr_un addr; - socketfd_ = socket(AF_UNIX, SOCK_STREAM, 0); + const int socketfd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socketfd_ == -1) { + if (socketfd == -1) { spdlog::error("Hyprland IPC: socketfd failed"); return; } @@ -102,38 +108,67 @@ void IPC::socketListener() { int l = sizeof(struct sockaddr_un); - if (connect(socketfd_, (struct sockaddr*)&addr, l) == -1) { - spdlog::error("Hyprland IPC: Unable to connect?"); + if (connect(socketfd, (struct sockaddr*)&addr, l) == -1) { + spdlog::error("Hyprland IPC: Unable to connect? {}", std::strerror(errno)); + close(socketfd); return; } - auto* file = fdopen(socketfd_, "r"); - if (file == nullptr) { - spdlog::error("Hyprland IPC: Couldn't open file descriptor"); - return; + + { + std::lock_guard lock(socketMutex_); + socketfd_ = socketfd; } - while (running_) { + + std::string pending; + while (running_.load(std::memory_order_relaxed)) { std::array buffer; // Hyprland socket2 events are max 1024 bytes + const ssize_t bytes_read = read(socketfd, buffer.data(), buffer.size()); - auto* receivedCharPtr = fgets(buffer.data(), buffer.size(), file); - - if (receivedCharPtr == nullptr) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; + if (bytes_read == 0) { + if (running_.load(std::memory_order_relaxed)) { + spdlog::warn("Hyprland IPC: Socket closed by peer"); + } + break; } - std::string messageReceived(buffer.data()); - messageReceived = messageReceived.substr(0, messageReceived.find_first_of('\n')); - spdlog::debug("hyprland IPC received {}", messageReceived); - - try { - parseIPC(messageReceived); - } catch (std::exception& e) { - spdlog::warn("Failed to parse IPC message: {}, reason: {}", messageReceived, e.what()); - } catch (...) { - throw; + if (bytes_read < 0) { + if (errno == EINTR) { + continue; + } + if (!running_.load(std::memory_order_relaxed)) { + break; + } + spdlog::error("Hyprland IPC: read failed: {}", std::strerror(errno)); + break; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + pending.append(buffer.data(), static_cast(bytes_read)); + for (auto newline_pos = pending.find('\n'); newline_pos != std::string::npos; + newline_pos = pending.find('\n')) { + std::string messageReceived = pending.substr(0, newline_pos); + pending.erase(0, newline_pos + 1); + if (messageReceived.empty()) { + continue; + } + spdlog::debug("hyprland IPC received {}", messageReceived); + + try { + parseIPC(messageReceived); + } catch (std::exception& e) { + spdlog::warn("Failed to parse IPC message: {}, reason: {}", messageReceived, e.what()); + } catch (...) { + throw; + } + } + } + { + std::lock_guard lock(socketMutex_); + if (socketfd_ != -1) { + if (close(socketfd_) == -1) { + spdlog::error("Hyprland IPC: Couldn't close socket"); + } + socketfd_ = -1; + } } spdlog::debug("Hyprland IPC stopped"); } @@ -178,7 +213,7 @@ void IPC::unregisterForIPC(EventHandler* ev_handler) { std::string IPC::getSocket1Reply(const std::string& rq) { // basically hyprctl - const auto serverSocket = socket(AF_UNIX, SOCK_STREAM, 0); + util::ScopedFd serverSocket(socket(AF_UNIX, SOCK_STREAM, 0)); if (serverSocket < 0) { throw std::runtime_error("Hyprland IPC: Couldn't open a socket (1)"); @@ -223,13 +258,11 @@ std::string IPC::getSocket1Reply(const std::string& rq) { if (sizeWritten < 0) { spdlog::error("Hyprland IPC: Couldn't read (5)"); - close(serverSocket); return ""; } response.append(buffer.data(), sizeWritten); } while (sizeWritten > 0); - close(serverSocket); return response; } diff --git a/src/modules/niri/backend.cpp b/src/modules/niri/backend.cpp index 68bb1724..c23aaa47 100644 --- a/src/modules/niri/backend.cpp +++ b/src/modules/niri/backend.cpp @@ -13,6 +13,7 @@ #include #include +#include "util/scoped_fd.hpp" #include "giomm/datainputstream.h" #include "giomm/dataoutputstream.h" #include "giomm/unixinputstream.h" @@ -30,7 +31,7 @@ int IPC::connectToSocket() { } struct sockaddr_un addr; - int socketfd = socket(AF_UNIX, SOCK_STREAM, 0); + util::ScopedFd socketfd(socket(AF_UNIX, SOCK_STREAM, 0)); if (socketfd == -1) { throw std::runtime_error("socketfd failed"); @@ -45,11 +46,10 @@ int IPC::connectToSocket() { int l = sizeof(struct sockaddr_un); if (connect(socketfd, (struct sockaddr*)&addr, l) == -1) { - close(socketfd); throw std::runtime_error("unable to connect"); } - return socketfd; + return socketfd.release(); } void IPC::startIPC() { @@ -235,7 +235,7 @@ void IPC::unregisterForIPC(EventHandler* ev_handler) { } Json::Value IPC::send(const Json::Value& request) { - int socketfd = connectToSocket(); + util::ScopedFd socketfd(connectToSocket()); auto unix_istream = Gio::UnixInputStream::create(socketfd, true); auto unix_ostream = Gio::UnixOutputStream::create(socketfd, false); diff --git a/src/modules/sway/ipc/client.cpp b/src/modules/sway/ipc/client.cpp index 4139a53b..3ebccccd 100644 --- a/src/modules/sway/ipc/client.cpp +++ b/src/modules/sway/ipc/client.cpp @@ -9,8 +9,8 @@ namespace waybar::modules::sway { Ipc::Ipc() { const std::string& socketPath = getSocketPath(); - fd_ = open(socketPath); - fd_event_ = open(socketPath); + fd_ = util::ScopedFd(open(socketPath)); + fd_event_ = util::ScopedFd(open(socketPath)); } Ipc::~Ipc() { @@ -21,15 +21,11 @@ Ipc::~Ipc() { if (write(fd_, "close-sway-ipc", 14) == -1) { spdlog::error("Failed to close sway IPC"); } - close(fd_); - fd_ = -1; } if (fd_event_ > 0) { if (write(fd_event_, "close-sway-ipc", 14) == -1) { spdlog::error("Failed to close sway IPC event handler"); } - close(fd_event_); - fd_event_ = -1; } } @@ -64,7 +60,7 @@ const std::string Ipc::getSocketPath() const { } int Ipc::open(const std::string& socketPath) const { - int32_t fd = socket(AF_UNIX, SOCK_STREAM, 0); + util::ScopedFd fd(socket(AF_UNIX, SOCK_STREAM, 0)); if (fd == -1) { throw std::runtime_error("Unable to open Unix socket"); } @@ -78,7 +74,7 @@ int Ipc::open(const std::string& socketPath) const { if (::connect(fd, reinterpret_cast(&addr), l) == -1) { throw std::runtime_error("Unable to connect to Sway"); } - return fd; + return fd.release(); } struct Ipc::ipc_response Ipc::recv(int fd) { diff --git a/src/modules/wayfire/backend.cpp b/src/modules/wayfire/backend.cpp index 545aaa89..42976d20 100644 --- a/src/modules/wayfire/backend.cpp +++ b/src/modules/wayfire/backend.cpp @@ -27,14 +27,14 @@ inline auto byteswap(uint32_t x) -> uint32_t { auto pack_and_write(Sock& sock, std::string&& buf) -> void { uint32_t len = buf.size(); if constexpr (std::endian::native != std::endian::little) len = byteswap(len); - (void)write(sock.fd, &len, 4); - (void)write(sock.fd, buf.data(), buf.size()); + (void)write(sock, &len, 4); + (void)write(sock, buf.data(), buf.size()); } auto read_exact(Sock& sock, size_t n) -> std::string { auto buf = std::string(n, 0); for (size_t i = 0; i < n;) { - auto r = read(sock.fd, &buf[i], n - i); + auto r = read(sock, &buf[i], n - i); if (r <= 0) { throw std::runtime_error("Wayfire IPC: read failed"); } @@ -111,7 +111,7 @@ auto IPC::connect() -> Sock { throw std::runtime_error{"Wayfire IPC: ipc not available"}; } - auto sock = socket(AF_UNIX, SOCK_STREAM, 0); + util::ScopedFd sock(socket(AF_UNIX, SOCK_STREAM, 0)); if (sock == -1) { throw std::runtime_error{"Wayfire IPC: socket() failed"}; } @@ -121,11 +121,10 @@ auto IPC::connect() -> Sock { addr.sun_path[sizeof(addr.sun_path) - 1] = 0; if (::connect(sock, (const sockaddr*)&addr, sizeof(addr)) == -1) { - close(sock); throw std::runtime_error{"Wayfire IPC: connect() failed"}; } - return {sock}; + return sock; } auto IPC::receive(Sock& sock) -> Json::Value { diff --git a/test/hyprland/backend.cpp b/test/hyprland/backend.cpp index b83b839c..cc7295ec 100644 --- a/test/hyprland/backend.cpp +++ b/test/hyprland/backend.cpp @@ -4,56 +4,114 @@ #include #endif -#include "fixtures/IPCTestFixture.hpp" +#include + +#include "modules/hyprland/backend.hpp" namespace fs = std::filesystem; namespace hyprland = waybar::modules::hyprland; -TEST_CASE_METHOD(IPCTestFixture, "XDGRuntimeDirExists", "[getSocketFolder]") { +namespace { +class IPCTestHelper : public hyprland::IPC { + public: + static void resetSocketFolder() { socketFolder_.clear(); } +}; + +std::size_t countOpenFds() { +#if defined(__linux__) + std::size_t count = 0; + for (const auto& _ : fs::directory_iterator("/proc/self/fd")) { + (void)_; + ++count; + } + return count; +#else + return 0; +#endif +} +} // namespace + +TEST_CASE("XDGRuntimeDirExists", "[getSocketFolder]") { // Test case: XDG_RUNTIME_DIR exists and contains "hypr" directory // Arrange - tempDir = fs::temp_directory_path() / "hypr_test/run/user/1000"; + constexpr auto instanceSig = "instance_sig"; + const fs::path tempDir = fs::temp_directory_path() / "hypr_test/run/user/1000"; + std::error_code ec; + fs::remove_all(tempDir, ec); fs::path expectedPath = tempDir / "hypr" / instanceSig; - fs::create_directories(tempDir / "hypr" / instanceSig); + fs::create_directories(expectedPath); setenv("XDG_RUNTIME_DIR", tempDir.c_str(), 1); + IPCTestHelper::resetSocketFolder(); // Act - fs::path actualPath = getSocketFolder(instanceSig); + fs::path actualPath = hyprland::IPC::getSocketFolder(instanceSig); // Assert expected result REQUIRE(actualPath == expectedPath); + fs::remove_all(tempDir, ec); } -TEST_CASE_METHOD(IPCTestFixture, "XDGRuntimeDirDoesNotExist", "[getSocketFolder]") { +TEST_CASE("XDGRuntimeDirDoesNotExist", "[getSocketFolder]") { // Test case: XDG_RUNTIME_DIR does not exist // Arrange + constexpr auto instanceSig = "instance_sig"; unsetenv("XDG_RUNTIME_DIR"); fs::path expectedPath = fs::path("/tmp") / "hypr" / instanceSig; + IPCTestHelper::resetSocketFolder(); // Act - fs::path actualPath = getSocketFolder(instanceSig); + fs::path actualPath = hyprland::IPC::getSocketFolder(instanceSig); // Assert expected result REQUIRE(actualPath == expectedPath); } -TEST_CASE_METHOD(IPCTestFixture, "XDGRuntimeDirExistsNoHyprDir", "[getSocketFolder]") { +TEST_CASE("XDGRuntimeDirExistsNoHyprDir", "[getSocketFolder]") { // Test case: XDG_RUNTIME_DIR exists but does not contain "hypr" directory // Arrange + constexpr auto instanceSig = "instance_sig"; fs::path tempDir = fs::temp_directory_path() / "hypr_test/run/user/1000"; + std::error_code ec; + fs::remove_all(tempDir, ec); fs::create_directories(tempDir); setenv("XDG_RUNTIME_DIR", tempDir.c_str(), 1); fs::path expectedPath = fs::path("/tmp") / "hypr" / instanceSig; + IPCTestHelper::resetSocketFolder(); // Act - fs::path actualPath = getSocketFolder(instanceSig); + fs::path actualPath = hyprland::IPC::getSocketFolder(instanceSig); // Assert expected result REQUIRE(actualPath == expectedPath); + fs::remove_all(tempDir, ec); } -TEST_CASE_METHOD(IPCTestFixture, "getSocket1Reply throws on no socket", "[getSocket1Reply]") { +TEST_CASE("getSocket1Reply throws on no socket", "[getSocket1Reply]") { + unsetenv("HYPRLAND_INSTANCE_SIGNATURE"); + IPCTestHelper::resetSocketFolder(); std::string request = "test_request"; - CHECK_THROWS(getSocket1Reply(request)); + CHECK_THROWS(hyprland::IPC::getSocket1Reply(request)); } + +#if defined(__linux__) +TEST_CASE("getSocket1Reply failure paths do not leak fds", "[getSocket1Reply][fd-leak]") { + const auto baseline = countOpenFds(); + + unsetenv("HYPRLAND_INSTANCE_SIGNATURE"); + for (int i = 0; i < 16; ++i) { + IPCTestHelper::resetSocketFolder(); + CHECK_THROWS(hyprland::IPC::getSocket1Reply("test_request")); + } + const auto after_missing_signature = countOpenFds(); + REQUIRE(after_missing_signature == baseline); + + setenv("HYPRLAND_INSTANCE_SIGNATURE", "definitely-not-running", 1); + for (int i = 0; i < 16; ++i) { + IPCTestHelper::resetSocketFolder(); + CHECK_THROWS(hyprland::IPC::getSocket1Reply("test_request")); + } + const auto after_connect_failures = countOpenFds(); + REQUIRE(after_connect_failures == baseline); +} +#endif diff --git a/test/hyprland/fixtures/IPCTestFixture.hpp b/test/hyprland/fixtures/IPCTestFixture.hpp deleted file mode 100644 index caa92975..00000000 --- a/test/hyprland/fixtures/IPCTestFixture.hpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "modules/hyprland/backend.hpp" - -namespace fs = std::filesystem; -namespace hyprland = waybar::modules::hyprland; - -class IPCTestFixture : public hyprland::IPC { - public: - IPCTestFixture() : IPC() { IPC::socketFolder_ = ""; } - ~IPCTestFixture() { fs::remove_all(tempDir); } - - protected: - const char* instanceSig = "instance_sig"; - fs::path tempDir = fs::temp_directory_path() / "hypr_test"; - - private: -}; - -class IPCMock : public IPCTestFixture { - public: - // Mock getSocket1Reply to return an empty string - static std::string getSocket1Reply(const std::string& rq) { return ""; } - - protected: - const char* instanceSig = "instance_sig"; -};