chore: update Rust to 2024 edition (#2066)

This commit is contained in:
Luna Yao
2026-04-09 18:22:12 +02:00
committed by GitHub
parent a8feb9ac2b
commit a879dd1b14
158 changed files with 1327 additions and 1231 deletions
+4
View File
@@ -14,6 +14,10 @@ exclude = [
"easytier-contrib/easytier-ohrs", # it needs ohrs sdk "easytier-contrib/easytier-ohrs", # it needs ohrs sdk
] ]
[workspace.package]
edition = "2024"
rust-version = "1.93.0"
[profile.dev] [profile.dev]
panic = "unwind" panic = "unwind"
debug = 2 debug = 2
@@ -1,7 +1,7 @@
[package] [package]
name = "easytier-android-jni" name = "easytier-android-jni"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition.workspace = true
[lib] [lib]
crate-type = ["cdylib"] crate-type = ["cdylib"]
@@ -1,7 +1,7 @@
use easytier::proto::api::manage::{NetworkInstanceRunningInfo, NetworkInstanceRunningInfoMap}; use easytier::proto::api::manage::{NetworkInstanceRunningInfo, NetworkInstanceRunningInfoMap};
use jni::JNIEnv;
use jni::objects::{JClass, JObjectArray, JString}; use jni::objects::{JClass, JObjectArray, JString};
use jni::sys::{jint, jstring}; use jni::sys::{jint, jstring};
use jni::JNIEnv;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::ffi::{CStr, CString}; use std::ffi::{CStr, CString};
use std::ptr; use std::ptr;
@@ -15,7 +15,7 @@ pub struct KeyValuePair {
} }
// 声明外部 C 函数 // 声明外部 C 函数
extern "C" { unsafe extern "C" {
fn set_tun_fd(inst_name: *const std::ffi::c_char, fd: std::ffi::c_int) -> std::ffi::c_int; fn set_tun_fd(inst_name: *const std::ffi::c_char, fd: std::ffi::c_int) -> std::ffi::c_int;
fn get_error_msg(out: *mut *const std::ffi::c_char); fn get_error_msg(out: *mut *const std::ffi::c_char);
fn free_string(s: *const std::ffi::c_char); fn free_string(s: *const std::ffi::c_char);
@@ -68,7 +68,7 @@ fn throw_exception(env: &mut JNIEnv, message: &str) {
} }
/// 设置 TUN 文件描述符 /// 设置 TUN 文件描述符
#[no_mangle] #[unsafe(no_mangle)]
pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_setTunFd( pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_setTunFd(
mut env: JNIEnv, mut env: JNIEnv,
_class: JClass, _class: JClass,
@@ -87,17 +87,17 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_setTunFd(
unsafe { unsafe {
let result = set_tun_fd(inst_name_cstr.as_ptr(), fd); let result = set_tun_fd(inst_name_cstr.as_ptr(), fd);
if result != 0 { if result != 0
if let Some(error) = get_last_error() { && let Some(error) = get_last_error()
throw_exception(&mut env, &error); {
} throw_exception(&mut env, &error);
} }
result result
} }
} }
/// 解析配置 /// 解析配置
#[no_mangle] #[unsafe(no_mangle)]
pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_parseConfig( pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_parseConfig(
mut env: JNIEnv, mut env: JNIEnv,
_class: JClass, _class: JClass,
@@ -115,17 +115,17 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_parseConfig(
unsafe { unsafe {
let result = parse_config(config_cstr.as_ptr()); let result = parse_config(config_cstr.as_ptr());
if result != 0 { if result != 0
if let Some(error) = get_last_error() { && let Some(error) = get_last_error()
throw_exception(&mut env, &error); {
} throw_exception(&mut env, &error);
} }
result result
} }
} }
/// 运行网络实例 /// 运行网络实例
#[no_mangle] #[unsafe(no_mangle)]
pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_runNetworkInstance( pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_runNetworkInstance(
mut env: JNIEnv, mut env: JNIEnv,
_class: JClass, _class: JClass,
@@ -143,17 +143,17 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_runNetworkInstance(
unsafe { unsafe {
let result = run_network_instance(config_cstr.as_ptr()); let result = run_network_instance(config_cstr.as_ptr());
if result != 0 { if result != 0
if let Some(error) = get_last_error() { && let Some(error) = get_last_error()
throw_exception(&mut env, &error); {
} throw_exception(&mut env, &error);
} }
result result
} }
} }
/// 保持网络实例 /// 保持网络实例
#[no_mangle] #[unsafe(no_mangle)]
pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_retainNetworkInstance( pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_retainNetworkInstance(
mut env: JNIEnv, mut env: JNIEnv,
_class: JClass, _class: JClass,
@@ -165,10 +165,10 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_retainNetworkInstance(
if instance_names.is_null() { if instance_names.is_null() {
unsafe { unsafe {
let result = retain_network_instance(ptr::null(), 0); let result = retain_network_instance(ptr::null(), 0);
if result != 0 { if result != 0
if let Some(error) = get_last_error() { && let Some(error) = get_last_error()
throw_exception(&mut env, &error); {
} throw_exception(&mut env, &error);
} }
return result; return result;
} }
@@ -187,10 +187,10 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_retainNetworkInstance(
if array_length == 0 { if array_length == 0 {
unsafe { unsafe {
let result = retain_network_instance(ptr::null(), 0); let result = retain_network_instance(ptr::null(), 0);
if result != 0 { if result != 0
if let Some(error) = get_last_error() { && let Some(error) = get_last_error()
throw_exception(&mut env, &error); {
} throw_exception(&mut env, &error);
} }
return result; return result;
} }
@@ -234,17 +234,17 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_retainNetworkInstance(
unsafe { unsafe {
let result = retain_network_instance(c_string_ptrs.as_ptr(), c_string_ptrs.len()); let result = retain_network_instance(c_string_ptrs.as_ptr(), c_string_ptrs.len());
if result != 0 { if result != 0
if let Some(error) = get_last_error() { && let Some(error) = get_last_error()
throw_exception(&mut env, &error); {
} throw_exception(&mut env, &error);
} }
result result
} }
} }
/// 收集网络信息 /// 收集网络信息
#[no_mangle] #[unsafe(no_mangle)]
pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_collectNetworkInfos( pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_collectNetworkInfos(
mut env: JNIEnv, mut env: JNIEnv,
_class: JClass, _class: JClass,
@@ -304,7 +304,7 @@ pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_collectNetworkInfos(
} }
/// 获取最后的错误信息 /// 获取最后的错误信息
#[no_mangle] #[unsafe(no_mangle)]
pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_getLastError( pub extern "system" fn Java_com_easytier_jni_EasyTierJNI_getLastError(
env: JNIEnv, env: JNIEnv,
_class: JClass, _class: JClass,
+1 -1
View File
@@ -1,7 +1,7 @@
[package] [package]
name = "easytier-ffi" name = "easytier-ffi"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition.workspace = true
[lib] [lib]
crate-type = ["cdylib"] crate-type = ["cdylib"]
+7 -7
View File
@@ -30,7 +30,7 @@ fn set_error_msg(msg: &str) {
/// # Safety /// # Safety
/// Set the tun fd /// Set the tun fd
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn set_tun_fd( pub unsafe extern "C" fn set_tun_fd(
inst_name: *const std::ffi::c_char, inst_name: *const std::ffi::c_char,
fd: std::ffi::c_int, fd: std::ffi::c_int,
@@ -59,7 +59,7 @@ pub unsafe extern "C" fn set_tun_fd(
/// # Safety /// # Safety
/// Get the last error message /// Get the last error message
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) { pub unsafe extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) {
let msg_buf = ERROR_MSG.lock().unwrap(); let msg_buf = ERROR_MSG.lock().unwrap();
if msg_buf.is_empty() { if msg_buf.is_empty() {
@@ -74,7 +74,7 @@ pub unsafe extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) {
} }
} }
#[no_mangle] #[unsafe(no_mangle)]
pub extern "C" fn free_string(s: *const std::ffi::c_char) { pub extern "C" fn free_string(s: *const std::ffi::c_char) {
if s.is_null() { if s.is_null() {
return; return;
@@ -86,7 +86,7 @@ pub extern "C" fn free_string(s: *const std::ffi::c_char) {
/// # Safety /// # Safety
/// Parse the config /// Parse the config
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int { pub unsafe extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
let cfg_str = unsafe { let cfg_str = unsafe {
assert!(!cfg_str.is_null()); assert!(!cfg_str.is_null());
@@ -105,7 +105,7 @@ pub unsafe extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::
/// # Safety /// # Safety
/// Run the network instance /// Run the network instance
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int { pub unsafe extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
let cfg_str = unsafe { let cfg_str = unsafe {
assert!(!cfg_str.is_null()); assert!(!cfg_str.is_null());
@@ -144,7 +144,7 @@ pub unsafe extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char)
/// # Safety /// # Safety
/// Retain the network instance /// Retain the network instance
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn retain_network_instance( pub unsafe extern "C" fn retain_network_instance(
inst_names: *const *const std::ffi::c_char, inst_names: *const *const std::ffi::c_char,
length: usize, length: usize,
@@ -188,7 +188,7 @@ pub unsafe extern "C" fn retain_network_instance(
/// # Safety /// # Safety
/// Collect the network infos /// Collect the network infos
#[no_mangle] #[unsafe(no_mangle)]
pub unsafe extern "C" fn collect_network_infos( pub unsafe extern "C" fn collect_network_infos(
infos: *mut KeyValuePair, infos: *mut KeyValuePair,
max_length: usize, max_length: usize,
+1 -1
View File
@@ -1,7 +1,7 @@
[package] [package]
name = "easytier-uptime" name = "easytier-uptime"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition.workspace = true
[dependencies] [dependencies]
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }
@@ -1,7 +1,7 @@
use std::ops::{Div, Mul}; use std::ops::{Div, Mul};
use axum::extract::{Path, State};
use axum::Json; use axum::Json;
use axum::extract::{Path, State};
use sea_orm::{ use sea_orm::{
ColumnTrait, Condition, EntityTrait, IntoActiveModel, ModelTrait, Order, PaginatorTrait, ColumnTrait, Condition, EntityTrait, IntoActiveModel, ModelTrait, Order, PaginatorTrait,
QueryFilter, QueryOrder, QuerySelect, Set, TryIntoModel, QueryFilter, QueryOrder, QuerySelect, Set, TryIntoModel,
@@ -14,7 +14,7 @@ use crate::api::{
models::*, models::*,
}; };
use crate::db::entity::{self, health_records, shared_nodes}; use crate::db::entity::{self, health_records, shared_nodes};
use crate::db::{operations::*, Db}; use crate::db::{Db, operations::*};
use crate::health_checker_manager::HealthCheckerManager; use crate::health_checker_manager::HealthCheckerManager;
use axum_extra::extract::Query; use axum_extra::extract::Query;
use std::sync::Arc; use std::sync::Arc;
@@ -273,7 +273,7 @@ pub struct InstanceFilterParams {
use crate::config::AppConfig; use crate::config::AppConfig;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::Serialize; use serde::Serialize;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@@ -370,19 +370,19 @@ pub async fn admin_get_nodes(
let ids = NodeOperations::filter_node_ids_by_tag(&app_state.db, &tag).await?; let ids = NodeOperations::filter_node_ids_by_tag(&app_state.db, &tag).await?;
filtered_ids = Some(ids); filtered_ids = Some(ids);
} }
if let Some(tags) = filters.tags { if let Some(tags) = filters.tags
if !tags.is_empty() { && !tags.is_empty()
let ids_any = NodeOperations::filter_node_ids_by_tags_any(&app_state.db, &tags).await?; {
filtered_ids = match filtered_ids { let ids_any = NodeOperations::filter_node_ids_by_tags_any(&app_state.db, &tags).await?;
Some(mut existing) => { filtered_ids = match filtered_ids {
existing.extend(ids_any); Some(mut existing) => {
existing.sort(); existing.extend(ids_any);
existing.dedup(); existing.sort();
Some(existing) existing.dedup();
} Some(existing)
None => Some(ids_any), }
}; None => Some(ids_any),
} };
} }
if let Some(ids) = filtered_ids { if let Some(ids) = filtered_ids {
if ids.is_empty() { if ids.is_empty() {
@@ -1,5 +1,5 @@
use axum::routing::{delete, get, post, put};
use axum::Router; use axum::Router;
use axum::routing::{delete, get, post, put};
use tower_http::compression::CompressionLayer; use tower_http::compression::CompressionLayer;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
@@ -1,7 +1,7 @@
use crate::db::entity::*;
use crate::db::Db; use crate::db::Db;
use crate::db::entity::*;
use sea_orm::*; use sea_orm::*;
use tokio::time::{sleep, Duration}; use tokio::time::{Duration, sleep};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
/// 数据清理策略配置 /// 数据清理策略配置
@@ -5,12 +5,12 @@ pub mod operations;
use std::fmt; use std::fmt;
use sea_orm::{ use sea_orm::{
prelude::*, sea_query::OnConflict, ColumnTrait as _, DatabaseConnection, DbErr, EntityTrait, ColumnTrait as _, DatabaseConnection, DbErr, EntityTrait, QueryFilter as _, Set,
QueryFilter as _, Set, SqlxSqliteConnector, Statement, TransactionTrait as _, SqlxSqliteConnector, Statement, TransactionTrait as _, prelude::*, sea_query::OnConflict,
}; };
use sea_orm_migration::MigratorTrait as _; use sea_orm_migration::MigratorTrait as _;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{migrate::MigrateDatabase as _, Sqlite, SqlitePool}; use sqlx::{Sqlite, SqlitePool, migrate::MigrateDatabase as _};
use crate::migrator; use crate::migrator;
@@ -1,8 +1,8 @@
use crate::api::CreateNodeRequest; use crate::api::CreateNodeRequest;
use crate::db::entity::*;
use crate::db::Db; use crate::db::Db;
use crate::db::HealthStats; use crate::db::HealthStats;
use crate::db::HealthStatus; use crate::db::HealthStatus;
use crate::db::entity::*;
use sea_orm::*; use sea_orm::*;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
@@ -19,9 +19,9 @@ use sqlx::any;
use tracing::{debug, error, info, instrument, warn}; use tracing::{debug, error, info, instrument, warn};
use crate::db::{ use crate::db::{
Db, HealthStatus,
entity::shared_nodes, entity::shared_nodes,
operations::{HealthOperations, NodeOperations}, operations::{HealthOperations, NodeOperations},
Db, HealthStatus,
}; };
pub struct HealthCheckOneNode { pub struct HealthCheckOneNode {
@@ -1,11 +1,11 @@
use std::{collections::HashSet, sync::Arc, time::Duration}; use std::{collections::HashSet, sync::Arc, time::Duration};
use anyhow::Context as _; use anyhow::Context as _;
use tokio::time::{interval, Interval}; use tokio::time::{Interval, interval};
use tracing::{error, info}; use tracing::{error, info};
use crate::{ use crate::{
db::{entity::shared_nodes, operations::NodeOperations, Db}, db::{Db, entity::shared_nodes, operations::NodeOperations},
health_checker::HealthChecker, health_checker::HealthChecker,
}; };
+4 -2
View File
@@ -10,7 +10,7 @@ mod migrator;
use api::routes::create_routes; use api::routes::create_routes;
use clap::Parser; use clap::Parser;
use config::AppConfig; use config::AppConfig;
use db::{operations::NodeOperations, Db}; use db::{Db, operations::NodeOperations};
use easytier::common::log; use easytier::common::log;
use health_checker::HealthChecker; use health_checker::HealthChecker;
use health_checker_manager::HealthCheckerManager; use health_checker_manager::HealthCheckerManager;
@@ -49,7 +49,9 @@ async fn main() -> anyhow::Result<()> {
// 如果提供了管理员密码,设置环境变量 // 如果提供了管理员密码,设置环境变量
if let Some(password) = args.admin_password { if let Some(password) = args.admin_password {
env::set_var("ADMIN_PASSWORD", password); unsafe {
env::set_var("ADMIN_PASSWORD", password);
}
} }
tracing::info!( tracing::info!(
+1 -1
View File
@@ -3,7 +3,7 @@ name = "easytier-gui"
version = "2.6.0" version = "2.6.0"
description = "EasyTier GUI" description = "EasyTier GUI"
authors = ["you"] authors = ["you"]
edition = "2021" edition.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+1 -1
View File
@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/ *--------------------------------------------------------------------------------------------*/
use super::Command; use super::Command;
use anyhow::{anyhow, Result}; use anyhow::{Result, anyhow};
use std::env; use std::env;
use std::ffi::OsStr; use std::ffi::OsStr;
use std::process::{Command as StdCommand, Output}; use std::process::{Command as StdCommand, Output};
+2 -2
View File
@@ -30,10 +30,10 @@ use std::os::unix::process::ExitStatusExt;
use std::path::Path; use std::path::Path;
use std::ptr; use std::ptr;
use libc::{fileno, wait, EINTR, SHUT_WR}; use libc::{EINTR, SHUT_WR, fileno, wait};
use security_framework_sys::authorization::{ use security_framework_sys::authorization::{
errAuthorizationSuccess, kAuthorizationFlagDefaults, kAuthorizationFlagDestroyRights,
AuthorizationCreate, AuthorizationExecuteWithPrivileges, AuthorizationFree, AuthorizationRef, AuthorizationCreate, AuthorizationExecuteWithPrivileges, AuthorizationFree, AuthorizationRef,
errAuthorizationSuccess, kAuthorizationFlagDefaults, kAuthorizationFlagDestroyRights,
}; };
const ENV_PATH: &str = "PATH"; const ENV_PATH: &str = "PATH";
@@ -11,11 +11,11 @@ use std::process::{ExitStatus, Output};
use winapi::shared::minwindef::{DWORD, LPVOID}; use winapi::shared::minwindef::{DWORD, LPVOID};
use winapi::um::processthreadsapi::{GetCurrentProcess, OpenProcessToken}; use winapi::um::processthreadsapi::{GetCurrentProcess, OpenProcessToken};
use winapi::um::securitybaseapi::GetTokenInformation; use winapi::um::securitybaseapi::GetTokenInformation;
use winapi::um::winnt::{TokenElevation, HANDLE, TOKEN_ELEVATION, TOKEN_QUERY}; use winapi::um::winnt::{HANDLE, TOKEN_ELEVATION, TOKEN_QUERY, TokenElevation};
use windows::core::{w, HSTRING, PCWSTR};
use windows::Win32::Foundation::HWND; use windows::Win32::Foundation::HWND;
use windows::Win32::UI::Shell::ShellExecuteW; use windows::Win32::UI::Shell::ShellExecuteW;
use windows::Win32::UI::WindowsAndMessaging::SW_HIDE; use windows::Win32::UI::WindowsAndMessaging::SW_HIDE;
use windows::core::{HSTRING, PCWSTR, w};
/// The implementation of state check and elevated executing varies on each platform /// The implementation of state check and elevated executing varies on each platform
impl Command { impl Command {
+34 -34
View File
@@ -21,9 +21,9 @@ use easytier::{
instance_manager::NetworkInstanceManager, instance_manager::NetworkInstanceManager,
launcher::NetworkConfig, launcher::NetworkConfig,
rpc_service::ApiRpcServer, rpc_service::ApiRpcServer,
tunnel::TunnelListener,
tunnel::ring::RingTunnelListener, tunnel::ring::RingTunnelListener,
tunnel::tcp::TcpTunnelListener, tunnel::tcp::TcpTunnelListener,
tunnel::TunnelListener,
utils::{self}, utils::{self},
}; };
use std::ops::Deref; use std::ops::Deref;
@@ -559,10 +559,10 @@ fn toggle_window_visibility(app: &tauri::AppHandle) {
} }
fn get_exe_path() -> String { fn get_exe_path() -> String {
if let Ok(appimage_path) = std::env::var("APPIMAGE") { if let Ok(appimage_path) = std::env::var("APPIMAGE")
if !appimage_path.is_empty() { && !appimage_path.is_empty()
return appimage_path; {
} return appimage_path;
} }
std::env::current_exe() std::env::current_exe()
.map(|p| p.to_string_lossy().to_string()) .map(|p| p.to_string_lossy().to_string())
@@ -596,8 +596,8 @@ mod manager {
use easytier::proto::rpc_types::controller::BaseController; use easytier::proto::rpc_types::controller::BaseController;
use easytier::rpc_service::logger::LoggerRpcService; use easytier::rpc_service::logger::LoggerRpcService;
use easytier::rpc_service::remote_client::PersistentConfig; use easytier::rpc_service::remote_client::PersistentConfig;
use easytier::tunnel::ring::RingTunnelConnector;
use easytier::tunnel::TunnelConnector; use easytier::tunnel::TunnelConnector;
use easytier::tunnel::ring::RingTunnelConnector;
use easytier::web_client::WebClientHooks; use easytier::web_client::WebClientHooks;
pub(super) struct GuiHooks { pub(super) struct GuiHooks {
@@ -979,34 +979,34 @@ mod manager {
.get_rpc_client(app.clone()) .get_rpc_client(app.clone())
.ok_or_else(|| anyhow::anyhow!("RPC client not found"))?; .ok_or_else(|| anyhow::anyhow!("RPC client not found"))?;
for id in enabled_networks { for id in enabled_networks {
if let Ok(uuid) = id.parse() { if let Ok(uuid) = id.parse()
if !self.storage.enabled_networks.contains(&uuid) { && !self.storage.enabled_networks.contains(&uuid)
let config = self {
.storage let config = self
.network_configs .storage
.get(&uuid) .network_configs
.map(|i| i.value().1.clone()); .get(&uuid)
let Some(config) = config else { .map(|i| i.value().1.clone());
continue; let Some(config) = config else {
}; continue;
let toml_config = config.gen_config()?; };
self.pre_run_network_instance_hook(&app, &toml_config) let toml_config = config.gen_config()?;
.await self.pre_run_network_instance_hook(&app, &toml_config)
.map_err(|e| anyhow::anyhow!(e))?; .await
client .map_err(|e| anyhow::anyhow!(e))?;
.run_network_instance( client
BaseController::default(), .run_network_instance(
RunNetworkInstanceRequest { BaseController::default(),
inst_id: None, RunNetworkInstanceRequest {
config: Some(config), inst_id: None,
overwrite: false, config: Some(config),
}, overwrite: false,
) },
.await?; )
self.post_run_network_instance_hook(&app, &uuid) .await?;
.await self.post_run_network_instance_hook(&app, &uuid)
.map_err(|e| anyhow::anyhow!(e))?; .await
} .map_err(|e| anyhow::anyhow!(e))?;
} }
} }
Ok(()) Ok(())
+1 -2
View File
@@ -2,13 +2,12 @@
name = "easytier-rpc-build" name = "easytier-rpc-build"
description = "Protobuf RPC Service Generator for EasyTier" description = "Protobuf RPC Service Generator for EasyTier"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition.workspace = true
homepage = "https://github.com/EasyTier/EasyTier" homepage = "https://github.com/EasyTier/EasyTier"
repository = "https://github.com/EasyTier/EasyTier" repository = "https://github.com/EasyTier/EasyTier"
authors = ["kkrainbow"] authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"] keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"] categories = ["network-programming", "command-line-utilities"]
rust-version = "1.93.0"
license-file = "LICENSE" license-file = "LICENSE"
readme = "README.md" readme = "README.md"
+1 -1
View File
@@ -1,7 +1,7 @@
[package] [package]
name = "easytier-web" name = "easytier-web"
version = "2.6.0" version = "2.6.0"
edition = "2021" edition.workspace = true
description = "Config server for easytier. easytier-core gets config from this and web frontend use it as restful api server." description = "Config server for easytier. easytier-core gets config from this and web frontend use it as restful api server."
[dependencies] [dependencies]
+4 -4
View File
@@ -2,8 +2,8 @@ pub mod session;
pub mod storage; pub mod storage;
use std::sync::{ use std::sync::{
atomic::{AtomicU32, Ordering},
Arc, Arc,
atomic::{AtomicU32, Ordering},
}; };
use dashmap::DashMap; use dashmap::DashMap;
@@ -19,11 +19,11 @@ use maxminddb::geoip2;
use session::{Location, Session}; use session::{Location, Session};
use storage::{Storage, StorageToken}; use storage::{Storage, StorageToken};
use crate::webhook::SharedWebhookConfig;
use crate::FeatureFlags; use crate::FeatureFlags;
use crate::webhook::SharedWebhookConfig;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use crate::db::{entity::user_running_network_configs, Db, UserIdInDb}; use crate::db::{Db, UserIdInDb, entity::user_running_network_configs};
#[derive(rust_embed::Embed)] #[derive(rust_embed::Embed)]
#[folder = "resources/"] #[folder = "resources/"]
@@ -340,7 +340,7 @@ mod tests {
}; };
use sqlx::Executor; use sqlx::Executor;
use crate::{client_manager::ClientManager, db::Db, FeatureFlags}; use crate::{FeatureFlags, client_manager::ClientManager, db::Db};
#[tokio::test] #[tokio::test]
async fn test_client() { async fn test_client() {
+25 -25
View File
@@ -20,11 +20,11 @@ use easytier::{
rpc_service::remote_client::{ListNetworkProps, Storage as _}, rpc_service::remote_client::{ListNetworkProps, Storage as _},
tunnel::Tunnel, tunnel::Tunnel,
}; };
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{RwLock, broadcast};
use super::storage::{Storage, StorageToken, WeakRefStorage}; use super::storage::{Storage, StorageToken, WeakRefStorage};
use crate::webhook::SharedWebhookConfig;
use crate::FeatureFlags; use crate::FeatureFlags;
use crate::webhook::SharedWebhookConfig;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Location { pub struct Location {
@@ -87,30 +87,30 @@ impl SessionData {
impl Drop for SessionData { impl Drop for SessionData {
fn drop(&mut self) { fn drop(&mut self) {
if let Ok(storage) = Storage::try_from(self.storage.clone()) { if let Ok(storage) = Storage::try_from(self.storage.clone())
if let Some(token) = self.storage_token.as_ref() { && let Some(token) = self.storage_token.as_ref()
storage.remove_client(token); {
storage.remove_client(token);
// Notify the webhook receiver when a node disconnects. // Notify the webhook receiver when a node disconnects.
if self.webhook_config.is_enabled() { if self.webhook_config.is_enabled() {
let webhook = self.webhook_config.clone(); let webhook = self.webhook_config.clone();
let machine_id = token.machine_id.to_string(); let machine_id = token.machine_id.to_string();
let user_id = Some(token.user_id); let user_id = Some(token.user_id);
let token_value = token.token.clone(); let token_value = token.token.clone();
let web_instance_id = webhook.web_instance_id.clone(); let web_instance_id = webhook.web_instance_id.clone();
let binding_version = self.binding_version; let binding_version = self.binding_version;
tokio::spawn(async move { tokio::spawn(async move {
webhook webhook
.notify_node_disconnected(&crate::webhook::NodeDisconnectedRequest { .notify_node_disconnected(&crate::webhook::NodeDisconnectedRequest {
machine_id, machine_id,
token: token_value, token: token_value,
user_id, user_id,
web_instance_id, web_instance_id,
binding_version, binding_version,
}) })
.await; .await;
}); });
}
} }
} }
} }
+4 -4
View File
@@ -8,11 +8,11 @@ use easytier::{
}; };
use entity::user_running_network_configs; use entity::user_running_network_configs;
use sea_orm::{ use sea_orm::{
prelude::Expr, sea_query::OnConflict, ColumnTrait as _, DatabaseConnection, DbErr, EntityTrait, ColumnTrait as _, DatabaseConnection, DbErr, EntityTrait, QueryFilter as _, Set,
QueryFilter as _, Set, SqlxSqliteConnector, TransactionTrait as _, SqlxSqliteConnector, TransactionTrait as _, prelude::Expr, sea_query::OnConflict,
}; };
use sea_orm_migration::MigratorTrait as _; use sea_orm_migration::MigratorTrait as _;
use sqlx::{migrate::MigrateDatabase as _, types::chrono, Sqlite, SqlitePool}; use sqlx::{Sqlite, SqlitePool, migrate::MigrateDatabase as _, types::chrono};
use uuid::Uuid; use uuid::Uuid;
use crate::migrator; use crate::migrator;
@@ -280,7 +280,7 @@ mod tests {
use easytier::{proto::api::manage::NetworkConfig, rpc_service::remote_client::Storage}; use easytier::{proto::api::manage::NetworkConfig, rpc_service::remote_client::Storage};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _}; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _};
use crate::db::{entity::user_running_network_configs, Db, ListNetworkProps}; use crate::db::{Db, ListNetworkProps, entity::user_running_network_configs};
#[tokio::test] #[tokio::test]
async fn test_user_network_config_management() { async fn test_user_network_config_management() {
+1 -1
View File
@@ -16,7 +16,7 @@ use easytier::{
log, log,
network::{local_ipv4, local_ipv6}, network::{local_ipv4, local_ipv6},
}, },
tunnel::{tcp::TcpTunnelListener, udp::UdpTunnelListener, TunnelListener}, tunnel::{TunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener},
utils::setup_panic_handler, utils::setup_panic_handler,
}; };
+12 -11
View File
@@ -1,7 +1,7 @@
use axum::{ use axum::{
Router,
http::StatusCode, http::StatusCode,
routing::{get, post, put}, routing::{get, post, put},
Router,
}; };
use axum_login::login_required; use axum_login::login_required;
use axum_messages::Message; use axum_messages::Message;
@@ -14,8 +14,8 @@ use std::sync::Arc;
use crate::FeatureFlags; use crate::FeatureFlags;
use super::{ use super::{
users::{AuthSession, Credentials},
AppStateInner, AppStateInner,
users::{AuthSession, Credentials},
}; };
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@@ -44,7 +44,7 @@ mod put {
use axum_login::AuthUser; use axum_login::AuthUser;
use easytier::proto::common::Void; use easytier::proto::common::Void;
use crate::restful::{other_error, users::ChangePassword, HttpHandleError}; use crate::restful::{HttpHandleError, other_error, users::ChangePassword};
use super::*; use super::*;
@@ -71,14 +71,14 @@ mod put {
} }
mod post { mod post {
use axum::{extract::Extension, Json}; use axum::{Json, extract::Extension};
use easytier::proto::common::Void; use easytier::proto::common::Void;
use crate::restful::{ use crate::restful::{
captcha::extension::{axum_tower_sessions::CaptchaAxumTowerSessionStaticExt, CaptchaUtil}, HttpHandleError,
captcha::extension::{CaptchaUtil, axum_tower_sessions::CaptchaAxumTowerSessionStaticExt},
other_error, other_error,
users::RegisterNewUser, users::RegisterNewUser,
HttpHandleError,
}; };
use super::*; use super::*;
@@ -99,7 +99,7 @@ mod post {
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))), Json::from(other_error(format!("{:?}", e))),
)) ));
} }
}; };
@@ -150,14 +150,15 @@ mod post {
mod get { mod get {
use crate::restful::{ use crate::restful::{
HttpHandleError,
captcha::{ captcha::{
builder::spec::SpecCaptcha,
extension::{axum_tower_sessions::CaptchaAxumTowerSessionExt as _, CaptchaUtil},
NewCaptcha as _, NewCaptcha as _,
builder::spec::SpecCaptcha,
extension::{CaptchaUtil, axum_tower_sessions::CaptchaAxumTowerSessionExt as _},
}, },
other_error, HttpHandleError, other_error,
}; };
use axum::{response::Response, Json}; use axum::{Json, response::Response};
use easytier::proto::common::Void; use easytier::proto::common::Void;
use tower_sessions::Session; use tower_sessions::Session;
@@ -2,8 +2,8 @@ use super::super::base::randoms::Randoms;
use super::super::utils::color::Color; use super::super::utils::color::Color;
use super::super::utils::font; use super::super::utils::font;
use base64::prelude::BASE64_STANDARD;
use base64::Engine; use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use rusttype::Font; use rusttype::Font;
use std::fmt::Debug; use std::fmt::Debug;
@@ -9,14 +9,14 @@ use super::super::{CaptchaFont, NewCaptcha};
use image::{ImageBuffer, Rgba}; use image::{ImageBuffer, Rgba};
use imageproc::drawing; use imageproc::drawing;
use rand::{rngs::ThreadRng, Rng}; use rand::{Rng, rngs::ThreadRng};
use rusttype::{Font, Scale}; use rusttype::{Font, Scale};
use std::io::{Cursor, Write}; use std::io::{Cursor, Write};
use std::sync::Arc; use std::sync::Arc;
mod color { mod color {
use image::Rgba; use image::Rgba;
use rand::{rngs::ThreadRng, Rng}; use rand::{Rng, rngs::ThreadRng};
pub fn gen_background_color(rng: &mut ThreadRng) -> Rgba<u8> { pub fn gen_background_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(200..=255); let red = rng.gen_range(200..=255);
let green = rng.gen_range(200..=255); let green = rng.gen_range(200..=255);
@@ -133,7 +133,7 @@ impl<'a, 'b> CaptchaBuilder<'a, 'b> {
fn draw_line(&self, image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>, rng: &mut ThreadRng) { fn draw_line(&self, image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>, rng: &mut ThreadRng) {
let line_color = color::gen_line_color(rng); let line_color = color::gen_line_color(rng);
let is_h = rng.gen(); let is_h = rng.r#gen();
let (start, end) = if is_h { let (start, end) = if is_h {
let xa = rng.gen_range(0.0..(self.width as f32) / 2.0); let xa = rng.gen_range(0.0..(self.width as f32) / 2.0);
let ya = rng.gen_range(0.0..(self.height as f32)); let ya = rng.gen_range(0.0..(self.height as f32));
+6 -6
View File
@@ -8,13 +8,13 @@ mod users;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use axum::extract::Path; use axum::extract::Path;
use axum::http::{header, Request, StatusCode}; use axum::http::{Request, StatusCode, header};
use axum::middleware::{self as axum_mw, Next}; use axum::middleware::{self as axum_mw, Next};
use axum::response::Response; use axum::response::Response;
use axum::routing::{delete, post}; use axum::routing::{delete, post};
use axum::{extract::State, routing::get, Extension, Json, Router}; use axum::{Extension, Json, Router, extract::State, routing::get};
use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer}; use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer};
use axum_login::{login_required, AuthManagerLayerBuilder, AuthUser, AuthzBackend}; use axum_login::{AuthManagerLayerBuilder, AuthUser, AuthzBackend, login_required};
use axum_messages::MessagesManagerLayer; use axum_messages::MessagesManagerLayer;
use easytier::common::config::{ConfigLoader, TomlConfigLoader}; use easytier::common::config::{ConfigLoader, TomlConfigLoader};
use easytier::common::scoped_task::ScopedTask; use easytier::common::scoped_task::ScopedTask;
@@ -23,17 +23,17 @@ use easytier::proto::rpc_types;
use network::NetworkApi; use network::NetworkApi;
use sea_orm::DbErr; use sea_orm::DbErr;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_sessions::Expiry;
use tower_sessions::cookie::time::Duration; use tower_sessions::cookie::time::Duration;
use tower_sessions::cookie::{Key, SameSite}; use tower_sessions::cookie::{Key, SameSite};
use tower_sessions::Expiry;
use tower_sessions_sqlx_store::SqliteStore; use tower_sessions_sqlx_store::SqliteStore;
use users::{AuthSession, Backend}; use users::{AuthSession, Backend};
use crate::client_manager::storage::StorageToken; use crate::FeatureFlags;
use crate::client_manager::ClientManager; use crate::client_manager::ClientManager;
use crate::client_manager::storage::StorageToken;
use crate::db::{Db, UserIdInDb}; use crate::db::{Db, UserIdInDb};
use crate::webhook::SharedWebhookConfig; use crate::webhook::SharedWebhookConfig;
use crate::FeatureFlags;
/// Embed assets for web dashboard, build frontend first /// Embed assets for web dashboard, build frontend first
#[cfg(feature = "embed")] #[cfg(feature = "embed")]
+2 -2
View File
@@ -1,7 +1,7 @@
use axum::extract::Path; use axum::extract::Path;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::{delete, post}; use axum::routing::{delete, post};
use axum::{extract::State, routing::get, Json, Router}; use axum::{Json, Router, extract::State, routing::get};
use axum_login::AuthUser; use axum_login::AuthUser;
use easytier::launcher::NetworkConfig; use easytier::launcher::NetworkConfig;
use easytier::proto::common::Void; use easytier::proto::common::Void;
@@ -16,7 +16,7 @@ use crate::db::UserIdInDb;
use super::users::AuthSession; use super::users::AuthSession;
use super::{ use super::{
convert_db_error, other_error, AppState, AppStateInner, Error, HttpHandleError, RpcError, AppState, AppStateInner, Error, HttpHandleError, RpcError, convert_db_error, other_error,
}; };
fn convert_rpc_error(e: RpcError) -> (StatusCode, Json<Error>) { fn convert_rpc_error(e: RpcError) -> (StatusCode, Json<Error>) {
+13 -12
View File
@@ -4,8 +4,8 @@ use std::time::Duration;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use axum::routing::get;
use axum::Router; use axum::Router;
use axum::routing::get;
use openidconnect::core::{ use openidconnect::core::{
CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
@@ -216,7 +216,9 @@ impl OidcConfig {
} = opts; } = opts;
if oidc_issuer_url.is_none() || oidc_client_id.is_none() || oidc_redirect_url.is_none() { if oidc_issuer_url.is_none() || oidc_client_id.is_none() || oidc_redirect_url.is_none() {
return Err(anyhow::anyhow!("--oidc-issuer-url, --oidc-client-id and --oidc-redirect-url are required when using OIDC authentication")); return Err(anyhow::anyhow!(
"--oidc-issuer-url, --oidc-client-id and --oidc-redirect-url are required when using OIDC authentication"
));
} }
if oidc_username_claim.trim().is_empty() { if oidc_username_claim.trim().is_empty() {
return Err(anyhow::anyhow!("--oidc-username-claim cannot be empty")); return Err(anyhow::anyhow!("--oidc-username-claim cannot be empty"));
@@ -373,18 +375,17 @@ mod route {
) )
.into_response(); .into_response();
} }
if let Some(verifier) = pkce_verifier { if let Some(verifier) = pkce_verifier
if let Err(e) = session && let Err(e) = session
.insert("oidc_pkce_verifier", verifier.secret().clone()) .insert("oidc_pkce_verifier", verifier.secret().clone())
.await .await
{ {
tracing::error!("Failed to store pkce_verifier in session: {:?}", e); tracing::error!("Failed to store pkce_verifier in session: {:?}", e);
return ( return (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Session error")), Json(other_error("Session error")),
) )
.into_response(); .into_response();
}
} }
if let Err(e) = session.insert("oidc_pkce_used", pkce_enabled).await { if let Err(e) = session.insert("oidc_pkce_used", pkce_enabled).await {
tracing::error!("Failed to store pkce_used in session: {:?}", e); tracing::error!("Failed to store pkce_used in session: {:?}", e);
+3 -3
View File
@@ -1,15 +1,15 @@
use axum::{ use axum::{
Json, Router,
extract::{Path, State}, extract::{Path, State},
http::StatusCode, http::StatusCode,
routing::post, routing::post,
Json, Router,
}; };
use axum_login::AuthUser as _; use axum_login::AuthUser as _;
use easytier::proto::rpc_types::controller::BaseController; use easytier::proto::rpc_types::controller::BaseController;
use crate::db::UserIdInDb; use crate::db::UserIdInDb;
use super::{other_error, AppState, HttpHandleError}; use super::{AppState, HttpHandleError, other_error};
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
pub struct ProxyRpcRequest { pub struct ProxyRpcRequest {
@@ -120,7 +120,7 @@ async fn handle_proxy_rpc_by_session(
return Err(( return Err((
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
other_error(format!("Unknown service: {}", service_name)).into(), other_error(format!("Unknown service: {}", service_name)).into(),
)) ));
} }
}; };
+3 -3
View File
@@ -39,9 +39,9 @@ impl AuthUser for User {
fn session_auth_hash(&self) -> &[u8] { fn session_auth_hash(&self) -> &[u8] {
self.db_user.password.as_bytes() // We use the password hash as the auth self.db_user.password.as_bytes() // We use the password hash as the auth
// hash--what this means // hash--what this means
// is when the user changes their password the // is when the user changes their password the
// auth session becomes invalid. // auth session becomes invalid.
} }
} }
+2 -1
View File
@@ -1,8 +1,9 @@
use axum::{ use axum::{
Router,
extract::State, extract::State,
http::header, http::header,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing, Router, routing,
}; };
use axum_embed::ServeEmbed; use axum_embed::ServeEmbed;
use easytier::common::scoped_task::ScopedTask; use easytier::common::scoped_task::ScopedTask;
+2 -2
View File
@@ -4,11 +4,11 @@ description = "A full meshed p2p VPN, connecting all your devices in one network
homepage = "https://github.com/EasyTier/EasyTier" homepage = "https://github.com/EasyTier/EasyTier"
repository = "https://github.com/EasyTier/EasyTier" repository = "https://github.com/EasyTier/EasyTier"
version = "2.6.0" version = "2.6.0"
edition = "2021" edition.workspace = true
rust-version.workspace = true
authors = ["kkrainbow"] authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"] keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"] categories = ["network-programming", "command-line-utilities"]
rust-version = "1.93.0"
license-file = "LICENSE" license-file = "LICENSE"
readme = "README.md" readme = "README.md"
+3 -1
View File
@@ -86,7 +86,9 @@ impl WindowsBuild {
} else { } else {
Self::download_protoc() Self::download_protoc()
}; };
std::env::set_var("PROTOC", protoc_path); unsafe {
std::env::set_var("PROTOC", protoc_path);
}
} }
} }
+4 -4
View File
@@ -3,7 +3,6 @@ use std::{io, mem::ManuallyDrop, net::SocketAddr, os::windows::io::AsRawSocket};
use anyhow::Context; use anyhow::Context;
use network_interface::NetworkInterfaceConfig; use network_interface::NetworkInterfaceConfig;
use windows::{ use windows::{
core::BSTR,
Win32::{ Win32::{
Foundation::{BOOL, FALSE}, Foundation::{BOOL, FALSE},
NetworkManagement::WindowsFirewall::{ NetworkManagement::WindowsFirewall::{
@@ -12,15 +11,16 @@ use windows::{
NET_FW_RULE_DIR_OUT, NET_FW_RULE_DIR_OUT,
}, },
Networking::WinSock::{ Networking::WinSock::{
htonl, setsockopt, WSAGetLastError, WSAIoctl, IPPROTO_IP, IPPROTO_IPV6, IP_UNICAST_IF, IPPROTO_IP, IPPROTO_IPV6, IPV6_UNICAST_IF, SIO_UDP_CONNRESET, SOCKET,
IPV6_UNICAST_IF, IP_UNICAST_IF, SIO_UDP_CONNRESET, SOCKET, SOCKET_ERROR, SOCKET_ERROR, WSAGetLastError, WSAIoctl, htonl, setsockopt,
}, },
System::Com::{ System::Com::{
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED, CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize,
}, },
System::Ole::{SafeArrayCreateVector, SafeArrayPutElement}, System::Ole::{SafeArrayCreateVector, SafeArrayPutElement},
System::Variant::{VARENUM, VARIANT, VT_ARRAY, VT_BSTR, VT_VARIANT}, System::Variant::{VARENUM, VARIANT, VT_ARRAY, VT_BSTR, VT_VARIANT},
}, },
core::BSTR,
}; };
pub fn disable_connection_reset<S: AsRawSocket>(socket: &S) -> io::Result<()> { pub fn disable_connection_reset<S: AsRawSocket>(socket: &S) -> io::Result<()> {
+19 -19
View File
@@ -507,7 +507,7 @@ impl AclProcessor {
matched_rule: Some(RuleId::Default), matched_rule: Some(RuleId::Default),
should_log: false, should_log: false,
log_context: Some(AclLogContext::UnsupportedChainType), log_context: Some(AclLogContext::UnsupportedChainType),
} };
} }
}; };
@@ -679,28 +679,28 @@ impl AclProcessor {
} }
// Source port check // Source port check
if let Some(src_port) = packet_info.src_port { if let Some(src_port) = packet_info.src_port
if !rule.src_port_ranges.is_empty() { && !rule.src_port_ranges.is_empty()
let matches = rule {
.src_port_ranges let matches = rule
.iter() .src_port_ranges
.any(|(start, end)| src_port >= *start && src_port <= *end); .iter()
if !matches { .any(|(start, end)| src_port >= *start && src_port <= *end);
return false; if !matches {
} return false;
} }
} }
// Destination port check // Destination port check
if let Some(dst_port) = packet_info.dst_port { if let Some(dst_port) = packet_info.dst_port
if !rule.dst_port_ranges.is_empty() { && !rule.dst_port_ranges.is_empty()
let matches = rule {
.dst_port_ranges let matches = rule
.iter() .dst_port_ranges
.any(|(start, end)| dst_port >= *start && dst_port <= *end); .iter()
if !matches { .any(|(start, end)| dst_port >= *start && dst_port <= *end);
return false; if !matches {
} return false;
} }
} }
+1 -1
View File
@@ -9,7 +9,7 @@ use zstd::bulk;
use zerocopy::{AsBytes as _, FromBytes as _}; use zerocopy::{AsBytes as _, FromBytes as _};
use crate::tunnel::packet_def::{CompressorAlgo, CompressorTail, ZCPacket, COMPRESSOR_TAIL_SIZE}; use crate::tunnel::packet_def::{COMPRESSOR_TAIL_SIZE, CompressorAlgo, CompressorTail, ZCPacket};
type Error = anyhow::Error; type Error = anyhow::Error;
+66 -63
View File
@@ -6,10 +6,10 @@ use std::{
}; };
use anyhow::Context; use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine as _}; use base64::{Engine as _, prelude::BASE64_STANDARD};
use cfg_if::cfg_if; use cfg_if::cfg_if;
use clap::builder::PossibleValue;
use clap::ValueEnum; use clap::ValueEnum;
use clap::builder::PossibleValue;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use strum::{Display, EnumString, VariantArray}; use strum::{Display, EnumString, VariantArray};
use tokio::io::AsyncReadExt as _; use tokio::io::AsyncReadExt as _;
@@ -621,14 +621,14 @@ impl ConfigLoader for TomlConfigLoader {
if locked_config.proxy_network.is_none() { if locked_config.proxy_network.is_none() {
locked_config.proxy_network = Some(vec![]); locked_config.proxy_network = Some(vec![]);
} }
if let Some(mapped_cidr) = mapped_cidr.as_ref() { if let Some(mapped_cidr) = mapped_cidr.as_ref()
if cidr.network_length() != mapped_cidr.network_length() { && cidr.network_length() != mapped_cidr.network_length()
return Err(anyhow::anyhow!( {
"Mapped CIDR must have the same network length as the original CIDR: {} != {}", return Err(anyhow::anyhow!(
cidr.network_length(), "Mapped CIDR must have the same network length as the original CIDR: {} != {}",
mapped_cidr.network_length() cidr.network_length(),
)); mapped_cidr.network_length()
} ));
} }
// insert if no duplicate // insert if no duplicate
if !locked_config if !locked_config
@@ -881,10 +881,10 @@ impl ConfigLoader for TomlConfigLoader {
let mut flag_map: serde_json::Map<String, serde_json::Value> = Default::default(); let mut flag_map: serde_json::Map<String, serde_json::Value> = Default::default();
for (key, value) in default_flags_hashmap { for (key, value) in default_flags_hashmap {
if let Some(v) = cur_flags_hashmap.get(&key) { if let Some(v) = cur_flags_hashmap.get(&key)
if *v != value { && *v != value
flag_map.insert(key, v.clone()); {
} flag_map.insert(key, v.clone());
} }
} }
@@ -1089,6 +1089,7 @@ pub async fn load_config_from_file(
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*; use super::*;
use crate::tests::{remove_env_var, set_env_var};
use std::io::Write; use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
@@ -1212,8 +1213,8 @@ proto = "tcp"
#[tokio::test] #[tokio::test]
async fn test_env_var_expansion_and_readonly_flag() { async fn test_env_var_expansion_and_readonly_flag() {
// 设置测试环境变量 // 设置测试环境变量
std::env::set_var("TEST_SECRET", "my-test-secret-123"); set_env_var("TEST_SECRET", "my-test-secret-123");
std::env::set_var("TEST_NETWORK", "test-network"); set_env_var("TEST_NETWORK", "test-network");
// 创建临时配置文件,包含环境变量占位符 // 创建临时配置文件,包含环境变量占位符
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
@@ -1253,8 +1254,8 @@ network_secret = "${TEST_SECRET}"
); );
// 清理环境变量 // 清理环境变量
std::env::remove_var("TEST_SECRET"); remove_env_var("TEST_SECRET");
std::env::remove_var("TEST_NETWORK"); remove_env_var("TEST_NETWORK");
} }
/// RPC API 安全测试(只读配置保护) /// RPC API 安全测试(只读配置保护)
@@ -1267,7 +1268,7 @@ network_secret = "${TEST_SECRET}"
/// `easytier/src/rpc_service/instance_manage.rs` 中实现 /// `easytier/src/rpc_service/instance_manage.rs` 中实现
#[tokio::test] #[tokio::test]
async fn test_readonly_config_api_protection() { async fn test_readonly_config_api_protection() {
std::env::set_var("API_TEST_SECRET", "secret-value"); set_env_var("API_TEST_SECRET", "secret-value");
// 创建包含环境变量的配置 // 创建包含环境变量的配置
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
@@ -1298,7 +1299,7 @@ network_secret = "${API_TEST_SECRET}"
"Permission flag should be set correctly" "Permission flag should be set correctly"
); );
std::env::remove_var("API_TEST_SECRET"); remove_env_var("API_TEST_SECRET");
} }
/// CLI 参数测试(--disable-env-parsing 开关) /// CLI 参数测试(--disable-env-parsing 开关)
@@ -1308,7 +1309,7 @@ network_secret = "${API_TEST_SECRET}"
/// - 配置不会被标记为只读 /// - 配置不会被标记为只读
#[tokio::test] #[tokio::test]
async fn test_disable_env_parsing_flag() { async fn test_disable_env_parsing_flag() {
std::env::set_var("DISABLED_TEST_VAR", "should-not-expand"); set_env_var("DISABLED_TEST_VAR", "should-not-expand");
// 创建包含环境变量占位符的配置 // 创建包含环境变量占位符的配置
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
@@ -1346,7 +1347,7 @@ network_secret = "${DISABLED_TEST_VAR}"
"Config should be NO_DELETE due to no config_dir, not env vars" "Config should be NO_DELETE due to no config_dir, not env vars"
); );
std::env::remove_var("DISABLED_TEST_VAR"); remove_env_var("DISABLED_TEST_VAR");
} }
/// 多实例隔离测试 /// 多实例隔离测试
@@ -1357,8 +1358,8 @@ network_secret = "${DISABLED_TEST_VAR}"
#[tokio::test] #[tokio::test]
async fn test_multiple_instances_with_different_env_vars() { async fn test_multiple_instances_with_different_env_vars() {
// 实例1:使用第一组环境变量 // 实例1:使用第一组环境变量
std::env::set_var("INSTANCE_SECRET", "instance1-secret"); set_env_var("INSTANCE_SECRET", "instance1-secret");
std::env::set_var("INSTANCE_NAME", "instance-one"); set_env_var("INSTANCE_NAME", "instance-one");
let mut temp_file1 = NamedTempFile::new().unwrap(); let mut temp_file1 = NamedTempFile::new().unwrap();
let config_content = r#" let config_content = r#"
@@ -1388,8 +1389,8 @@ network_secret = "${INSTANCE_SECRET}"
); );
// 实例2:修改环境变量后加载同一模板 // 实例2:修改环境变量后加载同一模板
std::env::set_var("INSTANCE_SECRET", "instance2-secret"); set_env_var("INSTANCE_SECRET", "instance2-secret");
std::env::set_var("INSTANCE_NAME", "instance-two"); set_env_var("INSTANCE_NAME", "instance-two");
let mut temp_file2 = NamedTempFile::new().unwrap(); let mut temp_file2 = NamedTempFile::new().unwrap();
temp_file2.write_all(config_content.as_bytes()).unwrap(); temp_file2.write_all(config_content.as_bytes()).unwrap();
@@ -1419,8 +1420,8 @@ network_secret = "${INSTANCE_SECRET}"
); );
// 清理 // 清理
std::env::remove_var("INSTANCE_SECRET"); remove_env_var("INSTANCE_SECRET");
std::env::remove_var("INSTANCE_NAME"); remove_env_var("INSTANCE_NAME");
} }
/// 实际配置字段测试(network_secret、peer.uri 等) /// 实际配置字段测试(network_secret、peer.uri 等)
@@ -1433,11 +1434,11 @@ network_secret = "${INSTANCE_SECRET}"
#[tokio::test] #[tokio::test]
async fn test_real_config_fields_expansion() { async fn test_real_config_fields_expansion() {
// 设置各种实际场景的环境变量 // 设置各种实际场景的环境变量
std::env::set_var("ET_SECRET", "production-secret-key"); set_env_var("ET_SECRET", "production-secret-key");
std::env::set_var("PEER_HOST", "peer.example.com"); set_env_var("PEER_HOST", "peer.example.com");
std::env::set_var("PEER_PORT", "11011"); set_env_var("PEER_PORT", "11011");
std::env::set_var("LISTEN_PORT", "11010"); set_env_var("LISTEN_PORT", "11010");
std::env::set_var("NETWORK_NAME", "prod-network"); set_env_var("NETWORK_NAME", "prod-network");
// 创建包含多个实际字段的完整配置 // 创建包含多个实际字段的完整配置
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
@@ -1485,11 +1486,11 @@ uri = "tcp://${PEER_HOST}:${PEER_PORT}"
assert!(control.is_no_delete()); assert!(control.is_no_delete());
// 清理环境变量 // 清理环境变量
std::env::remove_var("ET_SECRET"); remove_env_var("ET_SECRET");
std::env::remove_var("PEER_HOST"); remove_env_var("PEER_HOST");
std::env::remove_var("PEER_PORT"); remove_env_var("PEER_PORT");
std::env::remove_var("LISTEN_PORT"); remove_env_var("LISTEN_PORT");
std::env::remove_var("NETWORK_NAME"); remove_env_var("NETWORK_NAME");
} }
/// 带默认值的环境变量 /// 带默认值的环境变量
@@ -1499,8 +1500,8 @@ uri = "tcp://${PEER_HOST}:${PEER_PORT}"
#[tokio::test] #[tokio::test]
async fn test_env_var_with_default_value() { async fn test_env_var_with_default_value() {
// 确保变量未定义 // 确保变量未定义
std::env::remove_var("UNDEFINED_PORT"); remove_env_var("UNDEFINED_PORT");
std::env::remove_var("UNDEFINED_SECRET"); remove_env_var("UNDEFINED_SECRET");
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
let config_content = r#" let config_content = r#"
@@ -1541,7 +1542,7 @@ network_secret = "${UNDEFINED_SECRET:-default-secret}"
/// - 未定义的环境变量保持原样(shellexpand 的默认行为) /// - 未定义的环境变量保持原样(shellexpand 的默认行为)
#[tokio::test] #[tokio::test]
async fn test_undefined_env_var_without_default() { async fn test_undefined_env_var_without_default() {
std::env::remove_var("COMPLETELY_UNDEFINED"); remove_env_var("COMPLETELY_UNDEFINED");
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
let config_content = r#" let config_content = r#"
@@ -1571,6 +1572,8 @@ network_secret = "${COMPLETELY_UNDEFINED}"
// 注意:由于没有实际替换发生,控制标记不应因环境变量而设置 // 注意:由于没有实际替换发生,控制标记不应因环境变量而设置
// 但会因为其他原因(如没有 config_dir)被标记为 NO_DELETE // 但会因为其他原因(如没有 config_dir)被标记为 NO_DELETE
// 这里我们主要验证 NO_DELETE 标记的逻辑
// 由于没有 config_dir,文件会被标记为 NO_DELETE,但不是因为环境变量
assert!(control.is_no_delete()); assert!(control.is_no_delete());
} }
@@ -1582,9 +1585,9 @@ network_secret = "${COMPLETELY_UNDEFINED}"
#[tokio::test] #[tokio::test]
async fn test_boolean_type_env_vars() { async fn test_boolean_type_env_vars() {
// 设置布尔类型的环境变量 // 设置布尔类型的环境变量
std::env::set_var("ENABLE_DHCP", "true"); set_env_var("ENABLE_DHCP", "true");
std::env::set_var("ENABLE_ENCRYPTION", "false"); set_env_var("ENABLE_ENCRYPTION", "false");
std::env::set_var("ENABLE_IPV6", "true"); set_env_var("ENABLE_IPV6", "true");
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
let config_content = r#" let config_content = r#"
@@ -1622,9 +1625,9 @@ enable_ipv6 = ${ENABLE_IPV6}
assert!(control.is_no_delete()); assert!(control.is_no_delete());
// 清理 // 清理
std::env::remove_var("ENABLE_DHCP"); remove_env_var("ENABLE_DHCP");
std::env::remove_var("ENABLE_ENCRYPTION"); remove_env_var("ENABLE_ENCRYPTION");
std::env::remove_var("ENABLE_IPV6"); remove_env_var("ENABLE_IPV6");
} }
/// 数字类型环境变量 /// 数字类型环境变量
@@ -1635,8 +1638,8 @@ enable_ipv6 = ${ENABLE_IPV6}
#[tokio::test] #[tokio::test]
async fn test_numeric_type_env_vars() { async fn test_numeric_type_env_vars() {
// 设置数字类型的环境变量 // 设置数字类型的环境变量
std::env::set_var("MTU_VALUE", "1400"); set_env_var("MTU_VALUE", "1400");
std::env::set_var("THREAD_COUNT", "4"); set_env_var("THREAD_COUNT", "4");
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
let config_content = r#" let config_content = r#"
@@ -1671,8 +1674,8 @@ multi_thread_count = ${THREAD_COUNT}
assert!(control.is_no_delete()); assert!(control.is_no_delete());
// 清理 // 清理
std::env::remove_var("MTU_VALUE"); remove_env_var("MTU_VALUE");
std::env::remove_var("THREAD_COUNT"); remove_env_var("THREAD_COUNT");
} }
/// 混合类型环境变量 /// 混合类型环境变量
@@ -1684,12 +1687,12 @@ multi_thread_count = ${THREAD_COUNT}
#[tokio::test] #[tokio::test]
async fn test_mixed_type_env_vars() { async fn test_mixed_type_env_vars() {
// 设置不同类型的环境变量 // 设置不同类型的环境变量
std::env::set_var("MIXED_SECRET", "mixed-secret-key"); set_env_var("MIXED_SECRET", "mixed-secret-key");
std::env::set_var("MIXED_NETWORK", "production"); set_env_var("MIXED_NETWORK", "production");
std::env::set_var("MIXED_DHCP", "true"); set_env_var("MIXED_DHCP", "true");
std::env::set_var("MIXED_MTU", "1500"); set_env_var("MIXED_MTU", "1500");
std::env::set_var("MIXED_ENCRYPTION", "false"); set_env_var("MIXED_ENCRYPTION", "false");
std::env::set_var("MIXED_LISTEN_PORT", "12345"); set_env_var("MIXED_LISTEN_PORT", "12345");
let mut temp_file = NamedTempFile::new().unwrap(); let mut temp_file = NamedTempFile::new().unwrap();
let config_content = r#" let config_content = r#"
@@ -1741,11 +1744,11 @@ enable_encryption = ${MIXED_ENCRYPTION}
assert!(control.is_no_delete()); assert!(control.is_no_delete());
// 清理 // 清理
std::env::remove_var("MIXED_SECRET"); remove_env_var("MIXED_SECRET");
std::env::remove_var("MIXED_NETWORK"); remove_env_var("MIXED_NETWORK");
std::env::remove_var("MIXED_DHCP"); remove_env_var("MIXED_DHCP");
std::env::remove_var("MIXED_MTU"); remove_env_var("MIXED_MTU");
std::env::remove_var("MIXED_ENCRYPTION"); remove_env_var("MIXED_ENCRYPTION");
std::env::remove_var("MIXED_LISTEN_PORT"); remove_env_var("MIXED_LISTEN_PORT");
} }
} }
+1 -1
View File
@@ -1,6 +1,6 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::Context; use anyhow::Context;
use hickory_proto::runtime::TokioRuntimeProvider; use hickory_proto::runtime::TokioRuntimeProvider;
+21 -20
View File
@@ -42,10 +42,11 @@ pub fn expand_env_vars(text: &str) -> (String, bool) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::tests::{remove_env_var, set_env_var};
#[test] #[test]
fn test_expand_standard_syntax() { fn test_expand_standard_syntax() {
std::env::set_var("TEST_VAR_STANDARD", "test_value"); set_env_var("TEST_VAR_STANDARD", "test_value");
let (result, changed) = expand_env_vars("secret=${TEST_VAR_STANDARD}"); let (result, changed) = expand_env_vars("secret=${TEST_VAR_STANDARD}");
assert_eq!(result, "secret=test_value"); assert_eq!(result, "secret=test_value");
assert!(changed); assert!(changed);
@@ -53,7 +54,7 @@ mod tests {
#[test] #[test]
fn test_expand_short_syntax() { fn test_expand_short_syntax() {
std::env::set_var("TEST_VAR_SHORT", "short_value"); set_env_var("TEST_VAR_SHORT", "short_value");
let (result, changed) = expand_env_vars("key=$TEST_VAR_SHORT"); let (result, changed) = expand_env_vars("key=$TEST_VAR_SHORT");
assert_eq!(result, "key=short_value"); assert_eq!(result, "key=short_value");
assert!(changed); assert!(changed);
@@ -62,7 +63,7 @@ mod tests {
#[test] #[test]
fn test_expand_with_default() { fn test_expand_with_default() {
// 确保变量未定义 // 确保变量未定义
std::env::remove_var("UNDEFINED_VAR_WITH_DEFAULT"); remove_env_var("UNDEFINED_VAR_WITH_DEFAULT");
let (result, changed) = expand_env_vars("port=${UNDEFINED_VAR_WITH_DEFAULT:-8080}"); let (result, changed) = expand_env_vars("port=${UNDEFINED_VAR_WITH_DEFAULT:-8080}");
assert_eq!(result, "port=8080"); assert_eq!(result, "port=8080");
assert!(changed); assert!(changed);
@@ -84,8 +85,8 @@ mod tests {
#[test] #[test]
fn test_multiple_vars() { fn test_multiple_vars() {
std::env::set_var("VAR1", "value1"); set_env_var("VAR1", "value1");
std::env::set_var("VAR2", "value2"); set_env_var("VAR2", "value2");
let (result, changed) = expand_env_vars("${VAR1} and ${VAR2}"); let (result, changed) = expand_env_vars("${VAR1} and ${VAR2}");
assert_eq!(result, "value1 and value2"); assert_eq!(result, "value1 and value2");
assert!(changed); assert!(changed);
@@ -94,7 +95,7 @@ mod tests {
#[test] #[test]
fn test_undefined_var_without_default() { fn test_undefined_var_without_default() {
// 确保变量未定义 // 确保变量未定义
std::env::remove_var("COMPLETELY_UNDEFINED_VAR"); remove_env_var("COMPLETELY_UNDEFINED_VAR");
let (result, changed) = expand_env_vars("value=${COMPLETELY_UNDEFINED_VAR}"); let (result, changed) = expand_env_vars("value=${COMPLETELY_UNDEFINED_VAR}");
// shellexpand::env 对未定义的变量会保持原样 // shellexpand::env 对未定义的变量会保持原样
assert_eq!(result, "value=${COMPLETELY_UNDEFINED_VAR}"); assert_eq!(result, "value=${COMPLETELY_UNDEFINED_VAR}");
@@ -103,8 +104,8 @@ mod tests {
#[test] #[test]
fn test_complex_toml_config() { fn test_complex_toml_config() {
std::env::set_var("ET_SECRET", "my-secret-key"); set_env_var("ET_SECRET", "my-secret-key");
std::env::set_var("ET_PORT", "11010"); set_env_var("ET_PORT", "11010");
let config = r#" let config = r#"
[network_identity] [network_identity]
@@ -123,7 +124,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_escape_syntax_double_dollar() { fn test_escape_syntax_double_dollar() {
std::env::set_var("ESCAPED_VAR", "should_not_expand"); set_env_var("ESCAPED_VAR", "should_not_expand");
// shellexpand 使用 $$ 作为转义序列,表示字面量的单个 $ // shellexpand 使用 $$ 作为转义序列,表示字面量的单个 $
// $$ 会被转义为单个 $,不会触发变量扩展 // $$ 会被转义为单个 $,不会触发变量扩展
let (result, changed) = expand_env_vars("value=$${ESCAPED_VAR}"); let (result, changed) = expand_env_vars("value=$${ESCAPED_VAR}");
@@ -133,7 +134,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_escape_syntax_backslash() { fn test_escape_syntax_backslash() {
std::env::set_var("ESCAPED_VAR", "should_not_expand"); set_env_var("ESCAPED_VAR", "should_not_expand");
// shellexpand 中反斜杠转义的行为:\$ 会展开为 \<变量值> // shellexpand 中反斜杠转义的行为:\$ 会展开为 \<变量值>
// 这不是推荐的转义方式,此测试仅为记录实际行为 // 这不是推荐的转义方式,此测试仅为记录实际行为
let (result, changed) = expand_env_vars(r"value=\${ESCAPED_VAR}"); let (result, changed) = expand_env_vars(r"value=\${ESCAPED_VAR}");
@@ -143,7 +144,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_multiple_dollar_signs() { fn test_multiple_dollar_signs() {
std::env::set_var("TEST_VAR", "value"); set_env_var("TEST_VAR", "value");
// 测试多个连续的 $ 符号 // 测试多个连续的 $ 符号
let (result1, changed1) = expand_env_vars("$$"); let (result1, changed1) = expand_env_vars("$$");
assert_eq!(result1, "$"); assert_eq!(result1, "$");
@@ -161,7 +162,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_empty_var_value() { fn test_empty_var_value() {
std::env::set_var("EMPTY_VAR", ""); set_env_var("EMPTY_VAR", "");
let (result, changed) = expand_env_vars("value=${EMPTY_VAR}"); let (result, changed) = expand_env_vars("value=${EMPTY_VAR}");
// 变量存在但值为空 // 变量存在但值为空
assert_eq!(result, "value="); assert_eq!(result, "value=");
@@ -170,7 +171,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_default_with_special_chars() { fn test_default_with_special_chars() {
std::env::remove_var("UNDEFINED_SPECIAL"); remove_env_var("UNDEFINED_SPECIAL");
// 测试默认值包含冒号、等号、空格等特殊字符 // 测试默认值包含冒号、等号、空格等特殊字符
let (result, changed) = expand_env_vars("url=${UNDEFINED_SPECIAL:-http://localhost:8080}"); let (result, changed) = expand_env_vars("url=${UNDEFINED_SPECIAL:-http://localhost:8080}");
assert_eq!(result, "url=http://localhost:8080"); assert_eq!(result, "url=http://localhost:8080");
@@ -187,9 +188,9 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_var_name_with_numbers_underscores() { fn test_var_name_with_numbers_underscores() {
std::env::set_var("VAR_123", "num_value"); set_env_var("VAR_123", "num_value");
std::env::set_var("_VAR", "underscore_prefix"); set_env_var("_VAR", "underscore_prefix");
std::env::set_var("VAR_", "underscore_suffix"); set_env_var("VAR_", "underscore_suffix");
let (result1, changed1) = expand_env_vars("${VAR_123}"); let (result1, changed1) = expand_env_vars("${VAR_123}");
assert_eq!(result1, "num_value"); assert_eq!(result1, "num_value");
@@ -214,7 +215,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
// 注意:未闭合的 ${VAR 实际上 shellexpand 会当作普通文本处理 // 注意:未闭合的 ${VAR 实际上 shellexpand 会当作普通文本处理
// 它会尝试查找名为 "VAR" 的环境变量(到字符串末尾) // 它会尝试查找名为 "VAR" 的环境变量(到字符串末尾)
std::env::remove_var("VAR"); remove_env_var("VAR");
let (result2, _changed2) = expand_env_vars("incomplete ${VAR"); let (result2, _changed2) = expand_env_vars("incomplete ${VAR");
// 如果 VAR 未定义,shellexpand 会返回错误或保持原样 // 如果 VAR 未定义,shellexpand 会返回错误或保持原样
assert_eq!(result2, "incomplete ${VAR"); assert_eq!(result2, "incomplete ${VAR");
@@ -224,8 +225,8 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_mixed_defined_undefined_vars() { fn test_mixed_defined_undefined_vars() {
std::env::set_var("DEFINED_VAR", "defined"); set_env_var("DEFINED_VAR", "defined");
std::env::remove_var("UNDEFINED_VAR"); remove_env_var("UNDEFINED_VAR");
// 混合已定义和未定义的变量 // 混合已定义和未定义的变量
// shellexpand::env 在遇到未定义变量时会返回错误(默认行为) // shellexpand::env 在遇到未定义变量时会返回错误(默认行为)
@@ -237,7 +238,7 @@ uri = "tcp://127.0.0.1:${ET_PORT}"
#[test] #[test]
fn test_nested_braces() { fn test_nested_braces() {
std::env::set_var("OUTER", "outer_value"); set_env_var("OUTER", "outer_value");
// 嵌套的大括号是无效语法,shellexpand::env 会返回错误 // 嵌套的大括号是无效语法,shellexpand::env 会返回错误
let (result, changed) = expand_env_vars("${OUTER} and ${{INNER}}"); let (result, changed) = expand_env_vars("${OUTER} and ${{INNER}}");
// 由于语法错误,整个字符串保持不变 // 由于语法错误,整个字符串保持不变
+2 -2
View File
@@ -1,5 +1,5 @@
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::{HashMap, hash_map::DefaultHasher},
hash::Hasher, hash::Hasher,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
sync::{Arc, Mutex}, sync::{Arc, Mutex},
@@ -10,11 +10,11 @@ use arc_swap::ArcSwap;
use dashmap::DashMap; use dashmap::DashMap;
use super::{ use super::{
PeerId,
config::{ConfigLoader, Flags}, config::{ConfigLoader, Flags},
netns::NetNS, netns::NetNS,
network::IPCollector, network::IPCollector,
stun::{StunInfoCollector, StunInfoCollectorTrait}, stun::{StunInfoCollector, StunInfoCollectorTrait},
PeerId,
}; };
use crate::{ use crate::{
common::{ common::{
+1 -1
View File
@@ -1,6 +1,6 @@
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use super::{cidr_to_subnet_mask, run_shell_cmd, Error, IfConfiguerTrait}; use super::{Error, IfConfiguerTrait, cidr_to_subnet_mask, run_shell_cmd};
use async_trait::async_trait; use async_trait::async_trait;
use cidr::{Ipv4Inet, Ipv6Inet}; use cidr::{Ipv4Inet, Ipv6Inet};
+6 -6
View File
@@ -10,27 +10,27 @@ use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use cidr::{IpInet, Ipv4Inet, Ipv6Inet}; use cidr::{IpInet, Ipv4Inet, Ipv6Inet};
use netlink_packet_core::{ use netlink_packet_core::{
NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REQUEST, NetlinkDeserializable,
NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REQUEST, NetlinkHeader, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
}; };
use netlink_packet_route::{ use netlink_packet_route::{
AddressFamily, RouteNetlinkMessage,
address::{AddressAttribute, AddressMessage}, address::{AddressAttribute, AddressMessage},
route::{ route::{
RouteAddress, RouteAttribute, RouteHeader, RouteMessage, RouteProtocol, RouteScope, RouteAddress, RouteAttribute, RouteHeader, RouteMessage, RouteProtocol, RouteScope,
RouteType, RouteType,
}, },
AddressFamily, RouteNetlinkMessage,
}; };
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; use netlink_sys::{Socket, SocketAddr, protocols::NETLINK_ROUTE};
use nix::{ use nix::{
ifaddrs::getifaddrs, ifaddrs::getifaddrs,
libc::{self, ifreq, ioctl, Ioctl, SIOCGIFFLAGS, SIOCGIFMTU, SIOCSIFFLAGS, SIOCSIFMTU}, libc::{self, Ioctl, SIOCGIFFLAGS, SIOCGIFMTU, SIOCSIFFLAGS, SIOCSIFMTU, ifreq, ioctl},
net::if_::InterfaceFlags, net::if_::InterfaceFlags,
sys::socket::SockaddrLike as _, sys::socket::SockaddrLike as _,
}; };
use pnet::ipnetwork::ip_mask_to_prefix; use pnet::ipnetwork::ip_mask_to_prefix;
use super::{route::Route, Error, IfConfiguerTrait}; use super::{Error, IfConfiguerTrait, route::Route};
pub(crate) fn dummy_socket() -> Result<std::net::UdpSocket, Error> { pub(crate) fn dummy_socket() -> Result<std::net::UdpSocket, Error> {
Ok(std::net::UdpSocket::bind("0:0")?) Ok(std::net::UdpSocket::bind("0:0")?)
+1 -5
View File
@@ -740,10 +740,6 @@ impl InterfaceLuid {
// SAFETY: TODO // SAFETY: TODO
let ret = unsafe { SetIpInterfaceEntry(&mut row) }; let ret = unsafe { SetIpInterfaceEntry(&mut row) };
if NO_ERROR == ret { if NO_ERROR == ret { Ok(()) } else { Err(ret) }
Ok(())
} else {
Err(ret)
}
} }
} }
+5 -5
View File
@@ -10,14 +10,14 @@ use std::{
}; };
use windows_sys::Win32::{ use windows_sys::Win32::{
Foundation::NO_ERROR, Foundation::NO_ERROR,
NetworkManagement::IpHelper::{GetIfEntry, SetIfEntry, MIB_IFROW}, NetworkManagement::IpHelper::{GetIfEntry, MIB_IFROW, SetIfEntry},
System::Diagnostics::Debug::{ System::Diagnostics::Debug::{
FormatMessageW, FORMAT_MESSAGE_FROM_SYSTEM, FORMAT_MESSAGE_IGNORE_INSERTS, FORMAT_MESSAGE_FROM_SYSTEM, FORMAT_MESSAGE_IGNORE_INSERTS, FormatMessageW,
}, },
}; };
use winreg::{ use winreg::{
enums::{HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE},
RegKey, RegKey,
enums::{HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE},
}; };
use super::{Error, IfConfiguerTrait}; use super::{Error, IfConfiguerTrait};
@@ -331,7 +331,7 @@ impl RegistryManager {
r"SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_"; r"SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_";
pub fn reg_delete_obsoleted_items(dev_name: &str) -> io::Result<()> { pub fn reg_delete_obsoleted_items(dev_name: &str) -> io::Result<()> {
use winreg::{enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS, RegKey}; use winreg::{RegKey, enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS};
let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
let profiles_key = hklm.open_subkey_with_flags( let profiles_key = hklm.open_subkey_with_flags(
"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles", "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles",
@@ -405,7 +405,7 @@ impl RegistryManager {
} }
pub fn reg_change_catrgory_in_profile(dev_name: &str) -> io::Result<()> { pub fn reg_change_catrgory_in_profile(dev_name: &str) -> io::Result<()> {
use winreg::{enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS, RegKey}; use winreg::{RegKey, enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS};
let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
let profiles_key = hklm.open_subkey_with_flags( let profiles_key = hklm.open_subkey_with_flags(
"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles", "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles",
+2 -2
View File
@@ -9,11 +9,11 @@ use paste::paste;
use regex::Regex; use regex::Regex;
use tracing::level_filters::LevelFilter; use tracing::level_filters::LevelFilter;
use tracing::{Level, Metadata}; use tracing::{Level, Metadata};
use tracing_subscriber::filter::{filter_fn, FilterExt}; use tracing_subscriber::Registry;
use tracing_subscriber::filter::{FilterExt, filter_fn};
use tracing_subscriber::fmt::layer; use tracing_subscriber::fmt::layer;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::Registry;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{EnvFilter, Layer};
macro_rules! __log__ { macro_rules! __log__ {
+6 -6
View File
@@ -41,8 +41,8 @@ pub fn get_logger_timer<F: time::formatting::Formattable>(
tracing_subscriber::fmt::time::OffsetTime::new(local_offset, format) tracing_subscriber::fmt::time::OffsetTime::new(local_offset, format)
} }
pub fn get_logger_timer_rfc3339( pub fn get_logger_timer_rfc3339()
) -> tracing_subscriber::fmt::time::OffsetTime<time::format_description::well_known::Rfc3339> { -> tracing_subscriber::fmt::time::OffsetTime<time::format_description::well_known::Rfc3339> {
get_logger_timer(time::format_description::well_known::Rfc3339) get_logger_timer(time::format_description::well_known::Rfc3339)
} }
@@ -117,10 +117,10 @@ pub fn get_machine_id() -> uuid::Uuid {
.unwrap_or_else(|_| std::path::PathBuf::from("et_machine_id")); .unwrap_or_else(|_| std::path::PathBuf::from("et_machine_id"));
// try load from local file // try load from local file
if let Ok(mid) = std::fs::read_to_string(&machine_id_file) { if let Ok(mid) = std::fs::read_to_string(&machine_id_file)
if let Ok(mid) = uuid::Uuid::parse_str(mid.trim()) { && let Ok(mid) = uuid::Uuid::parse_str(mid.trim())
return mid; {
} return mid;
} }
#[cfg(any( #[cfg(any(
+1 -1
View File
@@ -1,7 +1,7 @@
use futures::Future; use futures::Future;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
use nix::sched::{setns, CloneFlags}; use nix::sched::{CloneFlags, setns};
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
use std::os::fd::AsFd; use std::os::fd::AsFd;
+17 -7
View File
@@ -374,7 +374,9 @@ impl UnsafeCounter {
/// that no other thread is accessing this counter simultaneously. /// that no other thread is accessing this counter simultaneously.
pub unsafe fn add(&self, delta: u64) { pub unsafe fn add(&self, delta: u64) {
let ptr = self.value.get(); let ptr = self.value.get();
*ptr = (*ptr).saturating_add(delta); unsafe {
*ptr = (*ptr).saturating_add(delta);
}
} }
/// Increment the counter by 1 /// Increment the counter by 1
@@ -382,7 +384,9 @@ impl UnsafeCounter {
/// This method is unsafe because it uses UnsafeCell. The caller must ensure /// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is accessing this counter simultaneously. /// that no other thread is accessing this counter simultaneously.
pub unsafe fn inc(&self) { pub unsafe fn inc(&self) {
self.add(1); unsafe {
self.add(1);
}
} }
/// Get the current value of the counter /// Get the current value of the counter
@@ -391,7 +395,7 @@ impl UnsafeCounter {
/// that no other thread is modifying this counter simultaneously. /// that no other thread is modifying this counter simultaneously.
pub unsafe fn get(&self) -> u64 { pub unsafe fn get(&self) -> u64 {
let ptr = self.value.get(); let ptr = self.value.get();
*ptr unsafe { *ptr }
} }
/// Reset the counter to zero /// Reset the counter to zero
@@ -400,7 +404,9 @@ impl UnsafeCounter {
/// that no other thread is accessing this counter simultaneously. /// that no other thread is accessing this counter simultaneously.
pub unsafe fn reset(&self) { pub unsafe fn reset(&self) {
let ptr = self.value.get(); let ptr = self.value.get();
*ptr = 0; unsafe {
*ptr = 0;
}
} }
/// Set the counter to a specific value /// Set the counter to a specific value
@@ -409,7 +415,9 @@ impl UnsafeCounter {
/// that no other thread is accessing this counter simultaneously. /// that no other thread is accessing this counter simultaneously.
pub unsafe fn set(&self, value: u64) { pub unsafe fn set(&self, value: u64) {
let ptr = self.value.get(); let ptr = self.value.get();
*ptr = value; unsafe {
*ptr = value;
}
} }
} }
@@ -446,7 +454,9 @@ impl MetricData {
/// that no other thread is accessing this timestamp simultaneously. /// that no other thread is accessing this timestamp simultaneously.
unsafe fn touch(&self) { unsafe fn touch(&self) {
let ptr = self.last_updated.get(); let ptr = self.last_updated.get();
*ptr = Instant::now(); unsafe {
*ptr = Instant::now();
}
} }
/// Get the last updated timestamp /// Get the last updated timestamp
@@ -455,7 +465,7 @@ impl MetricData {
/// that no other thread is modifying this timestamp simultaneously. /// that no other thread is modifying this timestamp simultaneously.
unsafe fn get_last_updated(&self) -> Instant { unsafe fn get_last_updated(&self) -> Instant {
let ptr = self.last_updated.get(); let ptr = self.last_updated.get();
*ptr unsafe { *ptr }
} }
} }
+7 -7
View File
@@ -11,8 +11,8 @@ use crossbeam::atomic::AtomicCell;
use rand::seq::IteratorRandom; use rand::seq::IteratorRandom;
use socket2::{SockAddr, SockRef}; use socket2::{SockAddr, SockRef};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{lookup_host, UdpSocket}; use tokio::net::{UdpSocket, lookup_host};
use tokio::sync::{broadcast, Mutex}; use tokio::sync::{Mutex, broadcast};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tracing::{Instrument, Level}; use tracing::{Instrument, Level};
@@ -1340,7 +1340,7 @@ impl StunInfoCollectorTrait for MockStunInfoCollector {
mod tests { mod tests {
use crate::{ use crate::{
common::scoped_task::ScopedTask, common::scoped_task::ScopedTask,
tunnel::{udp::UdpTunnelListener, TunnelListener}, tunnel::{TunnelListener, udp::UdpTunnelListener},
}; };
use tokio::time::{sleep, timeout}; use tokio::time::{sleep, timeout};
@@ -1404,10 +1404,10 @@ mod tests {
loop { loop {
let ret = detector.detect_nat_type(0).await; let ret = detector.detect_nat_type(0).await;
println!("{:#?}, {:?}", ret, ret.as_ref().map(|x| x.nat_type())); println!("{:#?}, {:?}", ret, ret.as_ref().map(|x| x.nat_type()));
if let Ok(resp) = ret { if let Ok(resp) = ret
if !resp.stun_resps.is_empty() { && !resp.stun_resps.is_empty()
return; {
} return;
} }
sleep(Duration::from_secs(1)).await; sleep(Duration::from_secs(1)).await;
} }
+2 -2
View File
@@ -1,13 +1,13 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use bytecodec::fixnum::{U32beDecoder, U32beEncoder}; use bytecodec::fixnum::{U32beDecoder, U32beEncoder};
use stun_codec::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder}; use stun_codec::net::{SocketAddrDecoder, SocketAddrEncoder, socket_addr_xor};
use stun_codec::rfc5389::attributes::{ use stun_codec::rfc5389::attributes::{
MappedAddress, Software, XorMappedAddress, XorMappedAddress2, MappedAddress, Software, XorMappedAddress, XorMappedAddress2,
}; };
use stun_codec::rfc5780::attributes::{OtherAddress, ResponseOrigin}; use stun_codec::rfc5780::attributes::{OtherAddress, ResponseOrigin};
use stun_codec::{define_attribute_enums, AttributeType, Message, TransactionId}; use stun_codec::{AttributeType, Message, TransactionId, define_attribute_enums};
use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode}; use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode};
+1 -1
View File
@@ -231,7 +231,7 @@ mod tests {
}; };
use super::*; use super::*;
use tokio::time::{sleep, Duration}; use tokio::time::{Duration, sleep};
/// Test initial state after creation /// Test initial state after creation
#[tokio::test] #[tokio::test]
@@ -57,17 +57,16 @@ impl Default for RollingConditionBase {
impl RollingCondition for RollingConditionBase { impl RollingCondition for RollingConditionBase {
fn should_rollover(&mut self, now: &DateTime<Local>, current_filesize: u64) -> bool { fn should_rollover(&mut self, now: &DateTime<Local>, current_filesize: u64) -> bool {
let mut rollover = false; let mut rollover = false;
if let Some(frequency) = self.frequency_opt.as_ref() { if let Some(frequency) = self.frequency_opt.as_ref()
if let Some(last_write) = self.last_write_opt.as_ref() { && let Some(last_write) = self.last_write_opt.as_ref()
if frequency.equivalent_datetime(now) != frequency.equivalent_datetime(last_write) { && frequency.equivalent_datetime(now) != frequency.equivalent_datetime(last_write)
rollover = true; {
} rollover = true;
}
} }
if let Some(max_size) = self.max_size_opt.as_ref() { if let Some(max_size) = self.max_size_opt.as_ref()
if current_filesize >= *max_size { && current_filesize >= *max_size
rollover = true; {
} rollover = true;
} }
self.last_write_opt = Some(*now); self.last_write_opt = Some(*now);
rollover rollover
@@ -81,11 +81,7 @@ where
/// Determines the final filename, where n==0 indicates the current file /// Determines the final filename, where n==0 indicates the current file
fn filename_for(&self, n: usize) -> String { fn filename_for(&self, n: usize) -> String {
let f = self.filename.clone(); let f = self.filename.clone();
if n > 0 { if n > 0 { format!("{}.{}", f, n) } else { f }
format!("{}.{}", f, n)
} else {
f
}
} }
/// Rotates old files to make room for a new one. /// Rotates old files to make room for a new one.
@@ -145,14 +141,14 @@ where
/// Writes data using the given datetime to calculate the rolling condition /// Writes data using the given datetime to calculate the rolling condition
pub fn write_with_datetime(&mut self, buf: &[u8], now: &DateTime<Local>) -> io::Result<usize> { pub fn write_with_datetime(&mut self, buf: &[u8], now: &DateTime<Local>) -> io::Result<usize> {
if self.condition.should_rollover(now, self.current_filesize) { if self.condition.should_rollover(now, self.current_filesize)
if let Err(e) = self.rollover() { && let Err(e) = self.rollover()
// If we can't rollover, just try to continue writing anyway {
// (better than missing data). // If we can't rollover, just try to continue writing anyway
// This will likely used to implement logging, so // (better than missing data).
// avoid using log::warn and log to stderr directly // This will likely used to implement logging, so
eprintln!("WARNING: Failed to rotate logfile {}: {}", self.filename, e); // avoid using log::warn and log to stderr directly
} eprintln!("WARNING: Failed to rotate logfile {}: {}", self.filename, e);
} }
self.open_writer_if_needed()?; self.open_writer_if_needed()?;
if let Some(writer) = self.writer_opt.as_mut() { if let Some(writer) = self.writer_opt.as_mut() {
+15 -17
View File
@@ -5,16 +5,16 @@ use std::{
net::{IpAddr, Ipv6Addr, SocketAddr}, net::{IpAddr, Ipv6Addr, SocketAddr},
str::FromStr, str::FromStr,
sync::{ sync::{
atomic::{AtomicBool, Ordering},
Arc, Arc,
atomic::{AtomicBool, Ordering},
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use crate::{ use crate::{
common::{ common::{
dns::socket_addrs, error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId, dns::socket_addrs, error::Error, global_ctx::ArcGlobalCtx,
PeerId, stun::StunInfoCollectorTrait,
}, },
connector::udp_hole_punch::handle_rpc_result, connector::udp_hole_punch::handle_rpc_result,
peers::{ peers::{
@@ -31,7 +31,7 @@ use crate::{
}, },
rpc_types::controller::BaseController, rpc_types::controller::BaseController,
}, },
tunnel::{matches_protocol, udp::UdpTunnelConnector, IpVersion}, tunnel::{IpVersion, matches_protocol, udp::UdpTunnelConnector},
use_global_var, use_global_var,
}; };
@@ -39,7 +39,7 @@ use super::{
create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer, create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer,
udp_hole_punch, udp_hole_punch,
}; };
use crate::tunnel::{matches_scheme, FromUrl, IpScheme, TunnelScheme}; use crate::tunnel::{FromUrl, IpScheme, TunnelScheme, matches_scheme};
use anyhow::Context; use anyhow::Context;
use rand::Rng; use rand::Rng;
use socket2::Protocol; use socket2::Protocol;
@@ -769,12 +769,9 @@ mod tests {
let port = if proto == "wg" { 11040 } else { 11041 }; let port = if proto == "wg" { 11040 } else { 11041 };
if !ipv6 { if !ipv6 {
p_c.get_global_ctx().config.set_listeners(vec![format!( p_c.get_global_ctx().config.set_listeners(vec![
"{}://0.0.0.0:{}", format!("{}://0.0.0.0:{}", proto, port).parse().unwrap(),
proto, port ]);
)
.parse()
.unwrap()]);
} else { } else {
p_c.get_global_ctx() p_c.get_global_ctx()
.config .config
@@ -814,11 +811,12 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert!(data assert!(
.dst_listener_blacklist data.dst_listener_blacklist
.contains(&DstListenerUrlBlackListItem( .contains(&DstListenerUrlBlackListItem(
1, 1,
"tcp://127.0.0.1:10222".parse().unwrap() "tcp://127.0.0.1:10222".parse().unwrap()
))); ))
);
} }
} }
+2 -2
View File
@@ -3,7 +3,7 @@ use std::{net::SocketAddr, sync::Arc};
use super::{create_connector_by_url, http_connector::TunnelWithInfo}; use super::{create_connector_by_url, http_connector::TunnelWithInfo};
use crate::{ use crate::{
common::{ common::{
dns::{resolve_txt_record, RESOLVER}, dns::{RESOLVER, resolve_txt_record},
error::Error, error::Error,
global_ctx::ArcGlobalCtx, global_ctx::ArcGlobalCtx,
log, log,
@@ -14,7 +14,7 @@ use crate::{
use anyhow::Context; use anyhow::Context;
use dashmap::DashSet; use dashmap::DashSet;
use hickory_resolver::proto::rr::rdata::SRV; use hickory_resolver::proto::rr::rdata::SRV;
use rand::{seq::SliceRandom, Rng as _}; use rand::{Rng as _, seq::SliceRandom};
use strum::VariantArray; use strum::VariantArray;
fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> { fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
+2 -2
View File
@@ -10,9 +10,9 @@ use rand::seq::SliceRandom as _;
use url::Url; use url::Url;
use crate::{ use crate::{
VERSION,
common::{error::Error, global_ctx::ArcGlobalCtx}, common::{error::Error, global_ctx::ArcGlobalCtx},
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, ZCPacketSink, ZCPacketStream}, tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, ZCPacketSink, ZCPacketStream},
VERSION,
}; };
use crate::proto::common::TunnelInfo; use crate::proto::common::TunnelInfo;
@@ -257,7 +257,7 @@ mod tests {
use crate::{ use crate::{
common::global_ctx::tests::get_mock_global_ctx_with_network, common::global_ctx::tests::get_mock_global_ctx_with_network,
tunnel::{tcp::TcpTunnelListener, TunnelConnector, TunnelListener}, tunnel::{TunnelConnector, TunnelListener, tcp::TcpTunnelListener},
}; };
use super::*; use super::*;
+1 -1
View File
@@ -7,7 +7,7 @@ use dashmap::DashSet;
use tokio::{sync::mpsc, task::JoinSet, time::timeout}; use tokio::{sync::mpsc, task::JoinSet, time::timeout};
use crate::{ use crate::{
common::{dns::socket_addrs, join_joinset_background, PeerId}, common::{PeerId, dns::socket_addrs, join_joinset_background},
peers::peer_conn::PeerConnId, peers::peer_conn::PeerConnId,
proto::{ proto::{
api::instance::{ api::instance::{
+2 -2
View File
@@ -8,8 +8,8 @@ use crate::{
connector::dns_connector::DnsTunnelConnector, connector::dns_connector::DnsTunnelConnector,
proto::common::PeerFeatureFlag, proto::common::PeerFeatureFlag,
tunnel::{ tunnel::{
self, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, FromUrl, self, FromUrl, IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme,
IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector,
}, },
utils::BoxExt, utils::BoxExt,
}; };
+26 -20
View File
@@ -9,7 +9,7 @@ use rand::Rng as _;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use crate::{ use crate::{
common::{join_joinset_background, stun::StunInfoCollectorTrait, PeerId}, common::{PeerId, join_joinset_background, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::BackOff, connector::udp_hole_punch::BackOff,
peers::{ peers::{
peer_manager::PeerManager, peer_manager::PeerManager,
@@ -24,8 +24,8 @@ use crate::{
rpc_types::{self, controller::BaseController}, rpc_types::{self, controller::BaseController},
}, },
tunnel::{ tunnel::{
tcp::{TcpTunnelConnector, TcpTunnelListener},
TunnelConnector as _, TunnelListener as _, TunnelConnector as _, TunnelListener as _,
tcp::{TcpTunnelConnector, TcpTunnelListener},
}, },
}; };
@@ -719,18 +719,20 @@ mod tests {
tokio::time::sleep(Duration::from_secs(2)).await; tokio::time::sleep(Duration::from_secs(2)).await;
assert!(p_a assert!(
.get_peer_map() p_a.get_peer_map()
.list_peer_conns(p_c.my_peer_id()) .list_peer_conns(p_c.my_peer_id())
.await .await
.map(|c| c.is_empty()) .map(|c| c.is_empty())
.unwrap_or(true)); .unwrap_or(true)
assert!(p_c );
.get_peer_map() assert!(
.list_peer_conns(p_a.my_peer_id()) p_c.get_peer_map()
.await .list_peer_conns(p_a.my_peer_id())
.map(|c| c.is_empty()) .await
.unwrap_or(true)); .map(|c| c.is_empty())
.unwrap_or(true)
);
} }
#[tokio::test] #[tokio::test]
@@ -751,14 +753,18 @@ mod tests {
connect_peer_manager(p_b.clone(), p_c.clone()).await; connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
assert!(!collect_lazy_punch_peers(p_a.clone()) assert!(
.await !collect_lazy_punch_peers(p_a.clone())
.contains(&p_c.my_peer_id())); .await
.contains(&p_c.my_peer_id())
);
p_a.mark_recent_traffic(p_c.my_peer_id()); p_a.mark_recent_traffic(p_c.my_peer_id());
assert!(collect_lazy_punch_peers(p_a.clone()) assert!(
.await collect_lazy_punch_peers(p_a.clone())
.contains(&p_c.my_peer_id())); .await
.contains(&p_c.my_peer_id())
);
} }
} }
@@ -8,9 +8,9 @@ use anyhow::Context;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::{ use crate::{
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId}, common::{PeerId, scoped_task::ScopedTask, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::common::{ connector::udp_hole_punch::common::{
try_connect_with_socket, UdpHolePunchListener, HOLE_PUNCH_PACKET_BODY_LEN, HOLE_PUNCH_PACKET_BODY_LEN, UdpHolePunchListener, try_connect_with_socket,
}, },
connector::udp_hole_punch::handle_rpc_result, connector::udp_hole_punch::handle_rpc_result,
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
@@ -21,7 +21,7 @@ use crate::{
}, },
rpc_types::{self, controller::BaseController}, rpc_types::{self, controller::BaseController},
}, },
tunnel::{udp::new_hole_punch_packet, Tunnel}, tunnel::{Tunnel, udp::new_hole_punch_packet},
}; };
use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray}; use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray};
@@ -340,7 +340,7 @@ impl PunchBothEasySymHoleClient {
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use std::{ use std::{
sync::{atomic::AtomicU32, Arc}, sync::{Arc, atomic::AtomicU32},
time::Duration, time::Duration,
}; };
@@ -349,7 +349,7 @@ pub mod tests {
use crate::connector::udp_hole_punch::RUN_TESTING; use crate::connector::udp_hole_punch::RUN_TESTING;
use crate::{ use crate::{
connector::udp_hole_punch::{ connector::udp_hole_punch::{
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector, UdpHolePunchConnector, tests::create_mock_peer_manager_with_mock_stun,
}, },
peers::tests::{connect_peer_manager, wait_route_appear}, peers::tests::{connect_peer_manager, wait_route_appear},
proto::common::NatType, proto::common::NatType,
@@ -8,21 +8,21 @@ use crossbeam::atomic::AtomicCell;
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use rand::seq::SliceRandom as _; use rand::seq::SliceRandom as _;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tracing::{instrument, Instrument, Level}; use tracing::{Instrument, Level, instrument};
use zerocopy::FromBytes as _; use zerocopy::FromBytes as _;
use crate::{ use crate::{
common::{ common::{
error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, PeerId, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
stun::StunInfoCollectorTrait as _, PeerId, stun::StunInfoCollectorTrait as _,
}, },
defer, defer,
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
proto::common::NatType, proto::common::NatType,
tunnel::{ tunnel::{
packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE},
udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener},
Tunnel, TunnelConnCounter, TunnelListener as _, Tunnel, TunnelConnCounter, TunnelListener as _,
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, UdpPacketType},
udp::{UdpTunnelConnector, UdpTunnelListener, new_hole_punch_packet},
}, },
}; };
@@ -7,9 +7,9 @@ use anyhow::Context;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::{ use crate::{
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId}, common::{PeerId, scoped_task::ScopedTask, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::common::{ connector::udp_hole_punch::common::{
try_connect_with_socket, UdpSocketArray, HOLE_PUNCH_PACKET_BODY_LEN, HOLE_PUNCH_PACKET_BODY_LEN, UdpSocketArray, try_connect_with_socket,
}, },
connector::udp_hole_punch::handle_rpc_result, connector::udp_hole_punch::handle_rpc_result,
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
@@ -20,7 +20,7 @@ use crate::{
}, },
rpc_types::{self, controller::BaseController}, rpc_types::{self, controller::BaseController},
}, },
tunnel::{udp::new_hole_punch_packet, Tunnel}, tunnel::{Tunnel, udp::new_hole_punch_packet},
}; };
use super::common::PunchHoleServerCommon; use super::common::PunchHoleServerCommon;
@@ -249,7 +249,7 @@ pub mod tests {
use crate::{ use crate::{
connector::udp_hole_punch::{ connector::udp_hole_punch::{
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector, UdpHolePunchConnector, tests::create_mock_peer_manager_with_mock_stun,
}, },
peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost}, peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost},
proto::common::NatType, proto::common::NatType,
+13 -9
View File
@@ -1,5 +1,5 @@
use std::{ use std::{
sync::{atomic::AtomicBool, Arc}, sync::{Arc, atomic::AtomicBool},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@@ -13,7 +13,7 @@ use sym_to_cone::{PunchSymToConeHoleClient, PunchSymToConeHoleServer};
use tokio::{sync::Mutex, task::JoinHandle}; use tokio::{sync::Mutex, task::JoinHandle};
use crate::{ use crate::{
common::{stun::StunInfoCollectorTrait, PeerId}, common::{PeerId, stun::StunInfoCollectorTrait},
peers::{ peers::{
peer_manager::PeerManager, peer_manager::PeerManager,
peer_task::{PeerTaskLauncher, PeerTaskManager}, peer_task::{PeerTaskLauncher, PeerTaskManager},
@@ -601,7 +601,7 @@ pub mod tests {
use crate::proto::common::NatType; use crate::proto::common::NatType;
use crate::tunnel::common::tests::wait_for_condition; use crate::tunnel::common::tests::wait_for_condition;
use super::{UdpHolePunchConnector, UdpHolePunchPeerTaskLauncher, RUN_TESTING}; use super::{RUN_TESTING, UdpHolePunchConnector, UdpHolePunchPeerTaskLauncher};
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) { pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector { udp_nat_type }); let collector = Box::new(MockStunInfoCollector { udp_nat_type });
@@ -676,14 +676,18 @@ pub mod tests {
connect_peer_manager(p_b.clone(), p_c.clone()).await; connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
assert!(!collect_lazy_punch_peers(p_a.clone()) assert!(
.await !collect_lazy_punch_peers(p_a.clone())
.contains(&p_c.my_peer_id())); .await
.contains(&p_c.my_peer_id())
);
p_a.mark_recent_traffic(p_c.my_peer_id()); p_a.mark_recent_traffic(p_c.my_peer_id());
assert!(collect_lazy_punch_peers(p_a.clone()) assert!(
.await collect_lazy_punch_peers(p_a.clone())
.contains(&p_c.my_peer_id())); .await
.contains(&p_c.my_peer_id())
);
} }
} }
@@ -2,24 +2,24 @@ use std::{
net::Ipv4Addr, net::Ipv4Addr,
ops::{Div, Mul}, ops::{Div, Mul},
sync::{ sync::{
atomic::{AtomicBool, Ordering},
Arc, Arc,
atomic::{AtomicBool, Ordering},
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use anyhow::Context; use anyhow::Context;
use rand::{seq::SliceRandom, Rng}; use rand::{Rng, seq::SliceRandom};
use tokio::{net::UdpSocket, sync::RwLock}; use tokio::{net::UdpSocket, sync::RwLock};
use tracing::Level; use tracing::Level;
use crate::{ use crate::{
common::{ common::{
global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId, PeerId, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, stun::StunInfoCollectorTrait,
}, },
connector::udp_hole_punch::{ connector::udp_hole_punch::{
common::{ common::{
send_symmetric_hole_punch_packet, try_connect_with_socket, HOLE_PUNCH_PACKET_BODY_LEN, HOLE_PUNCH_PACKET_BODY_LEN, send_symmetric_hole_punch_packet, try_connect_with_socket,
}, },
handle_rpc_result, handle_rpc_result,
}, },
@@ -33,7 +33,7 @@ use crate::{
}, },
rpc_types::{self, controller::BaseController}, rpc_types::{self, controller::BaseController},
}, },
tunnel::{udp::new_hole_punch_packet, Tunnel}, tunnel::{Tunnel, udp::new_hole_punch_packet},
}; };
use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray}; use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray};
@@ -445,16 +445,15 @@ impl PunchSymToConeHoleClient {
))?; ))?;
// try direct connect first // try direct connect first
if self.try_direct_connect.load(Ordering::Relaxed) { if self.try_direct_connect.load(Ordering::Relaxed)
if let Ok(tunnel) = try_connect_with_socket( && let Ok(tunnel) = try_connect_with_socket(
global_ctx.clone(), global_ctx.clone(),
Arc::new(UdpSocket::bind("0.0.0.0:0").await?), Arc::new(UdpSocket::bind("0.0.0.0:0").await?),
remote_mapped_addr.into(), remote_mapped_addr.into(),
) )
.await .await
{ {
return Ok(Some(tunnel)); return Ok(Some(tunnel));
}
} }
let stun_info = global_ctx.get_stun_info_collector().get_stun_info(); let stun_info = global_ctx.get_stun_info_collector().get_stun_info();
@@ -467,7 +466,7 @@ impl PunchSymToConeHoleClient {
return Err(anyhow::anyhow!("failed to get public ips")); return Err(anyhow::anyhow!("failed to get public ips"));
} }
let tid = rand::thread_rng().gen(); let tid = rand::thread_rng().r#gen();
let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(); let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes();
udp_array.add_intreast_tid(tid); udp_array.add_intreast_tid(tid);
defer! { udp_array.remove_intreast_tid(tid);} defer! { udp_array.remove_intreast_tid(tid);}
@@ -544,7 +543,7 @@ impl PunchSymToConeHoleClient {
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use std::{ use std::{
sync::{atomic::AtomicU32, Arc}, sync::{Arc, atomic::AtomicU32},
time::Duration, time::Duration,
}; };
@@ -552,7 +551,7 @@ pub mod tests {
use crate::{ use crate::{
connector::udp_hole_punch::{ connector::udp_hole_punch::{
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector, RUN_TESTING, RUN_TESTING, UdpHolePunchConnector, tests::create_mock_peer_manager_with_mock_stun,
}, },
peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost}, peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost},
proto::common::NatType, proto::common::NatType,
@@ -617,7 +616,7 @@ pub mod tests {
.await .await
.is_ok() .is_ok()
}, },
Duration::from_secs(30), Duration::from_secs(60),
) )
.await; .await;
println!("{:?}", p_a.list_routes().await); println!("{:?}", p_a.list_routes().await);
+27 -26
View File
@@ -4,15 +4,16 @@ use std::{
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
path::PathBuf, path::PathBuf,
process::ExitCode, process::ExitCode,
sync::{atomic::AtomicBool, Arc}, sync::{Arc, atomic::AtomicBool},
}; };
use crate::{ use crate::{
ShellType,
common::{ common::{
config::{ config::{
load_config_from_file, process_secure_mode_cfg, ConfigFileControl, ConfigLoader, ConfigFileControl, ConfigLoader, ConsoleLoggerConfig, EncryptionAlgorithm,
ConsoleLoggerConfig, EncryptionAlgorithm, FileLoggerConfig, LoggingConfigLoader, FileLoggerConfig, LoggingConfigLoader, NetworkIdentity, PeerConfig, PortForwardConfig,
NetworkIdentity, PeerConfig, PortForwardConfig, TomlConfigLoader, VpnPortalConfig, TomlConfigLoader, VpnPortalConfig, load_config_from_file, process_secure_mode_cfg,
}, },
constants::EASYTIER_VERSION, constants::EASYTIER_VERSION,
log, log,
@@ -23,7 +24,7 @@ use crate::{
proto::common::{CompressionAlgoPb, SecureModeConfig}, proto::common::{CompressionAlgoPb, SecureModeConfig},
rpc_service::ApiRpcServer, rpc_service::ApiRpcServer,
utils::setup_panic_handler, utils::setup_panic_handler,
web_client, ShellType, web_client,
}; };
use anyhow::Context; use anyhow::Context;
use cidr::IpCidr; use cidr::IpCidr;
@@ -34,7 +35,7 @@ use tokio::io::AsyncReadExt;
use crate::tunnel::IpScheme; use crate::tunnel::IpScheme;
#[cfg(feature = "jemalloc-prof")] #[cfg(feature = "jemalloc-prof")]
use jemalloc_ctl::{epoch, stats, Access as _, AsName as _}; use jemalloc_ctl::{Access as _, AsName as _, epoch, stats};
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
windows_service::define_windows_service!(ffi_service_main, win_service_main); windows_service::define_windows_service!(ffi_service_main, win_service_main);
@@ -741,17 +742,17 @@ impl Cli {
let origin_listeners = listeners; let origin_listeners = listeners;
let mut listeners: Vec<String> = Vec::new(); let mut listeners: Vec<String> = Vec::new();
if origin_listeners.len() == 1 { if origin_listeners.len() == 1
if let Ok(port) = origin_listeners[0].parse::<u16>() { && let Ok(port) = origin_listeners[0].parse::<u16>()
for proto in IpScheme::VARIANTS { {
listeners.push(format!( for proto in IpScheme::VARIANTS {
"{}://0.0.0.0:{}", listeners.push(format!(
proto, "{}://0.0.0.0:{}",
port + proto.port_offset() proto,
)); port + proto.port_offset()
} ));
return Ok(listeners);
} }
return Ok(listeners);
} }
for l in &origin_listeners { for l in &origin_listeners {
@@ -994,15 +995,15 @@ impl NetworkOptions {
local_public_key: None, local_public_key: None,
}; };
cfg.set_secure_mode(Some(process_secure_mode_cfg(c)?)); cfg.set_secure_mode(Some(process_secure_mode_cfg(c)?));
} else if let Some(secure_mode) = self.secure_mode { } else if let Some(secure_mode) = self.secure_mode
if secure_mode { && secure_mode
let c = SecureModeConfig { {
enabled: secure_mode, let c = SecureModeConfig {
local_private_key: self.local_private_key.clone(), enabled: secure_mode,
local_public_key: self.local_public_key.clone(), local_private_key: self.local_private_key.clone(),
}; local_public_key: self.local_public_key.clone(),
cfg.set_secure_mode(Some(process_secure_mode_cfg(c)?)); };
} cfg.set_secure_mode(Some(process_secure_mode_cfg(c)?));
} }
let mut f = cfg.get_flags(); let mut f = cfg.get_flags();
@@ -1134,7 +1135,7 @@ impl LoggingConfigLoader for &LoggingOptions {
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn win_service_set_work_dir(service_name: &std::ffi::OsString) -> anyhow::Result<()> { fn win_service_set_work_dir(service_name: &std::ffi::OsString) -> anyhow::Result<()> {
use crate::common::constants::WIN_SERVICE_WORK_DIR_REG_KEY; use crate::common::constants::WIN_SERVICE_WORK_DIR_REG_KEY;
use winreg::{enums::*, RegKey}; use winreg::{RegKey, enums::*};
let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
let key = hklm.open_subkey_with_flags(WIN_SERVICE_WORK_DIR_REG_KEY, KEY_READ)?; let key = hklm.open_subkey_with_flags(WIN_SERVICE_WORK_DIR_REG_KEY, KEY_READ)?;
+16 -12
View File
@@ -12,8 +12,8 @@ use std::{
}; };
use anyhow::Context; use anyhow::Context;
use base64::prelude::BASE64_STANDARD;
use base64::Engine as _; use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use cidr::Ipv4Inet; use cidr::Ipv4Inet;
use clap::{Args, CommandFactory, Parser, Subcommand}; use clap::{Args, CommandFactory, Parser, Subcommand};
use dashmap::DashMap; use dashmap::DashMap;
@@ -21,8 +21,8 @@ use easytier::ShellType;
use humansize::format_size; use humansize::format_size;
use rust_i18n::t; use rust_i18n::t;
use service_manager::*; use service_manager::*;
use tabled::settings::{location::ByColumnName, object::Columns, Disable, Modify, Style, Width}; use tabled::settings::{Disable, Modify, Style, Width, location::ByColumnName, object::Columns};
use terminal_size::{terminal_size, Width as TerminalWidth}; use terminal_size::{Width as TerminalWidth, terminal_size};
use unicode_width::UnicodeWidthStr; use unicode_width::UnicodeWidthStr;
use easytier::service_manager::{Service, ServiceInstallOptions}; use easytier::service_manager::{Service, ServiceInstallOptions};
@@ -42,9 +42,7 @@ use easytier::{
InstanceConfigPatch, PatchConfigRequest, PortForwardPatch, StringPatch, UrlPatch, InstanceConfigPatch, PatchConfigRequest, PortForwardPatch, StringPatch, UrlPatch,
}, },
instance::{ instance::{
instance_identifier::{InstanceSelector, Selector}, AclManageRpc, AclManageRpcClientFactory, Connector, ConnectorManageRpc,
list_global_foreign_network_response, list_peer_route_pair, AclManageRpc,
AclManageRpcClientFactory, Connector, ConnectorManageRpc,
ConnectorManageRpcClientFactory, CredentialManageRpc, ConnectorManageRpcClientFactory, CredentialManageRpc,
CredentialManageRpcClientFactory, DumpRouteRequest, ForeignNetworkEntryPb, CredentialManageRpcClientFactory, DumpRouteRequest, ForeignNetworkEntryPb,
GenerateCredentialRequest, GetAclStatsRequest, GetPrometheusStatsRequest, GenerateCredentialRequest, GetAclStatsRequest, GetPrometheusStatsRequest,
@@ -60,6 +58,8 @@ use easytier::{
StatsRpc, StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType, StatsRpc, StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType,
TcpProxyRpc, TcpProxyRpcClientFactory, TrustedKeySourcePb, VpnPortalInfo, TcpProxyRpc, TcpProxyRpcClientFactory, TrustedKeySourcePb, VpnPortalInfo,
VpnPortalRpc, VpnPortalRpcClientFactory, VpnPortalRpc, VpnPortalRpcClientFactory,
instance_identifier::{InstanceSelector, Selector},
list_global_foreign_network_response, list_peer_route_pair,
}, },
logger::{ logger::{
GetLoggerConfigRequest, LogLevel, LoggerRpc, LoggerRpcClientFactory, GetLoggerConfigRequest, LogLevel, LoggerRpc, LoggerRpcClientFactory,
@@ -75,8 +75,8 @@ use easytier::{
rpc_impl::standalone::StandAloneClient, rpc_impl::standalone::StandAloneClient,
rpc_types::controller::BaseController, rpc_types::controller::BaseController,
}, },
tunnel::{tcp::TcpTunnelConnector, TunnelScheme}, tunnel::{TunnelScheme, tcp::TcpTunnelConnector},
utils::{cost_to_str, PeerRoutePair}, utils::{PeerRoutePair, cost_to_str},
}; };
rust_i18n::i18n!("locales", fallback = "en"); rust_i18n::i18n!("locales", fallback = "en");
@@ -1972,7 +1972,12 @@ impl<'a> CommandHandler<'a> {
"info" => LogLevel::Info, "info" => LogLevel::Info,
"debug" => LogLevel::Debug, "debug" => LogLevel::Debug,
"trace" => LogLevel::Trace, "trace" => LogLevel::Trace,
_ => return Err(anyhow::anyhow!("Invalid log level: {}. Valid levels are: disabled, error, warning, info, debug, trace", level)), _ => {
return Err(anyhow::anyhow!(
"Invalid log level: {}. Valid levels are: disabled, error, warning, info, debug, trace",
level
));
}
}; };
let client = self.get_logger_client().await?; let client = self.get_logger_client().await?;
@@ -2497,10 +2502,9 @@ fn header_indices(headers: &[String], names: &[&str]) -> Vec<usize> {
if let Some(index) = headers if let Some(index) = headers
.iter() .iter()
.position(|header| header.eq_ignore_ascii_case(name)) .position(|header| header.eq_ignore_ascii_case(name))
&& !indices.contains(&index)
{ {
if !indices.contains(&index) { indices.push(index);
indices.push(index);
}
} }
} }
indices indices
+2 -2
View File
@@ -13,12 +13,12 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[cfg(feature = "jemalloc-prof")] #[cfg(feature = "jemalloc-prof")]
#[allow(non_upper_case_globals)] #[allow(non_upper_case_globals)]
#[export_name = "malloc_conf"] #[unsafe(export_name = "malloc_conf")]
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:19,retain:false\0"; pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:19,retain:false\0";
#[cfg(not(feature = "jemalloc-prof"))] #[cfg(not(feature = "jemalloc-prof"))]
#[allow(non_upper_case_globals)] #[allow(non_upper_case_globals)]
#[export_name = "malloc_conf"] #[unsafe(export_name = "malloc_conf")]
pub static malloc_conf: &[u8] = b"retain:false\0"; pub static malloc_conf: &[u8] = b"retain:false\0";
rust_i18n::i18n!("locales", fallback = "en"); rust_i18n::i18n!("locales", fallback = "en");
+1 -1
View File
@@ -46,9 +46,9 @@ use anyhow::Context;
use std::fmt; use std::fmt;
use std::io; use std::io;
use thiserror::Error; use thiserror::Error;
use util::target_addr::read_address;
use util::target_addr::TargetAddr; use util::target_addr::TargetAddr;
use util::target_addr::ToTargetAddr; use util::target_addr::ToTargetAddr;
use util::target_addr::read_address;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
+3 -3
View File
@@ -1,10 +1,10 @@
use super::Socks5Command;
use super::new_udp_header; use super::new_udp_header;
use super::parse_udp_request; use super::parse_udp_request;
use super::read_exact; use super::read_exact;
use super::util::stream::tcp_connect_with_timeout; use super::util::stream::tcp_connect_with_timeout;
use super::util::target_addr::{read_address, TargetAddr}; use super::util::target_addr::{TargetAddr, read_address};
use super::Socks5Command; use super::{AuthenticationMethod, ReplyError, Result, SocksError, consts};
use super::{consts, AuthenticationMethod, ReplyError, Result, SocksError};
use anyhow::Context; use anyhow::Context;
use std::io; use std::io;
use std::net::IpAddr; use std::net::IpAddr;
@@ -1,6 +1,6 @@
use crate::gateway::fast_socks5::SocksError;
use crate::gateway::fast_socks5::consts; use crate::gateway::fast_socks5::consts;
use crate::gateway::fast_socks5::consts::SOCKS5_ADDR_TYPE_IPV4; use crate::gateway::fast_socks5::consts::SOCKS5_ADDR_TYPE_IPV4;
use crate::gateway::fast_socks5::SocksError;
use crate::read_exact; use crate::read_exact;
use anyhow::Context; use anyhow::Context;
@@ -99,7 +99,7 @@ impl TargetAddr {
buf.extend_from_slice(&(addr.ip()).octets()); // ip buf.extend_from_slice(&(addr.ip()).octets()); // ip
buf.extend_from_slice(&addr.port().to_be_bytes()); // port buf.extend_from_slice(&addr.port().to_be_bytes()); // port
} }
TargetAddr::Domain(ref domain, port) => { TargetAddr::Domain(domain, port) => {
debug!("TargetAddr::Domain"); debug!("TargetAddr::Domain");
if domain.len() > u8::MAX as usize { if domain.len() > u8::MAX as usize {
return Err(SocksError::ExceededMaxDomainLen(domain.len()).into()); return Err(SocksError::ExceededMaxDomainLen(domain.len()).into());
+6 -6
View File
@@ -8,29 +8,29 @@ use std::{
use anyhow::Context; use anyhow::Context;
use pnet::packet::{ use pnet::packet::{
icmp::{self, echo_reply::MutableEchoReplyPacket, IcmpCode, IcmpTypes, MutableIcmpPacket}, Packet,
icmp::{self, IcmpCode, IcmpTypes, MutableIcmpPacket, echo_reply::MutableEchoReplyPacket},
ip::IpNextHeaderProtocols, ip::IpNextHeaderProtocols,
ipv4::Ipv4Packet, ipv4::Ipv4Packet,
Packet,
}; };
use socket2::Socket; use socket2::Socket;
use tokio::{ use tokio::{
sync::{mpsc::UnboundedSender, Mutex}, sync::{Mutex, mpsc::UnboundedSender},
task::JoinSet, task::JoinSet,
}; };
use tracing::Instrument; use tracing::Instrument;
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, common::{PeerId, error::Error, global_ctx::ArcGlobalCtx},
gateway::ip_reassembler::ComposeIpv4PacketArgs, gateway::ip_reassembler::ComposeIpv4PacketArgs,
peers::{peer_manager::PeerManager, PeerPacketFilter}, peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::packet_def::{PacketType, ZCPacket}, tunnel::packet_def::{PacketType, ZCPacket},
}; };
use super::{ use super::{
ip_reassembler::{compose_ipv4_packet, IpReassembler},
CidrSet, CidrSet,
ip_reassembler::{IpReassembler, compose_ipv4_packet},
}; };
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+15 -3
View File
@@ -1,7 +1,7 @@
use dashmap::DashMap; use dashmap::DashMap;
use pnet::packet::Packet;
use pnet::packet::ip::IpNextHeaderProtocol; use pnet::packet::ip::IpNextHeaderProtocol;
use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet}; use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet};
use pnet::packet::Packet;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@@ -45,13 +45,25 @@ impl IpPacket {
// make sure the fragment doesn't overlap with existing fragments // make sure the fragment doesn't overlap with existing fragments
for f in &self.fragments { for f in &self.fragments {
if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 { if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 {
tracing::trace!("fragment overlap 1, f.offset = {}, fragment.offset = {}, f.data.len() = {}, fragment.data.len() = {}", f.offset, fragment.offset, f.data.len(), fragment.data.len()); tracing::trace!(
"fragment overlap 1, f.offset = {}, fragment.offset = {}, f.data.len() = {}, fragment.data.len() = {}",
f.offset,
fragment.offset,
f.data.len(),
fragment.data.len()
);
return; return;
} }
if fragment.offset <= f.offset if fragment.offset <= f.offset
&& f.offset < fragment.offset + fragment.data.len() as u16 && f.offset < fragment.offset + fragment.data.len() as u16
{ {
tracing::trace!("fragment overlap 2, f.offset = {}, fragment.offset = {}, f.data.len() = {}, fragment.data.len() = {}", f.offset, fragment.offset, f.data.len(), fragment.data.len()); tracing::trace!(
"fragment overlap 2, f.offset = {}, fragment.offset = {}, f.data.len() = {}, fragment.data.len() = {}",
f.offset,
fragment.offset,
f.data.len(),
fragment.data.len()
);
return; return;
} }
} }
+2 -2
View File
@@ -17,8 +17,8 @@ use prost::Message;
use tokio::{select, task::JoinSet}; use tokio::{select, task::JoinSet};
use super::{ use super::{
tcp_proxy::{NatDstConnector, NatDstTcpConnector, TcpProxy},
CidrSet, CidrSet,
tcp_proxy::{NatDstConnector, NatDstTcpConnector, TcpProxy},
}; };
use crate::{ use crate::{
common::{ common::{
@@ -27,7 +27,7 @@ use crate::{
global_ctx::{ArcGlobalCtx, GlobalCtx}, global_ctx::{ArcGlobalCtx, GlobalCtx},
}, },
gateway::wrapped_proxy::{ProxyAclHandler, TcpProxyForWrappedSrcTrait}, gateway::wrapped_proxy::{ProxyAclHandler, TcpProxyForWrappedSrcTrait},
peers::{peer_manager::PeerManager, PeerPacketFilter}, peers::{PeerPacketFilter, peer_manager::PeerManager},
proto::{ proto::{
acl::{ChainType, Protocol}, acl::{ChainType, Protocol},
api::instance::{ api::instance::{
+11 -13
View File
@@ -1,11 +1,11 @@
use crate::common::PeerId;
use crate::common::acl_processor::PacketInfo; use crate::common::acl_processor::PacketInfo;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx}; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx};
use crate::common::PeerId; use crate::gateway::CidrSet;
use crate::gateway::tcp_proxy::{NatDstConnector, TcpProxy}; use crate::gateway::tcp_proxy::{NatDstConnector, TcpProxy};
use crate::gateway::wrapped_proxy::{ProxyAclHandler, TcpProxyForWrappedSrcTrait}; use crate::gateway::wrapped_proxy::{ProxyAclHandler, TcpProxyForWrappedSrcTrait};
use crate::gateway::CidrSet;
use crate::peers::peer_manager::PeerManager;
use crate::peers::PeerPacketFilter; use crate::peers::PeerPacketFilter;
use crate::peers::peer_manager::PeerManager;
use crate::proto::acl::{ChainType, Protocol}; use crate::proto::acl::{ChainType, Protocol};
use crate::proto::api::instance::{ use crate::proto::api::instance::{
ListTcpProxyEntryRequest, ListTcpProxyEntryResponse, TcpProxyEntry, TcpProxyEntryState, ListTcpProxyEntryRequest, ListTcpProxyEntryResponse, TcpProxyEntry, TcpProxyEntryState,
@@ -15,10 +15,10 @@ use crate::proto::peer_rpc::KcpConnData as QuicConnData;
use crate::proto::rpc_types; use crate::proto::rpc_types;
use crate::proto::rpc_types::controller::BaseController; use crate::proto::rpc_types::controller::BaseController;
use crate::tunnel::packet_def::{ use crate::tunnel::packet_def::{
PacketType, PeerManagerHeader, ZCPacket, ZCPacketType, TAIL_RESERVED_SIZE, PacketType, PeerManagerHeader, TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType,
}; };
use crate::tunnel::quic::{client_config, endpoint_config, server_config}; use crate::tunnel::quic::{client_config, endpoint_config, server_config};
use anyhow::{anyhow, Context, Error}; use anyhow::{Context, Error, anyhow};
use atomic_refcell::AtomicRefCell; use atomic_refcell::AtomicRefCell;
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use dashmap::DashMap; use dashmap::DashMap;
@@ -36,11 +36,11 @@ use std::ptr::copy_nonoverlapping;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::task::Poll; use std::task::Poll;
use std::time::Duration; use std::time::Duration;
use tokio::io::{join, AsyncReadExt, Join}; use tokio::io::{AsyncReadExt, Join, join};
use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{Receiver, Sender, channel};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tokio::time::{timeout, Instant}; use tokio::time::{Instant, timeout};
use tokio::{join, pin, select}; use tokio::{join, pin, select};
use tokio_util::sync::PollSender; use tokio_util::sync::PollSender;
use tracing::{debug, error, info, instrument, trace, warn}; use tracing::{debug, error, info, instrument, trace, warn};
@@ -174,9 +174,7 @@ impl AsyncUdpSocket for QuicSocket {
} }
trace!( trace!(
"{:?} received {:?} bytes from {:?}", "{:?} received {:?} bytes from {:?}",
self.addr, self.addr, len, packet.addr
len,
packet.addr
); );
buf[0..len].copy_from_slice(&packet.payload); buf[0..len].copy_from_slice(&packet.payload);
*meta = RecvMeta { *meta = RecvMeta {
@@ -193,7 +191,7 @@ impl AsyncUdpSocket for QuicSocket {
return Poll::Ready(Err(std::io::Error::new( return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted, std::io::ErrorKind::ConnectionAborted,
"socket closed", "socket closed",
))) )));
} }
Poll::Pending => break, Poll::Pending => break,
} }
@@ -1250,7 +1248,7 @@ mod tests {
// We agree that the first byte of data is (stream_index % 255) // We agree that the first byte of data is (stream_index % 255)
// This ensures stream data is not mixed // This ensures stream data is not mixed
let expected_byte = data[0] as usize; // Get the actual received marker let expected_byte = data[0] as usize; // Get the actual received marker
// Simple check of head and tail here, CRC can be used in production // Simple check of head and tail here, CRC can be used in production
if data[data.len() - 1] != data[0] { if data[data.len() - 1] != data[0] {
panic!("Stream data corruption"); panic!("Stream data corruption");
} }
+23 -26
View File
@@ -2,8 +2,8 @@ use std::{
any::Any, any::Any,
net::{IpAddr, Ipv4Addr, SocketAddr}, net::{IpAddr, Ipv4Addr, SocketAddr},
sync::{ sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Weak, Arc, Weak,
atomic::{AtomicBool, AtomicUsize, Ordering},
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@@ -28,7 +28,7 @@ use crate::{
util::stream::tcp_connect_with_timeout, util::stream::tcp_connect_with_timeout,
}, },
ip_reassembler::IpReassembler, ip_reassembler::IpReassembler,
tokio_smoltcp::{channel_device, BufferSize, Net, NetConfig}, tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
}, },
tunnel::{ tunnel::{
common::setup_sokcet2, common::setup_sokcet2,
@@ -38,20 +38,20 @@ use crate::{
use anyhow::Context; use anyhow::Context;
use dashmap::DashMap; use dashmap::DashMap;
use pnet::packet::{ use pnet::packet::{
ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, Packet, Packet, ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket,
}; };
use tokio::{ use tokio::{
io::{AsyncRead, AsyncWrite}, io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpSocket, UdpSocket}, net::{TcpListener, TcpSocket, UdpSocket},
select, select,
sync::{mpsc, Mutex, Notify}, sync::{Mutex, Notify, mpsc},
task::JoinSet, task::JoinSet,
time::timeout, time::timeout,
}; };
use crate::{ use crate::{
common::{error::Error, global_ctx::GlobalCtx}, common::{error::Error, global_ctx::GlobalCtx},
peers::{peer_manager::PeerManager, PeerPacketFilter}, peers::{PeerPacketFilter, peer_manager::PeerManager},
}; };
#[cfg(feature = "kcp")] #[cfg(feature = "kcp")]
@@ -92,12 +92,10 @@ impl AsyncRead for SocksTcpStream {
buf: &mut tokio::io::ReadBuf<'_>, buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> { ) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_read(cx, buf), SocksTcpStream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
SocksTcpStream::SmolTcp(ref mut stream) => { SocksTcpStream::SmolTcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
std::pin::Pin::new(stream).poll_read(cx, buf)
}
#[cfg(feature = "kcp")] #[cfg(feature = "kcp")]
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_read(cx, buf), SocksTcpStream::Kcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
} }
} }
} }
@@ -109,12 +107,10 @@ impl AsyncWrite for SocksTcpStream {
buf: &[u8], buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> { ) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_write(cx, buf), SocksTcpStream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
SocksTcpStream::SmolTcp(ref mut stream) => { SocksTcpStream::SmolTcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
std::pin::Pin::new(stream).poll_write(cx, buf)
}
#[cfg(feature = "kcp")] #[cfg(feature = "kcp")]
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_write(cx, buf), SocksTcpStream::Kcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
} }
} }
@@ -123,10 +119,10 @@ impl AsyncWrite for SocksTcpStream {
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx), SocksTcpStream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::SmolTcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx), SocksTcpStream::SmolTcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
#[cfg(feature = "kcp")] #[cfg(feature = "kcp")]
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx), SocksTcpStream::Kcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
} }
} }
@@ -135,10 +131,10 @@ impl AsyncWrite for SocksTcpStream {
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx), SocksTcpStream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
SocksTcpStream::SmolTcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx), SocksTcpStream::SmolTcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
#[cfg(feature = "kcp")] #[cfg(feature = "kcp")]
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx), SocksTcpStream::Kcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
} }
} }
} }
@@ -284,10 +280,10 @@ impl AsyncTcpConnector for Socks5AutoConnector {
return Err(anyhow::anyhow!("peer manager is dropped").into()); return Err(anyhow::anyhow!("peer manager is dropped").into());
}; };
if let Some(local_addr) = self.smoltcp_net.as_ref().map(|n| n.get_address()) { if let Some(local_addr) = self.smoltcp_net.as_ref().map(|n| n.get_address())
if local_addr == addr.ip() { && local_addr == addr.ip()
addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port()); {
} addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port());
} }
if self.smoltcp_net.is_none() if self.smoltcp_net.is_none()
@@ -805,7 +801,8 @@ impl Socks5Server {
Ok((from_client, from_server)) => { Ok((from_client, from_server)) => {
tracing::info!( tracing::info!(
"port forward connection finished: client->server: {} bytes, server->client: {} bytes", "port forward connection finished: client->server: {} bytes, server->client: {} bytes",
from_client, from_server from_client,
from_server
); );
} }
Err(e) => { Err(e) => {
+10 -10
View File
@@ -3,19 +3,19 @@ use cidr::Ipv4Inet;
use core::panic; use core::panic;
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
use dashmap::DashMap; use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket};
use pnet::packet::MutablePacket; use pnet::packet::MutablePacket;
use pnet::packet::Packet; use pnet::packet::Packet;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
use pnet::packet::tcp::{MutableTcpPacket, TcpPacket, ipv4_checksum};
use socket2::{SockRef, TcpKeepalive}; use socket2::{SockRef, TcpKeepalive};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::atomic::{AtomicBool, AtomicU16}; use std::sync::atomic::{AtomicBool, AtomicU16};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, copy_bidirectional};
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::Instrument; use tracing::Instrument;
@@ -38,7 +38,7 @@ use crate::tunnel::packet_def::{PacketType, PeerManagerHeader, ZCPacket};
use super::CidrSet; use super::CidrSet;
#[cfg(feature = "smoltcp")] #[cfg(feature = "smoltcp")]
use super::tokio_smoltcp::{self, channel_device, Net, NetConfig}; use super::tokio_smoltcp::{self, Net, NetConfig, channel_device};
#[async_trait::async_trait] #[async_trait::async_trait]
pub(crate) trait NatDstConnector: Send + Sync + Clone + 'static { pub(crate) trait NatDstConnector: Send + Sync + Clone + 'static {
@@ -347,10 +347,10 @@ impl<C: NatDstConnector> PeerPacketFilter for TcpProxy<C> {
if let Err(e) = smoltcp_stack_sender.try_send(packet) { if let Err(e) = smoltcp_stack_sender.try_send(packet) {
tracing::error!("send to smoltcp stack failed: {:?}", e); tracing::error!("send to smoltcp stack failed: {:?}", e);
} }
} else if let Some(peer_manager) = self.get_peer_manager() { } else if let Some(peer_manager) = self.get_peer_manager()
if let Err(e) = peer_manager.get_nic_channel().send(packet).await { && let Err(e) = peer_manager.get_nic_channel().send(packet).await
tracing::error!("send to nic failed: {:?}", e); {
} tracing::error!("send to nic failed: {:?}", e);
} }
return None; return None;
} else { } else {
@@ -5,7 +5,7 @@ use std::{
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{Receiver, Sender, channel};
use tokio_util::sync::{PollSendError, PollSender}; use tokio_util::sync::{PollSendError, PollSender};
use super::device::AsyncDevice; use super::device::AsyncDevice;
+1 -1
View File
@@ -6,8 +6,8 @@ use std::{
io, io,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
sync::{ sync::{
atomic::{AtomicU16, Ordering},
Arc, Arc,
atomic::{AtomicU16, Ordering},
}, },
}; };
@@ -2,7 +2,7 @@ use super::{
device::{BufferDevice, Packet}, device::{BufferDevice, Packet},
socket_allocator::{BufferSize, SocketAlloctor}, socket_allocator::{BufferSize, SocketAlloctor},
}; };
use futures::{stream::iter, FutureExt, SinkExt, StreamExt}; use futures::{FutureExt, SinkExt, StreamExt, stream::iter};
use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; use parking_lot::{MappedMutexGuard, Mutex, MutexGuard};
use smoltcp::{ use smoltcp::{
iface::{Context, Interface, SocketHandle}, iface::{Context, Interface, SocketHandle},
@@ -92,10 +92,10 @@ async fn run(
// wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets) // wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets)
for (_, socket) in socket_allocator.sockets().lock().iter_mut() { for (_, socket) in socket_allocator.sockets().lock().iter_mut() {
if let Socket::Tcp(tcp) = socket { if let Socket::Tcp(tcp) = socket
if tcp.state() == smoltcp::socket::tcp::State::Closed { && tcp.state() == smoltcp::socket::tcp::State::Closed
tcp.abort(); {
} tcp.abort();
} }
} }
} }
+1 -1
View File
@@ -1,6 +1,6 @@
use super::{reactor::Reactor, socket_allocator::SocketHandle}; use super::{reactor::Reactor, socket_allocator::SocketHandle};
use futures::future::{self, poll_fn}; use futures::future::{self, poll_fn};
use futures::{ready, Stream}; use futures::{Stream, ready};
pub use smoltcp::socket::tcp; pub use smoltcp::socket::tcp;
use smoltcp::socket::udp; use smoltcp::socket::udp;
use smoltcp::wire::{IpAddress, IpEndpoint}; use smoltcp::wire::{IpAddress, IpEndpoint};
@@ -85,9 +85,8 @@ impl SocketAlloctor {
vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_tx_meta_size], vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_tx_meta_size],
vec![0; self.buffer_size.udp_tx_size], vec![0; self.buffer_size.udp_tx_size],
); );
let udp = udp::Socket::new(rx_buffer, tx_buffer);
udp udp::Socket::new(rx_buffer, tx_buffer)
} }
} }
+7 -7
View File
@@ -1,6 +1,6 @@
use std::{ use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4}, net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::{atomic::AtomicBool, Arc, Weak}, sync::{Arc, Weak, atomic::AtomicBool},
time::Duration, time::Duration,
}; };
@@ -9,12 +9,12 @@ use cidr::Ipv4Inet;
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
use dashmap::DashMap; use dashmap::DashMap;
use pnet::packet::{ use pnet::packet::{
Packet,
ip::IpNextHeaderProtocols, ip::IpNextHeaderProtocols,
ipv4::Ipv4Packet, ipv4::Ipv4Packet,
udp::{self, MutableUdpPacket}, udp::{self, MutableUdpPacket},
Packet,
}; };
use tokio::sync::mpsc::{channel, error::TrySendError, Receiver, Sender}; use tokio::sync::mpsc::{Receiver, Sender, channel, error::TrySendError};
use tokio::{ use tokio::{
net::UdpSocket, net::UdpSocket,
sync::Mutex, sync::Mutex,
@@ -25,16 +25,16 @@ use tokio::{
use tracing::Level; use tracing::Level;
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, PeerId}, common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
gateway::ip_reassembler::{compose_ipv4_packet, ComposeIpv4PacketArgs}, gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
peers::{peer_manager::PeerManager, PeerPacketFilter}, peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::{ tunnel::{
common::{reserve_buf, setup_sokcet2}, common::{reserve_buf, setup_sokcet2},
packet_def::{PacketType, ZCPacket}, packet_def::{PacketType, ZCPacket},
}, },
}; };
use super::{ip_reassembler::IpReassembler, CidrSet}; use super::{CidrSet, ip_reassembler::IpReassembler};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey { struct UdpNatKey {
+3 -3
View File
@@ -4,19 +4,19 @@ use std::{
}; };
use pnet::packet::{ use pnet::packet::{
Packet as _,
ip::IpNextHeaderProtocols, ip::IpNextHeaderProtocols,
ipv4::Ipv4Packet, ipv4::Ipv4Packet,
tcp::{TcpFlags, TcpPacket}, tcp::{TcpFlags, TcpPacket},
Packet as _,
}; };
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite, copy_bidirectional};
use tokio_util::io::InspectReader; use tokio_util::io::InspectReader;
use crate::tunnel::packet_def::{PacketType, PeerManagerHeader}; use crate::tunnel::packet_def::{PacketType, PeerManagerHeader};
use crate::{ use crate::{
common::{acl_processor::PacketInfo, error::Result}, common::{acl_processor::PacketInfo, error::Result},
gateway::tcp_proxy::{NatDstConnector, TcpProxy}, gateway::tcp_proxy::{NatDstConnector, TcpProxy},
peers::{acl_filter::AclFilter, NicPacketFilter}, peers::{NicPacketFilter, acl_filter::AclFilter},
proto::acl::{Action, ChainType}, proto::acl::{Action, ChainType},
tunnel::packet_def::ZCPacket, tunnel::packet_def::ZCPacket,
}; };
+1 -1
View File
@@ -5,11 +5,11 @@ use hickory_proto::rr::LowerName;
use hickory_resolver::config::ResolverOpts; use hickory_resolver::config::ResolverOpts;
use hickory_resolver::name_server::TokioConnectionProvider; use hickory_resolver::name_server::TokioConnectionProvider;
use hickory_resolver::system_conf::read_system_conf; use hickory_resolver::system_conf::read_system_conf;
use hickory_server::ServerFuture;
use hickory_server::authority::{AuthorityObject, Catalog, ZoneType}; use hickory_server::authority::{AuthorityObject, Catalog, ZoneType};
use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo}; use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
use hickory_server::store::forwarder::ForwardConfig; use hickory_server::store::forwarder::ForwardConfig;
use hickory_server::store::{forwarder::ForwardAuthority, in_memory::InMemoryAuthority}; use hickory_server::store::{forwarder::ForwardAuthority, in_memory::InMemoryAuthority};
use hickory_server::ServerFuture;
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::str::FromStr; use std::str::FromStr;
@@ -7,28 +7,28 @@
// all the clients will exit and let the easytier instance to launch a new server instance. // all the clients will exit and let the easytier instance to launch a new server instance.
use super::{ use super::{
MAGIC_DNS_INSTANCE_ADDR,
config::{GeneralConfigBuilder, RunConfigBuilder}, config::{GeneralConfigBuilder, RunConfigBuilder},
server::Server, server::Server,
system_config::{OSConfig, SystemConfig}, system_config::{OSConfig, SystemConfig},
MAGIC_DNS_INSTANCE_ADDR,
}; };
use crate::{ use crate::{
common::{ common::{
ifcfg::{IfConfiger, IfConfiguerTrait},
PeerId, PeerId,
ifcfg::{IfConfiger, IfConfiguerTrait},
}, },
instance::dns_server::{ instance::dns_server::{
config::{Record, RecordBuilder, RecordType}, config::{Record, RecordBuilder, RecordType},
server::build_authority, server::build_authority,
}, },
peers::{peer_manager::PeerManager, NicPacketFilter}, peers::{NicPacketFilter, peer_manager::PeerManager},
proto::{ proto::{
api::instance::Route, api::instance::Route,
common::{TunnelInfo, Void}, common::{TunnelInfo, Void},
magic_dns::{ magic_dns::{
dns_record::{self},
DnsRecord, DnsRecordA, DnsRecordList, GetDnsRecordResponse, HandshakeRequest, DnsRecord, DnsRecordA, DnsRecordList, GetDnsRecordResponse, HandshakeRequest,
HandshakeResponse, MagicDnsServerRpc, MagicDnsServerRpcServer, UpdateDnsRecordRequest, HandshakeResponse, MagicDnsServerRpc, MagicDnsServerRpcServer, UpdateDnsRecordRequest,
dns_record::{self},
}, },
rpc_impl::standalone::{RpcServerHook, StandAloneServer}, rpc_impl::standalone::{RpcServerHook, StandAloneServer},
rpc_types::controller::{BaseController, Controller}, rpc_types::controller::{BaseController, Controller},
@@ -47,11 +47,10 @@ use pnet::packet::icmp::{IcmpTypes, MutableIcmpPacket};
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use pnet::packet::udp::UdpPacket; use pnet::packet::udp::UdpPacket;
use pnet::packet::{ use pnet::packet::{
icmp, MutablePacket, Packet, icmp,
ip::IpNextHeaderProtocols, ip::IpNextHeaderProtocols,
ipv4::{self, MutableIpv4Packet}, ipv4::{self, MutableIpv4Packet},
udp::{self, MutableUdpPacket}, udp::{self, MutableUdpPacket},
MutablePacket, Packet,
}; };
use std::net::{SocketAddr, SocketAddrV4}; use std::net::{SocketAddr, SocketAddrV4};
use std::sync::Mutex; use std::sync::Mutex;
@@ -528,18 +527,18 @@ impl MagicDnsServerInstance {
let mut dns_server = Server::new(dns_config); let mut dns_server = Server::new(dns_config);
dns_server.run().await?; dns_server.run().await?;
if !tun_inet.contains(&fake_ip) { if !tun_inet.contains(&fake_ip)
if let Some(tun_dev_name) = &tun_dev { && let Some(tun_dev_name) = &tun_dev
let cost = if cfg!(target_os = "windows") { {
Some(4) let cost = if cfg!(target_os = "windows") {
} else { Some(4)
None } else {
}; None
let ifcfg = IfConfiger {}; };
ifcfg let ifcfg = IfConfiger {};
.add_ipv4_route(tun_dev_name, fake_ip, 32, cost) ifcfg
.await?; .add_ipv4_route(tun_dev_name, fake_ip, 32, cost)
} .await?;
} }
let data = Arc::new(MagicDnsServerInstanceData { let data = Arc::new(MagicDnsServerInstanceData {
@@ -587,13 +586,13 @@ impl MagicDnsServerInstance {
if let Err(e) = ret { if let Err(e) = ret {
tracing::error!("Failed to close system config: {:?}", e); tracing::error!("Failed to close system config: {:?}", e);
} }
if !self.tun_inet.contains(&self.data.fake_ip) { if !self.tun_inet.contains(&self.data.fake_ip)
if let Some(tun_dev_name) = &self.data.tun_dev { && let Some(tun_dev_name) = &self.data.tun_dev
let ifcfg = IfConfiger {}; {
let _ = ifcfg let ifcfg = IfConfiger {};
.remove_ipv4_route(tun_dev_name, self.data.fake_ip, 32) let _ = ifcfg
.await; .remove_ipv4_route(tun_dev_name, self.data.fake_ip, 32)
} .await;
} }
} }
+37 -34
View File
@@ -17,12 +17,12 @@ use tokio::{sync::oneshot, task::JoinSet};
#[cfg(feature = "magic-dns")] #[cfg(feature = "magic-dns")]
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::common::PeerId;
use crate::common::acl_processor::AclRuleBuilder; use crate::common::acl_processor::AclRuleBuilder;
use crate::common::config::ConfigLoader; use crate::common::config::ConfigLoader;
use crate::common::error::Error; use crate::common::error::Error;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent}; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent};
use crate::common::scoped_task::ScopedTask; use crate::common::scoped_task::ScopedTask;
use crate::common::PeerId;
use crate::connector::direct::DirectConnectorManager; use crate::connector::direct::DirectConnectorManager;
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager}; use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
use crate::connector::tcp_hole_punch::TcpHolePunchConnector; use crate::connector::tcp_hole_punch::TcpHolePunchConnector;
@@ -40,7 +40,7 @@ use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
#[cfg(feature = "tun")] #[cfg(feature = "tun")]
use crate::peers::recv_packet_from_chan; use crate::peers::recv_packet_from_chan;
use crate::peers::rpc_service::PeerManagerRpcService; use crate::peers::rpc_service::PeerManagerRpcService;
use crate::peers::{create_packet_recv_chan, PacketRecvChanReceiver}; use crate::peers::{PacketRecvChanReceiver, create_packet_recv_chan};
use crate::proto::api::config::{ use crate::proto::api::config::{
ConfigPatchAction, ConfigRpc, GetConfigRequest, GetConfigResponse, PatchConfigRequest, ConfigPatchAction, ConfigRpc, GetConfigRequest, GetConfigResponse, PatchConfigRequest,
PatchConfigResponse, PortForwardPatch, PatchConfigResponse, PortForwardPatch,
@@ -63,7 +63,7 @@ use crate::utils::weak_upgrade;
use crate::vpn_portal::{self, VpnPortal}; use crate::vpn_portal::{self, VpnPortal};
#[cfg(feature = "magic-dns")] #[cfg(feature = "magic-dns")]
use super::dns_server::{runner::DnsRunner, MAGIC_DNS_FAKE_IP}; use super::dns_server::{MAGIC_DNS_FAKE_IP, runner::DnsRunner};
use super::listeners::ListenerManager; use super::listeners::ListenerManager;
#[cfg(feature = "socks5")] #[cfg(feature = "socks5")]
@@ -272,11 +272,11 @@ impl InstanceConfigPatcher {
global_ctx.set_hostname(hostname.clone()); global_ctx.set_hostname(hostname.clone());
global_ctx.config.set_hostname(Some(hostname)); global_ctx.config.set_hostname(Some(hostname));
} }
if let Some(ipv4) = patch.ipv4 { if let Some(ipv4) = patch.ipv4
if !global_ctx.config.get_dhcp() { && !global_ctx.config.get_dhcp()
global_ctx.set_ipv4(Some(ipv4.into())); {
global_ctx.config.set_ipv4(Some(ipv4.into())); global_ctx.set_ipv4(Some(ipv4.into()));
} global_ctx.config.set_ipv4(Some(ipv4.into()));
} }
if let Some(ipv6) = patch.ipv6 { if let Some(ipv6) = patch.ipv6 {
global_ctx.set_ipv6(Some(ipv6.into())); global_ctx.set_ipv6(Some(ipv6.into()));
@@ -667,13 +667,13 @@ impl Instance {
packet_recv: Arc<Mutex<PacketRecvChanReceiver>>, packet_recv: Arc<Mutex<PacketRecvChanReceiver>>,
) { ) {
#[cfg(feature = "magic-dns")] #[cfg(feature = "magic-dns")]
if let Some(old_ctx) = arc_nic_ctx.lock().await.take() { if let Some(old_ctx) = arc_nic_ctx.lock().await.take()
if let Some(dns_runner) = old_ctx.magic_dns { && let Some(dns_runner) = old_ctx.magic_dns
dns_runner.dns_runner_cancel_token.cancel(); {
tracing::debug!("cancelling dns runner task"); dns_runner.dns_runner_cancel_token.cancel();
let ret = dns_runner.dns_runner_task.await; tracing::debug!("cancelling dns runner task");
tracing::debug!("dns runner task cancelled, ret: {:?}", ret); let ret = dns_runner.dns_runner_task.await;
} tracing::debug!("dns runner task cancelled, ret: {:?}", ret);
}; };
let mut tasks = JoinSet::new(); let mut tasks = JoinSet::new();
@@ -772,10 +772,11 @@ impl Instance {
let dhcp_inet = used_ipv4.iter().next().unwrap_or(&default_ipv4_addr); let dhcp_inet = used_ipv4.iter().next().unwrap_or(&default_ipv4_addr);
// if old ip is already in this subnet and not conflicted, use it // if old ip is already in this subnet and not conflicted, use it
if let Some(ip) = current_dhcp_ip { if let Some(ip) = current_dhcp_ip
if ip.network() == dhcp_inet.network() && !used_ipv4.contains(&ip) { && ip.network() == dhcp_inet.network()
continue; && !used_ipv4.contains(&ip)
} {
continue;
} }
// find an available ip in the subnet // find an available ip in the subnet
@@ -1070,7 +1071,9 @@ impl Instance {
self.peer_manager.my_peer_id() self.peer_manager.my_peer_id()
} }
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc<Controller = BaseController> + Clone { fn get_vpn_portal_rpc_service(
&self,
) -> impl VpnPortalRpc<Controller = BaseController> + Clone + use<> {
#[derive(Clone)] #[derive(Clone)]
struct VpnPortalRpcService { struct VpnPortalRpcService {
peer_mgr: Weak<PeerManager>, peer_mgr: Weak<PeerManager>,
@@ -1115,7 +1118,7 @@ impl Instance {
fn get_mapped_listener_manager_rpc_service( fn get_mapped_listener_manager_rpc_service(
&self, &self,
) -> impl MappedListenerManageRpc<Controller = BaseController> + Clone { ) -> impl MappedListenerManageRpc<Controller = BaseController> + Clone + use<> {
#[derive(Clone)] #[derive(Clone)]
pub struct MappedListenerManagerRpcService(Weak<GlobalCtx>); pub struct MappedListenerManagerRpcService(Weak<GlobalCtx>);
@@ -1146,7 +1149,7 @@ impl Instance {
fn get_port_forward_manager_rpc_service( fn get_port_forward_manager_rpc_service(
&self, &self,
) -> impl PortForwardManageRpc<Controller = BaseController> + Clone { ) -> impl PortForwardManageRpc<Controller = BaseController> + Clone + use<> {
#[derive(Clone)] #[derive(Clone)]
pub struct PortForwardManagerRpcService { pub struct PortForwardManagerRpcService {
global_ctx: Weak<GlobalCtx>, global_ctx: Weak<GlobalCtx>,
@@ -1176,7 +1179,7 @@ impl Instance {
} }
} }
fn get_stats_rpc_service(&self) -> impl StatsRpc<Controller = BaseController> + Clone { fn get_stats_rpc_service(&self) -> impl StatsRpc<Controller = BaseController> + Clone + use<> {
#[derive(Clone)] #[derive(Clone)]
pub struct StatsRpcService { pub struct StatsRpcService {
global_ctx: Weak<GlobalCtx>, global_ctx: Weak<GlobalCtx>,
@@ -1242,7 +1245,7 @@ impl Instance {
} }
} }
fn get_config_service(&self) -> impl ConfigRpc<Controller = BaseController> + Clone { fn get_config_service(&self) -> impl ConfigRpc<Controller = BaseController> + Clone + use<> {
#[derive(Clone)] #[derive(Clone)]
pub struct ConfigRpcService { pub struct ConfigRpcService {
patcher: InstanceConfigPatcher, patcher: InstanceConfigPatcher,
@@ -1285,7 +1288,7 @@ impl Instance {
} }
} }
pub fn get_api_rpc_service(&self) -> impl InstanceRpcService { pub fn get_api_rpc_service(&self) -> impl InstanceRpcService + use<> {
use crate::proto::api::instance::*; use crate::proto::api::instance::*;
#[derive(Clone)] #[derive(Clone)]
@@ -1308,15 +1311,15 @@ impl Instance {
#[async_trait::async_trait] #[async_trait::async_trait]
impl< impl<
A: PeerManageRpc<Controller = BaseController> + Send + Sync, A: PeerManageRpc<Controller = BaseController> + Send + Sync,
B: ConnectorManageRpc<Controller = BaseController> + Send + Sync, B: ConnectorManageRpc<Controller = BaseController> + Send + Sync,
C: MappedListenerManageRpc<Controller = BaseController> + Send + Sync, C: MappedListenerManageRpc<Controller = BaseController> + Send + Sync,
D: VpnPortalRpc<Controller = BaseController> + Send + Sync, D: VpnPortalRpc<Controller = BaseController> + Send + Sync,
E: AclManageRpc<Controller = BaseController> + Send + Sync, E: AclManageRpc<Controller = BaseController> + Send + Sync,
F: PortForwardManageRpc<Controller = BaseController> + Send + Sync, F: PortForwardManageRpc<Controller = BaseController> + Send + Sync,
G: StatsRpc<Controller = BaseController> + Send + Sync, G: StatsRpc<Controller = BaseController> + Send + Sync,
H: ConfigRpc<Controller = BaseController> + Send + Sync, H: ConfigRpc<Controller = BaseController> + Send + Sync,
> InstanceRpcService for ApiRpcServiceImpl<A, B, C, D, E, F, G, H> > InstanceRpcService for ApiRpcServiceImpl<A, B, C, D, E, F, G, H>
{ {
fn get_peer_manage_service(&self) -> &dyn PeerManageRpc<Controller = BaseController> { fn get_peer_manage_service(&self) -> &dyn PeerManageRpc<Controller = BaseController> {
&self.peer_mgr_rpc_service &self.peer_mgr_rpc_service
+3 -3
View File
@@ -17,8 +17,8 @@ use crate::{
}, },
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
tunnel::{ tunnel::{
self, ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, IpScheme, self, IpScheme, Tunnel, TunnelListener, TunnelScheme, ring::RingTunnelListener,
Tunnel, TunnelListener, TunnelScheme, tcp::TcpTunnelListener, udp::UdpTunnelListener,
}, },
utils::BoxExt, utils::BoxExt,
}; };
@@ -284,7 +284,7 @@ mod tests {
use crate::{ use crate::{
common::global_ctx::tests::get_mock_global_ctx, common::global_ctx::tests::get_mock_global_ctx,
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector, TunnelError}, tunnel::{TunnelConnector, TunnelError, packet_def::ZCPacket, ring::RingTunnelConnector},
}; };
use super::*; use super::*;
+8 -6
View File
@@ -15,18 +15,18 @@ use crate::{
log, log,
}, },
instance::proxy_cidrs_monitor::ProxyCidrsMonitor, instance::proxy_cidrs_monitor::ProxyCidrsMonitor,
peers::{peer_manager::PeerManager, recv_packet_from_chan, PacketRecvChanReceiver}, peers::{PacketRecvChanReceiver, peer_manager::PeerManager, recv_packet_from_chan},
tunnel::{ tunnel::{
common::{reserve_buf, FramedWriter, TunnelWrapper, ZCPacketToBytes},
packet_def::{ZCPacket, ZCPacketType, TAIL_RESERVED_SIZE},
StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
common::{FramedWriter, TunnelWrapper, ZCPacketToBytes, reserve_buf},
packet_def::{TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType},
}, },
}; };
use byteorder::WriteBytesExt as _; use byteorder::WriteBytesExt as _;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use cidr::{Ipv4Inet, Ipv6Inet}; use cidr::{Ipv4Inet, Ipv6Inet};
use futures::{lock::BiLock, ready, SinkExt, Stream, StreamExt}; use futures::{SinkExt, Stream, StreamExt, lock::BiLock, ready};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use pnet::packet::{ipv4::Ipv4Packet, ipv6::Ipv6Packet}; use pnet::packet::{ipv4::Ipv4Packet, ipv6::Ipv6Packet};
use tokio::{ use tokio::{
@@ -530,7 +530,9 @@ impl VirtualNic {
Ok(_) => tracing::info!("add_self_to_firewall_allowlist successful!"), Ok(_) => tracing::info!("add_self_to_firewall_allowlist successful!"),
Err(error) => { Err(error) => {
log::warn!(%error, "Failed to add Easytier to firewall allowlist, Subnet proxy and KCP proxy may not work properly."); log::warn!(%error, "Failed to add Easytier to firewall allowlist, Subnet proxy and KCP proxy may not work properly.");
log::warn!("You can add firewall rules manually, or use --use-smoltcp to run with user-space TCP/IP stack."); log::warn!(
"You can add firewall rules manually, or use --use-smoltcp to run with user-space TCP/IP stack."
);
} }
} }
@@ -768,7 +770,7 @@ impl VirtualNic {
Ok(()) Ok(())
} }
pub fn get_ifcfg(&self) -> impl IfConfiguerTrait { pub fn get_ifcfg(&self) -> impl IfConfiguerTrait + use<> {
IfConfiger {} IfConfiger {}
} }
} }
+56 -44
View File
@@ -85,11 +85,11 @@ impl NetworkInstanceManager {
let _t = instance_event_receiver let _t = instance_event_receiver
.map(|event| ScopedTask::from(handle_event(instance_id, event))); .map(|event| ScopedTask::from(handle_event(instance_id, event)));
instance_stop_notifier.notified().await; instance_stop_notifier.notified().await;
if let Some(instance) = instance_map.get(&instance_id) { if let Some(instance) = instance_map.get(&instance_id)
if let Some(error) = instance.get_latest_error_msg() { && let Some(error) = instance.get_latest_error_msg()
log::error!(%error, "instance {} stopped", instance_id); {
instance_error_messages.insert(instance_id, error); log::error!(%error, "instance {} stopped", instance_id);
} instance_error_messages.insert(instance_id, error);
} }
stop_check_notifier.notify_one(); stop_check_notifier.notify_one();
instance_stop_tasks.remove(&instance_id); instance_stop_tasks.remove(&instance_id);
@@ -543,45 +543,57 @@ mod tests {
let port = crate::utils::find_free_tcp_port(10012..65534).expect("no free tcp port found"); let port = crate::utils::find_free_tcp_port(10012..65534).expect("no free tcp port found");
assert!(manager assert!(
.run_network_instance( manager
TomlConfigLoader::new_from_str(cfg_str).unwrap(), .run_network_instance(
true, TomlConfigLoader::new_from_str(cfg_str).unwrap(),
ConfigFileControl::STATIC_CONFIG true,
) ConfigFileControl::STATIC_CONFIG
.is_err()); )
assert!(manager .is_err()
.run_network_instance( );
TomlConfigLoader::new_from_str(cfg_str).unwrap(), assert!(
true, manager
ConfigFileControl::STATIC_CONFIG .run_network_instance(
) TomlConfigLoader::new_from_str(cfg_str).unwrap(),
.is_err()); true,
assert!(manager ConfigFileControl::STATIC_CONFIG
.run_network_instance( )
TomlConfigLoader::new_from_str(cfg_str) .is_err()
.inspect(|c| { );
c.set_listeners(vec![format!("tcp://0.0.0.0:{}", port).parse().unwrap()]); assert!(
}) manager
.unwrap(), .run_network_instance(
false, TomlConfigLoader::new_from_str(cfg_str)
ConfigFileControl::STATIC_CONFIG .inspect(|c| {
) c.set_listeners(vec![
.is_ok()); format!("tcp://0.0.0.0:{}", port).parse().unwrap(),
assert!(manager ]);
.run_network_instance( })
TomlConfigLoader::new_from_str(cfg_str).unwrap(), .unwrap(),
true, false,
ConfigFileControl::STATIC_CONFIG ConfigFileControl::STATIC_CONFIG
) )
.is_err()); .is_ok()
assert!(manager );
.run_network_instance( assert!(
TomlConfigLoader::new_from_str(cfg_str).unwrap(), manager
false, .run_network_instance(
ConfigFileControl::STATIC_CONFIG TomlConfigLoader::new_from_str(cfg_str).unwrap(),
) true,
.is_ok()); ConfigFileControl::STATIC_CONFIG
)
.is_err()
);
assert!(
manager
.run_network_instance(
TomlConfigLoader::new_from_str(cfg_str).unwrap(),
false,
ConfigFileControl::STATIC_CONFIG
)
.is_ok()
);
std::thread::sleep(std::time::Duration::from_secs(1)); // wait instance actually started std::thread::sleep(std::time::Duration::from_secs(1)); // wait instance actually started
+35 -29
View File
@@ -1,12 +1,12 @@
use crate::common::config::{process_secure_mode_cfg, ConfigFileControl, PortForwardConfig}; use crate::common::config::{ConfigFileControl, PortForwardConfig, process_secure_mode_cfg};
use crate::proto::api::{self, manage}; use crate::proto::api::{self, manage};
use crate::proto::rpc_types::controller::BaseController; use crate::proto::rpc_types::controller::BaseController;
use crate::rpc_service::InstanceRpcService; use crate::rpc_service::InstanceRpcService;
use crate::{ use crate::{
common::{ common::{
config::{ config::{
gen_default_flags, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader, VpnPortalConfig,
VpnPortalConfig, gen_default_flags,
}, },
constants::EASYTIER_VERSION, constants::EASYTIER_VERSION,
global_ctx::{EventBusSubscriber, GlobalCtxEvent}, global_ctx::{EventBusSubscriber, GlobalCtxEvent},
@@ -19,7 +19,7 @@ use chrono::{DateTime, Local};
use std::{ use std::{
collections::VecDeque, collections::VecDeque,
net::SocketAddr, net::SocketAddr,
sync::{atomic::AtomicBool, Arc, Mutex, RwLock}, sync::{Arc, Mutex, RwLock, atomic::AtomicBool},
}; };
use tokio::{ use tokio::{
sync::{broadcast, mpsc}, sync::{broadcast, mpsc},
@@ -272,10 +272,10 @@ impl Drop for EasyTierLauncher {
fn drop(&mut self) { fn drop(&mut self) {
self.stop_flag self.stop_flag
.store(true, std::sync::atomic::Ordering::Relaxed); .store(true, std::sync::atomic::Ordering::Relaxed);
if let Some(handle) = self.thread_handle.take() { if let Some(handle) = self.thread_handle.take()
if let Err(e) = handle.join() { && let Err(e) = handle.join()
println!("Error when joining thread: {:?}", e); {
} println!("Error when joining thread: {:?}", e);
} }
} }
} }
@@ -656,12 +656,12 @@ impl NetworkConfig {
cfg.set_exit_nodes(exit_nodes); cfg.set_exit_nodes(exit_nodes);
} }
if self.enable_socks5.unwrap_or_default() { if self.enable_socks5.unwrap_or_default()
if let Some(socks5_port) = self.socks5_port { && let Some(socks5_port) = self.socks5_port
cfg.set_socks5_portal(Some( {
format!("socks5://0.0.0.0:{}", socks5_port).parse().unwrap(), cfg.set_socks5_portal(Some(
)); format!("socks5://0.0.0.0:{}", socks5_port).parse().unwrap(),
} ));
} }
if !self.mapped_listeners.is_empty() { if !self.mapped_listeners.is_empty() {
@@ -909,11 +909,11 @@ impl NetworkConfig {
result.vpn_portal_listen_port = Some(vpn_config.wireguard_listen.port() as i32); result.vpn_portal_listen_port = Some(vpn_config.wireguard_listen.port() as i32);
} }
if let Some(routes) = config.get_routes() { if let Some(routes) = config.get_routes()
if !routes.is_empty() { && !routes.is_empty()
result.enable_manual_routes = Some(true); {
result.routes = routes.iter().map(|r| r.to_string()).collect(); result.enable_manual_routes = Some(true);
} result.routes = routes.iter().map(|r| r.to_string()).collect();
} }
let exit_nodes = config.get_exit_nodes(); let exit_nodes = config.get_exit_nodes();
@@ -986,10 +986,10 @@ impl NetworkConfig {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
common::config::{process_secure_mode_cfg, ConfigLoader}, common::config::{ConfigLoader, process_secure_mode_cfg},
proto::common::SecureModeConfig, proto::common::SecureModeConfig,
}; };
use base64::prelude::{Engine as _, BASE64_STANDARD}; use base64::prelude::{BASE64_STANDARD, Engine as _};
use rand::Rng; use rand::Rng;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
@@ -1014,9 +1014,12 @@ mod tests {
let generated_config_str = generated_config.dump(); let generated_config_str = generated_config.dump();
assert_eq!( assert_eq!(
config_str, generated_config_str, config_str,
"Generated config does not match original config:\nOriginal:\n{}\n\nGenerated:\n{}\nNetwork Config: {}\n", generated_config_str,
config_str, generated_config_str, serde_json::to_string(&network_config).unwrap() "Generated config does not match original config:\nOriginal:\n{}\n\nGenerated:\n{}\nNetwork Config: {}\n",
config_str,
generated_config_str,
serde_json::to_string(&network_config).unwrap()
); );
Ok(()) Ok(())
} }
@@ -1033,13 +1036,13 @@ mod tests {
config.set_dhcp(rng.gen_bool(0.5)); config.set_dhcp(rng.gen_bool(0.5));
if rng.gen_bool(0.7) { if rng.gen_bool(0.7) {
let hostname = format!("host-{}", rng.gen::<u16>()); let hostname = format!("host-{}", rng.r#gen::<u16>());
config.set_hostname(Some(hostname)); config.set_hostname(Some(hostname));
} }
config.set_network_identity(crate::common::config::NetworkIdentity::new( config.set_network_identity(crate::common::config::NetworkIdentity::new(
format!("network-{}", rng.gen::<u16>()), format!("network-{}", rng.r#gen::<u16>()),
format!("secret-{}", rng.gen::<u64>()), format!("secret-{}", rng.r#gen::<u64>()),
)); ));
config.set_inst_name(config.get_network_identity().network_name.clone()); config.set_inst_name(config.get_network_identity().network_name.clone());
@@ -1251,9 +1254,12 @@ mod tests {
let generated_config_str = generated_config.dump(); let generated_config_str = generated_config.dump();
assert_eq!( assert_eq!(
config_str, generated_config_str, config_str,
generated_config_str,
"Generated config does not match original config:\nOriginal:\n{}\n\nGenerated:\n{}\nNetwork Config: {}\n", "Generated config does not match original config:\nOriginal:\n{}\n\nGenerated:\n{}\nNetwork Config: {}\n",
config_str, generated_config_str, serde_json::to_string(&network_config).unwrap() config_str,
generated_config_str,
serde_json::to_string(&network_config).unwrap()
); );
} }
+8 -8
View File
@@ -12,7 +12,7 @@ use tokio::task::JoinSet;
use tracing::Instrument; use tracing::Instrument;
use crate::{ use crate::{
common::{global_ctx::GlobalCtx, PeerId}, common::{PeerId, global_ctx::GlobalCtx},
peers::{ peers::{
peer_manager::PeerManager, peer_manager::PeerManager,
peer_map::PeerMap, peer_map::PeerMap,
@@ -30,7 +30,7 @@ use crate::{
}, },
}; };
use super::{server::PeerCenterServer, Digest, Error}; use super::{Digest, Error, server::PeerCenterServer};
#[async_trait::async_trait] #[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)] #[auto_impl::auto_impl(&, Arc, Box)]
@@ -97,12 +97,12 @@ impl PeerCenterBase {
&self, &self,
job_ctx: T, job_ctx: T,
job_fn: impl Fn( job_fn: impl Fn(
Box<dyn PeerCenterRpc<Controller = BaseController> + Send>, Box<dyn PeerCenterRpc<Controller = BaseController> + Send>,
Arc<PeridicJobCtx<T>>, Arc<PeridicJobCtx<T>>,
) -> Fut ) -> Fut
+ Send + Send
+ Sync + Sync
+ 'static, + 'static,
) { ) {
let my_peer_id = self.my_peer_id; let my_peer_id = self.my_peer_id;
let peer_mgr = self.peer_mgr.clone(); let peer_mgr = self.peer_mgr.clone();
+19 -19
View File
@@ -3,14 +3,14 @@ use std::sync::atomic::Ordering;
use std::time::Instant; use std::time::Instant;
use std::{ use std::{
net::IpAddr, net::IpAddr,
sync::{atomic::AtomicBool, Arc}, sync::{Arc, atomic::AtomicBool},
}; };
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use dashmap::DashMap; use dashmap::DashMap;
use pnet::packet::ipv6::Ipv6Packet; use pnet::packet::ipv6::Ipv6Packet;
use pnet::packet::{ use pnet::packet::{
ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, Packet as _, Packet as _, ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket,
}; };
use crate::common::scoped_task::ScopedTask; use crate::common::scoped_task::ScopedTask;
@@ -238,23 +238,23 @@ impl AclFilter {
chain_type: ChainType, chain_type: ChainType,
processor: &AclProcessor, processor: &AclProcessor,
) { ) {
if result.should_log { if result.should_log
if let Some(ref log_context) = result.log_context { && let Some(ref log_context) = result.log_context
let log_message = log_context.to_message(); {
tracing::info!( let log_message = log_context.to_message();
src_ip = %packet_info.src_ip, tracing::info!(
dst_ip = %packet_info.dst_ip, src_ip = %packet_info.src_ip,
src_port = packet_info.src_port, dst_ip = %packet_info.dst_ip,
dst_port = packet_info.dst_port, src_port = packet_info.src_port,
src_group = packet_info.src_groups.join(","), dst_port = packet_info.dst_port,
dst_group = packet_info.dst_groups.join(","), src_group = packet_info.src_groups.join(","),
protocol = ?packet_info.protocol, dst_group = packet_info.dst_groups.join(","),
action = ?result.action, protocol = ?packet_info.protocol,
rule = result.matched_rule_str().as_deref().unwrap_or("unknown"), action = ?result.action,
chain_type = ?chain_type, rule = result.matched_rule_str().as_deref().unwrap_or("unknown"),
"ACL: {}", log_message chain_type = ?chain_type,
); "ACL: {}", log_message
} );
} }
// Update global statistics in the ACL processor // Update global statistics in the ACL processor
+15 -14
View File
@@ -5,8 +5,8 @@ use std::{
time::{Duration, SystemTime, UNIX_EPOCH}, time::{Duration, SystemTime, UNIX_EPOCH},
}; };
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine; use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
@@ -62,10 +62,10 @@ impl CredentialManager {
.map(|x| x.trim().to_string()) .map(|x| x.trim().to_string())
.filter(|x| !x.is_empty()) .filter(|x| !x.is_empty())
{ {
if let Some(existing) = credentials.get(&id) { if let Some(existing) = credentials.get(&id)
if !existing.secret.is_empty() { && !existing.secret.is_empty()
return (id, existing.secret.clone()); {
} return (id, existing.secret.clone());
} }
id id
} else { } else {
@@ -191,10 +191,10 @@ impl CredentialManager {
return; return;
}; };
let creds = self.credentials.lock().unwrap(); let creds = self.credentials.lock().unwrap();
if let Ok(json) = serde_json::to_string_pretty(&*creds) { if let Ok(json) = serde_json::to_string_pretty(&*creds)
if let Err(e) = std::fs::write(path, json) { && let Err(e) = std::fs::write(path, json)
tracing::warn!(?e, "failed to save credentials to disk"); {
} tracing::warn!(?e, "failed to save credentials to disk");
} }
} }
@@ -386,11 +386,12 @@ mod tests {
); );
assert!(tc.credential.as_ref().unwrap().expiry_unix > 0); assert!(tc.credential.as_ref().unwrap().expiry_unix > 0);
assert!(tc.verify_credential_hmac("sec")); assert!(tc.verify_credential_hmac("sec"));
assert!(tc assert!(
.credential tc.credential
.as_ref() .as_ref()
.map(|x| !x.pubkey.is_empty()) .map(|x| !x.pubkey.is_empty())
.unwrap_or(false)); .unwrap_or(false)
);
let sk: [u8; 32] = BASE64_STANDARD.decode(&secret).unwrap().try_into().unwrap(); let sk: [u8; 32] = BASE64_STANDARD.decode(&secret).unwrap().try_into().unwrap();
let pk = PublicKey::from(&StaticSecret::from(sk)).as_bytes().to_vec(); let pk = PublicKey::from(&StaticSecret::from(sk)).as_bytes().to_vec();
+1 -1
View File
@@ -137,7 +137,7 @@ impl Encryptor for AesGcmCipher {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
peers::encrypt::{aes_gcm::AesGcmCipher, Encryptor}, peers::encrypt::{Encryptor, aes_gcm::AesGcmCipher},
tunnel::packet_def::{StandardAeadTail, ZCPacket}, tunnel::packet_def::{StandardAeadTail, ZCPacket},
}; };
use zerocopy::FromBytes; use zerocopy::FromBytes;
+1 -1
View File
@@ -155,7 +155,7 @@ impl Encryptor for RingCipher {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
peers::encrypt::{ring::RingCipher, Encryptor}, peers::encrypt::{Encryptor, ring::RingCipher},
tunnel::packet_def::{StandardAeadTail, ZCPacket}, tunnel::packet_def::{StandardAeadTail, ZCPacket},
}; };
use zerocopy::FromBytes; use zerocopy::FromBytes;
+1 -1
View File
@@ -61,7 +61,7 @@ impl Encryptor for XorCipher {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
peers::encrypt::{xor::XorCipher, Encryptor}, peers::encrypt::{Encryptor, xor::XorCipher},
tunnel::packet_def::ZCPacket, tunnel::packet_def::ZCPacket,
}; };
+2 -2
View File
@@ -1,11 +1,11 @@
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, PeerId}, common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
tunnel::packet_def::ZCPacket, tunnel::packet_def::ZCPacket,
}; };
use super::{peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager, PacketRecvChan}; use super::{PacketRecvChan, peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager};
pub struct ForeignNetworkClient { pub struct ForeignNetworkClient {
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,

Some files were not shown because too many files have changed in this diff Show More