diff --git a/Cargo.lock b/Cargo.lock index 77e118d2..acda4b25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2233,7 +2233,9 @@ dependencies = [ "prost-build", "prost-reflect", "prost-reflect-build", - "prost-types", + "prost-wkt", + "prost-wkt-build", + "prost-wkt-types", "quinn", "quinn-plaintext", "rand 0.8.5", @@ -3382,8 +3384,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.14", + "regex-syntax 0.8.10", ] [[package]] @@ -4093,7 +4095,7 @@ dependencies = [ "globset", "log", "memchr", - "regex-automata 0.4.7", + "regex-automata 0.4.14", "same-file", "walkdir", "winapi-util", @@ -4203,6 +4205,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "inventory" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "009ae045c87e7082cb72dab0ccd01ae075dd00141ddc108f43a0ea150a9e7227" +dependencies = [ + "rustversion", +] + [[package]] name = "ip_network" version = "0.4.1" @@ -6558,9 +6569,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.2" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2ecbe40f08db5c006b5764a2645f7f3f141ce756412ac9e1dd6087e6d32995" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", "prost-derive", @@ -6568,11 +6579,10 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.2" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8650aabb6c35b860610e9cff5dc1af886c9e25073b7b1712a68972af4281302" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ - "bytes", "heck 0.5.0", "itertools 0.12.1", "log", @@ -6589,9 +6599,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.2" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", "itertools 0.12.1", @@ -6635,13 +6645,59 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.2" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60caa6738c7369b940c3d49246a8d1749323674c65cb13010134f5c9bad5b519" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" dependencies = [ "prost", ] +[[package]] +name = "prost-wkt" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497e1e938f0c09ef9cabe1d49437b4016e03e8f82fbbe5d1c62a9b61b9decae1" +dependencies = [ + "chrono", + "inventory", + "prost", + "serde", + "serde_derive", + "serde_json", + "typetag", +] + +[[package]] +name = "prost-wkt-build" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07b8bf115b70a7aa5af1fd5d6e9418492e9ccb6e4785e858c938e28d132a884b" +dependencies = [ + "heck 0.5.0", + "prost", + "prost-build", + "prost-types", + "quote", +] + +[[package]] +name = "prost-wkt-types" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8cdde6df0a98311c839392ca2f2f0bcecd545f86a62b4e3c6a49c336e970fe5" +dependencies = [ + "chrono", + "prost", + "prost-build", + "prost-types", + "prost-wkt", + "prost-wkt-build", + "regex", + "serde", + "serde_derive", + "serde_json", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -6750,9 +6806,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.40" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -6975,14 +7031,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.14", + "regex-syntax 0.8.10", ] [[package]] @@ -6996,13 +7052,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.10", ] [[package]] @@ -7013,9 +7069,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "relative-path" @@ -9982,6 +10038,30 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "typetag" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be2212c8a9b9bcfca32024de14998494cf9a5dfa59ea1b829de98bac374b86bf" +dependencies = [ + "erased-serde", + "inventory", + "once_cell", + "serde", + "typetag-impl", +] + +[[package]] +name = "typetag-impl" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27a7a9b72ba121f6f1f6c3632b85604cac41aedb5ddc70accbebb6cac83de846" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "uds_windows" version = "1.1.0" diff --git a/easytier-rpc-build/src/lib.rs b/easytier-rpc-build/src/lib.rs index d2162578..f41f1e18 100644 --- a/easytier-rpc-build/src/lib.rs +++ b/easytier-rpc-build/src/lib.rs @@ -41,6 +41,8 @@ impl prost_build::ServiceGenerator for ServiceGenerator { let mut match_output_type_methods = String::new(); let mut match_output_proto_type_methods = String::new(); let mut match_handle_methods = String::new(); + // generate trait default method Xxx::json_call_method match branch + let mut match_trait_json_methods = String::new(); let mut match_method_try_from = String::new(); @@ -180,6 +182,22 @@ impl prost_build::ServiceGenerator for ServiceGenerator { namespace = NAMESPACE, ) .unwrap(); + + write!( + match_trait_json_methods, + r#" "{name}" | "{proto_name}" => {{ + let req: {input_type} = ::serde_json::from_value(json).map_err(|e| {namespace}::error::Error::MalformatRpcPacket(format!("json error: {{}}", e)))?; + let resp = self.{typed_method}(ctrl, req).await?; + Ok(::serde_json::to_value(resp).map_err(|e| {namespace}::error::Error::MalformatRpcPacket(format!("json error: {{}}", e)))?) + }} +"#, + name = method.name, + proto_name = method.proto_name, + input_type = method.input_type, + typed_method = method.name, + namespace = NAMESPACE, + ) + .unwrap(); } ServiceGenerator::write_comments(&mut buf, 0, &service.comments).unwrap(); @@ -192,6 +210,18 @@ pub trait {name} {{ type Controller: {namespace}::controller::Controller; {trait_methods} + + async fn json_call_method( + &self, + ctrl: Self::Controller, + method_name: &str, + json: ::serde_json::Value, + ) -> {namespace}::error::Result<::serde_json::Value> {{ + match method_name {{ +{match_trait_json_methods} + _ => Err({namespace}::error::Error::InvalidMethodIndex(0, method_name.to_string())), + }} + }} }} #[async_trait::async_trait] @@ -262,7 +292,7 @@ impl Clone for {client_name}Factory { impl {namespace}::__rt::RpcClientFactory for {client_name}Factory where C: {namespace}::controller::Controller {{ type Descriptor = {descriptor_name}; - type ClientImpl = Box + Send + 'static>; + type ClientImpl = Box + Send + Sync + 'static>; type Controller = C; fn new(handler: impl {namespace}::handler::Handler) -> Self::ClientImpl {{ @@ -394,6 +424,7 @@ impl {namespace}::descriptor::MethodDescriptor for {method_descriptor_name} {{ match_output_type_methods = match_output_type_methods, match_output_proto_type_methods = match_output_proto_type_methods, match_handle_methods = match_handle_methods, + match_trait_json_methods = match_trait_json_methods, namespace = NAMESPACE, ).unwrap(); } diff --git a/easytier-web/src/client_manager/session.rs b/easytier-web/src/client_manager/session.rs index 40b3980e..7d88da86 100644 --- a/easytier-web/src/client_manager/session.rs +++ b/easytier-web/src/client_manager/session.rs @@ -339,10 +339,14 @@ impl Session { self.data.clone() } - pub fn scoped_rpc_client(&self) -> SessionRpcClient { + pub fn scoped_client(&self) -> F::ClientImpl { self.rpc_mgr .rpc_client() - .scoped_client::>(1, 1, "".to_string()) + .scoped_client::(1, 1, "".to_string()) + } + + pub fn scoped_rpc_client(&self) -> SessionRpcClient { + self.scoped_client::>() } pub async fn get_token(&self) -> Option { diff --git a/easytier-web/src/restful/mod.rs b/easytier-web/src/restful/mod.rs index d6933db9..a92c4dab 100644 --- a/easytier-web/src/restful/mod.rs +++ b/easytier-web/src/restful/mod.rs @@ -2,6 +2,7 @@ mod auth; pub(crate) mod captcha; mod network; pub(crate) mod oidc; +mod rpc; mod users; use std::{net::SocketAddr, sync::Arc}; @@ -248,6 +249,7 @@ impl RestfulServer { .route("/api/v1/summary", get(Self::handle_get_summary)) .route("/api/v1/sessions", get(Self::handle_list_all_sessions)) .merge(NetworkApi::build_route()) + .merge(rpc::router()) .route_layer(login_required!(Backend)) .merge(auth::router().layer(Extension(self.feature_flags.clone()))) .merge(oidc::router()) diff --git a/easytier-web/src/restful/rpc.rs b/easytier-web/src/restful/rpc.rs new file mode 100644 index 00000000..f635114d --- /dev/null +++ b/easytier-web/src/restful/rpc.rs @@ -0,0 +1,175 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + routing::post, + Json, Router, +}; +use axum_login::AuthUser as _; + +use super::{other_error, AppState, HttpHandleError}; + +#[derive(Debug, serde::Deserialize)] +pub struct ProxyRpcRequest { + pub service_name: String, + pub method_name: String, + pub payload: serde_json::Value, +} + +macro_rules! match_service { + ($factory:ty, $method_name:expr, $payload:expr, $session:expr) => {{ + let client = $session.scoped_client::<$factory>(); + client + .json_call_method( + easytier::proto::rpc_types::controller::BaseController::default(), + &$method_name, + $payload, + ) + .await + }}; +} + +pub async fn handle_proxy_rpc( + auth_session: super::users::AuthSession, + State(client_mgr): AppState, + Path(machine_id): Path, + Json(req): Json, +) -> Result, HttpHandleError> { + let user_id = auth_session + .user + .as_ref() + .ok_or((StatusCode::UNAUTHORIZED, other_error("Unauthorized").into()))? + .id(); + + let session = client_mgr + .get_session_by_machine_id(user_id, &machine_id) + .ok_or(( + StatusCode::NOT_FOUND, + other_error("Session not found").into(), + ))?; + + let ProxyRpcRequest { + service_name, + method_name, + payload, + } = req; + + let resp = match service_name.as_str() { + "api.manage.WebClientService" => match_service!( + easytier::proto::api::manage::WebClientServiceClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.PeerManageRpcService" => match_service!( + easytier::proto::api::instance::PeerManageRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.ConnectorManageRpcService" => match_service!( + easytier::proto::api::instance::ConnectorManageRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.MappedListenerManageRpcService" => match_service!( + easytier::proto::api::instance::MappedListenerManageRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.VpnPortalRpcService" => match_service!( + easytier::proto::api::instance::VpnPortalRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.TcpProxyRpcService" => match_service!( + easytier::proto::api::instance::TcpProxyRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.AclManageRpcService" => match_service!( + easytier::proto::api::instance::AclManageRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.PortForwardManageRpcService" => match_service!( + easytier::proto::api::instance::PortForwardManageRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.StatsRpcService" => match_service!( + easytier::proto::api::instance::StatsRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.instance.CredentialManageRpcService" => match_service!( + easytier::proto::api::instance::CredentialManageRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.logger.LoggerRpcService" => match_service!( + easytier::proto::api::logger::LoggerRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + "api.config.ConfigRpcService" => match_service!( + easytier::proto::api::config::ConfigRpcClientFactory< + easytier::proto::rpc_types::controller::BaseController, + >, + method_name, + payload, + session + ), + _ => { + return Err(( + StatusCode::BAD_REQUEST, + other_error(format!("Unknown service: {}", service_name)).into(), + )) + } + }; + + match resp { + Ok(v) => Ok(Json(v)), + Err(e) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + other_error(format!("RPC Error: {:?}", e)).into(), + )), + } +} + +pub fn router() -> Router { + Router::new().route( + "/api/v1/machines/:machine-id/proxy-rpc", + post(handle_proxy_rpc), + ) +} diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 0fc57154..68fbb868 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -116,8 +116,9 @@ uuid = { version = "1.5.0", features = [ once_cell = "1.18.0" # for rpc -prost = "0.13" -prost-types = "0.13" +prost = "0.13.5" +prost-wkt = "0.6" +prost-wkt-types = "0.6" anyhow = "1.0" url = { version = "2.5", features = ["serde"] } @@ -308,7 +309,8 @@ jemalloc-sys = { package = "tikv-jemalloc-sys", version = "0.6.0", features = [ tonic-build = "0.12" globwalk = "0.8.1" regex = "1" -prost-build = "0.13.2" +prost-build = "0.13.5" +prost-wkt-build = "0.6" easytier-rpc-build = { path = "../easytier-rpc-build", features = [ "internal-namespace", ] } diff --git a/easytier/build.rs b/easytier/build.rs index 363a70ef..ac7bedc8 100644 --- a/easytier/build.rs +++ b/easytier/build.rs @@ -1,5 +1,8 @@ #[cfg(target_os = "windows")] -use std::{env, io::Cursor, path::PathBuf}; +use std::io::Cursor; +use std::{env, path::PathBuf}; + +use prost_wkt_build::{FileDescriptorSet, Message as _}; #[cfg(target_os = "windows")] struct WindowsBuild {} @@ -157,30 +160,25 @@ fn main() -> Result<(), Box> { println!("cargo:rerun-if-changed={proto_file}"); } + let out = PathBuf::from(env::var("OUT_DIR").unwrap()); + let descriptor_file = out.join("descriptors.bin"); + let mut config = prost_build::Config::new(); config + .type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") + .extern_path(".google.protobuf.Any", "::prost_wkt_types::Any") + .extern_path(".google.protobuf.Timestamp", "::prost_wkt_types::Timestamp") + .extern_path(".google.protobuf.Value", "::prost_wkt_types::Value") + .file_descriptor_set_path(&descriptor_file) .protoc_arg("--experimental_allow_proto3_optional") - .type_attribute(".acl", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute(".api", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute(".web", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute(".config", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute( - "peer_rpc.GetIpListResponse", - "#[derive(serde::Serialize, serde::Deserialize)]", - ) .type_attribute("peer_rpc.DirectConnectedPeerInfo", "#[derive(Hash)]") .type_attribute("peer_rpc.PeerInfoForGlobalMap", "#[derive(Hash)]") .type_attribute("peer_rpc.ForeignNetworkRouteInfoKey", "#[derive(Hash, Eq)]") .type_attribute( "peer_rpc.RouteForeignNetworkSummary.Info", - "#[derive(Hash, Eq, serde::Serialize, serde::Deserialize)]", - ) - .type_attribute( - "peer_rpc.RouteForeignNetworkSummary", - "#[derive(Hash, Eq, serde::Serialize, serde::Deserialize)]", + "#[derive(Hash, Eq)]", ) + .type_attribute("peer_rpc.RouteForeignNetworkSummary", "#[derive(Hash, Eq)]") .type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]") .field_attribute(".api.manage.NetworkConfig", "#[serde(default)]") .service_generator(Box::new(easytier_rpc_build::ServiceGenerator::default())) @@ -193,6 +191,10 @@ fn main() -> Result<(), Box> { .file_descriptor_set_bytes("crate::proto::DESCRIPTOR_POOL_BYTES") .compile_protos_with_config(config, &proto_files_reflect, &["src/proto/"])?; + let descriptor_bytes = std::fs::read(descriptor_file).unwrap(); + let descriptor = FileDescriptorSet::decode(&descriptor_bytes[..]).unwrap(); + prost_wkt_build::add_serde(out, descriptor); + check_locale(); Ok(()) } diff --git a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs index ddc11401..c66bc531 100644 --- a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs +++ b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs @@ -342,7 +342,8 @@ impl PunchSymToConeHoleClient { async fn get_rpc_stub( &self, dst_peer_id: PeerId, - ) -> Box + std::marker::Send + 'static> { + ) -> Box + std::marker::Send + Sync + 'static> + { self.peer_mgr .get_peer_rpc_mgr() .rpc_client() diff --git a/easytier/src/proto/api.rs b/easytier/src/proto/api.rs index c9cdb768..0750c3cc 100644 --- a/easytier/src/proto/api.rs +++ b/easytier/src/proto/api.rs @@ -260,3 +260,99 @@ pub mod logger { pub mod manage { include!(concat!(env!("OUT_DIR"), "/api.manage.rs")); } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use prost::Message; + + use super::manage::{ + ListNetworkInstanceRequest, ListNetworkInstanceResponse, WebClientService, + WebClientServiceClient, WebClientServiceDescriptor, WebClientServiceMethodDescriptor, + }; + use crate::proto::common::Uuid; + use crate::proto::rpc_types::controller::BaseController; + use crate::proto::rpc_types::descriptor::ServiceDescriptor; + use crate::proto::rpc_types::error::Error; + use crate::proto::rpc_types::handler::Handler; + + #[derive(Clone, Default)] + struct WebClientServiceJsonCallHandler; + + #[async_trait::async_trait] + impl Handler for WebClientServiceJsonCallHandler { + type Descriptor = WebClientServiceDescriptor; + type Controller = BaseController; + + async fn call( + &self, + _ctrl: Self::Controller, + method: ::Method, + input: Bytes, + ) -> crate::proto::rpc_types::error::Result { + match method { + WebClientServiceMethodDescriptor::ListNetworkInstance => { + let _req = ListNetworkInstanceRequest::decode(input.as_ref()).unwrap(); + let resp = ListNetworkInstanceResponse { + inst_ids: vec![Uuid { + part1: 1, + part2: 2, + part3: 3, + part4: 4, + }], + }; + Ok(Bytes::from(resp.encode_to_vec())) + } + _ => Err(Error::ExecutionError(anyhow::anyhow!( + "unsupported method in test handler" + ))), + } + } + } + + #[tokio::test] + async fn web_client_service_call_json_method_supports_snake_and_proto_method_name() { + let client = WebClientServiceClient::new(WebClientServiceJsonCallHandler); + + let snake_result = client + .json_call_method( + BaseController::default(), + "list_network_instance", + serde_json::json!({}), + ) + .await + .unwrap(); + assert_eq!( + snake_result["inst_ids"][0], + serde_json::json!({ + "part1": 1, + "part2": 2, + "part3": 3, + "part4": 4 + }) + ); + + let proto_result = client + .json_call_method( + BaseController::default(), + "ListNetworkInstance", + serde_json::json!({}), + ) + .await + .unwrap(); + assert_eq!(proto_result["inst_ids"].as_array().unwrap().len(), 1); + } + + #[tokio::test] + async fn web_client_service_call_json_method_rejects_unknown_method() { + let client = WebClientServiceClient::new(WebClientServiceJsonCallHandler); + let ret = client + .json_call_method( + BaseController::default(), + "not_exist_method", + serde_json::json!({}), + ) + .await; + assert!(ret.is_err()); + } +} diff --git a/easytier/src/proto/tests.rs b/easytier/src/proto/tests.rs index b161a11e..8d9379e7 100644 --- a/easytier/src/proto/tests.rs +++ b/easytier/src/proto/tests.rs @@ -7,6 +7,101 @@ use tokio::task::JoinSet; use super::rpc_impl::RpcController; +#[derive(Clone, Default)] +struct GreetingJsonCallHandler; + +#[async_trait::async_trait] +impl crate::proto::rpc_types::handler::Handler for GreetingJsonCallHandler { + type Descriptor = GreetingDescriptor; + type Controller = crate::proto::rpc_types::controller::BaseController; + + async fn call( + &self, + _ctrl: Self::Controller, + method: ::Method, + input: bytes::Bytes, + ) -> crate::proto::rpc_types::error::Result { + use prost::Message; + match method { + GreetingMethodDescriptor::SayHello => { + let req = SayHelloRequest::decode(input)?; + let resp = SayHelloResponse { + greeting: format!("Hello {}!", req.name), + }; + Ok(bytes::Bytes::from(resp.encode_to_vec())) + } + GreetingMethodDescriptor::SayGoodbye => { + let req = SayGoodbyeRequest::decode(input)?; + let resp = SayGoodbyeResponse { + greeting: format!("Goodbye, {}!", req.name), + }; + Ok(bytes::Bytes::from(resp.encode_to_vec())) + } + } + } +} + +#[tokio::test] +async fn greeting_client_json_call_method_supports_snake_and_proto_method_name() { + let client = GreetingClient::new(GreetingJsonCallHandler); + + let snake = client + .json_call_method( + crate::proto::rpc_types::controller::BaseController::default(), + "say_hello", + serde_json::json!({"name": "world"}), + ) + .await + .unwrap(); + assert_eq!(snake["greeting"], serde_json::json!("Hello world!")); + + let proto = client + .json_call_method( + crate::proto::rpc_types::controller::BaseController::default(), + "SayHello", + serde_json::json!({"name": "world"}), + ) + .await + .unwrap(); + assert_eq!(proto["greeting"], serde_json::json!("Hello world!")); +} + +#[tokio::test] +async fn greeting_client_json_call_method_rejects_invalid_json() { + let client = GreetingClient::new(GreetingJsonCallHandler); + + let err = client + .json_call_method( + crate::proto::rpc_types::controller::BaseController::default(), + "say_hello", + serde_json::json!({"name": 123}), + ) + .await + .unwrap_err(); + assert!(matches!( + err, + crate::proto::rpc_types::error::Error::MalformatRpcPacket(_) + )); +} + +#[tokio::test] +async fn greeting_client_json_call_method_rejects_unknown_method() { + let client = GreetingClient::new(GreetingJsonCallHandler); + + let err = client + .json_call_method( + crate::proto::rpc_types::controller::BaseController::default(), + "not_exist_method", + serde_json::json!({"name": "world"}), + ) + .await + .unwrap_err(); + assert!(matches!( + err, + crate::proto::rpc_types::error::Error::InvalidMethodIndex(0, _) + )); +} + #[derive(Clone)] pub struct GreetingService { pub delay_ms: u64, diff --git a/easytier/src/rpc_service/api.rs b/easytier/src/rpc_service/api.rs index 23d2149e..e414e9e9 100644 --- a/easytier/src/rpc_service/api.rs +++ b/easytier/src/rpc_service/api.rs @@ -31,7 +31,7 @@ use crate::{ stats::StatsRpcService, vpn_portal::VpnPortalRpcService, }, tunnel::{tcp::TcpTunnelListener, TunnelListener}, - web_client::DefaultHooks, + web_client::{DefaultHooks, WebClientHooks}, }; pub struct ApiRpcServer { @@ -64,7 +64,7 @@ impl ApiRpcServer { impl ApiRpcServer { pub fn from_tunnel(tunnel: T, instance_manager: Arc) -> Self { let rpc_server = StandAloneServer::new(tunnel); - register_api_rpc_service(&instance_manager, rpc_server.registry()); + register_api_rpc_service(&instance_manager, rpc_server.registry(), None); Self { rpc_server } } } @@ -87,9 +87,10 @@ impl Drop for ApiRpcServer { } } -fn register_api_rpc_service( +pub fn register_api_rpc_service( instance_manager: &Arc, registry: &ServiceRegistry, + hooks: Option>, ) { registry.register( PeerManageRpcServer::new(PeerManageRpcService::new(instance_manager.clone())), @@ -148,7 +149,7 @@ fn register_api_rpc_service( registry.register( WebClientServiceServer::new(InstanceManageRpcService::new( instance_manager.clone(), - Arc::new(DefaultHooks), + hooks.unwrap_or(Arc::new(DefaultHooks)), )), "", ); diff --git a/easytier/src/rpc_service/mod.rs b/easytier/src/rpc_service/mod.rs index e06d05ab..cd6ca67e 100644 --- a/easytier/src/rpc_service/mod.rs +++ b/easytier/src/rpc_service/mod.rs @@ -1,5 +1,4 @@ mod acl_manage; -mod api; mod config; mod connector_manage; mod credential_manage; @@ -11,6 +10,7 @@ mod proxy; mod stats; mod vpn_portal; +pub mod api; pub mod instance_manage; pub mod logger; pub mod remote_client; diff --git a/easytier/src/web_client/controller.rs b/easytier/src/web_client/controller.rs index 84b3c057..c8c4c4b8 100644 --- a/easytier/src/web_client/controller.rs +++ b/easytier/src/web_client/controller.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use crate::{ - instance_manager::NetworkInstanceManager, - rpc_service::instance_manage::InstanceManageRpcService, web_client::WebClientHooks, + instance_manager::NetworkInstanceManager, proto::rpc_impl::service_registry::ServiceRegistry, + rpc_service::api::register_api_rpc_service, web_client::WebClientHooks, }; pub struct Controller { @@ -39,8 +39,8 @@ impl Controller { self.hostname.clone() } - pub fn get_rpc_service(&self) -> InstanceManageRpcService { - InstanceManageRpcService::new(self.manager.clone(), self.hooks.clone()) + pub fn register_api_rpc_service(&self, registry: &ServiceRegistry) { + register_api_rpc_service(&self.manager, registry, Some(self.hooks.clone())); } pub(super) fn notify_manager_stopping(&self) { diff --git a/easytier/src/web_client/session.rs b/easytier/src/web_client/session.rs index c4e0d251..75de2df1 100644 --- a/easytier/src/web_client/session.rs +++ b/easytier/src/web_client/session.rs @@ -9,7 +9,6 @@ use tokio::{ use crate::{ common::{constants::EASYTIER_VERSION, get_machine_id}, proto::{ - api::manage::WebClientServiceServer, rpc_impl::bidirect::BidirectRpcManager, rpc_types::controller::BaseController, web::{ @@ -43,10 +42,7 @@ impl Session { let rpc_mgr = BidirectRpcManager::new(); rpc_mgr.run_with_tunnel(tunnel); - rpc_mgr.rpc_server().registry().register( - WebClientServiceServer::new(controller.get_rpc_service()), - "", - ); + controller.register_api_rpc_service(rpc_mgr.rpc_server().registry()); let (tx, _rx1) = broadcast::channel(2); let heartbeat_ctx = HeartbeatCtx {