From f0d0353c732277baf84723bd1660fca2c1546d64 Mon Sep 17 00:00:00 2001
From: jaensen <4954577+jaensen@users.noreply.github.com>
Date: Wed, 17 May 2023 18:52:02 +0200
Subject: [PATCH] add validation for the 'from' an 'to' params of the
 'compute_transfer' function.

---
 Cargo.toml    |  1 +
 src/server.rs | 61 ++++++++++++++++++++++++++++++++++-----------------
 2 files changed, 42 insertions(+), 20 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index e2ae8bb..3634df7 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,3 +12,4 @@ json = "^0.12.4"
 num-bigint = "^0.4.3"
 serde = { version = "1.0.149", features = ["serde_derive"] }
 serde_json = "1.0.89"
+regex = "1.8.1"
diff --git a/src/server.rs b/src/server.rs
index cafc5d9..8ab7c10 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -4,6 +4,7 @@ use crate::types::edge::EdgeDB;
 use crate::types::{Address, Edge, U256};
 use json::JsonValue;
 use num_bigint::BigUint;
+use regex::Regex;
 use std::error::Error;
 use std::fmt::{Debug, Display, Formatter};
 use std::io::Read;
@@ -35,6 +36,38 @@ impl Display for InputValidationError {
     }
 }
 
+fn validate_and_parse_ethereum_address(address: &str) -> Result<Address, Box<dyn Error>> {
+    let re = Regex::new(r"^0x[0-9a-fA-F]{40}$").unwrap();
+    if re.is_match(address) {
+        Ok(Address::from(address))
+    } else {
+        Err(Box::new(InputValidationError(format!(
+            "Invalid Ethereum address: {}",
+            address
+        ))))
+    }
+}
+
+fn validate_and_parse_u256(value_str: &str) -> Result<U256, Box<dyn Error>> {
+    match BigUint::from_str(value_str) {
+        Ok(parsed_value) => {
+            if parsed_value > U256::MAX.into() {
+                Err(Box::new(InputValidationError(format!(
+                    "Value {} is too large. Maximum value is {}.",
+                    parsed_value,
+                    U256::MAX
+                ))))
+            } else {
+                Ok(U256::from_bigint_truncating(parsed_value))
+            }
+        }
+        Err(e) => Err(Box::new(InputValidationError(format!(
+            "Invalid value: {}. Couldn't parse value: {}",
+            value_str, e
+        )))),
+    }
+}
+
 pub fn start_server(listen_at: &str, queue_size: usize, threads: u64) {
     let edges: Arc<RwLock<Arc<EdgeDB>>> = Arc::new(RwLock::new(Arc::new(EdgeDB::default())));
 
@@ -156,38 +189,26 @@ fn compute_transfer(
     socket.write_all(chunked_header().as_bytes())?;
 
     let parsed_value_param = match request.params["value"].as_str() {
-        Some(value_str) => match BigUint::from_str(value_str) {
-            Ok(parsed_value) => parsed_value,
-            Err(e) => {
-                return Err(Box::new(InputValidationError(format!(
-                    "Invalid value: {}. Couldn't parse value: {}",
-                    value_str, e
-                ))));
-            }
-        },
-        None => U256::MAX.into(),
+        Some(value_str) => validate_and_parse_u256(value_str)?,
+        None => U256::MAX,
     };
 
-    if parsed_value_param > U256::MAX.into() {
-        return Err(Box::new(InputValidationError(format!(
-            "Value {} is too large. Maximum value is {}.",
-            parsed_value_param,
-            U256::MAX
-        ))));
-    }
+    let from_address = validate_and_parse_ethereum_address(&request.params["from"].to_string())?;
+    let to_address = validate_and_parse_ethereum_address(&request.params["to"].to_string())?;
 
     let max_distances = if request.params["iterative"].as_bool().unwrap_or_default() {
         vec![Some(1), Some(2), None]
     } else {
         vec![None]
     };
+
     let max_transfers = request.params["max_transfers"].as_u64();
     for max_distance in max_distances {
         let (flow, transfers) = graph::compute_flow(
-            &Address::from(request.params["from"].to_string().as_str()),
-            &Address::from(request.params["to"].to_string().as_str()),
+            &from_address,
+            &to_address,
             edges,
-            U256::from_bigint_truncating(parsed_value_param.clone()),
+            parsed_value_param,
             max_distance,
             max_transfers,
         );