diff --git a/src/server.rs b/src/server.rs index 35fb89f..158f717 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,6 +6,7 @@ use std::collections::HashMap; use std::error::Error; use std::io::{BufRead, BufReader, Write}; use std::ops::Deref; +use std::sync::mpsc::TrySendError; use std::sync::{mpsc, Arc, Mutex, RwLock}; use std::thread; use std::{ @@ -29,16 +30,10 @@ pub fn start_server(port: u16, queue_size: usize, threads: u64) { for _ in 0..threads { let rec = protected_receiver.clone(); let e = edges.clone(); - thread::spawn(move || { - loop { - let socket = rec.lock().unwrap().recv().unwrap(); - match handle_connection(e.deref(), socket) { - Ok(()) => {} - Err(e) => { - // TODO respond to the jsonrpc - println!("Error handling connection: {e}"); - } - } + thread::spawn(move || loop { + let socket = rec.lock().unwrap().recv().unwrap(); + if let Err(e) = handle_connection(e.deref(), socket) { + println!("Error handling connection: {e}"); } }); } @@ -48,7 +43,12 @@ pub fn start_server(port: u16, queue_size: usize, threads: u64) { match listener.accept() { Ok((socket, _)) => match sender.try_send(socket) { Ok(()) => {} - Err(e) => println!("Queue full: {e}"), + Err(TrySendError::Full(mut socket)) => { + let _ = socket.write_all(b"HTTP/1.1 503 Service Unavailable\r\n\r\n"); + } + Err(TrySendError::Disconnected(_)) => { + panic!("Internal communication channel disconnected."); + } }, Err(e) => println!("Error accepting connection: {e}"), } @@ -62,24 +62,35 @@ fn handle_connection( let request = read_request(&mut socket)?; match request.method.as_str() { "load_edges_binary" => { - let updated_edges = read_edges_binary(&request.params["file"].to_string())?; - let len = updated_edges.len(); - *edges.write().unwrap() = Arc::new(updated_edges); - socket.write_all(jsonrpc_response(request.id, len).as_bytes())?; + let response = match load_edges_binary(edges, &request.params["file"].to_string()) { + Ok(len) => jsonrpc_response(request.id, len), + Err(e) => { + jsonrpc_error_response(request.id, -32000, &format!("Error loading edges: {e}")) + } + }; + socket.write_all(response.as_bytes())?; } "compute_transfer" => { println!("Computing flow"); let e = edges.read().unwrap().clone(); compute_transfer(request, e.as_ref(), socket)?; } - "cancel" => {} - "update_edges" => {} - // TODO error handling - _ => {} + _ => socket + .write_all(jsonrpc_error_response(request.id, -32601, "Method not found").as_bytes())?, }; Ok(()) } +fn load_edges_binary( + edges: &RwLock>>>, + file: &String, +) -> Result> { + let updated_edges = read_edges_binary(file)?; + let len = updated_edges.len(); + *edges.write().unwrap() = Arc::new(updated_edges); + Ok(len) +} + fn compute_transfer( request: JsonRpcRequest, edges: &HashMap>, @@ -107,7 +118,6 @@ fn compute_transfer( "Computed flow with max distance {:?}: {}", max_distance, flow ); - // TODO error handling socket.write_all( chunked_response( &(jsonrpc_result( @@ -185,6 +195,23 @@ fn jsonrpc_result(id: JsonValue, result: impl Into) -> String { .dump() } +fn jsonrpc_error_response(id: JsonValue, code: i64, message: &str) -> String { + let payload = json::object! { + jsonrpc: "2.0", + id: id, + error: { + code: code, + message: message + } + } + .dump(); + format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}", + payload.len(), + payload + ) +} + fn chunked_header() -> String { "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n".to_string() }