Browse Source

Add multiple users support

Sergey Chushin 3 years ago
parent
commit
605dcb5933
7 changed files with 366 additions and 76 deletions
  1. 1 1
      Cargo.toml
  2. 178 0
      src/client.rs
  3. 67 45
      src/connection.rs
  4. 53 20
      src/db.rs
  5. 1 0
      src/main.rs
  6. 15 5
      src/protocol.rs
  7. 51 5
      src/server.rs

+ 1 - 1
Cargo.toml

@@ -8,7 +8,7 @@ edition = "2018"
 
 [dependencies]
 clap = "2.33.3"
-tokio = { version = "1.5.0", features = ["rt-multi-thread", "net", "io-util"] }
+tokio = { version = "1.5.0", features = ["rt-multi-thread", "net", "io-util", "sync"] }
 tokio-rustls = "0.22.0"
 protobuf = "2.22.1"
 sled = "0.34.6"

+ 178 - 0
src/client.rs

@@ -0,0 +1,178 @@
+use std::sync::Arc;
+
+use tokio::io::{AsyncWrite, AsyncRead};
+use tokio::sync::mpsc;
+use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
+
+use crate::connection::Connection;
+use crate::db::{Db, User};
+use crate::protocol::{MumblePacket, VoicePacket, MumblePacketWriter};
+use crate::proto::mumble::{UserState, Ping};
+use tokio::task::JoinHandle;
+
+pub enum Message {
+    NewUser(u32)
+}
+
+pub struct Client {
+    inner_sender: UnboundedSender<InnerMessage>,
+    handler_task: JoinHandle<()>,
+    packet_task: JoinHandle<()>,
+}
+
+pub enum Error {
+    UserNotFound,
+    StreamError(crate::protocol::Error),
+}
+
+struct Handler<W> {
+    db: Arc<Db>,
+    writer: MumblePacketWriter<W>,
+    response_sender: UnboundedSender<Message>,
+}
+
+enum InnerMessage {
+    Message(Message),
+    Packet(MumblePacket),
+}
+
+type ResponseReceiver = UnboundedReceiver<Message>;
+
+impl Client {
+    pub async fn new<S>(connection: Connection<S>, db: Arc<Db>) -> (Client, ResponseReceiver)
+    where
+        S: 'static + AsyncRead + AsyncWrite + Unpin + Send,
+    {
+        let (sender, mut receiver) = mpsc::unbounded_channel();
+        let (response_sender, response_receiver) = mpsc::unbounded_channel();
+
+        let writer = connection.writer;
+        let handler_task = tokio::spawn(async move {
+            let mut handler = Handler {
+                db,
+                writer,
+                response_sender,
+            };
+            loop {
+                let message = match receiver.recv().await {
+                    None => return,
+                    Some(msg) => msg,
+                };
+
+                match message {
+                    InnerMessage::Message(msg) => {
+                        let result = handler.handle_message(msg).await;
+                        if result.is_err() {
+                            return;
+                        }
+                    }
+                    InnerMessage::Packet(packet) => {
+                        let result = handler.handle_packet(packet).await;
+                        if result.is_err() {
+                            return;
+                        }
+                    }
+                }
+            }
+        });
+
+        let inner_sender = sender.clone();
+        let mut reader = connection.reader;
+        let packet_task = tokio::spawn(async move {
+            loop {
+                let packet = match reader.read().await{
+                    Ok(packet) => packet,
+                    Err(_) => return, //TODO
+                };
+
+                sender.send(InnerMessage::Packet(packet));
+            }
+        });
+
+        return (Client {
+            inner_sender,
+            handler_task,
+            packet_task,
+        }, response_receiver);
+    }
+
+    pub fn post_message(&self, message: Message) {
+        self.inner_sender.send(InnerMessage::Message(message));
+    }
+}
+
+impl Drop for Client {
+    fn drop(&mut self) {
+        self.handler_task.abort();
+        self.packet_task.abort();
+    }
+}
+
+impl<W> Handler<W>
+    where
+        W: AsyncWrite + Unpin + Send,
+{
+    async fn handle_packet(&mut self, packet: MumblePacket) -> Result<(), Error> {
+        match packet {
+            MumblePacket::Ping(ping) => {
+                if ping.has_timestamp() {
+                    let mut ping = Ping::new();
+                    ping.set_timestamp(ping.get_timestamp());
+                    self.writer.write(MumblePacket::Ping(ping)).await?;
+                }
+            }
+            MumblePacket::UdpTunnel(voice) => {
+                match voice {
+                    VoicePacket::Ping(_) => {
+                        self.writer.write(MumblePacket::UdpTunnel(voice));
+                        println!("VoicePing");
+                    }
+                    VoicePacket::AudioData(_) => { println!("AudioData"); }
+                }
+            }
+            _ => println!("unimplemented!")
+        }
+        Ok(())
+    }
+
+    async fn handle_message(&mut self, message: Message) -> Result<(), Error> {
+        match message {
+            Message::NewUser(session_id) => self.new_user_connected(session_id).await?,
+        }
+
+        Ok(())
+    }
+
+    async fn new_user_connected(&mut self, session_id: u32) -> Result<(), Error> {
+        if let Some(user) = self.db.get_user_by_session_id(session_id).await {
+            self.writer.write(MumblePacket::from(user)).await?;
+        }
+        Ok(())
+    }
+}
+
+impl From<User> for UserState {
+    fn from(user: User) -> Self {
+        let mut user_state = UserState::new();
+        if let Some(id) = user.id {
+            user_state.set_user_id(id)
+        }
+        user_state.set_name(user.username);
+        user_state.set_channel_id(user.channel_id);
+        user_state.set_session(user.session_id);
+        user_state
+    }
+}
+
+impl From<User> for MumblePacket {
+    fn from(user: User) -> Self {
+        MumblePacket::UserState(UserState::from(user))
+    }
+}
+
+impl From<crate::protocol::Error> for Error {
+    fn from(err: crate::protocol::Error) -> Self {
+        Error::StreamError(err)
+    }
+}
+

+ 67 - 45
src/connection.rs

@@ -1,93 +1,115 @@
 use std::sync::Arc;
 
-use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
 
-use crate::db::{Db, User};
-use crate::proto::mumble::{ChannelState, Ping, UserState, Version};
-use crate::protocol::{MUMBLE_PROTOCOL_VERSION, MumblePacket, MumblePacketStream};
+use crate::db::Db;
+use crate::proto::mumble::{ChannelState, CodecVersion, PermissionQuery, ServerConfig, ServerSync, UserState, Version};
+use crate::protocol::{MUMBLE_PROTOCOL_VERSION, MumblePacket, MumblePacketReader, MumblePacketWriter};
 
 pub struct Connection<S> {
-    db: Arc<Db>,
-    stream: MumblePacketStream<S>,
+    pub reader: MumblePacketReader<ReadHalf<S>>,
+    pub writer: MumblePacketWriter<WriteHalf<S>>,
+    pub session_id: u32,
+}
+
+pub struct ConnectionConfig {
+    pub max_bandwidth: u32,
+    pub welcome_text: String,
 }
 
 pub enum Error {
     ConnectionSetupError,
     AuthenticationError,
+    StreamError(crate::protocol::Error),
 }
 
 impl<S> Connection<S>
     where
         S: AsyncRead + AsyncWrite + Unpin + Send,
 {
-    pub async fn setup_connection(db: Arc<Db>, mut stream: MumblePacketStream<S>) -> Result<Connection<S>, Error> {
-        let _ = match stream.read().await? {
+    pub async fn setup_connection(db: Arc<Db>, stream: S, config: ConnectionConfig) -> Result<Connection<S>, Error> {
+        let (mut reader, mut writer) = crate::protocol::new(stream);
+
+        //Version exchange
+        let _ = match reader.read().await? {
             MumblePacket::Version(version) => version,
             _ => return Err(Error::ConnectionSetupError)
         };
         let mut version = Version::new();
         version.set_version(MUMBLE_PROTOCOL_VERSION);
-        stream.write(MumblePacket::Version(version)).await?;
+        writer.write(MumblePacket::Version(version)).await?;
 
-        let auth = match stream.read().await? {
+        //Authentication
+        let mut auth = match reader.read().await? {
             MumblePacket::Authenticate(auth) => auth,
             _ => return Err(Error::ConnectionSetupError)
         };
         if !auth.has_username() {
             return Err(Error::AuthenticationError);
         }
-        db.add_new_user(User {
-            username: auth.get_username().to_string(),
-            channel_id: 0,
-            is_connected: true,
-        }).await;
+        let session_id = db.add_new_user(auth.take_username()).await;
+
+        //Crypt setup TODO
 
-        //TODO UDP crypt setup
+        //CodecVersion
+        let mut codec_version = CodecVersion::new();
+        codec_version.set_alpha(-2147483632);
+        codec_version.set_beta(0);
+        codec_version.set_prefer_alpha(true);
+        codec_version.set_opus(true);
+        writer.write(MumblePacket::CodecVersion(codec_version)).await?;
 
-        let channels = db.get_channels();
+        //Channel state
+        let channels = db.get_channels().await;
         for channel in channels {
             let mut channel_state = ChannelState::new();
-            channel_state.set_name(channel.name);
             channel_state.set_channel_id(channel.id);
-            stream.write(MumblePacket::ChannelState(channel_state)).await?;
+            channel_state.set_name(channel.name);
+            writer.write(MumblePacket::ChannelState(channel_state)).await?;
         }
 
-        let connected_users = db.get_connected_users();
+        //PermissionQuery
+        let mut permission_query = PermissionQuery::new();
+        permission_query.set_permissions(134743822);
+        permission_query.set_channel_id(0);
+        writer.write(MumblePacket::PermissionQuery(permission_query));
+
+        //User states
+        let connected_users = db.get_connected_users().await;
         for user in connected_users {
             let mut user_state = UserState::new();
             user_state.set_name(user.username);
+            user_state.set_session(user.session_id);
             user_state.set_channel_id(user.channel_id);
-            stream.write(MumblePacket::UserState(user_state)).await?;
+            writer.write(MumblePacket::UserState(user_state)).await?;
         }
 
-        Ok(Connection {
-            db,
-            stream,
-        })
-    }
+        //Server sync
+        let mut server_sync = ServerSync::new();
+        server_sync.set_session(session_id);
+        server_sync.set_welcome_text(config.welcome_text);
+        server_sync.set_max_bandwidth(config.max_bandwidth);
+        server_sync.set_permissions(134743822);
+        writer.write(MumblePacket::ServerSync(server_sync)).await?;
 
-    pub async fn read_packet(&mut self) -> Result<MumblePacket, Error> {
-        Ok(self.stream.read().await?)
-    }
+        //ServerConfig
+        let mut server_config = ServerConfig::new();
+        server_config.set_max_users(10);
+        server_config.set_allow_html(true);
+        server_config.set_message_length(5000);
+        server_config.set_image_message_length(131072);
+        writer.write(MumblePacket::ServerConfig(server_config)).await?;
 
-    pub async fn handle_packet(&mut self, packet: MumblePacket) -> Result<(), Error> {
-        match packet {
-            MumblePacket::Ping(ping) => {
-                if ping.has_timestamp() {
-                    let mut ping = Ping::new();
-                    ping.set_timestamp(ping.get_timestamp());
-                    self.stream.write(MumblePacket::Ping(ping)).await?;
-                }
-            }
-            _ => println!("unimplemented!")
-        }
-        Ok(())
+        Ok(Connection {
+            reader,
+            writer,
+            session_id,
+        })
     }
 }
 
 impl From<crate::protocol::Error> for Error {
-    fn from(_: crate::protocol::Error) -> Self {
-        Error::ConnectionSetupError
+    fn from(err: crate::protocol::Error) -> Self {
+        Error::StreamError(err)
     }
-}
-
+}

+ 53 - 20
src/db.rs

@@ -1,20 +1,34 @@
+use std::collections::HashMap;
+
+use tokio::sync::RwLock;
+
 use serde::{Deserialize, Serialize};
 
+const ROOT_CHANNEL_ID: u32 = 0;
 const USER_TREE_NAME: &[u8] = b"users";
 const CHANNEL_TREE_NAME: &[u8] = b"channels";
-const ROOT_CHANNEL_KEY: &[u8] = &0_u64.to_be_bytes();
+
+type SessionId = u32;
 
 pub struct Db {
     db: sled::Db,
     users: sled::Tree,
     channels: sled::Tree,
+    connected_users: RwLock<HashMap<SessionId, User>>,
 }
 
-#[derive(Serialize, Deserialize)]
 pub struct User {
+    pub id: Option<u32>,
     pub username: String,
     pub channel_id: u32,
-    pub is_connected: bool,
+    pub session_id: SessionId,
+}
+
+#[derive(Serialize, Deserialize)]
+struct PersistentUserData {
+    id: u32,
+    username: String,
+    channel_id: u32,
 }
 
 #[derive(Serialize, Deserialize)]
@@ -34,39 +48,58 @@ impl Db {
             name: "Root".to_string(),
         }).unwrap();
         channels.compare_and_swap(
-            ROOT_CHANNEL_KEY,
+            ROOT_CHANNEL_ID.to_be_bytes(),
             Option::<&[u8]>::None,
             Some(root_channel))
-            .unwrap().unwrap();
+            .unwrap();
 
         Db {
             db,
             users,
             channels,
+            connected_users: RwLock::new(HashMap::new()),
         }
     }
 
-    pub async fn add_new_user(&self, user: User) {
-        let id = self.users.len().to_be_bytes();
-
-        self.users.insert(
-            id,
-            bincode::serialize(&user).unwrap(),
-        ).unwrap();
-
-        self.users.flush_async().await.unwrap();
+    pub async fn add_new_user(&self, username: String) -> u32 {
+        let mut connected_users = self.connected_users.write().await;
+        let session_id = connected_users.len() as u32;
+        connected_users.insert(session_id, User {
+            id: None,
+            username,
+            channel_id: ROOT_CHANNEL_ID,
+            session_id,
+        });
+        session_id
     }
 
-    pub fn get_channels(&self) -> Vec<Channel> {
+    pub async fn get_channels(&self) -> Vec<Channel> {
         self.channels.iter().values()
             .map(|channel| bincode::deserialize(&channel.unwrap()).unwrap())
             .collect()
     }
 
-    pub fn get_connected_users(&self) -> Vec<User> {
-        self.users.iter().values()
-            .map(|user| bincode::deserialize(&user.unwrap()).unwrap())
-            .filter(|user: &User| user.is_connected)
-            .collect()
+    pub async fn get_connected_users(&self) -> Vec<User> {
+        let users = self.connected_users.read().await;
+        users.values().map(|el| el.clone()).collect()
+    }
+
+    pub async fn get_user_by_session_id(&self, session_id: u32) -> Option<User> {
+        let connected_users = self.connected_users.read().await;
+        if let Some(user) = connected_users.get(&session_id) {
+            return Some(user.clone())
+        }
+        None
     }
 }
+
+impl Clone for User {
+    fn clone(&self) -> Self {
+        User {
+            id: self.id,
+            username: self.username.clone(),
+            channel_id: self.channel_id,
+            session_id: self.session_id,
+        }
+    }
+}

+ 1 - 0
src/main.rs

@@ -10,6 +10,7 @@ mod proto;
 mod protocol;
 mod connection;
 mod db;
+mod client;
 
 fn main() {
     let matches = App::new("Rumble")

+ 15 - 5
src/protocol.rs

@@ -1,5 +1,5 @@
 use protobuf::{Message, ProtobufError};
-use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
 
 use crate::proto::mumble::{ACL as Acl, Authenticate, BanList, ChannelRemove, ChannelState,
                            CodecVersion, ContextAction, ContextActionModify, CryptSetup, PermissionDenied,
@@ -105,20 +105,24 @@ enum Codecs {
     Opus,
 }
 
-pub fn new<S, R, W>(stream: S) -> (MumblePacketReader<R>, MumblePacketWriter<W>)
+pub fn new<S>(stream: S) -> (MumblePacketReader<ReadHalf<S>>, MumblePacketWriter<WriteHalf<S>>)
     where
         S: AsyncRead + AsyncWrite + Unpin + Send,
-        R: AsyncRead + Unpin + Send,
-        W: AsyncWrite + Unpin + Send,
 {
     let (reader, writer) = tokio::io::split(stream);
-    (MumblePacketReader { reader }, MumblePacketWriter { writer })
+    (MumblePacketReader::new(reader), MumblePacketWriter::new(writer))
 }
 
 impl<R> MumblePacketReader<R>
     where
         R: AsyncRead + Unpin + Send,
 {
+    pub fn new(reader: R) -> Self {
+        MumblePacketReader {
+            reader
+        }
+    }
+
     pub async fn read(&mut self) -> Result<MumblePacket, Error> {
         let packet_type = self.reader.read_u16().await?;
         let payload_length = self.reader.read_u32().await?;
@@ -286,6 +290,12 @@ impl<W> MumblePacketWriter<W>
     where
         W: AsyncWrite + Unpin + Send,
 {
+    pub fn new(writer: W) -> Self {
+        MumblePacketWriter {
+            writer
+        }
+    }
+
     pub async fn write(&mut self, packet: MumblePacket) -> Result<(), Error> {
         match packet {
             MumblePacket::UdpTunnel(value) => {

+ 51 - 5
src/server.rs

@@ -1,14 +1,19 @@
+use std::collections::{HashMap, HashSet};
+use std::future::Future;
 use std::net::{IpAddr, SocketAddr};
 use std::sync::Arc;
+use std::task::{Context, Poll};
 
 use tokio::net::{TcpListener, TcpStream};
+use tokio::sync::{mpsc, RwLock};
+use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
 use tokio_rustls::{server::TlsStream, TlsAcceptor};
 use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
 
-use crate::connection::Connection;
+use crate::connection::{Connection, ConnectionConfig};
 use crate::db::Db;
-use crate::proto::mumble::Version;
-use crate::protocol::MumblePacketStream;
+use crate::protocol::MumblePacket;
+use crate::client::{Client, Message};
 
 pub struct Config {
     pub ip_address: IpAddr,
@@ -18,8 +23,11 @@ pub struct Config {
     pub path_to_db_file: String,
 }
 
+type Clients = Arc<RwLock<HashMap<u32, Client>>>;
+
 pub async fn run(config: Config) -> std::io::Result<()> {
     let db = Arc::new(Db::open(&config.path_to_db_file));
+
     let mut tls_config = ServerConfig::new(NoClientAuth::new());
     tls_config.set_single_cert(vec![config.certificate], config.private_key)
         .expect("Invalid private key");
@@ -29,18 +37,56 @@ pub async fn run(config: Config) -> std::io::Result<()> {
         SocketAddr::new(config.ip_address, config.port)
     ).await?;
 
+    let clients = Arc::new(RwLock::new(HashMap::new()));
     loop {
         let (stream, _) = listener.accept().await?;
         let acceptor = acceptor.clone();
         let db = Arc::clone(&db);
+        let clients = Arc::clone(&clients);
 
         tokio::spawn(async move {
             let stream = acceptor.accept(stream).await;
             if let Ok(stream) = stream {
-                process(db, MumblePacketStream::new(stream)).await;
+                process(db, stream, clients).await;
             }
         });
     }
 }
 
-async fn process(db: Arc<Db>, stream: MumblePacketStream<TlsStream<TcpStream>>) {}
+async fn process(db: Arc<Db>, stream: TlsStream<TcpStream>, clients: Clients) {
+    let connection_config = ConnectionConfig {
+        max_bandwidth: 128000,
+        welcome_text: "Welcome!".to_string(),
+    };
+    let mut connection = match Connection::setup_connection(db.clone(), stream, connection_config).await {
+        Ok(connection) => connection,
+        Err(_) => {
+            eprintln!("Error establishing a connection");
+            return;
+        }
+    };
+    let session_id = connection.session_id;
+    let (client, mut response_receiver) = Client::new(connection, db).await;
+
+    {
+        let mut clients = clients.write().await;
+        for client in clients.values() {
+            client.post_message(Message::NewUser(session_id))
+        }
+        clients.insert(session_id, client);
+    }
+
+    loop {
+        let message = match response_receiver.recv().await {
+            Some(msg) => msg,
+            None => return,
+        };
+
+        match message {
+            _ => {}
+        }
+    }
+}
+
+
+