diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index 615579d65..667a5161e 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -6,44 +6,18 @@ use crate::{ }; use anyhow::Context; use std::net::SocketAddr; -use std::net::ToSocketAddrs; +use tokio::net::ToSocketAddrs; use tokio_socks::{IntoTargetAddr, TargetAddr}; -fn to_socket_addr(host: T) -> ResultType { - let mut addr_ipv4 = None; - let mut addr_ipv6 = None; - for addr in host.to_socket_addrs()? { - if addr.is_ipv4() && addr_ipv4.is_none() { - addr_ipv4 = Some(addr); - } - if addr.is_ipv6() && addr_ipv6.is_none() { - addr_ipv6 = Some(addr); - } - } - if let Some(addr) = addr_ipv4 { - Ok(addr) - } else { - addr_ipv6.context("Failed to solve") - } -} - -pub fn get_target_addr(host: &str) -> ResultType> { - let addr = match Config::get_network_type() { - NetworkType::Direct => to_socket_addr(&host)?.into_target_addr()?, - NetworkType::ProxySocks => host.into_target_addr()?, - } - .to_owned(); - Ok(addr) -} - pub fn test_if_valid_server(host: &str) -> String { let mut host = host.to_owned(); if !host.contains(":") { host = format!("{}:{}", host, 0); } + use std::net::ToSocketAddrs; match Config::get_network_type() { - NetworkType::Direct => match to_socket_addr(&host) { + NetworkType::Direct => match host.to_socket_addrs() { Err(err) => err.to_string(), Ok(_) => "".to_owned(), }, @@ -54,56 +28,51 @@ pub fn test_if_valid_server(host: &str) -> String { } } -pub trait IntoTargetAddr2<'a> { - /// Converts the value of self to a `TargetAddr`. - fn into_target_addr2(&self) -> ResultType>; +pub trait IsResolvedSocketAddr { + fn resolve(&self) -> Option<&SocketAddr>; } -impl<'a> IntoTargetAddr2<'a> for SocketAddr { - fn into_target_addr2(&self) -> ResultType> { - Ok(TargetAddr::Ip(*self)) +impl IsResolvedSocketAddr for SocketAddr { + fn resolve(&self) -> Option<&SocketAddr> { + Some(&self) } } -impl<'a> IntoTargetAddr2<'a> for TargetAddr<'a> { - fn into_target_addr2(&self) -> ResultType> { - Ok(self.clone()) +impl IsResolvedSocketAddr for String { + fn resolve(&self) -> Option<&SocketAddr> { + None } } -impl<'a> IntoTargetAddr2<'a> for String { - fn into_target_addr2(&self) -> ResultType> { - Ok(to_socket_addr(self)?.into_target_addr()?) +impl IsResolvedSocketAddr for &str { + fn resolve(&self) -> Option<&SocketAddr> { + None } } -impl<'a> IntoTargetAddr2<'a> for &str { - fn into_target_addr2(&self) -> ResultType> { - Ok(to_socket_addr(self)?.into_target_addr()?) - } -} - -pub async fn connect_tcp<'t, T: IntoTargetAddr2<'t> + std::fmt::Debug>( +#[inline] +pub async fn connect_tcp< + 't, + T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display, +>( target: T, ms_timeout: u64, ) -> ResultType { - let target_addr = target.into_target_addr2()?; - let local = Config::get_any_listen_addr(is_ipv4(&target_addr)); - connect_tcp_local(target_addr, local, ms_timeout) - .await - .context(format!("Invalid target addr: {:?}", target)) + connect_tcp_local(target, None, ms_timeout).await } -pub async fn connect_tcp_local<'t, T: IntoTargetAddr<'t> + std::fmt::Debug>( +pub async fn connect_tcp_local< + 't, + T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display, +>( target: T, - local: SocketAddr, + local: Option, ms_timeout: u64, ) -> ResultType { - let target_addr = target.into_target_addr()?; if let Some(conf) = Config::get_socks() { return FramedStream::connect( conf.proxy.as_str(), - target_addr, + target, local, conf.username.as_str(), conf.password.as_str(), @@ -111,13 +80,15 @@ pub async fn connect_tcp_local<'t, T: IntoTargetAddr<'t> + std::fmt::Debug>( ) .await; } - let mut addr = ToSocketAddrs::to_socket_addrs(&target_addr)? - .next() - .context(format!("Invalid target addr: {:?}", target_addr))?; - if local.is_ipv6() && addr.is_ipv4() { - addr = query_nip_io(&addr)?; + if let Some(target) = target.resolve() { + if let Some(local) = local { + if local.is_ipv6() && target.is_ipv4() { + let target = query_nip_io(&target).await?; + return Ok(FramedStream::new(target, Some(local), ms_timeout).await?); + } + } } - Ok(FramedStream::new(addr, local, ms_timeout).await?) + Ok(FramedStream::new(target, local, ms_timeout).await?) } #[inline] @@ -129,8 +100,12 @@ pub fn is_ipv4(target: &TargetAddr<'_>) -> bool { } #[inline] -pub fn query_nip_io(addr: &SocketAddr) -> ResultType { - to_socket_addr(format!("{}.nip.io:{}", addr.ip(), addr.port())) +pub async fn query_nip_io(addr: &SocketAddr) -> ResultType { + tokio::net::lookup_host(format!("{}.nip.io:{}", addr.ip(), addr.port())) + .await? + .filter(|x| x.is_ipv6()) + .next() + .context("Failed to get ipv6 from nip.io") } #[inline] @@ -143,17 +118,29 @@ pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String { addr } -pub async fn new_udp_for(target: &TargetAddr<'_>, ms_timeout: u64) -> ResultType { - new_udp(Config::get_any_listen_addr(is_ipv4(target)), ms_timeout).await +async fn test_is_ipv4(target: &str) -> bool { + if let Ok(Ok(s)) = super::timeout(1000, tokio::net::TcpStream::connect(target)).await { + return s.local_addr().map(|x| x.is_ipv4()).unwrap_or(true); + } + true +} + +#[inline] +pub async fn new_udp_for(target: &str, ms_timeout: u64) -> ResultType { + new_udp( + Config::get_any_listen_addr(test_is_ipv4(target).await), + ms_timeout, + ) + .await } async fn new_udp(local: T, ms_timeout: u64) -> ResultType { match Config::get_socks() { - None => Ok(FramedSocket::new(to_socket_addr(&local)?).await?), + None => Ok(FramedSocket::new(local).await?), Some(conf) => { let socket = FramedSocket::new_proxy( conf.proxy.as_str(), - to_socket_addr(local)?, + local, conf.username.as_str(), conf.password.as_str(), ms_timeout, @@ -164,10 +151,10 @@ async fn new_udp(local: T, ms_timeout: u64) -> ResultType) -> ResultType> { +pub async fn rebind_udp_for(target: &str) -> ResultType> { match Config::get_network_type() { NetworkType::Direct => Ok(Some( - FramedSocket::new(Config::get_any_listen_addr(is_ipv4(target))).await?, + FramedSocket::new(Config::get_any_listen_addr(test_is_ipv4(target).await)).await?, )), _ => Ok(None), } @@ -175,19 +162,17 @@ pub async fn rebind_udp_for(target: &TargetAddr<'_>) -> ResultType Result( - remote_addr: T1, - local_addr: T2, + pub async fn new( + remote_addr: T, + local_addr: Option, ms_timeout: u64, ) -> ResultType { - for local_addr in lookup_host(&local_addr).await? { - for remote_addr in lookup_host(&remote_addr).await? { - let stream = super::timeout( - ms_timeout, - new_socket(local_addr, true)?.connect(remote_addr), - ) - .await??; - stream.set_nodelay(true).ok(); - let addr = stream.local_addr()?; - return Ok(Self( - Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), - addr, - None, - 0, - )); + for remote_addr in lookup_host(&remote_addr).await? { + let local = if let Some(addr) = local_addr { + addr + } else { + crate::config::Config::get_any_listen_addr(remote_addr.is_ipv4()) + }; + if let Ok(socket) = new_socket(local, true) { + if let Ok(Ok(stream)) = + super::timeout(ms_timeout, socket.connect(remote_addr)).await + { + stream.set_nodelay(true).ok(); + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + } } } - bail!("could not resolve to any address"); + bail!(format!("Failed to connect to {}", remote_addr)); } - pub async fn connect<'a, 't, P, T1, T2>( + pub async fn connect<'a, 't, P, T>( proxy: P, - target: T1, - local: T2, + target: T, + local_addr: Option, username: &'a str, password: &'a str, ms_timeout: u64, ) -> ResultType where P: ToProxyAddrs, - T1: IntoTargetAddr<'t>, - T2: ToSocketAddrs, + T: IntoTargetAddr<'t>, { - if let Some(local) = lookup_host(&local).await?.next() { - if let Some(proxy) = proxy.to_proxy_addrs().next().await { - let stream = - super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy?)).await??; - stream.set_nodelay(true).ok(); - let stream = if username.trim().is_empty() { - super::timeout( - ms_timeout, - Socks5Stream::connect_with_socket(stream, target), - ) - .await?? - } else { - super::timeout( - ms_timeout, - Socks5Stream::connect_with_password_and_socket( - stream, target, username, password, - ), - ) - .await?? - }; - let addr = stream.local_addr()?; - return Ok(Self( - Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), - addr, - None, - 0, - )); + if let Some(Ok(proxy)) = proxy.to_proxy_addrs().next().await { + let local = if let Some(addr) = local_addr { + addr + } else { + crate::config::Config::get_any_listen_addr(proxy.is_ipv4()) }; - }; + let stream = + super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy)).await??; + stream.set_nodelay(true).ok(); + let stream = if username.trim().is_empty() { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_socket(stream, target), + ) + .await?? + } else { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_password_and_socket( + stream, target, username, password, + ), + ) + .await?? + }; + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + } bail!("could not resolve to any address"); } diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs index 1f5bf2637..38121a4e1 100644 --- a/libs/hbb_common/src/udp.rs +++ b/libs/hbb_common/src/udp.rs @@ -164,4 +164,13 @@ impl FramedSocket { None } } + + pub fn is_ipv4(&self) -> bool { + if let FramedSocket::Direct(x) = self { + if let Ok(v) = x.get_ref().local_addr() { + return v.is_ipv4(); + } + } + true + } } diff --git a/src/cli.rs b/src/cli.rs index b4f552d5f..2b2cae320 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -102,6 +102,8 @@ pub async fn connect_test( handler, ).await { log::error!("Failed to connect {}: {}", &id, err); + } else { + // rpassword::prompt_password("Input anything to exit").ok(); } } diff --git a/src/client.rs b/src/client.rs index 1c6f63f36..fe9d9dac0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -376,7 +376,7 @@ impl Client { log::info!("peer address: {}, timeout: {}", peer, connect_timeout); let start = std::time::Instant::now(); // NOTICE: Socks5 is be used event in intranet. Which may be not a good way. - let mut conn = socket_client::connect_tcp_local(peer, local_addr, connect_timeout).await; + let mut conn = socket_client::connect_tcp_local(peer, Some(local_addr), connect_timeout).await; let mut direct = !conn.is_err(); if interface.is_force_relay() || conn.is_err() { if !relay_server.is_empty() { diff --git a/src/common.rs b/src/common.rs index f83fbc69d..1ae9b7dbe 100644 --- a/src/common.rs +++ b/src/common.rs @@ -310,15 +310,9 @@ async fn test_nat_type_() -> ResultType { }); let mut port1 = 0; let mut port2 = 0; - let server1 = socket_client::get_target_addr(&server1)?; - let server2 = socket_client::get_target_addr(&server2)?; for i in 0..2 { let mut socket = socket_client::connect_tcp( - if i == 0 { - server1.clone() - } else { - server2.clone() - }, + if i == 0 { &*server1 } else { &*server2 }, RENDEZVOUS_TIMEOUT, ) .await?; @@ -525,8 +519,7 @@ pub fn check_software_update() { async fn check_software_update_() -> hbb_common::ResultType<()> { sleep(3.).await; - let rendezvous_server = - socket_client::get_target_addr(&format!("rs-sg.rustdesk.com:{}", config::RENDEZVOUS_PORT))?; + let rendezvous_server = format!("rs-sg.rustdesk.com:{}", config::RENDEZVOUS_PORT); let mut socket = socket_client::new_udp_for(&rendezvous_server, RENDEZVOUS_TIMEOUT).await?; let mut msg_out = RendezvousMessage::new(); diff --git a/src/rendezvous_mediator.rs b/src/rendezvous_mediator.rs index f962810a2..ca08172d3 100644 --- a/src/rendezvous_mediator.rs +++ b/src/rendezvous_mediator.rs @@ -38,10 +38,11 @@ static SHOULD_EXIT: AtomicBool = AtomicBool::new(false); #[derive(Clone)] pub struct RendezvousMediator { - addr: TargetAddr<'static>, + addr: String, host: String, host_prefix: String, last_id_pk_registry: String, + is_ipv4: bool, } impl RendezvousMediator { @@ -111,13 +112,15 @@ impl RendezvousMediator { }) .unwrap_or(host.to_owned()); let mut rz = Self { - addr: socket_client::get_target_addr(&crate::check_port(&host, RENDEZVOUS_PORT))?, + addr: crate::check_port(&host, RENDEZVOUS_PORT), + is_ipv4: false, host: host.clone(), host_prefix, last_id_pk_registry: "".to_owned(), }; let mut socket = socket_client::new_udp_for(&rz.addr, RENDEZVOUS_TIMEOUT).await?; + rz.is_ipv4 = socket.is_ipv4(); const TIMER_OUT: Duration = Duration::from_secs(1); let mut timer = interval(TIMER_OUT); @@ -248,11 +251,11 @@ impl RendezvousMediator { Config::update_latency(&host, -1); old_latency = 0; if last_dns_check.elapsed().as_millis() as i64 > DNS_INTERVAL { - rz.addr = socket_client::get_target_addr(&crate::check_port(&host, RENDEZVOUS_PORT))?; // in some case of network reconnect (dial IP network), // old UDP socket not work any more after network recover if let Some(s) = socket_client::rebind_udp_for(&rz.addr).await? { socket = s; + rz.is_ipv4 = socket.is_ipv4(); } last_dns_check = Instant::now(); } @@ -314,14 +317,14 @@ impl RendezvousMediator { } msg_out.set_relay_response(rr); socket.send(&msg_out).await?; - let v4 = socket_client::is_ipv4(&self.addr); - crate::create_relay_connection(server, relay_server, uuid, peer_addr, secure, v4).await; + crate::create_relay_connection(server, relay_server, uuid, peer_addr, secure, self.is_ipv4) + .await; Ok(()) } async fn handle_intranet(&self, fla: FetchLocalAddr, server: ServerPtr) -> ResultType<()> { let relay_server = self.get_relay_server(fla.relay_server); - if !socket_client::is_ipv4(&self.addr) { + if !self.is_ipv4 { // nat64, go relay directly, because current hbbs will crash if demangle ipv6 address let uuid = Uuid::new_v4().to_string(); return self @@ -382,7 +385,7 @@ impl RendezvousMediator { let local_addr = socket.local_addr(); // key important here for punch hole to tell my gateway incoming peer is safe. // it can not be async here, because local_addr can not be reused, we must close the connection before use it again. - allow_err!(socket_client::connect_tcp_local(peer_addr, local_addr, 30).await); + allow_err!(socket_client::connect_tcp_local(peer_addr, Some(local_addr), 30).await); socket }; let mut msg_out = Message::new(); @@ -655,8 +658,7 @@ async fn create_online_stream() -> ResultType { bail!("Invalid server address: {}", rendezvous_server); } let online_server = format!("{}:{}", tmp[0], port - 1); - let server_addr = socket_client::get_target_addr(&online_server)?; - socket_client::connect_tcp(server_addr, RENDEZVOUS_TIMEOUT).await + socket_client::connect_tcp(online_server, RENDEZVOUS_TIMEOUT).await } async fn query_online_states_(