diff --git a/src/bin/daemon.rs b/src/bin/daemon.rs index e6eb153..9a1a6a9 100644 --- a/src/bin/daemon.rs +++ b/src/bin/daemon.rs @@ -15,7 +15,7 @@ use std::os::unix::fs::PermissionsExt; use std::{fs, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - net::UnixListener, + net::{UnixListener, UnixStream}, time::sleep, }; @@ -83,83 +83,92 @@ async fn main() -> Result<()> { async fn commands_loop(listener: UnixListener) -> Result<()> { loop { - let (mut stream, _addr) = listener.accept().await?; + let (stream, _addr) = listener.accept().await?; tokio::spawn(async move { - // ---------- Read request (start) ---------- - let mut len_bytes = [0u8; 4]; - if stream.read_exact(&mut len_bytes).await.is_err() { - eprintln!("Failed to read message length from client!"); - return; - } + handle_connection(stream).await; + }); + } +} - let request_len = u32::from_le_bytes(len_bytes) as usize; +async fn handle_connection(mut stream: UnixStream) { + // ---------- Read request (start) ---------- + let mut len_bytes = [0u8; 4]; + if stream.read_exact(&mut len_bytes).await.is_err() { + eprintln!("Failed to read message length from client!"); + return; + } - if request_len > MAX_MESSAGE_SIZE { - eprintln!( - "Failed to read message from client: request too large ({} bytes)!", - request_len - ); - return; - } + let request_len = u32::from_le_bytes(len_bytes) as usize; - let mut buffer = vec![0u8; request_len]; - if stream.read_exact(&mut buffer).await.is_err() { - eprintln!("Failed to read message from client!"); - return; - } + if request_len > MAX_MESSAGE_SIZE { + eprintln!( + "Failed to read message from client: request too large ({} bytes)!", + request_len + ); + return; + } - let request: Request = match serde_json::from_slice(&buffer) { - Ok(req) => req, - Err(err) => { - let response = - Response::new(false, format!("Failed to parse request: {}", err)); - let response_data = match serde_json::to_vec(&response) { - Ok(data) => data, - Err(_) => return, // Should not happen with this simple Response - }; - let response_len = response_data.len() as u32; - let _ = stream.write_all(&response_len.to_le_bytes()).await; - let _ = stream.write_all(&response_data).await; - return; - } - }; - // ---------- Read request (end) ---------- + let mut buffer = Vec::new(); + if (&mut stream) + .take(request_len as u64) + .read_to_end(&mut buffer) + .await + .is_err() + || buffer.len() != request_len + { + eprintln!("Failed to read message from client!"); + return; + } - // ---------- Generate response (start) ---------- - let command = parse_command(&request); - let response: Response; - if let Some(command) = command { - response = command.execute().await; - } else { - response = Response::new(false, "Unknown command"); - } - // ---------- Generate response (end) ---------- - - // ---------- Send response (start) ---------- + let request: Request = match serde_json::from_slice(&buffer) { + Ok(req) => req, + Err(err) => { + let response = Response::new(false, format!("Failed to parse request: {}", err)); let response_data = match serde_json::to_vec(&response) { Ok(data) => data, - Err(err) => { - eprintln!("Failed to serialize response: {}", err); - return; - } + Err(_) => return, // Should not happen with this simple Response }; let response_len = response_data.len() as u32; + let _ = stream.write_all(&response_len.to_le_bytes()).await; + let _ = stream.write_all(&response_data).await; + return; + } + }; + // ---------- Read request (end) ---------- - if stream.write_all(&response_len.to_le_bytes()).await.is_err() { - eprintln!("Failed to write response length to client!"); - return; - } - if stream.write_all(&response_data).await.is_err() { - eprintln!("Failed to write response to client!"); - return; - } - // ---------- Send response (end) ---------- + // ---------- Generate response (start) ---------- + let command = parse_command(&request); + let response: Response; + if let Some(command) = command { + response = command.execute().await; + } else { + response = Response::new(false, "Unknown command"); + } + // ---------- Generate response (end) ---------- - if response.status && response.message.eq("killed") { - std::process::exit(0); - } - }); + // ---------- Send response (start) ---------- + let response_data = match serde_json::to_vec(&response) { + Ok(data) => data, + Err(err) => { + eprintln!("Failed to serialize response: {}", err); + return; + } + }; + let response_len = response_data.len() as u32; + + if stream.write_all(&response_len.to_le_bytes()).await.is_err() { + eprintln!("Failed to write response length to client!"); + return; + } + if stream.write_all(&response_data).await.is_err() { + eprintln!("Failed to write response to client!"); + return; + } + // ---------- Send response (end) ---------- + + if response.status && response.message.eq("killed") { + std::process::exit(0); } }