From 70271a9ce08f222f8356346316d87426056dcf77 Mon Sep 17 00:00:00 2001 From: Artemis Tosini Date: Tue, 27 Aug 2024 21:05:17 +0000 Subject: [PATCH] Add method for subscribing to a message type --- daemon/src/connection.rs | 50 +++++++++++++++++++++++++++++++++++++--- daemon/src/main.rs | 16 ++++++------- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/daemon/src/connection.rs b/daemon/src/connection.rs index aff0a22..a8dafd8 100644 --- a/daemon/src/connection.rs +++ b/daemon/src/connection.rs @@ -15,10 +15,14 @@ use color_eyre::eyre::WrapErr; use color_eyre::eyre::{self, bail, OptionExt}; use futures::{SinkExt as _, StreamExt}; use serde::Deserialize; +use tokio::sync::broadcast; use vapore_proto::{enums_clientserver::EMsg, steammessages_base::CMsgProtoBufHeader}; use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage}; +/// Maximum number of messages in the buffer for by-message-type subscriptions +const CHANNEL_CAPACITY: usize = 16; + #[derive(Debug, Deserialize)] struct GetCMListResult<'a> { // No need to check servers, we're only implementing websockets @@ -70,6 +74,9 @@ struct SessionInner { /// Messages ready for receivers by job id /// TODO: Support multiple messages for same job ID receive_messages: HashMap, + + /// Senders for per-type subscriptions + subscribe_senders: HashMap>, } impl SessionInner { @@ -111,6 +118,23 @@ impl Context { for message in raw_messages.into_iter() { log::trace!("Got message: {:?}", message); + // Send based on message type to subscribers + let action = message.action; + if session.subscribe_senders.contains_key(&action) { + let sender_unused = session + .subscribe_senders + .get(&action) + .unwrap() + .send(message.clone()) + .is_err(); + + if sender_unused { + log::debug!("No more subscribers for type {:?}", action); + session.subscribe_senders.remove(&action); + } + } + + // Send based on jobid to call_service_method etc. let jobid = message.header.jobid_target(); if let Some(waker) = session.receive_wakers.remove(&jobid) { if session.receive_messages.insert(jobid, message).is_some() { @@ -204,6 +228,7 @@ impl CMSession { send_waker: None, receive_wakers: HashMap::new(), receive_messages: HashMap::new(), + subscribe_senders: HashMap::new(), }; let inner_wrapped = Arc::new(Mutex::new(inner)); @@ -225,7 +250,7 @@ impl CMSession { method: String, body: T, ) -> CallServiceMethod { - log::trace!("Calling service method `{}`", method); + log::trace!("Calling service method {}", method); CallServiceMethod:: { session: self, @@ -236,14 +261,17 @@ impl CMSession { } } - /// Send a message without a jobid - pub async fn send_notification( + /// Send a message without a jobid. + /// Returns as soon as the message is in the send buffer + pub fn send_notification( &mut self, action: EMsg, body: T, ) -> eyre::Result<()> { let mut inner = self.inner.lock().expect("Lock was poisoned"); + log::trace!("Sending notification of type {:?}", action); + let header = CMsgProtoBufHeader { steamid: inner.steam_id, realm: Some(inner.realm), @@ -265,6 +293,22 @@ impl CMSession { Ok(()) } + + /// Subscribe to receive a notification every time a message of a + /// given type is received + pub fn subscribe_message_type( + &mut self, + action: EMsg, + ) -> broadcast::Receiver { + let mut inner = self.inner.lock().expect("Lock was poisoned"); + if let Some(sender) = inner.subscribe_senders.get(&action) { + sender.subscribe() + } else { + let (sender, receiver) = broadcast::channel(CHANNEL_CAPACITY); + inner.subscribe_senders.insert(action, sender); + receiver + } + } } #[must_use = "futures do nothing unless polled"] diff --git a/daemon/src/main.rs b/daemon/src/main.rs index 9b4e14f..e4f7a43 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -25,15 +25,13 @@ pub async fn main() -> eyre::Result<()> { let mut session = connection::CMSession::connect(&servers[0]).await?; - session - .send_notification( - EMsg::k_EMsgClientHello, - CMsgClientHello { - protocol_version: Some(0x1002c), - ..Default::default() - }, - ) - .await?; + session.send_notification( + EMsg::k_EMsgClientHello, + CMsgClientHello { + protocol_version: Some(0x1002c), + ..Default::default() + }, + )?; log::debug!("Sent hello");