Share the websocket between tasks
This commit is contained in:
parent
ded292f332
commit
44aa1f8aac
|
@ -1,4 +1,11 @@
|
||||||
use std::collections::BTreeMap;
|
use std::{
|
||||||
|
collections::{BTreeMap, HashMap, VecDeque},
|
||||||
|
future::Future,
|
||||||
|
marker::PhantomData,
|
||||||
|
pin::Pin,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
task::{Poll, Waker},
|
||||||
|
};
|
||||||
|
|
||||||
use async_tungstenite::{
|
use async_tungstenite::{
|
||||||
tokio::{connect_async, ConnectStream},
|
tokio::{connect_async, ConnectStream},
|
||||||
|
@ -6,9 +13,9 @@ use async_tungstenite::{
|
||||||
};
|
};
|
||||||
use color_eyre::eyre::WrapErr;
|
use color_eyre::eyre::WrapErr;
|
||||||
use color_eyre::eyre::{self, bail, OptionExt};
|
use color_eyre::eyre::{self, bail, OptionExt};
|
||||||
use futures::{SinkExt as _, TryStreamExt};
|
use futures::{SinkExt as _, StreamExt};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use vapore_proto::{enums_clientserver, steammessages_base};
|
use vapore_proto::{enums_clientserver::EMsg, steammessages_base::CMsgProtoBufHeader};
|
||||||
|
|
||||||
use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage};
|
use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage};
|
||||||
|
|
||||||
|
@ -40,8 +47,8 @@ pub async fn bootstrap_find_servers() -> eyre::Result<Vec<String>> {
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CMSession {
|
#[derive(Clone)]
|
||||||
socket: WebSocketStream<ConnectStream>,
|
struct SessionInner {
|
||||||
/// Steam ID of current user. When set to None we are not logged in
|
/// Steam ID of current user. When set to None we are not logged in
|
||||||
steam_id: Option<u64>,
|
steam_id: Option<u64>,
|
||||||
/// Next jobid to use for messages that start a "job"
|
/// Next jobid to use for messages that start a "job"
|
||||||
|
@ -52,6 +59,134 @@ pub struct CMSession {
|
||||||
/// Session ID for our socket, assigned by the server after login.
|
/// Session ID for our socket, assigned by the server after login.
|
||||||
/// Should be 0 before we login
|
/// Should be 0 before we login
|
||||||
client_session_id: i32,
|
client_session_id: i32,
|
||||||
|
|
||||||
|
/// Messages ready to send
|
||||||
|
send_queue: VecDeque<tungstenite::Message>,
|
||||||
|
/// Waker for the sending thread
|
||||||
|
send_waker: Option<Waker>,
|
||||||
|
|
||||||
|
/// Recievers waiting for responses by job id
|
||||||
|
receive_wakers: HashMap<u64, Waker>,
|
||||||
|
/// Messages ready for receivers by job id
|
||||||
|
/// TODO: Support multiple messages for same job ID
|
||||||
|
receive_messages: HashMap<u64, CMRawProtoBufMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionInner {
|
||||||
|
pub fn alloc_jobid(&mut self) -> u64 {
|
||||||
|
let jobid = self.next_jobid;
|
||||||
|
self.next_jobid += 1;
|
||||||
|
jobid
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wake_sender(&mut self) {
|
||||||
|
if let Some(waker) = self.send_waker.take() {
|
||||||
|
waker.wake()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Context {
|
||||||
|
socket: WebSocketStream<ConnectStream>,
|
||||||
|
session: Arc<Mutex<SessionInner>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
fn handle_receive(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
message: tungstenite::Result<tungstenite::Message>,
|
||||||
|
) -> eyre::Result<()> {
|
||||||
|
// Technically everything should be Binary but I think I saw some Text before
|
||||||
|
let message_data = match message? {
|
||||||
|
tungstenite::Message::Text(t) => t.into_bytes(),
|
||||||
|
tungstenite::Message::Binary(b) => b,
|
||||||
|
_ => eyre::bail!("Unexpected WebSocket frame type"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let raw_messages = CMRawProtoBufMessage::try_parse_multi(&message_data)
|
||||||
|
.wrap_err("Parsing raw messages")?;
|
||||||
|
|
||||||
|
let mut session = self.session.lock().expect("Lock was poisoned");
|
||||||
|
|
||||||
|
for message in raw_messages.into_iter() {
|
||||||
|
log::trace!("Got message: {:?}", message);
|
||||||
|
|
||||||
|
let jobid = message.header.jobid_target();
|
||||||
|
if let Some(waker) = session.receive_wakers.remove(&jobid) {
|
||||||
|
if session.receive_messages.insert(jobid, message).is_some() {
|
||||||
|
log::info!("Received duplicate message for jobid {}", jobid);
|
||||||
|
};
|
||||||
|
waker.wake();
|
||||||
|
} else {
|
||||||
|
log::info!("Received message for unknown jobid {}", jobid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_send(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> eyre::Result<()> {
|
||||||
|
// TODO: figure out how to not cloe the Arc
|
||||||
|
let session_arc = self.session.clone();
|
||||||
|
let mut session = session_arc.lock().expect("Lock was poisoned");
|
||||||
|
|
||||||
|
while !session.send_queue.is_empty() {
|
||||||
|
match self.socket.poll_ready_unpin(cx) {
|
||||||
|
Poll::Ready(ret) => ret?,
|
||||||
|
Poll::Pending => return Ok(()),
|
||||||
|
}
|
||||||
|
let message = session.send_queue.pop_front().unwrap();
|
||||||
|
self.socket.start_send_unpin(message)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.socket.poll_flush_unpin(cx) {
|
||||||
|
Poll::Ready(ret) => ret?,
|
||||||
|
Poll::Pending => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_inner(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> eyre::Result<()> {
|
||||||
|
{
|
||||||
|
let mut session = self.session.lock().expect("Lock was poisoned");
|
||||||
|
session.send_waker = Some(cx.waker().clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Poll::Ready(maybe_message) = self.as_mut().socket.poll_next_unpin(cx) {
|
||||||
|
let message = maybe_message.ok_or_eyre("Socket was closed while trying to recieve")?;
|
||||||
|
if let Err(err) = self.as_mut().handle_receive(message) {
|
||||||
|
log::warn!("Got error while processing message: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All errors in send are critical, since everything's encoded by the time we get there
|
||||||
|
self.handle_send(cx)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Future for Context {
|
||||||
|
type Output = eyre::Result<()>;
|
||||||
|
|
||||||
|
fn poll(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Self::Output> {
|
||||||
|
if let Err(error) = self.poll_inner(cx) {
|
||||||
|
log::error!("Failed while talking to socket: {}", error);
|
||||||
|
todo!("Reopen new socket");
|
||||||
|
};
|
||||||
|
|
||||||
|
// We should always be pending unless we're going to return an error
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CMSession {
|
||||||
|
inner: Arc<Mutex<SessionInner>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CMSession {
|
impl CMSession {
|
||||||
|
@ -60,90 +195,59 @@ impl CMSession {
|
||||||
.await
|
.await
|
||||||
.wrap_err("Connecting to Steam server")?;
|
.wrap_err("Connecting to Steam server")?;
|
||||||
|
|
||||||
Ok(Self {
|
let inner = SessionInner {
|
||||||
socket,
|
|
||||||
steam_id: None,
|
steam_id: None,
|
||||||
next_jobid: 0,
|
next_jobid: 0,
|
||||||
realm: 1,
|
realm: 1,
|
||||||
client_session_id: 0,
|
client_session_id: 0,
|
||||||
|
send_queue: VecDeque::new(),
|
||||||
|
send_waker: None,
|
||||||
|
receive_wakers: HashMap::new(),
|
||||||
|
receive_messages: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let inner_wrapped = Arc::new(Mutex::new(inner));
|
||||||
|
|
||||||
|
let context = Context {
|
||||||
|
socket,
|
||||||
|
session: inner_wrapped.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::spawn(context);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
inner: inner_wrapped,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn call_service_method<Request: protobuf::Message, Response: protobuf::Message>(
|
pub fn call_service_method<T: protobuf::Message, U: protobuf::Message>(
|
||||||
&mut self,
|
&mut self,
|
||||||
method: String,
|
method: String,
|
||||||
body: Request,
|
body: T,
|
||||||
) -> eyre::Result<CMProtoBufMessage<Response>> {
|
) -> CallServiceMethod<T, U> {
|
||||||
log::trace!("Calling service method `{}`", method);
|
log::trace!("Calling service method `{}`", method);
|
||||||
|
|
||||||
let action = if self.is_authed() {
|
CallServiceMethod::<T, U> {
|
||||||
enums_clientserver::EMsg::k_EMsgServiceMethodCallFromClient
|
session: self,
|
||||||
} else {
|
|
||||||
enums_clientserver::EMsg::k_EMsgServiceMethodCallFromClientNonAuthed
|
|
||||||
};
|
|
||||||
|
|
||||||
let jobid = self.next_jobid;
|
|
||||||
self.next_jobid += 1;
|
|
||||||
|
|
||||||
let header = steammessages_base::CMsgProtoBufHeader {
|
|
||||||
steamid: self.steam_id,
|
|
||||||
target_job_name: Some(method),
|
|
||||||
realm: Some(self.realm),
|
|
||||||
client_sessionid: Some(self.client_session_id),
|
|
||||||
jobid_source: Some(jobid),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let message = CMProtoBufMessage {
|
|
||||||
action,
|
|
||||||
header,
|
|
||||||
body,
|
body,
|
||||||
};
|
method,
|
||||||
let serialized = message.serialize()?;
|
jobid: None,
|
||||||
self.socket
|
_phantom: PhantomData,
|
||||||
.send(tungstenite::protocol::Message::Binary(serialized))
|
|
||||||
.await?;
|
|
||||||
let response_message = self
|
|
||||||
.socket
|
|
||||||
.try_next()
|
|
||||||
.await?
|
|
||||||
.ok_or_eyre("No message recieved")?;
|
|
||||||
|
|
||||||
let tungstenite::protocol::Message::Binary(response_binary) = response_message else {
|
|
||||||
bail!("Message recieved was not binary")
|
|
||||||
};
|
|
||||||
|
|
||||||
let responses_raw = CMRawProtoBufMessage::try_parse_multi(&response_binary)?;
|
|
||||||
if responses_raw.len() != 1 {
|
|
||||||
todo!("Multiple responses")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let response_raw = responses_raw.into_iter().next().unwrap();
|
|
||||||
|
|
||||||
if response_raw.action != enums_clientserver::EMsg::k_EMsgServiceMethodResponse {
|
|
||||||
bail!(
|
|
||||||
"Wanted ServiceMethodResponse, got {:?}",
|
|
||||||
response_raw.action
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if response_raw.header.jobid_target() != jobid {
|
|
||||||
bail!("Got wrong jobid")
|
|
||||||
}
|
|
||||||
|
|
||||||
CMProtoBufMessage::<Response>::deserialize(response_raw)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a message without a jobid
|
/// Send a message without a jobid
|
||||||
pub async fn send_notification<T: protobuf::Message>(
|
pub async fn send_notification<T: protobuf::Message>(
|
||||||
&mut self,
|
&mut self,
|
||||||
action: enums_clientserver::EMsg,
|
action: EMsg,
|
||||||
body: T,
|
body: T,
|
||||||
) -> eyre::Result<()> {
|
) -> eyre::Result<()> {
|
||||||
let header = steammessages_base::CMsgProtoBufHeader {
|
let mut inner = self.inner.lock().expect("Lock was poisoned");
|
||||||
steamid: self.steam_id,
|
|
||||||
realm: Some(self.realm),
|
let header = CMsgProtoBufHeader {
|
||||||
client_sessionid: Some(self.client_session_id),
|
steamid: inner.steam_id,
|
||||||
|
realm: Some(inner.realm),
|
||||||
|
client_sessionid: Some(inner.client_session_id),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let message = CMProtoBufMessage {
|
let message = CMProtoBufMessage {
|
||||||
|
@ -154,15 +258,93 @@ impl CMSession {
|
||||||
|
|
||||||
let serialized = message.serialize()?;
|
let serialized = message.serialize()?;
|
||||||
|
|
||||||
self.socket
|
inner
|
||||||
.send(tungstenite::protocol::Message::Binary(serialized))
|
.send_queue
|
||||||
.await?;
|
.push_back(tungstenite::protocol::Message::Binary(serialized));
|
||||||
|
inner.wake_sender();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Whether the current session is authenticated
|
#[must_use = "futures do nothing unless polled"]
|
||||||
pub fn is_authed(&self) -> bool {
|
pub struct CallServiceMethod<'a, T: protobuf::Message, U: protobuf::Message> {
|
||||||
self.steam_id.is_some()
|
session: &'a mut CMSession,
|
||||||
|
body: T,
|
||||||
|
method: String,
|
||||||
|
jobid: Option<u64>,
|
||||||
|
_phantom: PhantomData<U>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: protobuf::Message, U: protobuf::Message> Unpin for CallServiceMethod<'a, T, U> {}
|
||||||
|
|
||||||
|
impl<'a, T: protobuf::Message, U: protobuf::Message> CallServiceMethod<'a, T, U> {
|
||||||
|
fn finalize_response(
|
||||||
|
&self,
|
||||||
|
response: CMRawProtoBufMessage,
|
||||||
|
) -> eyre::Result<CMProtoBufMessage<U>> {
|
||||||
|
if response.action != EMsg::k_EMsgServiceMethodResponse {
|
||||||
|
bail!("Wanted ServiceMethodResponse, got {:?}", response.action);
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.header.jobid_target() != self.jobid.unwrap() {
|
||||||
|
bail!("Got wrong jobid")
|
||||||
|
}
|
||||||
|
|
||||||
|
CMProtoBufMessage::<U>::deserialize(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: protobuf::Message, U: protobuf::Message> Future for CallServiceMethod<'_, T, U> {
|
||||||
|
type Output = eyre::Result<CMProtoBufMessage<U>>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let session_arc = self.session.inner.clone();
|
||||||
|
let mut session = session_arc.lock().expect("Lock was poisoned");
|
||||||
|
|
||||||
|
// We only have to send the message once, use jobid for that flag
|
||||||
|
if self.jobid.is_none() {
|
||||||
|
let jobid = session.alloc_jobid();
|
||||||
|
self.jobid = Some(jobid);
|
||||||
|
|
||||||
|
let action = if session.steam_id.is_some() {
|
||||||
|
EMsg::k_EMsgServiceMethodCallFromClient
|
||||||
|
} else {
|
||||||
|
EMsg::k_EMsgServiceMethodCallFromClientNonAuthed
|
||||||
|
};
|
||||||
|
|
||||||
|
let header = CMsgProtoBufHeader {
|
||||||
|
steamid: session.steam_id,
|
||||||
|
target_job_name: Some(self.method.clone()),
|
||||||
|
realm: Some(session.realm),
|
||||||
|
client_sessionid: Some(session.client_session_id),
|
||||||
|
jobid_source: self.jobid,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let message = CMProtoBufMessage {
|
||||||
|
action,
|
||||||
|
header,
|
||||||
|
body: self.body.clone(),
|
||||||
|
};
|
||||||
|
let serialized = message.serialize()?;
|
||||||
|
session
|
||||||
|
.send_queue
|
||||||
|
.push_back(tungstenite::protocol::Message::Binary(serialized));
|
||||||
|
|
||||||
|
session.receive_wakers.insert(jobid, cx.waker().clone());
|
||||||
|
|
||||||
|
session.wake_sender();
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
let jobid = self.jobid.unwrap();
|
||||||
|
|
||||||
|
let Some(response) = session.receive_messages.remove(&jobid) else {
|
||||||
|
session.receive_wakers.insert(jobid, cx.waker().clone());
|
||||||
|
return Poll::Pending;
|
||||||
|
};
|
||||||
|
|
||||||
|
Poll::Ready(self.finalize_response(response))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue