找回密码
 立即注册
首页 业界区 业界 记录一下 简单udp和sni 代理 done

记录一下 简单udp和sni 代理 done

痨砖 2025-6-4 20:32:17
由于之前借鉴 Kestrel 了非常多抽象和优化实现,对于后续的扩展非常便利,
实现 简单udp和sni 代理 两个功能比预期快了超多(当然也有偷懒因素)
(PS 大家有空的话,能否在 GitHub https://github.com/fs7744/NZOrz 点个 star 呢?毕竟借鉴代码也不易呀 哈哈哈哈哈)
简单udp代理

这里的udp 代理功能比较简单:代理程序收到任何 udp 包都会通过路由匹配找 upstream ,然后转发给upstream
udp proxy 使用配置

基本格式和之前 tcp proxy 一致,
只是Protocols得选择UDP, 然后多了UdpResponses 允许 upstream 返回多少个 udp 包给请求者, 默认为0,即不返回任何包
  1. {
  2.   "Logging": {
  3.     "LogLevel": {
  4.       "Default": "Information"
  5.     }
  6.   },
  7.   "ReverseProxy": {
  8.     "Routes": {
  9.       "udpTest": {
  10.         "Protocols": [ "UDP" ],
  11.         "Match": {
  12.           "Hosts": [ "*:5000" ]
  13.         },
  14.         "ClusterId": "udpTest",
  15.         "RetryCount": 1,
  16.         "UdpResponses": 1,
  17.         "Timeout": "00:00:11"
  18.       }
  19.     },
  20.     "Clusters": {
  21.       "udpTest": {
  22.         "LoadBalancingPolicy": "RoundRobin",
  23.         "HealthCheck": {
  24.           "Passive": {
  25.             "Enable": true
  26.           }
  27.         },
  28.         "Destinations": [
  29.           {
  30.             "Address": "127.0.0.1:11000"
  31.           }
  32.         ]
  33.       }
  34.     }
  35.   }
  36. }
复制代码
实现

这里列举一下,表明有多简单
ps: 由于要实现的是非常简单udp代理,所以不基于IMultiplexedConnectionListener ,而基于 IConnectionListener 方式 (对,就是俺偷懒了)
1. 实现 UdpConnectionContext

偷懒就直接把udp 包数据放 context 上了,不放 Parameters 上,减少字典实例和内存使用
  1. public sealed class UdpConnectionContext : TransportConnection
  2. {
  3.     private readonly IMemoryOwner<byte> memory;
  4.     public Socket Socket { get; }
  5.     public int ReceivedBytesCount { get; }
  6.     public Memory<byte> ReceivedBytes => memory.Memory.Slice(0, ReceivedBytesCount);
  7.     public UdpConnectionContext(Socket socket, UdpReceiveFromResult result)
  8.     {
  9.         Socket = socket;
  10.         ReceivedBytesCount = result.ReceivedBytesCount;
  11.         this.memory = result.Buffer;
  12.         LocalEndPoint = socket.LocalEndPoint;
  13.         RemoteEndPoint = result.RemoteEndPoint;
  14.     }
  15.     public UdpConnectionContext(Socket socket, EndPoint remoteEndPoint, int receivedBytes, IMemoryOwner<byte> memory)
  16.     {
  17.         Socket = socket;
  18.         ReceivedBytesCount = receivedBytes;
  19.         this.memory = memory;
  20.         LocalEndPoint = socket.LocalEndPoint;
  21.         RemoteEndPoint = remoteEndPoint;
  22.     }
  23.     public override ValueTask DisposeAsync()
  24.     {
  25.         memory.Dispose();
  26.         return default;
  27.     }
  28. }
复制代码
2. 实现 IConnectionListener
  1. internal sealed class UdpConnectionListener : IConnectionListener
  2. {
  3.     private EndPoint? udpEndPoint;
  4.     private readonly GatewayProtocols protocols;
  5.     private OrzLogger _logger;
  6.     private readonly IUdpConnectionFactory connectionFactory;
  7.     private readonly Func<EndPoint, GatewayProtocols, Socket> createBoundListenSocket;
  8.     private Socket? _listenSocket;
  9.     public UdpConnectionListener(EndPoint? udpEndPoint, GatewayProtocols protocols, IRouteContractor contractor, OrzLogger logger, IUdpConnectionFactory connectionFactory)
  10.     {
  11.         this.udpEndPoint = udpEndPoint;
  12.         this.protocols = protocols;
  13.         _logger = logger;
  14.         this.connectionFactory = connectionFactory;
  15.         createBoundListenSocket = contractor.GetSocketTransportOptions().CreateBoundListenSocket;
  16.     }
  17.     public EndPoint EndPoint => udpEndPoint;
  18.     internal void Bind()
  19.     {
  20.         if (_listenSocket != null)
  21.         {
  22.             throw new InvalidOperationException("Transport is already bound.");
  23.         }
  24.         Socket listenSocket;
  25.         try
  26.         {
  27.             listenSocket = createBoundListenSocket(EndPoint, protocols);
  28.         }
  29.         catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)
  30.         {
  31.             throw new AddressInUseException(e.Message, e);
  32.         }
  33.         Debug.Assert(listenSocket.LocalEndPoint != null);
  34.         _listenSocket = listenSocket;
  35.     }
  36.     public async ValueTask<ConnectionContext?> AcceptAsync(CancellationToken cancellationToken = default)
  37.     {
  38.         while (true)
  39.         {
  40.             try
  41.             {
  42.                 Debug.Assert(_listenSocket != null, "Bind must be called first.");
  43.                 var r = await connectionFactory.ReceiveAsync(_listenSocket, cancellationToken);
  44.                 return new UdpConnectionContext(_listenSocket, r);
  45.             }
  46.             catch (ObjectDisposedException)
  47.             {
  48.                 // A call was made to UnbindAsync/DisposeAsync just return null which signals we're done
  49.                 return null;
  50.             }
  51.             catch (SocketException e) when (e.SocketErrorCode == SocketError.OperationAborted)
  52.             {
  53.                 // A call was made to UnbindAsync/DisposeAsync just return null which signals we're done
  54.                 return null;
  55.             }
  56.             catch (SocketException)
  57.             {
  58.                 // The connection got reset while it was in the backlog, so we try again.
  59.                 _logger.ConnectionReset("(null)");
  60.             }
  61.         }
  62.     }
  63.     public ValueTask DisposeAsync()
  64.     {
  65.         _listenSocket?.Dispose();
  66.         return default;
  67.     }
  68.     public ValueTask UnbindAsync(CancellationToken cancellationToken = default)
  69.     {
  70.         _listenSocket?.Dispose();
  71.         return default;
  72.     }
  73. }
复制代码
3. 实现 IConnectionListenerFactory
  1. public sealed class UdpTransportFactory : IConnectionListenerFactory, IConnectionListenerFactorySelector
  2. {
  3.     private readonly IRouteContractor contractor;
  4.     private readonly OrzLogger logger;
  5.     private readonly IUdpConnectionFactory connectionFactory;
  6.     public UdpTransportFactory(
  7.         IRouteContractor contractor,
  8.         OrzLogger logger,
  9.         IUdpConnectionFactory connectionFactory)
  10.     {
  11.         ArgumentNullException.ThrowIfNull(contractor);
  12.         ArgumentNullException.ThrowIfNull(logger);
  13.         this.contractor = contractor;
  14.         this.logger = logger;
  15.         this.connectionFactory = connectionFactory;
  16.     }
  17.     public ValueTask<IConnectionListener> BindAsync(EndPoint endpoint, GatewayProtocols protocols, CancellationToken cancellationToken = default)
  18.     {
  19.         var transport = new UdpConnectionListener(endpoint, GatewayProtocols.UDP, contractor, logger, connectionFactory);
  20.         transport.Bind();
  21.         return new ValueTask<IConnectionListener>(transport);
  22.     }
  23.     public bool CanBind(EndPoint endpoint, GatewayProtocols protocols)
  24.     {
  25.         if (!protocols.HasFlag(GatewayProtocols.UDP)) return false;
  26.         return endpoint switch
  27.         {
  28.             IPEndPoint _ => true,
  29.             _ => false
  30.         };
  31.     }
  32. }
复制代码
4. 在 L4ProxyMiddleware 实现udp proxy 具体逻辑

路由和之前tcp的公用,这里就不列举了
  1. public class L4ProxyMiddleware : IOrderMiddleware
  2. {   
  3.     public async Task Invoke(ConnectionContext context, ConnectionDelegate next)
  4.     {
  5.         try
  6.         {
  7.             if (context.Protocols == GatewayProtocols.SNI)
  8.             {
  9.                 await SNIProxyAsync(context);
  10.             }
  11.             else
  12.             {
  13.                 var route = await router.MatchAsync(context);
  14.                 if (route is null)
  15.                 {
  16.                     logger.NotFoundRouteL4(context.LocalEndPoint);
  17.                 }
  18.                 else
  19.                 {
  20.                     context.Route = route;
  21.                     logger.ProxyBegin(route.RouteId);
  22.                     if (context.Protocols == GatewayProtocols.TCP)
  23.                     {
  24.                         await TcpProxyAsync(context, route);
  25.                     }
  26.                     else
  27.                     {
  28.                         await UdpProxyAsync((UdpConnectionContext)context, route);
  29.                     }
  30.                     logger.ProxyEnd(route.RouteId);
  31.                 }
  32.             }
  33.         }
  34.         catch (Exception ex)
  35.         {
  36.             logger.UnexpectedException(ex.Message, ex);
  37.         }
  38.         finally
  39.         {
  40.             await next(context);
  41.         }
  42.     }
  43.     private async Task UdpProxyAsync(UdpConnectionContext context, RouteConfig route)
  44.     {
  45.         try
  46.         {
  47.             var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  48.             var cts = route.CreateTimeoutTokenSource(cancellationTokenSourcePool);
  49.             var token = cts.Token;
  50.             if (await DoUdpSendToAsync(socket, context, route, route.RetryCount, await reqUdp(context, context.ReceivedBytes, token), token))
  51.             {
  52.                 var c = route.UdpResponses;
  53.                 while (c > 0)
  54.                 {
  55.                     var r = await udp.ReceiveAsync(socket, token);
  56.                     c--;
  57.                     await udp.SendToAsync(context.Socket, context.RemoteEndPoint, await respUdp(context, r.GetReceivedBytes(), token), token);
  58.                 }
  59.             }
  60.             else
  61.             {
  62.                 logger.NotFoundAvailableUpstream(route.ClusterId);
  63.             }
  64.         }
  65.         catch (OperationCanceledException)
  66.         {
  67.             logger.ConnectUpstreamTimeout(route.RouteId);
  68.         }
  69.         catch (Exception ex)
  70.         {
  71.             logger.UnexpectedException(nameof(UdpProxyAsync), ex);
  72.         }
  73.         finally
  74.         {
  75.             context.SelectedDestination?.ConcurrencyCounter.Decrement();
  76.         }
  77.     }
复制代码
所以是不是真的简单, 理论上基于 Kestrel 也是一个样子哦
优化

当然参考于 Kestrel 的 tcp socket 处理,也是有些简单优化的, 比如

  • 不使用 UdpClient (ps 不是因为实现烂哈,而是其比较公用,没有机会让我们改变里面的内容)
  • 基于 SocketAsyncEventArgs, IValueTaskSource 和 SocketAsyncEventArgs, IValueTaskSource 实现 将异步读写交予 PipeScheduler 的逻辑
  • 基于 ConcurrentQueue 实现简单的 udp发送对象池,加强对象复用,稍稍稍微减少内存占用
  • 基于 ConcurrentQueue 实现简单的 CancellationTokenSource对象池,加强对象复用,稍稍稍微减少内存占用
sni代理

除了 tcp 和 udp 的基本代理, 也尝试实现了一个 对tcp的 sni 代理,(比如 http1 和 http2 的 https)
不过目前只实现了代理不做ssl加密解密,upstream自己处理的pass 模式,如果代理要实现ssl加密解密,理论上基于现成的 sslstream
sni proxy 使用配置

只需配置Listen 中 公用的 sni 监听端口
然后不同 sni 配置自己的路由和upstream就好
同时每个route 可以通过SupportSslProtocols限制 tls 版本
举个栗子
  1. {
  2.   "Logging": {
  3.     "LogLevel": {
  4.       "Default": "Information"
  5.     }
  6.   },
  7.   "ReverseProxy": {
  8.     "Listen": {
  9.       "snitest": {
  10.         "Protocols": "SNI",
  11.         "Address": [ "*:444" ]
  12.       }
  13.     },
  14.     "Routes": {
  15.       "snitestroute": {
  16.         "Protocols": "SNI",
  17.         "SupportSslProtocols": [ "Tls13", "Tls12" ],
  18.         "Match": {
  19.           "Hosts": [ "*google.com" ]
  20.         },
  21.         "ClusterId": "apidemo"
  22.       },
  23.       "snitestroute2": {
  24.         "Protocols": "Tcp",
  25.         "Match": {
  26.           "Hosts": [ "*:448" ]
  27.         },
  28.         "ClusterId": "apidemo"
  29.       }
  30.     },
  31.     "Clusters": {
  32.       "apidemo": {
  33.         "LoadBalancingPolicy": "RoundRobin",
  34.         "HealthCheck": {
  35.           "Active": {
  36.             "Enable": true,
  37.             "Policy": "Connect"
  38.           }
  39.         },
  40.         "Destinations": [
  41.           {
  42.             "Address": "https://www.google.com"
  43.           }
  44.         ]
  45.       }
  46.     }
  47.   }
  48. }
复制代码
实现

核心实现其实只有 路由 处理 ,proxy 代理和 tcp 代理一模一样(在请求 和 upstream 间搬运 tcp数据而已)
路由处理

通过 ClientHello 找到要访问的 域名, 然后通过域名匹配路由找到 upstream, 最后搬运 tcp数据
ClientHello 解析就直接搬运自TlsFrameHelper
  1.     /// 路由匹配
  2.     public async ValueTask<(RouteConfig, ReadResult)> MatchSNIAsync(ConnectionContext context, CancellationToken token)
  3.     {
  4.         if (sniRoute is null) return (null, default);
  5.         var (hello, rr) = await TryGetClientHelloAsync(context, token);
  6.         if (hello.HasValue)
  7.         {
  8.             var h = hello.Value;
  9.             var r = await sniRoute.MatchAsync(h.TargetName.Reverse(), h, MatchSNI);
  10.             if (r is null)
  11.             {
  12.                 logger.NotFoundRouteSni(h.TargetName);
  13.             }
  14.             return (r, rr);
  15.         }
  16.         else
  17.         {
  18.             logger.NotFoundRouteSni("client hello failed");
  19.             return (null, rr);
  20.         }
  21.     }
  22.     /// 匹配 tls 版本
  23.     private bool MatchSNI(RouteConfig config, TlsFrameInfo info)
  24.     {
  25.         if (!config.SupportSslProtocols.HasValue) return true;
  26.         var v = config.SupportSslProtocols.Value;
  27.         if (v == SslProtocols.None) return true;
  28.         var t = info.SupportedVersions;
  29.         if (v.HasFlag(SslProtocols.Tls13) && t.HasFlag(SslProtocols.Tls13)) return true;
  30.         else if (v.HasFlag(SslProtocols.Tls12) && t.HasFlag(SslProtocols.Tls12)) return true;
  31.         else if (v.HasFlag(SslProtocols.Tls11) && t.HasFlag(SslProtocols.Tls11)) return true;
  32.         else if (v.HasFlag(SslProtocols.Tls) && t.HasFlag(SslProtocols.Tls)) return true;
  33.         else if (v.HasFlag(SslProtocols.Ssl3) && t.HasFlag(SslProtocols.Ssl3)) return true;
  34.         else if (v.HasFlag(SslProtocols.Ssl2) && t.HasFlag(SslProtocols.Ssl2)) return true;
  35.         else if (v.HasFlag(SslProtocols.Default) && t.HasFlag(SslProtocols.Default)) return true;
  36.         else return false;
  37.     }
  38.     /// 解析 ClientHello
  39.     private static async ValueTask<(TlsFrameInfo?, ReadResult)> TryGetClientHelloAsync(ConnectionContext context, CancellationToken token)
  40.     {
  41.         var input = context.Transport.Input;
  42.         TlsFrameInfo info = default;
  43.         while (true)
  44.         {
  45.             var f = await input.ReadAsync(token).ConfigureAwait(false);
  46.             if (f.IsCompleted)
  47.             {
  48.                 return (null, f);
  49.             }
  50.             var buffer = f.Buffer;
  51.             if (buffer.Length == 0)
  52.             {
  53.                 continue;
  54.             }
  55.             var data = buffer.IsSingleSegment ? buffer.First.Span : buffer.ToArray();
  56.             if (TlsFrameHelper.TryGetFrameInfo(data, ref info))
  57.             {
  58.                 return (info, f);
  59.             }
  60.             else
  61.             {
  62.                 input.AdvanceTo(buffer.Start, buffer.End);
  63.                 continue;
  64.             }
  65.         }
  66.     }
复制代码
搬运 tcp数据
  1. private async Task SNIProxyAsync(ConnectionContext context)
  2. {
  3.     var c = cancellationTokenSourcePool.Rent();
  4.     c.CancelAfter(options.ConnectionTimeout);
  5.     var (route, r) = await router.MatchSNIAsync(context, c.Token);
  6.     if (route is not null)
  7.     {
  8.         context.Route = route;
  9.         logger.ProxyBegin(route.RouteId);
  10.         ConnectionContext upstream = null;
  11.         try
  12.         {
  13.             upstream = await DoConnectionAsync(context, route, route.RetryCount);
  14.             if (upstream is null)
  15.             {
  16.                 logger.NotFoundAvailableUpstream(route.ClusterId);
  17.             }
  18.             else
  19.             {
  20.                 context.SelectedDestination?.ConcurrencyCounter.Increment();
  21.                 var cts = route.CreateTimeoutTokenSource(cancellationTokenSourcePool);
  22.                 var t = cts.Token;
  23.                 await r.CopyToAsync(upstream.Transport.Output, t); // 和tcp 代理搬运数据唯一不同, 要先发送 ClientHello 数据,因为已经被我们读取了
  24.                 context.Transport.Input.AdvanceTo(r.Buffer.End);
  25.                 var task = hasMiddlewareTcp ?
  26.                         await Task.WhenAny(
  27.                         context.Transport.Input.CopyToAsync(new MiddlewarePipeWriter(upstream.Transport.Output, context, reqTcp), t)
  28.                         , upstream.Transport.Input.CopyToAsync(new MiddlewarePipeWriter(context.Transport.Output, context, respTcp), t))
  29.                         : await Task.WhenAny(
  30.                         context.Transport.Input.CopyToAsync(upstream.Transport.Output, t)
  31.                         , upstream.Transport.Input.CopyToAsync(context.Transport.Output, t));
  32.                 if (task.IsCanceled)
  33.                 {
  34.                     logger.ProxyTimeout(route.RouteId, route.Timeout);
  35.                 }
  36.             }
  37.         }
  38.         catch (OperationCanceledException)
  39.         {
  40.             logger.ConnectUpstreamTimeout(route.RouteId);
  41.         }
  42.         catch (Exception ex)
  43.         {
  44.             logger.UnexpectedException(nameof(TcpProxyAsync), ex);
  45.         }
  46.         finally
  47.         {
  48.             context.SelectedDestination?.ConcurrencyCounter.Decrement();
  49.             upstream?.Abort();
  50.         }
  51.         logger.ProxyEnd(route.RouteId);
  52.     }
  53. }
复制代码
组件各部分都是可替换或者可增加的

因为整体都是基于ioc的,所以组件各部分都是可替换或者可增加的, 客制化扩展还是很高的哦
目前暴露的列表可在 代码这里面查看
  1. internal static HostApplicationBuilder UseOrzDefaults(this HostApplicationBuilder builder)
  2. {
  3.     var services = builder.Services;
  4.     services.AddSingleton<IHostedService, HostedService>();
  5.     services.AddSingleton(TimeProvider.System);
  6.     services.AddSingleton<IMeterFactory, DummyMeterFactory>();
  7.     services.AddSingleton<IServer, OrzServer>();
  8.     services.AddSingleton<OrzLogger>();
  9.     services.AddSingleton<OrzMetrics>();
  10.     services.AddSingleton<IConnectionListenerFactory, SocketTransportFactory>();
  11.     services.AddSingleton<IConnectionListenerFactory, UdpTransportFactory>();
  12.     services.AddSingleton<IUdpConnectionFactory, UdpConnectionFactory>();
  13.     services.AddSingleton<IConnectionFactory, SocketConnectionFactory>();
  14.     services.AddSingleton<IRouteContractorValidator, RouteContractorValidator>();
  15.     services.AddSingleton<IEndPointConvertor, CommonEndPointConvertor>();
  16.     services.AddSingleton<IL4Router, L4Router>();
  17.     services.AddSingleton<IOrderMiddleware, L4ProxyMiddleware>();
  18.     services.AddSingleton<ILoadBalancingPolicyFactory, LoadBalancingPolicy>();
  19.     services.AddSingleton<IClusterConfigValidator, ClusterConfigValidator>();
  20.     services.AddSingleton<IDestinationResolver, DnsDestinationResolver>();
  21.     services.AddSingleton<ILoadBalancingPolicy, RandomLoadBalancingPolicy>();
  22.     services.AddSingleton<ILoadBalancingPolicy, RoundRobinLoadBalancingPolicy>();
  23.     services.AddSingleton<ILoadBalancingPolicy, LeastRequestsLoadBalancingPolicy>();
  24.     services.AddSingleton<ILoadBalancingPolicy, PowerOfTwoChoicesLoadBalancingPolicy>();
  25.     services.AddSingleton<IHealthReporter, PassiveHealthReporter>();
  26.     services.AddSingleton<IHealthUpdater, HealthyAndUnknownDestinationsUpdater>();
  27.     services.AddSingleton<IActiveHealthCheckMonitor, ActiveHealthCheckMonitor>();
  28.     services.AddSingleton<IActiveHealthChecker, ConnectionActiveHealthChecker>();
  29.     return builder;
  30. }
复制代码
比如要添加 负载均衡策略,就可以实现
  1. public interface ILoadBalancingPolicy
  2. {
  3.     string Name { get; }
  4.     DestinationState? PickDestination(ConnectionContext context, IReadOnlyList<DestinationState> availableDestinations);
  5. }
复制代码
如果对全部已有负载均衡策略都不满意,那就可以直接替换 ILoadBalancingPolicyFactory
  1. public interface ILoadBalancingPolicyFactory
  2. {
  3.     DestinationState? PickDestination(ConnectionContext context, RouteConfig route);
  4. }
复制代码
比如你就可以通过sni将开发环境(或者其他环境)无法访问的请求在一台有其他访问权限的机器进行转发
差不多就做了这些,造轮子还是挺好玩的,当然大家如果在 GitHub https://github.com/fs7744/NZOrz 点个 star, 就更好玩了

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册