Add method for subscribing to a message type

This commit is contained in:
Artemis Tosini 2024-08-27 21:05:17 +00:00
parent 44aa1f8aac
commit 70271a9ce0
Signed by: artemist
GPG key ID: EE5227935FE3FF18
2 changed files with 54 additions and 12 deletions

View file

@ -15,10 +15,14 @@ use color_eyre::eyre::WrapErr;
use color_eyre::eyre::{self, bail, OptionExt};
use futures::{SinkExt as _, StreamExt};
use serde::Deserialize;
use tokio::sync::broadcast;
use vapore_proto::{enums_clientserver::EMsg, steammessages_base::CMsgProtoBufHeader};
use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage};
/// Maximum number of messages in the buffer for by-message-type subscriptions
const CHANNEL_CAPACITY: usize = 16;
#[derive(Debug, Deserialize)]
struct GetCMListResult<'a> {
// No need to check servers, we're only implementing websockets
@ -70,6 +74,9 @@ struct SessionInner {
/// Messages ready for receivers by job id
/// TODO: Support multiple messages for same job ID
receive_messages: HashMap<u64, CMRawProtoBufMessage>,
/// Senders for per-type subscriptions
subscribe_senders: HashMap<EMsg, broadcast::Sender<CMRawProtoBufMessage>>,
}
impl SessionInner {
@ -111,6 +118,23 @@ impl Context {
for message in raw_messages.into_iter() {
log::trace!("Got message: {:?}", message);
// Send based on message type to subscribers
let action = message.action;
if session.subscribe_senders.contains_key(&action) {
let sender_unused = session
.subscribe_senders
.get(&action)
.unwrap()
.send(message.clone())
.is_err();
if sender_unused {
log::debug!("No more subscribers for type {:?}", action);
session.subscribe_senders.remove(&action);
}
}
// Send based on jobid to call_service_method etc.
let jobid = message.header.jobid_target();
if let Some(waker) = session.receive_wakers.remove(&jobid) {
if session.receive_messages.insert(jobid, message).is_some() {
@ -204,6 +228,7 @@ impl CMSession {
send_waker: None,
receive_wakers: HashMap::new(),
receive_messages: HashMap::new(),
subscribe_senders: HashMap::new(),
};
let inner_wrapped = Arc::new(Mutex::new(inner));
@ -225,7 +250,7 @@ impl CMSession {
method: String,
body: T,
) -> CallServiceMethod<T, U> {
log::trace!("Calling service method `{}`", method);
log::trace!("Calling service method {}", method);
CallServiceMethod::<T, U> {
session: self,
@ -236,14 +261,17 @@ impl CMSession {
}
}
/// Send a message without a jobid
pub async fn send_notification<T: protobuf::Message>(
/// Send a message without a jobid.
/// Returns as soon as the message is in the send buffer
pub fn send_notification<T: protobuf::Message>(
&mut self,
action: EMsg,
body: T,
) -> eyre::Result<()> {
let mut inner = self.inner.lock().expect("Lock was poisoned");
log::trace!("Sending notification of type {:?}", action);
let header = CMsgProtoBufHeader {
steamid: inner.steam_id,
realm: Some(inner.realm),
@ -265,6 +293,22 @@ impl CMSession {
Ok(())
}
/// Subscribe to receive a notification every time a message of a
/// given type is received
pub fn subscribe_message_type(
&mut self,
action: EMsg,
) -> broadcast::Receiver<CMRawProtoBufMessage> {
let mut inner = self.inner.lock().expect("Lock was poisoned");
if let Some(sender) = inner.subscribe_senders.get(&action) {
sender.subscribe()
} else {
let (sender, receiver) = broadcast::channel(CHANNEL_CAPACITY);
inner.subscribe_senders.insert(action, sender);
receiver
}
}
}
#[must_use = "futures do nothing unless polled"]

View file

@ -25,15 +25,13 @@ pub async fn main() -> eyre::Result<()> {
let mut session = connection::CMSession::connect(&servers[0]).await?;
session
.send_notification(
session.send_notification(
EMsg::k_EMsgClientHello,
CMsgClientHello {
protocol_version: Some(0x1002c),
..Default::default()
},
)
.await?;
)?;
log::debug!("Sent hello");