summary refs log tree commit diff
diff options
context:
space:
mode:
authorPaweł Dybiec <pawel@dybiec.info>2022-12-27 17:47:24 +0000
committerPaweł Dybiec <pawel@dybiec.info>2022-12-27 17:50:29 +0000
commitc23b1ecb97ee16d4b2ecb2d9a97b39778d30f2ef (patch)
treeaee90b17c0e4274ba4f9afbf6d67eac0d60fba76
parentHandle shutdown gracefully (diff)
format
-rw-r--r--src/main.rs124
-rw-r--r--src/mattermost/client.rs7
2 files changed, 72 insertions, 59 deletions
diff --git a/src/main.rs b/src/main.rs
index da396c2..bc3b8b7 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -6,25 +6,28 @@ use tokio::{self};
 use tracing::{debug, warn};
 
 struct Vav {
-    db_connection: Option<sqlite::ConnectionWithFullMutex>
+    db_connection: Option<sqlite::ConnectionWithFullMutex>,
 }
-impl Vav{
-    fn new<T:AsRef<Path>>(path:Option<T>) -> Self {
-        let ret = Self{
-            db_connection:path.and_then(|path|sqlite::Connection::open_with_full_mutex(path).ok()),
+impl Vav {
+    fn new<T: AsRef<Path>>(path: Option<T>) -> Self {
+        let ret = Self {
+            db_connection: path
+                .and_then(|path| sqlite::Connection::open_with_full_mutex(path).ok()),
         };
-        if let Some(connection) = &ret.db_connection{
-            let create_table_res=connection.execute("CREATE TABLE keyval (key text NOT NULL PRIMARY KEY, value text NOT NULL)");
-            if let Err(err) = create_table_res{
+        if let Some(connection) = &ret.db_connection {
+            let create_table_res = connection.execute(
+                "CREATE TABLE keyval (key text NOT NULL PRIMARY KEY, value text NOT NULL)",
+            );
+            if let Err(err) = create_table_res {
                 warn!("Error while creating db {err}");
-            }      
+            }
         }
         ret
     }
 }
-const INSERT_STATEMENT:&str = "INSERT INTO keyval (key,value) VALUES (?,?)";
-const SELECT_ONE:&str = "SELECT key,value FROM keyval where key=?";
-const SELECT_ALL:&str = "SELECT key,value FROM keyval";
+const INSERT_STATEMENT: &str = "INSERT INTO keyval (key,value) VALUES (?,?)";
+const SELECT_ONE: &str = "SELECT key,value FROM keyval where key=?";
+const SELECT_ALL: &str = "SELECT key,value FROM keyval";
 #[async_trait::async_trait]
 impl mattermost::Handler for Vav {
     async fn handle(
@@ -42,52 +45,56 @@ impl mattermost::Handler for Vav {
         if !message.starts_with('!') {
             return Ok(());
         }
-        let message=message.strip_prefix('!').unwrap();
-        let (message,rest) = message.split_once(' ').unwrap_or((message,""));
-        match message{
-            "store" =>{
+        let message = message.strip_prefix('!').unwrap();
+        let (message, rest) = message.split_once(' ').unwrap_or((message, ""));
+        match message {
+            "store" => {
                 let message = rest;
-                let (name,value) =message.split_once(' ').ok_or(anyhow::anyhow!("missing value in command store"))?;
-                if let Some(connection) = &self.db_connection{
+                let (name, value) = message
+                    .split_once(' ')
+                    .ok_or(anyhow::anyhow!("missing value in command store"))?;
+                if let Some(connection) = &self.db_connection {
                     let mut statement = connection.prepare(INSERT_STATEMENT)?;
-                    statement.bind((1,name))?;
-                    statement.bind((2,value))?;
-                    if let Err(err) = statement.next(){
+                    statement.bind((1, name))?;
+                    statement.bind((2, value))?;
+                    if let Err(err) = statement.next() {
                         warn!("Error while writing to db {err}");
-                    }      
+                    }
                 }
-        
             }
-            "lookup" =>{
+            "lookup" => {
                 let name = rest;
-                let response =if let Some(connection) = &self.db_connection{
+                let response = if let Some(connection) = &self.db_connection {
                     let mut statement = connection.prepare(SELECT_ONE)?;
-                    statement.bind((1,name))?;
-                    match statement.next(){
-                        Ok(sqlite::State::Done) => {
-                            "no entry under that name".to_owned()
-                        },
-                        Ok(sqlite::State::Row) => {
-                            statement.read::<String,_>(1)?
-                        },
+                    statement.bind((1, name))?;
+                    match statement.next() {
+                        Ok(sqlite::State::Done) => "no entry under that name".to_owned(),
+                        Ok(sqlite::State::Row) => statement.read::<String, _>(1)?,
                         Err(err) => {
-                        warn!("Error while writing to db {err}");
-                        return Err(err.into())},
+                            warn!("Error while writing to db {err}");
+                            return Err(err.into());
+                        }
                     }
-                } else {"uggh, no db".to_owned()};
-                client.reply_to(posted.post,response).await?;
+                } else {
+                    "uggh, no db".to_owned()
+                };
+                client.reply_to(posted.post, response).await?;
             }
-            "list" =>{
-                let response =if let Some(connection) = &self.db_connection{
+            "list" => {
+                let response = if let Some(connection) = &self.db_connection {
                     let mut statement = connection.prepare(SELECT_ALL)?;
-                    let mut res =vec!["Stored keys:".to_owned()];
-                    while let Ok(result) =  statement.next(){
-                            if result == sqlite::State::Done{ break;}
-                            res.push(statement.read::<String,_>(0)?)
+                    let mut res = vec!["Stored keys:".to_owned()];
+                    while let Ok(result) = statement.next() {
+                        if result == sqlite::State::Done {
+                            break;
+                        }
+                        res.push(statement.read::<String, _>(0)?)
                     }
                     res.join("\n")
-                } else {"uggh, no db".to_owned()};
-                client.reply_to(posted.post,response).await?;
+                } else {
+                    "uggh, no db".to_owned()
+                };
+                client.reply_to(posted.post, response).await?;
             }
             _ => return Err(anyhow!("Unrecognized command {message}")),
         }
@@ -105,18 +112,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     let mut client = mattermost::Client::new(auth, "https://mattermost.continuum.ii.uni.wroc.pl");
     client.update_bearer_token().await?;
     {
-    let (shutdown_send, shutdown_recv) = tokio::sync::mpsc::unbounded_channel();
-    let websocket_task =tokio::spawn( async move {
-    client.handle_websocket_stream(Vav::new(db),shutdown_recv).await});
-    match tokio::signal::ctrl_c().await {
-        Ok(()) => {},
-        Err(err) => {
-            eprintln!("Unable to listen for shutdown signal: {}", err);
-            // we also shut down in case of error
-        },
-    }
-    shutdown_send.send(())?;
-    websocket_task.await??;
+        let (shutdown_send, shutdown_recv) = tokio::sync::mpsc::unbounded_channel();
+        let websocket_task = tokio::spawn(async move {
+            client
+                .handle_websocket_stream(Vav::new(db), shutdown_recv)
+                .await
+        });
+        match tokio::signal::ctrl_c().await {
+            Ok(()) => {}
+            Err(err) => {
+                eprintln!("Unable to listen for shutdown signal: {}", err);
+                // we also shut down in case of error
+            }
+        }
+        shutdown_send.send(())?;
+        websocket_task.await??;
     }
     Ok(())
 }
diff --git a/src/mattermost/client.rs b/src/mattermost/client.rs
index 6025970..3d59837 100644
--- a/src/mattermost/client.rs
+++ b/src/mattermost/client.rs
@@ -56,7 +56,10 @@ impl Client {
     }
     async fn get_working_ws_stream(
         &self,
-    ) -> Result<async_tungstenite::WebSocketStream<async_tungstenite::tokio::ConnectStream>, anyhow::Error> {
+    ) -> Result<
+        async_tungstenite::WebSocketStream<async_tungstenite::tokio::ConnectStream>,
+        anyhow::Error,
+    > {
         let url = format!("{}/api/v4/websocket", self.url.replacen("http", "ws", 1));
         let token = self
             .bearer_token
@@ -76,7 +79,7 @@ impl Client {
         handler: T,
         mut shutdown: tokio::sync::mpsc::UnboundedReceiver<()>,
     ) -> Result<(), anyhow::Error> {
-        let mut ws_stream=self.get_working_ws_stream().await?;
+        let mut ws_stream = self.get_working_ws_stream().await?;
         loop {
             tokio::select! {
                 message = ws_stream.next() => {