lib: Replace eyre with thiserror

This commit is contained in:
Artemis Tosini 2024-09-05 01:34:05 +00:00
parent 0503e4ecbf
commit d539731705
Signed by: artemist
GPG key ID: ADFFE553DCBB831E
6 changed files with 106 additions and 56 deletions

1
Cargo.lock generated
View file

@ -1873,6 +1873,7 @@ dependencies = [
"reqwest", "reqwest",
"rsa", "rsa",
"serde", "serde",
"thiserror",
"tokio", "tokio",
"vapore-proto", "vapore-proto",
] ]

View file

@ -15,6 +15,7 @@ rand = "0.8.5"
reqwest = { version = "0.12", features = ["rustls-tls-native-roots"], default-features = false} reqwest = { version = "0.12", features = ["rustls-tls-native-roots"], default-features = false}
rsa = "0.9.6" rsa = "0.9.6"
serde = { version = "1.0.209", features = ["derive"] } serde = { version = "1.0.209", features = ["derive"] }
thiserror = "1.0.63"
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "macros", "time"]} tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "macros", "time"]}
vapore-proto.path = "../proto" vapore-proto.path = "../proto"

View file

@ -12,8 +12,6 @@ use async_tungstenite::{
tokio::{connect_async, ConnectStream}, tokio::{connect_async, ConnectStream},
tungstenite, WebSocketStream, tungstenite, WebSocketStream,
}; };
use color_eyre::eyre::WrapErr;
use color_eyre::eyre::{self, bail, OptionExt};
use futures::{SinkExt as _, StreamExt}; use futures::{SinkExt as _, StreamExt};
use serde::Deserialize; use serde::Deserialize;
use tokio::sync::broadcast; use tokio::sync::broadcast;
@ -22,7 +20,10 @@ use vapore_proto::{
steammessages_clientserver_login::CMsgClientHeartBeat, steammessages_clientserver_login::CMsgClientHeartBeat,
}; };
use crate::message::{CMProtoBufMessage, CMRawProtoBufMessage}; use crate::{
message::{CMProtoBufMessage, CMRawProtoBufMessage},
ClientError,
};
/// Maximum number of messages in the buffer for by-message-type subscriptions /// Maximum number of messages in the buffer for by-message-type subscriptions
const CHANNEL_CAPACITY: usize = 16; const CHANNEL_CAPACITY: usize = 16;
@ -46,7 +47,7 @@ pub struct GetCMListForConnectResponse<'a> {
message: &'a str, message: &'a str,
} }
pub async fn bootstrap_find_servers() -> eyre::Result<Vec<String>> { pub async fn bootstrap_find_servers() -> Result<Vec<String>, ClientError> {
let response = reqwest::get( let response = reqwest::get(
"https://api.steampowered.com/ISteamDirectory/GetCMListForConnect/v1/?cellid=0&format=vdf", "https://api.steampowered.com/ISteamDirectory/GetCMListForConnect/v1/?cellid=0&format=vdf",
) )
@ -56,11 +57,10 @@ pub async fn bootstrap_find_servers() -> eyre::Result<Vec<String>> {
let result: GetCMListForConnectResponse = keyvalues_serde::from_str(&response)?; let result: GetCMListForConnectResponse = keyvalues_serde::from_str(&response)?;
if result.success != 1 { if result.success != 1 {
eyre::bail!( return Err(ClientError::EResult(
"GetCMList returned bad result {} wtih message {}",
result.success, result.success,
result.message result.message.to_string(),
) ));
} }
Ok(result Ok(result
@ -118,16 +118,15 @@ impl Context {
fn handle_receive( fn handle_receive(
self: Pin<&mut Self>, self: Pin<&mut Self>,
message: tungstenite::Result<tungstenite::Message>, message: tungstenite::Result<tungstenite::Message>,
) -> eyre::Result<()> { ) -> Result<(), ClientError> {
// Technically everything should be Binary but I think I saw some Text before // Technically everything should be Binary but I think I saw some Text before
let message_data = match message? { let message_data = match message? {
tungstenite::Message::Text(t) => t.into_bytes(), tungstenite::Message::Text(t) => t.into_bytes(),
tungstenite::Message::Binary(b) => b, tungstenite::Message::Binary(b) => b,
_ => eyre::bail!("Unexpected WebSocket frame type"), _ => return Err(ClientError::BadWSMessageType),
}; };
let raw_messages = CMRawProtoBufMessage::try_parse_multi(&message_data) let raw_messages = CMRawProtoBufMessage::try_parse_multi(&message_data)?;
.wrap_err("Parsing raw messages")?;
let mut session = self.session.lock().expect("Lock was poisoned"); let mut session = self.session.lock().expect("Lock was poisoned");
@ -166,10 +165,13 @@ impl Context {
Ok(()) Ok(())
} }
fn handle_send(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> eyre::Result<()> { fn handle_send(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context,
) -> Result<(), ClientError> {
// TODO: figure out how to not cloe the Arc // TODO: figure out how to not cloe the Arc
let session_arc = self.session.clone(); let session_arc = self.session.clone();
let mut session = session_arc.lock().expect("Lock was poisoned"); let mut session = session_arc.lock()?;
while !session.send_queue.is_empty() { while !session.send_queue.is_empty() {
match self.socket.poll_ready_unpin(cx) { match self.socket.poll_ready_unpin(cx) {
@ -188,14 +190,17 @@ impl Context {
Ok(()) Ok(())
} }
fn poll_inner(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> eyre::Result<()> { fn poll_inner(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Result<(), ClientError> {
{ {
let mut session = self.session.lock().expect("Lock was poisoned"); let mut session = self.session.lock().expect("Lock was poisoned");
session.send_waker = Some(cx.waker().clone()); session.send_waker = Some(cx.waker().clone());
} }
if let Poll::Ready(maybe_message) = self.as_mut().socket.poll_next_unpin(cx) { if let Poll::Ready(maybe_message) = self.as_mut().socket.poll_next_unpin(cx) {
let message = maybe_message.ok_or_eyre("Socket was closed while trying to recieve")?; let message = maybe_message.ok_or(ClientError::ClosedSocket)?;
if let Err(err) = self.as_mut().handle_receive(message) { if let Err(err) = self.as_mut().handle_receive(message) {
log::warn!("Got error while processing message: {:?}", err); log::warn!("Got error while processing message: {:?}", err);
} }
@ -209,7 +214,7 @@ impl Context {
} }
impl Future for Context { impl Future for Context {
type Output = eyre::Result<()>; type Output = Result<(), ClientError>;
fn poll( fn poll(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
@ -231,10 +236,8 @@ pub struct CMSession {
} }
impl CMSession { impl CMSession {
pub async fn connect(server: &str) -> eyre::Result<Self> { pub async fn connect(server: &str) -> Result<Self, ClientError> {
let (socket, _) = connect_async(server) let (socket, _) = connect_async(server).await?;
.await
.wrap_err("Connecting to Steam server")?;
let inner = SessionInner { let inner = SessionInner {
steam_id: None, steam_id: None,
@ -269,7 +272,7 @@ impl CMSession {
tokio::spawn(async move { cloned.send_heartbeat_task(interval).await }); tokio::spawn(async move { cloned.send_heartbeat_task(interval).await });
} }
async fn send_heartbeat_task(self, interval_secs: u32) -> eyre::Result<()> { async fn send_heartbeat_task(self, interval_secs: u32) -> Result<(), ClientError> {
let mut interval = tokio::time::interval(time::Duration::from_secs(interval_secs as u64)); let mut interval = tokio::time::interval(time::Duration::from_secs(interval_secs as u64));
loop { loop {
interval.tick().await; interval.tick().await;
@ -309,7 +312,7 @@ impl CMSession {
&self, &self,
action: EMsg, action: EMsg,
body: T, body: T,
) -> eyre::Result<()> { ) -> Result<(), ClientError> {
let mut inner = self.inner.lock().expect("Lock was poisoned"); let mut inner = self.inner.lock().expect("Lock was poisoned");
log::trace!("Sending notification of type {:?}", action); log::trace!("Sending notification of type {:?}", action);
@ -368,13 +371,9 @@ impl<'a, T: protobuf::Message, U: protobuf::Message> CallServiceMethod<'a, T, U>
fn finalize_response( fn finalize_response(
&self, &self,
response: CMRawProtoBufMessage, response: CMRawProtoBufMessage,
) -> eyre::Result<CMProtoBufMessage<U>> { ) -> Result<CMProtoBufMessage<U>, ClientError> {
if response.action != EMsg::k_EMsgServiceMethodResponse { if response.action != EMsg::k_EMsgServiceMethodResponse {
bail!("Wanted ServiceMethodResponse, got {:?}", response.action); return Err(ClientError::BadResponseAction(response.action));
}
if response.header.jobid_target() != self.jobid.unwrap() {
bail!("Got wrong jobid")
} }
CMProtoBufMessage::<U>::deserialize(response) CMProtoBufMessage::<U>::deserialize(response)
@ -382,7 +381,7 @@ impl<'a, T: protobuf::Message, U: protobuf::Message> CallServiceMethod<'a, T, U>
} }
impl<T: protobuf::Message, U: protobuf::Message> Future for CallServiceMethod<'_, T, U> { impl<T: protobuf::Message, U: protobuf::Message> Future for CallServiceMethod<'_, T, U> {
type Output = eyre::Result<CMProtoBufMessage<U>>; type Output = Result<CMProtoBufMessage<U>, ClientError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let session_arc = self.session.inner.clone(); let session_arc = self.session.inner.clone();

52
lib/src/error.rs Normal file
View file

@ -0,0 +1,52 @@
#[non_exhaustive]
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Valve returned bad result {0} with message `{1}`")]
EResult(u32, String),
#[error("Request failure")]
Reqwest(#[from] reqwest::Error),
#[error("VDF Deserialization failure")]
Vdf(#[from] keyvalues_serde::Error),
#[error("WebSocket connection error")]
WebSocket(#[from] async_tungstenite::tungstenite::Error),
#[error("WebSocket was closed while trying to recieve")]
ClosedSocket,
#[error("ProtoBuf Deserialization error")]
Protobuf(#[from] protobuf::Error),
#[error("Invalid WebSocket message type from server")]
BadWSMessageType,
#[error("Decompression Error")]
DecompressionError(#[source] std::io::Error),
#[error("Invalid decompressed output (expected {0} bytes, got {1})")]
DecompressionInvalid(u32, usize),
#[error("Invalid message action {0}")]
InvalidAction(u32),
#[error("Invalid message length")]
InvalidMessageLength,
#[error("Message too short (need {0} bytes, got {1})")]
MessageTooShort(usize, usize),
#[error("Lock was poisoned")]
LockPoisoned,
#[error("Expected action ServiceMethodResponse, got {0:?}")]
BadResponseAction(vapore_proto::enums_clientserver::EMsg),
}
impl<T> From<std::sync::PoisonError<T>> for ClientError {
fn from(_value: std::sync::PoisonError<T>) -> Self {
// The guard won't be Send so we can't return it from async functions
ClientError::LockPoisoned
}
}

View file

@ -1,2 +1,5 @@
pub mod connection; pub mod connection;
pub mod error;
pub mod message; pub mod message;
pub use error::ClientError;

View file

@ -1,6 +1,5 @@
use std::io::Read; use std::io::Read;
use color_eyre::eyre;
use flate2::read::GzDecoder; use flate2::read::GzDecoder;
use protobuf::{Enum as _, Message as _}; use protobuf::{Enum as _, Message as _};
use vapore_proto::{ use vapore_proto::{
@ -8,6 +7,8 @@ use vapore_proto::{
steammessages_base::{CMsgMulti, CMsgProtoBufHeader}, steammessages_base::{CMsgMulti, CMsgProtoBufHeader},
}; };
use crate::ClientError;
/// A message sent over the socket. Can be either sent or recieved /// A message sent over the socket. Can be either sent or recieved
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CMProtoBufMessage<T: protobuf::Message> { pub struct CMProtoBufMessage<T: protobuf::Message> {
@ -17,11 +18,15 @@ pub struct CMProtoBufMessage<T: protobuf::Message> {
} }
impl<T: protobuf::Message> CMProtoBufMessage<T> { impl<T: protobuf::Message> CMProtoBufMessage<T> {
pub fn serialize(&self) -> eyre::Result<Vec<u8>> { pub fn serialize(&self) -> Result<Vec<u8>, ClientError> {
// 4 bytes for type, 4 bytes for header length, then header and body // 4 bytes for type, 4 bytes for header length, then header and body
// No alignment requirements // No alignment requirements
let length = 4 + 4 + self.header.compute_size() + self.body.compute_size(); let length = 4 + 4 + self.header.compute_size() + self.body.compute_size();
let mut out = Vec::with_capacity(length.try_into()?); let mut out = Vec::with_capacity(
length
.try_into()
.map_err(|_| ClientError::InvalidMessageLength)?,
);
out.extend_from_slice(&(self.action.value() as u32 | 0x80000000).to_le_bytes()); out.extend_from_slice(&(self.action.value() as u32 | 0x80000000).to_le_bytes());
out.extend_from_slice(&self.header.cached_size().to_le_bytes()); out.extend_from_slice(&self.header.cached_size().to_le_bytes());
@ -31,7 +36,7 @@ impl<T: protobuf::Message> CMProtoBufMessage<T> {
Ok(out) Ok(out)
} }
pub fn deserialize(raw: CMRawProtoBufMessage) -> eyre::Result<Self> { pub fn deserialize(raw: CMRawProtoBufMessage) -> Result<Self, ClientError> {
let body = T::parse_from_bytes(&raw.body)?; let body = T::parse_from_bytes(&raw.body)?;
Ok(Self { Ok(Self {
@ -51,25 +56,18 @@ pub struct CMRawProtoBufMessage {
} }
impl CMRawProtoBufMessage { impl CMRawProtoBufMessage {
pub fn try_parse(binary: &[u8]) -> eyre::Result<Self> { pub fn try_parse(binary: &[u8]) -> Result<Self, ClientError> {
if binary.len() < 8 { if binary.len() < 8 {
eyre::bail!( return Err(ClientError::MessageTooShort(8, binary.len()));
"Message too short for type (need 8 bytes, was {} bytes)",
binary.len()
);
} }
let raw_action = u32::from_le_bytes(binary[0..4].try_into().unwrap()) & !0x8000_0000; let raw_action = u32::from_le_bytes(binary[0..4].try_into().unwrap()) & !0x8000_0000;
let action = EMsg::from_i32(raw_action as i32) let action = EMsg::from_i32(raw_action as i32)
.ok_or_else(|| eyre::eyre!("Unknown message action {}", raw_action))?; .ok_or_else(|| ClientError::InvalidAction(raw_action))?;
let header_length = u32::from_le_bytes(binary[4..8].try_into().unwrap()); let header_length = u32::from_le_bytes(binary[4..8].try_into().unwrap());
let header_end = 8 + header_length as usize; let header_end = 8 + header_length as usize;
if binary.len() < header_end { if binary.len() < header_end {
eyre::bail!( return Err(ClientError::MessageTooShort(header_end, binary.len()));
"Message too short for header (need {}, was {})",
header_end,
binary.len()
)
} }
let header = CMsgProtoBufHeader::parse_from_bytes(&binary[8..header_end])?; let header = CMsgProtoBufHeader::parse_from_bytes(&binary[8..header_end])?;
@ -82,7 +80,7 @@ impl CMRawProtoBufMessage {
}) })
} }
pub fn try_parse_multi(binary: &[u8]) -> eyre::Result<Vec<Self>> { pub fn try_parse_multi(binary: &[u8]) -> Result<Vec<Self>, ClientError> {
let root_raw = Self::try_parse(binary)?; let root_raw = Self::try_parse(binary)?;
if root_raw.action != EMsg::k_EMsgMulti { if root_raw.action != EMsg::k_EMsgMulti {
return Ok(vec![root_raw]); return Ok(vec![root_raw]);
@ -98,14 +96,14 @@ impl CMRawProtoBufMessage {
gzip_decompressed.reserve(size_unzipped as usize); gzip_decompressed.reserve(size_unzipped as usize);
let mut gz = GzDecoder::new(root.body.message_body()); let mut gz = GzDecoder::new(root.body.message_body());
gz.read_to_end(&mut gzip_decompressed)?; gz.read_to_end(&mut gzip_decompressed)
.map_err(ClientError::DecompressionError)?;
if gzip_decompressed.len() != size_unzipped as usize { if gzip_decompressed.len() != size_unzipped as usize {
eyre::bail!( return Err(ClientError::DecompressionInvalid(
"Expected decompressed len {}, got {}",
size_unzipped, size_unzipped,
gzip_decompressed.len() gzip_decompressed.len(),
); ));
} }
&gzip_decompressed &gzip_decompressed
@ -118,11 +116,7 @@ impl CMRawProtoBufMessage {
let full_length = u32::from_le_bytes(body[0..4].try_into().unwrap()); let full_length = u32::from_le_bytes(body[0..4].try_into().unwrap());
let message_end = 4 + full_length as usize; let message_end = 4 + full_length as usize;
if body.len() < message_end { if body.len() < message_end {
eyre::bail!( return Err(ClientError::MessageTooShort(message_end, body.len()));
"sub-message too short (need {}, got {})",
message_end,
body.len()
)
} }
match Self::try_parse(&body[4..message_end]) { match Self::try_parse(&body[4..message_end]) {