Share the websocket between tasks
This commit is contained in:
parent
ded292f332
commit
44aa1f8aac
|
@ -1,4 +1,11 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, VecDeque},
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
task::{Poll, Waker},
|
||||
};
|
||||
|
||||
use async_tungstenite::{
|
||||
tokio::{connect_async, ConnectStream},
|
||||
|
@ -6,9 +13,9 @@ use async_tungstenite::{
|
|||
};
|
||||
use color_eyre::eyre::WrapErr;
|
||||
use color_eyre::eyre::{self, bail, OptionExt};
|
||||
use futures::{SinkExt as _, TryStreamExt};
|
||||
use futures::{SinkExt as _, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use vapore_proto::{enums_clientserver, steammessages_base};
|
||||
use vapore_proto::{enums_clientserver::EMsg, steammessages_base::CMsgProtoBufHeader};
|
||||
|
||||
use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage};
|
||||
|
||||
|
@ -40,8 +47,8 @@ pub async fn bootstrap_find_servers() -> eyre::Result<Vec<String>> {
|
|||
.collect())
|
||||
}
|
||||
|
||||
pub struct CMSession {
|
||||
socket: WebSocketStream<ConnectStream>,
|
||||
#[derive(Clone)]
|
||||
struct SessionInner {
|
||||
/// Steam ID of current user. When set to None we are not logged in
|
||||
steam_id: Option<u64>,
|
||||
/// Next jobid to use for messages that start a "job"
|
||||
|
@ -52,6 +59,134 @@ pub struct CMSession {
|
|||
/// Session ID for our socket, assigned by the server after login.
|
||||
/// Should be 0 before we login
|
||||
client_session_id: i32,
|
||||
|
||||
/// Messages ready to send
|
||||
send_queue: VecDeque<tungstenite::Message>,
|
||||
/// Waker for the sending thread
|
||||
send_waker: Option<Waker>,
|
||||
|
||||
/// Recievers waiting for responses by job id
|
||||
receive_wakers: HashMap<u64, Waker>,
|
||||
/// Messages ready for receivers by job id
|
||||
/// TODO: Support multiple messages for same job ID
|
||||
receive_messages: HashMap<u64, CMRawProtoBufMessage>,
|
||||
}
|
||||
|
||||
impl SessionInner {
|
||||
pub fn alloc_jobid(&mut self) -> u64 {
|
||||
let jobid = self.next_jobid;
|
||||
self.next_jobid += 1;
|
||||
jobid
|
||||
}
|
||||
|
||||
pub fn wake_sender(&mut self) {
|
||||
if let Some(waker) = self.send_waker.take() {
|
||||
waker.wake()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Context {
|
||||
socket: WebSocketStream<ConnectStream>,
|
||||
session: Arc<Mutex<SessionInner>>,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
fn handle_receive(
|
||||
self: Pin<&mut Self>,
|
||||
message: tungstenite::Result<tungstenite::Message>,
|
||||
) -> eyre::Result<()> {
|
||||
// Technically everything should be Binary but I think I saw some Text before
|
||||
let message_data = match message? {
|
||||
tungstenite::Message::Text(t) => t.into_bytes(),
|
||||
tungstenite::Message::Binary(b) => b,
|
||||
_ => eyre::bail!("Unexpected WebSocket frame type"),
|
||||
};
|
||||
|
||||
let raw_messages = CMRawProtoBufMessage::try_parse_multi(&message_data)
|
||||
.wrap_err("Parsing raw messages")?;
|
||||
|
||||
let mut session = self.session.lock().expect("Lock was poisoned");
|
||||
|
||||
for message in raw_messages.into_iter() {
|
||||
log::trace!("Got message: {:?}", message);
|
||||
|
||||
let jobid = message.header.jobid_target();
|
||||
if let Some(waker) = session.receive_wakers.remove(&jobid) {
|
||||
if session.receive_messages.insert(jobid, message).is_some() {
|
||||
log::info!("Received duplicate message for jobid {}", jobid);
|
||||
};
|
||||
waker.wake();
|
||||
} else {
|
||||
log::info!("Received message for unknown jobid {}", jobid);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_send(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> eyre::Result<()> {
|
||||
// TODO: figure out how to not cloe the Arc
|
||||
let session_arc = self.session.clone();
|
||||
let mut session = session_arc.lock().expect("Lock was poisoned");
|
||||
|
||||
while !session.send_queue.is_empty() {
|
||||
match self.socket.poll_ready_unpin(cx) {
|
||||
Poll::Ready(ret) => ret?,
|
||||
Poll::Pending => return Ok(()),
|
||||
}
|
||||
let message = session.send_queue.pop_front().unwrap();
|
||||
self.socket.start_send_unpin(message)?;
|
||||
}
|
||||
|
||||
match self.socket.poll_flush_unpin(cx) {
|
||||
Poll::Ready(ret) => ret?,
|
||||
Poll::Pending => (),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_inner(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> eyre::Result<()> {
|
||||
{
|
||||
let mut session = self.session.lock().expect("Lock was poisoned");
|
||||
session.send_waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
if let Poll::Ready(maybe_message) = self.as_mut().socket.poll_next_unpin(cx) {
|
||||
let message = maybe_message.ok_or_eyre("Socket was closed while trying to recieve")?;
|
||||
if let Err(err) = self.as_mut().handle_receive(message) {
|
||||
log::warn!("Got error while processing message: {}", err);
|
||||
}
|
||||
}
|
||||
|
||||
// All errors in send are critical, since everything's encoded by the time we get there
|
||||
self.handle_send(cx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Context {
|
||||
type Output = eyre::Result<()>;
|
||||
|
||||
fn poll(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Self::Output> {
|
||||
if let Err(error) = self.poll_inner(cx) {
|
||||
log::error!("Failed while talking to socket: {}", error);
|
||||
todo!("Reopen new socket");
|
||||
};
|
||||
|
||||
// We should always be pending unless we're going to return an error
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CMSession {
|
||||
inner: Arc<Mutex<SessionInner>>,
|
||||
}
|
||||
|
||||
impl CMSession {
|
||||
|
@ -60,90 +195,59 @@ impl CMSession {
|
|||
.await
|
||||
.wrap_err("Connecting to Steam server")?;
|
||||
|
||||
Ok(Self {
|
||||
socket,
|
||||
let inner = SessionInner {
|
||||
steam_id: None,
|
||||
next_jobid: 0,
|
||||
realm: 1,
|
||||
client_session_id: 0,
|
||||
send_queue: VecDeque::new(),
|
||||
send_waker: None,
|
||||
receive_wakers: HashMap::new(),
|
||||
receive_messages: HashMap::new(),
|
||||
};
|
||||
|
||||
let inner_wrapped = Arc::new(Mutex::new(inner));
|
||||
|
||||
let context = Context {
|
||||
socket,
|
||||
session: inner_wrapped.clone(),
|
||||
};
|
||||
|
||||
tokio::spawn(context);
|
||||
|
||||
Ok(Self {
|
||||
inner: inner_wrapped,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn call_service_method<Request: protobuf::Message, Response: protobuf::Message>(
|
||||
pub fn call_service_method<T: protobuf::Message, U: protobuf::Message>(
|
||||
&mut self,
|
||||
method: String,
|
||||
body: Request,
|
||||
) -> eyre::Result<CMProtoBufMessage<Response>> {
|
||||
body: T,
|
||||
) -> CallServiceMethod<T, U> {
|
||||
log::trace!("Calling service method `{}`", method);
|
||||
|
||||
let action = if self.is_authed() {
|
||||
enums_clientserver::EMsg::k_EMsgServiceMethodCallFromClient
|
||||
} else {
|
||||
enums_clientserver::EMsg::k_EMsgServiceMethodCallFromClientNonAuthed
|
||||
};
|
||||
|
||||
let jobid = self.next_jobid;
|
||||
self.next_jobid += 1;
|
||||
|
||||
let header = steammessages_base::CMsgProtoBufHeader {
|
||||
steamid: self.steam_id,
|
||||
target_job_name: Some(method),
|
||||
realm: Some(self.realm),
|
||||
client_sessionid: Some(self.client_session_id),
|
||||
jobid_source: Some(jobid),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let message = CMProtoBufMessage {
|
||||
action,
|
||||
header,
|
||||
CallServiceMethod::<T, U> {
|
||||
session: self,
|
||||
body,
|
||||
};
|
||||
let serialized = message.serialize()?;
|
||||
self.socket
|
||||
.send(tungstenite::protocol::Message::Binary(serialized))
|
||||
.await?;
|
||||
let response_message = self
|
||||
.socket
|
||||
.try_next()
|
||||
.await?
|
||||
.ok_or_eyre("No message recieved")?;
|
||||
|
||||
let tungstenite::protocol::Message::Binary(response_binary) = response_message else {
|
||||
bail!("Message recieved was not binary")
|
||||
};
|
||||
|
||||
let responses_raw = CMRawProtoBufMessage::try_parse_multi(&response_binary)?;
|
||||
if responses_raw.len() != 1 {
|
||||
todo!("Multiple responses")
|
||||
method,
|
||||
jobid: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
|
||||
let response_raw = responses_raw.into_iter().next().unwrap();
|
||||
|
||||
if response_raw.action != enums_clientserver::EMsg::k_EMsgServiceMethodResponse {
|
||||
bail!(
|
||||
"Wanted ServiceMethodResponse, got {:?}",
|
||||
response_raw.action
|
||||
);
|
||||
}
|
||||
|
||||
if response_raw.header.jobid_target() != jobid {
|
||||
bail!("Got wrong jobid")
|
||||
}
|
||||
|
||||
CMProtoBufMessage::<Response>::deserialize(response_raw)
|
||||
}
|
||||
|
||||
/// Send a message without a jobid
|
||||
pub async fn send_notification<T: protobuf::Message>(
|
||||
&mut self,
|
||||
action: enums_clientserver::EMsg,
|
||||
action: EMsg,
|
||||
body: T,
|
||||
) -> eyre::Result<()> {
|
||||
let header = steammessages_base::CMsgProtoBufHeader {
|
||||
steamid: self.steam_id,
|
||||
realm: Some(self.realm),
|
||||
client_sessionid: Some(self.client_session_id),
|
||||
let mut inner = self.inner.lock().expect("Lock was poisoned");
|
||||
|
||||
let header = CMsgProtoBufHeader {
|
||||
steamid: inner.steam_id,
|
||||
realm: Some(inner.realm),
|
||||
client_sessionid: Some(inner.client_session_id),
|
||||
..Default::default()
|
||||
};
|
||||
let message = CMProtoBufMessage {
|
||||
|
@ -154,15 +258,93 @@ impl CMSession {
|
|||
|
||||
let serialized = message.serialize()?;
|
||||
|
||||
self.socket
|
||||
.send(tungstenite::protocol::Message::Binary(serialized))
|
||||
.await?;
|
||||
inner
|
||||
.send_queue
|
||||
.push_back(tungstenite::protocol::Message::Binary(serialized));
|
||||
inner.wake_sender();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the current session is authenticated
|
||||
pub fn is_authed(&self) -> bool {
|
||||
self.steam_id.is_some()
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct CallServiceMethod<'a, T: protobuf::Message, U: protobuf::Message> {
|
||||
session: &'a mut CMSession,
|
||||
body: T,
|
||||
method: String,
|
||||
jobid: Option<u64>,
|
||||
_phantom: PhantomData<U>,
|
||||
}
|
||||
|
||||
impl<'a, T: protobuf::Message, U: protobuf::Message> Unpin for CallServiceMethod<'a, T, U> {}
|
||||
|
||||
impl<'a, T: protobuf::Message, U: protobuf::Message> CallServiceMethod<'a, T, U> {
|
||||
fn finalize_response(
|
||||
&self,
|
||||
response: CMRawProtoBufMessage,
|
||||
) -> eyre::Result<CMProtoBufMessage<U>> {
|
||||
if response.action != EMsg::k_EMsgServiceMethodResponse {
|
||||
bail!("Wanted ServiceMethodResponse, got {:?}", response.action);
|
||||
}
|
||||
|
||||
if response.header.jobid_target() != self.jobid.unwrap() {
|
||||
bail!("Got wrong jobid")
|
||||
}
|
||||
|
||||
CMProtoBufMessage::<U>::deserialize(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: protobuf::Message, U: protobuf::Message> Future for CallServiceMethod<'_, T, U> {
|
||||
type Output = eyre::Result<CMProtoBufMessage<U>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
||||
let session_arc = self.session.inner.clone();
|
||||
let mut session = session_arc.lock().expect("Lock was poisoned");
|
||||
|
||||
// We only have to send the message once, use jobid for that flag
|
||||
if self.jobid.is_none() {
|
||||
let jobid = session.alloc_jobid();
|
||||
self.jobid = Some(jobid);
|
||||
|
||||
let action = if session.steam_id.is_some() {
|
||||
EMsg::k_EMsgServiceMethodCallFromClient
|
||||
} else {
|
||||
EMsg::k_EMsgServiceMethodCallFromClientNonAuthed
|
||||
};
|
||||
|
||||
let header = CMsgProtoBufHeader {
|
||||
steamid: session.steam_id,
|
||||
target_job_name: Some(self.method.clone()),
|
||||
realm: Some(session.realm),
|
||||
client_sessionid: Some(session.client_session_id),
|
||||
jobid_source: self.jobid,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let message = CMProtoBufMessage {
|
||||
action,
|
||||
header,
|
||||
body: self.body.clone(),
|
||||
};
|
||||
let serialized = message.serialize()?;
|
||||
session
|
||||
.send_queue
|
||||
.push_back(tungstenite::protocol::Message::Binary(serialized));
|
||||
|
||||
session.receive_wakers.insert(jobid, cx.waker().clone());
|
||||
|
||||
session.wake_sender();
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
let jobid = self.jobid.unwrap();
|
||||
|
||||
let Some(response) = session.receive_messages.remove(&jobid) else {
|
||||
session.receive_wakers.insert(jobid, cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
Poll::Ready(self.finalize_response(response))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue