Share the websocket between tasks

This commit is contained in:
Artemis Tosini 2024-08-26 08:16:53 +00:00
parent ded292f332
commit 44aa1f8aac
Signed by: artemist
GPG key ID: EE5227935FE3FF18

View file

@ -1,4 +1,11 @@
use std::collections::BTreeMap; use std::{
collections::{BTreeMap, HashMap, VecDeque},
future::Future,
marker::PhantomData,
pin::Pin,
sync::{Arc, Mutex},
task::{Poll, Waker},
};
use async_tungstenite::{ use async_tungstenite::{
tokio::{connect_async, ConnectStream}, tokio::{connect_async, ConnectStream},
@ -6,9 +13,9 @@ use async_tungstenite::{
}; };
use color_eyre::eyre::WrapErr; use color_eyre::eyre::WrapErr;
use color_eyre::eyre::{self, bail, OptionExt}; use color_eyre::eyre::{self, bail, OptionExt};
use futures::{SinkExt as _, TryStreamExt}; use futures::{SinkExt as _, StreamExt};
use serde::Deserialize; use serde::Deserialize;
use vapore_proto::{enums_clientserver, steammessages_base}; use vapore_proto::{enums_clientserver::EMsg, steammessages_base::CMsgProtoBufHeader};
use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage}; use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage};
@ -40,8 +47,8 @@ pub async fn bootstrap_find_servers() -> eyre::Result<Vec<String>> {
.collect()) .collect())
} }
pub struct CMSession { #[derive(Clone)]
socket: WebSocketStream<ConnectStream>, struct SessionInner {
/// Steam ID of current user. When set to None we are not logged in /// Steam ID of current user. When set to None we are not logged in
steam_id: Option<u64>, steam_id: Option<u64>,
/// Next jobid to use for messages that start a "job" /// Next jobid to use for messages that start a "job"
@ -52,6 +59,134 @@ pub struct CMSession {
/// Session ID for our socket, assigned by the server after login. /// Session ID for our socket, assigned by the server after login.
/// Should be 0 before we login /// Should be 0 before we login
client_session_id: i32, client_session_id: i32,
/// Messages ready to send
send_queue: VecDeque<tungstenite::Message>,
/// Waker for the sending thread
send_waker: Option<Waker>,
/// Recievers waiting for responses by job id
receive_wakers: HashMap<u64, Waker>,
/// Messages ready for receivers by job id
/// TODO: Support multiple messages for same job ID
receive_messages: HashMap<u64, CMRawProtoBufMessage>,
}
impl SessionInner {
pub fn alloc_jobid(&mut self) -> u64 {
let jobid = self.next_jobid;
self.next_jobid += 1;
jobid
}
pub fn wake_sender(&mut self) {
if let Some(waker) = self.send_waker.take() {
waker.wake()
}
}
}
struct Context {
socket: WebSocketStream<ConnectStream>,
session: Arc<Mutex<SessionInner>>,
}
impl Context {
fn handle_receive(
self: Pin<&mut Self>,
message: tungstenite::Result<tungstenite::Message>,
) -> eyre::Result<()> {
// Technically everything should be Binary but I think I saw some Text before
let message_data = match message? {
tungstenite::Message::Text(t) => t.into_bytes(),
tungstenite::Message::Binary(b) => b,
_ => eyre::bail!("Unexpected WebSocket frame type"),
};
let raw_messages = CMRawProtoBufMessage::try_parse_multi(&message_data)
.wrap_err("Parsing raw messages")?;
let mut session = self.session.lock().expect("Lock was poisoned");
for message in raw_messages.into_iter() {
log::trace!("Got message: {:?}", message);
let jobid = message.header.jobid_target();
if let Some(waker) = session.receive_wakers.remove(&jobid) {
if session.receive_messages.insert(jobid, message).is_some() {
log::info!("Received duplicate message for jobid {}", jobid);
};
waker.wake();
} else {
log::info!("Received message for unknown jobid {}", jobid);
}
}
Ok(())
}
fn handle_send(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> eyre::Result<()> {
// TODO: figure out how to not cloe the Arc
let session_arc = self.session.clone();
let mut session = session_arc.lock().expect("Lock was poisoned");
while !session.send_queue.is_empty() {
match self.socket.poll_ready_unpin(cx) {
Poll::Ready(ret) => ret?,
Poll::Pending => return Ok(()),
}
let message = session.send_queue.pop_front().unwrap();
self.socket.start_send_unpin(message)?;
}
match self.socket.poll_flush_unpin(cx) {
Poll::Ready(ret) => ret?,
Poll::Pending => (),
}
Ok(())
}
fn poll_inner(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> eyre::Result<()> {
{
let mut session = self.session.lock().expect("Lock was poisoned");
session.send_waker = Some(cx.waker().clone());
}
if let Poll::Ready(maybe_message) = self.as_mut().socket.poll_next_unpin(cx) {
let message = maybe_message.ok_or_eyre("Socket was closed while trying to recieve")?;
if let Err(err) = self.as_mut().handle_receive(message) {
log::warn!("Got error while processing message: {}", err);
}
}
// All errors in send are critical, since everything's encoded by the time we get there
self.handle_send(cx)?;
Ok(())
}
}
impl Future for Context {
type Output = eyre::Result<()>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if let Err(error) = self.poll_inner(cx) {
log::error!("Failed while talking to socket: {}", error);
todo!("Reopen new socket");
};
// We should always be pending unless we're going to return an error
Poll::Pending
}
}
#[derive(Clone)]
pub struct CMSession {
inner: Arc<Mutex<SessionInner>>,
} }
impl CMSession { impl CMSession {
@ -60,90 +195,59 @@ impl CMSession {
.await .await
.wrap_err("Connecting to Steam server")?; .wrap_err("Connecting to Steam server")?;
Ok(Self { let inner = SessionInner {
socket,
steam_id: None, steam_id: None,
next_jobid: 0, next_jobid: 0,
realm: 1, realm: 1,
client_session_id: 0, client_session_id: 0,
send_queue: VecDeque::new(),
send_waker: None,
receive_wakers: HashMap::new(),
receive_messages: HashMap::new(),
};
let inner_wrapped = Arc::new(Mutex::new(inner));
let context = Context {
socket,
session: inner_wrapped.clone(),
};
tokio::spawn(context);
Ok(Self {
inner: inner_wrapped,
}) })
} }
pub async fn call_service_method<Request: protobuf::Message, Response: protobuf::Message>( pub fn call_service_method<T: protobuf::Message, U: protobuf::Message>(
&mut self, &mut self,
method: String, method: String,
body: Request, body: T,
) -> eyre::Result<CMProtoBufMessage<Response>> { ) -> CallServiceMethod<T, U> {
log::trace!("Calling service method `{}`", method); log::trace!("Calling service method `{}`", method);
let action = if self.is_authed() { CallServiceMethod::<T, U> {
enums_clientserver::EMsg::k_EMsgServiceMethodCallFromClient session: self,
} else {
enums_clientserver::EMsg::k_EMsgServiceMethodCallFromClientNonAuthed
};
let jobid = self.next_jobid;
self.next_jobid += 1;
let header = steammessages_base::CMsgProtoBufHeader {
steamid: self.steam_id,
target_job_name: Some(method),
realm: Some(self.realm),
client_sessionid: Some(self.client_session_id),
jobid_source: Some(jobid),
..Default::default()
};
let message = CMProtoBufMessage {
action,
header,
body, body,
}; method,
let serialized = message.serialize()?; jobid: None,
self.socket _phantom: PhantomData,
.send(tungstenite::protocol::Message::Binary(serialized))
.await?;
let response_message = self
.socket
.try_next()
.await?
.ok_or_eyre("No message recieved")?;
let tungstenite::protocol::Message::Binary(response_binary) = response_message else {
bail!("Message recieved was not binary")
};
let responses_raw = CMRawProtoBufMessage::try_parse_multi(&response_binary)?;
if responses_raw.len() != 1 {
todo!("Multiple responses")
} }
let response_raw = responses_raw.into_iter().next().unwrap();
if response_raw.action != enums_clientserver::EMsg::k_EMsgServiceMethodResponse {
bail!(
"Wanted ServiceMethodResponse, got {:?}",
response_raw.action
);
}
if response_raw.header.jobid_target() != jobid {
bail!("Got wrong jobid")
}
CMProtoBufMessage::<Response>::deserialize(response_raw)
} }
/// Send a message without a jobid /// Send a message without a jobid
pub async fn send_notification<T: protobuf::Message>( pub async fn send_notification<T: protobuf::Message>(
&mut self, &mut self,
action: enums_clientserver::EMsg, action: EMsg,
body: T, body: T,
) -> eyre::Result<()> { ) -> eyre::Result<()> {
let header = steammessages_base::CMsgProtoBufHeader { let mut inner = self.inner.lock().expect("Lock was poisoned");
steamid: self.steam_id,
realm: Some(self.realm), let header = CMsgProtoBufHeader {
client_sessionid: Some(self.client_session_id), steamid: inner.steam_id,
realm: Some(inner.realm),
client_sessionid: Some(inner.client_session_id),
..Default::default() ..Default::default()
}; };
let message = CMProtoBufMessage { let message = CMProtoBufMessage {
@ -154,15 +258,93 @@ impl CMSession {
let serialized = message.serialize()?; let serialized = message.serialize()?;
self.socket inner
.send(tungstenite::protocol::Message::Binary(serialized)) .send_queue
.await?; .push_back(tungstenite::protocol::Message::Binary(serialized));
inner.wake_sender();
Ok(()) Ok(())
} }
}
/// Whether the current session is authenticated #[must_use = "futures do nothing unless polled"]
pub fn is_authed(&self) -> bool { pub struct CallServiceMethod<'a, T: protobuf::Message, U: protobuf::Message> {
self.steam_id.is_some() session: &'a mut CMSession,
body: T,
method: String,
jobid: Option<u64>,
_phantom: PhantomData<U>,
}
impl<'a, T: protobuf::Message, U: protobuf::Message> Unpin for CallServiceMethod<'a, T, U> {}
impl<'a, T: protobuf::Message, U: protobuf::Message> CallServiceMethod<'a, T, U> {
fn finalize_response(
&self,
response: CMRawProtoBufMessage,
) -> eyre::Result<CMProtoBufMessage<U>> {
if response.action != EMsg::k_EMsgServiceMethodResponse {
bail!("Wanted ServiceMethodResponse, got {:?}", response.action);
}
if response.header.jobid_target() != self.jobid.unwrap() {
bail!("Got wrong jobid")
}
CMProtoBufMessage::<U>::deserialize(response)
}
}
impl<T: protobuf::Message, U: protobuf::Message> Future for CallServiceMethod<'_, T, U> {
type Output = eyre::Result<CMProtoBufMessage<U>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let session_arc = self.session.inner.clone();
let mut session = session_arc.lock().expect("Lock was poisoned");
// We only have to send the message once, use jobid for that flag
if self.jobid.is_none() {
let jobid = session.alloc_jobid();
self.jobid = Some(jobid);
let action = if session.steam_id.is_some() {
EMsg::k_EMsgServiceMethodCallFromClient
} else {
EMsg::k_EMsgServiceMethodCallFromClientNonAuthed
};
let header = CMsgProtoBufHeader {
steamid: session.steam_id,
target_job_name: Some(self.method.clone()),
realm: Some(session.realm),
client_sessionid: Some(session.client_session_id),
jobid_source: self.jobid,
..Default::default()
};
let message = CMProtoBufMessage {
action,
header,
body: self.body.clone(),
};
let serialized = message.serialize()?;
session
.send_queue
.push_back(tungstenite::protocol::Message::Binary(serialized));
session.receive_wakers.insert(jobid, cx.waker().clone());
session.wake_sender();
return Poll::Pending;
}
let jobid = self.jobid.unwrap();
let Some(response) = session.receive_messages.remove(&jobid) else {
session.receive_wakers.insert(jobid, cx.waker().clone());
return Poll::Pending;
};
Poll::Ready(self.finalize_response(response))
} }
} }