From c7b2689ae1df675f1828e83cad6549c41e1ccb87 Mon Sep 17 00:00:00 2001 From: Laci0503 Date: Tue, 23 Jul 2024 01:20:30 +0200 Subject: [PATCH] Split the gateway into two classes --- Discord.API/Gateway/AbstractGateway.cs | 99 ++++++++++++++++++++++++++ Discord.API/Gateway/GatewayClient.cs | 77 +++++++------------- 2 files changed, 125 insertions(+), 51 deletions(-) create mode 100644 Discord.API/Gateway/AbstractGateway.cs diff --git a/Discord.API/Gateway/AbstractGateway.cs b/Discord.API/Gateway/AbstractGateway.cs new file mode 100644 index 0000000..ec94772 --- /dev/null +++ b/Discord.API/Gateway/AbstractGateway.cs @@ -0,0 +1,99 @@ +using System.Reactive.Linq; +using Serilog; +using Websocket.Client; + +namespace Discord.API; + +public abstract class AbstractGateway { + protected readonly IWebsocketClient WebsocketClient; + protected readonly TimeProvider TimeProvider; + + public AbstractGateway(IWebsocketClient client, TimeProvider time_provider){ + this.WebsocketClient = client; + this.TimeProvider = time_provider; + WebsocketClient.DisconnectionHappened.Subscribe(DisconnectHandlerInternal); + WebsocketClient.ReconnectionHappened.Subscribe(ReconnectHandler); + WebsocketClient.MessageReceived.Subscribe(MessageReceivedHandler); + WebsocketClient.Start(); + } + + private void DisconnectHandlerInternal(DisconnectionInfo info){ + StopHeartbeat(); + DisconnectHandler(info); + } + + protected abstract void DisconnectHandler(DisconnectionInfo info); + + protected abstract void ReconnectHandler(ReconnectionInfo info); + + protected abstract void MessageReceivedHandler(ResponseMessage message); + + #region Heartbeat + + private CancellationTokenSource? HeartbeatCts; + private CancellationTokenSource? InstantHeartbeatCts; + private DateTime? HeartbeatSent; + protected int HeartbeatPing {get; private set;} + private bool HeartbeatAcked; + + /// + /// Sends a heartbeat packet to the remote + /// + protected abstract Task SendHeartbeat(); + /// + /// Handler for when the heartbeat ack was not received + /// + /// Indicates if the heartbeat loop should exit + protected abstract Task MissingHeartbeatAckHandler(); + + private async Task HeartbeatTask(int heartbeat_interval){ + CancellationToken ct = HeartbeatCts!.Token; + InstantHeartbeatCts = new(); + HeartbeatSent = null; + await Task.Delay( + TimeSpan.FromMilliseconds(Random.Shared.Next(1, heartbeat_interval)), + TimeProvider, + InstantHeartbeatCts.Token); + + if(ct.IsCancellationRequested) return; // Check if the task was cancelled when we were waiting + + using PeriodicTimer pd = new(TimeSpan.FromMilliseconds(heartbeat_interval), TimeProvider); + HeartbeatAcked = true; + do{ + if(!HeartbeatAcked){ + if(await MissingHeartbeatAckHandler()){ + break; + } + } + await SendHeartbeat(); + HeartbeatSent = TimeProvider.GetUtcNow().DateTime; + HeartbeatAcked = false; + + // Create new InstantHeartbeatCts if it was used + if(InstantHeartbeatCts.IsCancellationRequested) InstantHeartbeatCts = new(); + } while(await pd.WaitForNextTickAsync(InstantHeartbeatCts.Token) && !ct.IsCancellationRequested); + } + + protected int HeartbeatAckReceived(){ + HeartbeatAcked = true; + if(HeartbeatSent == null){ + Log.Warning("GATEWAY(abstract): HeartbeatAck received before heartbeat was sent"); + return -1; + } + HeartbeatPing = (int)(TimeProvider.GetUtcNow().DateTime - HeartbeatSent.Value).TotalMilliseconds; + return HeartbeatPing; + } + + protected void StartHeartbeat(int heartbeat_interval){ + StopHeartbeat(); + HeartbeatCts = new(); + Task.Run(() => HeartbeatTask(heartbeat_interval), HeartbeatCts.Token); + } + + protected void StopHeartbeat(){ + HeartbeatCts?.Cancel(); + InstantHeartbeatCts?.Cancel(); + } + + #endregion +} \ No newline at end of file diff --git a/Discord.API/Gateway/GatewayClient.cs b/Discord.API/Gateway/GatewayClient.cs index 0aaac96..fdf76c6 100644 --- a/Discord.API/Gateway/GatewayClient.cs +++ b/Discord.API/Gateway/GatewayClient.cs @@ -1,51 +1,43 @@ using System.Collections.Specialized; -using System.Security.Cryptography; using System.Text.Json; using Serilog; using Websocket.Client; namespace Discord.API; -public class GatewayClient { +public class GatewayClient : AbstractGateway { private const int ApiVersion = 10; private string ApiKey; - private IWebsocketClient Websocket; - private CancellationTokenSource? HeartbeatCts; private ulong? Sequence = null; - private DateTime? HeartbeatAckReceived = null; - private TimeProvider timeProvider; public GatewayClient(string url, string api_key, TimeProvider time_provider) : this(url, api_key, time_provider, uri => new WebsocketClient(uri)) { } - internal GatewayClient(string url, string api_key, TimeProvider time_provider, Func websocket_client_factory){ - this.timeProvider=time_provider; + internal GatewayClient(string url, string api_key, TimeProvider time_provider, Func websocket_client_factory) + : base(websocket_client_factory.Invoke(BuildUrl(url)), time_provider){ this.ApiKey=api_key; + Log.Debug("GATEWAY: Created new gateway, with url: {url}", WebsocketClient.Url); + } + + private static Uri BuildUrl(string url){ UriBuilder uriBuilder = new(url); NameValueCollection query = System.Web.HttpUtility.ParseQueryString(""); query.Add("v", ApiVersion.ToString()); query.Add("encoding", "json"); uriBuilder.Query=query.ToString(); - Websocket = websocket_client_factory.Invoke(uriBuilder.Uri); - Websocket.DisconnectionHappened.Subscribe(DisconnectionHandler); - Websocket.ReconnectionHappened.Subscribe(ReconnectionHandler); - Websocket.MessageReceived.Subscribe(MessageReceivedHandler); - Websocket.Start(); - - Log.Debug("GATEWAY: Created new gateway, with url: {url}", uriBuilder.ToString()); + return uriBuilder.Uri; } - private void DisconnectionHandler(DisconnectionInfo info){ - StopHeartbeat(); + protected override void DisconnectHandler(DisconnectionInfo info){ Log.Information("GATEWAY: Disconnected. Type: {DisconnectionType}", info.Type); } - private void ReconnectionHandler(ReconnectionInfo info){ - Log.Information("GATEWAY: (Re)Connected to server. Url: {url}, Type: {Type}", Websocket.Url, info.Type); + protected override void ReconnectHandler(ReconnectionInfo info){ + Log.Information("GATEWAY: (Re)Connected to server. Url: {url}, Type: {Type}", WebsocketClient.Url, info.Type); } - private void MessageReceivedHandler(ResponseMessage msg){ + protected override void MessageReceivedHandler(ResponseMessage msg){ if(msg.MessageType != System.Net.WebSockets.WebSocketMessageType.Text) return; try{ GatewayPacket packet = JsonSerializer.Deserialize(msg.Text!, SourceGenerationContext.Default.GatewayPacket) @@ -75,41 +67,24 @@ public class GatewayClient { } private void HeartbeatAckHandler(){ - HeartbeatAckReceived = timeProvider.GetUtcNow().DateTime; Log.Debug("GATEWAY: Heartbeat ACK received"); + Log.Information("GATEWAY: Heartbeat ping: {ping_ms} ms", HeartbeatAckReceived()); } - private void StartHeartbeat(int heartbeat_interval){ - HeartbeatCts?.Cancel(); - HeartbeatCts = new CancellationTokenSource(); - CancellationToken ct = HeartbeatCts.Token; - Task.Run(async ()=>{ - await Task.Delay(Random.Shared.Next(1, heartbeat_interval)); - using PeriodicTimer pd = new(TimeSpan.FromMilliseconds(heartbeat_interval), timeProvider); - HeartbeatAckReceived = timeProvider.GetUtcNow().DateTime; - DateTime HeartbeatSent = timeProvider.GetUtcNow().DateTime; - do{ - if(HeartbeatAckReceived == null){ - Log.Debug("GATEWAY: Heartbeat ack not received. Reconnecting."); - _ = Websocket.Reconnect(); - break; - } - Log.Information("GATEWAY: Heartbeat ping time is {time_ms} ms", (HeartbeatAckReceived.Value - HeartbeatSent).TotalMilliseconds); - HeartbeatPacket packet = new HeartbeatPacket(){ - Sequence=this.Sequence - }; - if(Websocket.IsRunning) - if(!Websocket.Send(JsonSerializer.Serialize(packet, SourceGenerationContext.Default.HeartbeatPacket))){ - Log.Warning("GATEWAY: Failed to queue heartbeat message"); - } - HeartbeatSent = timeProvider.GetUtcNow().DateTime; - HeartbeatAckReceived = null; - Log.Debug("GATEWAY: Heartbeat sent"); - }while(await pd.WaitForNextTickAsync(ct) && !ct.IsCancellationRequested); - }); + protected override Task SendHeartbeat() + { + HeartbeatPacket packet = new(){ + Sequence = Sequence + }; + WebsocketClient.Send(JsonSerializer.Serialize(packet, SourceGenerationContext.Default.HeartbeatPacket)); + Log.Debug("GATEWAY: Heartbeat sent"); + return Task.CompletedTask; } - private void StopHeartbeat(){ - HeartbeatCts?.Cancel(); + protected override Task MissingHeartbeatAckHandler() + { + _ = WebsocketClient.Reconnect(); + Log.Debug("GATEWAY: Heartbeat ack missed. Reconnecting"); + return Task.FromResult(true); } }