Browse Source

Handle disconnected users

Sergey Chushin 3 năm trước cách đây
mục cha
commit
b4e4bcca2c
5 tập tin đã thay đổi với 70 bổ sung41 xóa
  1. 43 18
      src/client.rs
  2. 1 1
      src/connection.rs
  3. 8 4
      src/db.rs
  4. 5 8
      src/protocol.rs
  5. 13 10
      src/server.rs

+ 43 - 18
src/client.rs

@@ -1,17 +1,22 @@
 use std::sync::Arc;
 use std::sync::Arc;
 
 
-use tokio::io::{AsyncWrite, AsyncRead};
+use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::sync::mpsc;
 use tokio::sync::mpsc;
 use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
 use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
+use tokio::task::JoinHandle;
 
 
 use crate::connection::Connection;
 use crate::connection::Connection;
 use crate::db::{Db, User};
 use crate::db::{Db, User};
-use crate::protocol::{MumblePacket, VoicePacket, MumblePacketWriter};
-use crate::proto::mumble::{UserState, Ping};
-use tokio::task::JoinHandle;
+use crate::proto::mumble::{Ping, UserRemove, UserState};
+use crate::protocol::{MumblePacket, MumblePacketWriter, VoicePacket};
 
 
 pub enum Message {
 pub enum Message {
-    NewUser(u32)
+    UserConnected(u32),
+    UserDisconnected(u32),
+}
+
+pub enum ResponseMessage {
+    Disconnected
 }
 }
 
 
 pub struct Client {
 pub struct Client {
@@ -21,42 +26,45 @@ pub struct Client {
 }
 }
 
 
 pub enum Error {
 pub enum Error {
-    UserNotFound,
     StreamError(crate::protocol::Error),
     StreamError(crate::protocol::Error),
 }
 }
 
 
 struct Handler<W> {
 struct Handler<W> {
     db: Arc<Db>,
     db: Arc<Db>,
     writer: MumblePacketWriter<W>,
     writer: MumblePacketWriter<W>,
-    response_sender: UnboundedSender<Message>,
+    session_id: u32,
+    response_sender: UnboundedSender<ResponseMessage>,
 }
 }
 
 
 enum InnerMessage {
 enum InnerMessage {
     Message(Message),
     Message(Message),
     Packet(MumblePacket),
     Packet(MumblePacket),
+    Disconnected,
 }
 }
 
 
-type ResponseReceiver = UnboundedReceiver<Message>;
+type ResponseReceiver = UnboundedReceiver<ResponseMessage>;
 
 
 impl Client {
 impl Client {
     pub async fn new<S>(connection: Connection<S>, db: Arc<Db>) -> (Client, ResponseReceiver)
     pub async fn new<S>(connection: Connection<S>, db: Arc<Db>) -> (Client, ResponseReceiver)
-    where
-        S: 'static + AsyncRead + AsyncWrite + Unpin + Send,
+        where
+            S: 'static + AsyncRead + AsyncWrite + Unpin + Send,
     {
     {
         let (sender, mut receiver) = mpsc::unbounded_channel();
         let (sender, mut receiver) = mpsc::unbounded_channel();
         let (response_sender, response_receiver) = mpsc::unbounded_channel();
         let (response_sender, response_receiver) = mpsc::unbounded_channel();
 
 
         let writer = connection.writer;
         let writer = connection.writer;
+        let session_id = connection.session_id;
         let handler_task = tokio::spawn(async move {
         let handler_task = tokio::spawn(async move {
             let mut handler = Handler {
             let mut handler = Handler {
                 db,
                 db,
                 writer,
                 writer,
+                session_id,
                 response_sender,
                 response_sender,
             };
             };
             loop {
             loop {
                 let message = match receiver.recv().await {
                 let message = match receiver.recv().await {
-                    None => return,
                     Some(msg) => msg,
                     Some(msg) => msg,
+                    None => return,
                 };
                 };
 
 
                 match message {
                 match message {
@@ -72,6 +80,10 @@ impl Client {
                             return;
                             return;
                         }
                         }
                     }
                     }
+                    InnerMessage::Disconnected => {
+                        handler.self_disconnected().await;
+                        return;
+                    }
                 }
                 }
             }
             }
         });
         });
@@ -80,12 +92,13 @@ impl Client {
         let mut reader = connection.reader;
         let mut reader = connection.reader;
         let packet_task = tokio::spawn(async move {
         let packet_task = tokio::spawn(async move {
             loop {
             loop {
-                let packet = match reader.read().await{
-                    Ok(packet) => packet,
-                    Err(_) => return, //TODO
+                match reader.read().await {
+                    Ok(packet) => sender.send(InnerMessage::Packet(packet)),
+                    Err(_) => {
+                        sender.send(InnerMessage::Disconnected);
+                        return;
+                    }
                 };
                 };
-
-                sender.send(InnerMessage::Packet(packet));
             }
             }
         });
         });
 
 
@@ -124,7 +137,7 @@ impl<W> Handler<W>
             MumblePacket::UdpTunnel(voice) => {
             MumblePacket::UdpTunnel(voice) => {
                 match voice {
                 match voice {
                     VoicePacket::Ping(_) => {
                     VoicePacket::Ping(_) => {
-                        self.writer.write(MumblePacket::UdpTunnel(voice));
+                        self.writer.write(MumblePacket::UdpTunnel(voice)).await;
                         println!("VoicePing");
                         println!("VoicePing");
                     }
                     }
                     VoicePacket::AudioData(_) => { println!("AudioData"); }
                     VoicePacket::AudioData(_) => { println!("AudioData"); }
@@ -137,7 +150,8 @@ impl<W> Handler<W>
 
 
     async fn handle_message(&mut self, message: Message) -> Result<(), Error> {
     async fn handle_message(&mut self, message: Message) -> Result<(), Error> {
         match message {
         match message {
-            Message::NewUser(session_id) => self.new_user_connected(session_id).await?,
+            Message::UserConnected(session_id) => self.new_user_connected(session_id).await?,
+            Message::UserDisconnected(session_id) => self.user_disconnected(session_id).await?,
         }
         }
 
 
         Ok(())
         Ok(())
@@ -149,6 +163,17 @@ impl<W> Handler<W>
         }
         }
         Ok(())
         Ok(())
     }
     }
+
+    async fn user_disconnected(&mut self, session_id: u32) -> Result<(), Error> {
+        let mut user_remove = UserRemove::new();
+        user_remove.set_session(session_id);
+        Ok(self.writer.write(MumblePacket::UserRemove(user_remove)).await?)
+    }
+
+    async fn self_disconnected(&mut self) {
+        self.db.remove_connected_user(self.session_id).await;
+        self.response_sender.send(ResponseMessage::Disconnected);
+    }
 }
 }
 
 
 impl From<User> for UserState {
 impl From<User> for UserState {

+ 1 - 1
src/connection.rs

@@ -72,7 +72,7 @@ impl<S> Connection<S>
         let mut permission_query = PermissionQuery::new();
         let mut permission_query = PermissionQuery::new();
         permission_query.set_permissions(134743822);
         permission_query.set_permissions(134743822);
         permission_query.set_channel_id(0);
         permission_query.set_channel_id(0);
-        writer.write(MumblePacket::PermissionQuery(permission_query));
+        writer.write(MumblePacket::PermissionQuery(permission_query)).await?;
 
 
         //User states
         //User states
         let connected_users = db.get_connected_users().await;
         let connected_users = db.get_connected_users().await;

+ 8 - 4
src/db.rs

@@ -1,8 +1,7 @@
 use std::collections::HashMap;
 use std::collections::HashMap;
 
 
-use tokio::sync::RwLock;
-
 use serde::{Deserialize, Serialize};
 use serde::{Deserialize, Serialize};
+use tokio::sync::RwLock;
 
 
 const ROOT_CHANNEL_ID: u32 = 0;
 const ROOT_CHANNEL_ID: u32 = 0;
 const USER_TREE_NAME: &[u8] = b"users";
 const USER_TREE_NAME: &[u8] = b"users";
@@ -81,16 +80,21 @@ impl Db {
 
 
     pub async fn get_connected_users(&self) -> Vec<User> {
     pub async fn get_connected_users(&self) -> Vec<User> {
         let users = self.connected_users.read().await;
         let users = self.connected_users.read().await;
-        users.values().map(|el| el.clone()).collect()
+        users.values().cloned().collect()
     }
     }
 
 
     pub async fn get_user_by_session_id(&self, session_id: u32) -> Option<User> {
     pub async fn get_user_by_session_id(&self, session_id: u32) -> Option<User> {
         let connected_users = self.connected_users.read().await;
         let connected_users = self.connected_users.read().await;
         if let Some(user) = connected_users.get(&session_id) {
         if let Some(user) = connected_users.get(&session_id) {
-            return Some(user.clone())
+            return Some(user.clone());
         }
         }
         None
         None
     }
     }
+
+    pub async fn remove_connected_user(&self, session_id: u32) {
+        let mut connected_users = self.connected_users.write().await;
+        connected_users.remove(&session_id);
+    }
 }
 }
 
 
 impl Clone for User {
 impl Clone for User {

+ 5 - 8
src/protocol.rs

@@ -73,7 +73,7 @@ pub enum VoicePacket {
 
 
 pub enum Error {
 pub enum Error {
     UnknownPacketType,
     UnknownPacketType,
-    ConnectionError,
+    ConnectionError(std::io::Error),
     ParsingError,
     ParsingError,
 }
 }
 
 
@@ -442,17 +442,14 @@ fn serialize_voice_packet(packet: VoicePacket) -> Vec<u8> {
 }
 }
 
 
 impl From<std::io::Error> for Error {
 impl From<std::io::Error> for Error {
-    fn from(_: std::io::Error) -> Self {
-        Error::ConnectionError
+    fn from(err: std::io::Error) -> Self {
+        Error::ConnectionError(err)
     }
     }
 }
 }
 
 
 impl From<ProtobufError> for Error {
 impl From<ProtobufError> for Error {
-    fn from(error: ProtobufError) -> Self {
-        match error {
-            ProtobufError::IoError(_) | ProtobufError::WireError(_) => Error::ConnectionError,
-            ProtobufError::Utf8(_) | ProtobufError::MessageNotInitialized { .. } => Error::ParsingError
-        }
+    fn from(_: ProtobufError) -> Self {
+        Error::ParsingError
     }
     }
 }
 }
 
 

+ 13 - 10
src/server.rs

@@ -1,19 +1,15 @@
-use std::collections::{HashMap, HashSet};
-use std::future::Future;
+use std::collections::HashMap;
 use std::net::{IpAddr, SocketAddr};
 use std::net::{IpAddr, SocketAddr};
 use std::sync::Arc;
 use std::sync::Arc;
-use std::task::{Context, Poll};
 
 
 use tokio::net::{TcpListener, TcpStream};
 use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::{mpsc, RwLock};
-use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
+use tokio::sync::RwLock;
 use tokio_rustls::{server::TlsStream, TlsAcceptor};
 use tokio_rustls::{server::TlsStream, TlsAcceptor};
 use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
 use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
 
 
+use crate::client::{Client, Message, ResponseMessage};
 use crate::connection::{Connection, ConnectionConfig};
 use crate::connection::{Connection, ConnectionConfig};
 use crate::db::Db;
 use crate::db::Db;
-use crate::protocol::MumblePacket;
-use crate::client::{Client, Message};
 
 
 pub struct Config {
 pub struct Config {
     pub ip_address: IpAddr,
     pub ip_address: IpAddr,
@@ -58,7 +54,7 @@ async fn process(db: Arc<Db>, stream: TlsStream<TcpStream>, clients: Clients) {
         max_bandwidth: 128000,
         max_bandwidth: 128000,
         welcome_text: "Welcome!".to_string(),
         welcome_text: "Welcome!".to_string(),
     };
     };
-    let mut connection = match Connection::setup_connection(db.clone(), stream, connection_config).await {
+    let connection = match Connection::setup_connection(db.clone(), stream, connection_config).await {
         Ok(connection) => connection,
         Ok(connection) => connection,
         Err(_) => {
         Err(_) => {
             eprintln!("Error establishing a connection");
             eprintln!("Error establishing a connection");
@@ -71,7 +67,7 @@ async fn process(db: Arc<Db>, stream: TlsStream<TcpStream>, clients: Clients) {
     {
     {
         let mut clients = clients.write().await;
         let mut clients = clients.write().await;
         for client in clients.values() {
         for client in clients.values() {
-            client.post_message(Message::NewUser(session_id))
+            client.post_message(Message::UserConnected(session_id))
         }
         }
         clients.insert(session_id, client);
         clients.insert(session_id, client);
     }
     }
@@ -83,7 +79,14 @@ async fn process(db: Arc<Db>, stream: TlsStream<TcpStream>, clients: Clients) {
         };
         };
 
 
         match message {
         match message {
-            _ => {}
+            ResponseMessage::Disconnected => {
+                let mut clients = clients.write().await;
+                clients.remove(&session_id);
+                for client in clients.values() {
+                    client.post_message(Message::UserDisconnected(session_id))
+                }
+                return;
+            }
         }
         }
     }
     }
 }
 }