diff --git a/Discord.API/Gateway/AbstractGateway.cs b/Discord.API/Gateway/AbstractGateway.cs
new file mode 100644
index 0000000..ec94772
--- /dev/null
+++ b/Discord.API/Gateway/AbstractGateway.cs
@@ -0,0 +1,99 @@
+using System.Reactive.Linq;
+using Serilog;
+using Websocket.Client;
+
+namespace Discord.API;
+
+public abstract class AbstractGateway {
+ protected readonly IWebsocketClient WebsocketClient;
+ protected readonly TimeProvider TimeProvider;
+
+ public AbstractGateway(IWebsocketClient client, TimeProvider time_provider){
+ this.WebsocketClient = client;
+ this.TimeProvider = time_provider;
+ WebsocketClient.DisconnectionHappened.Subscribe(DisconnectHandlerInternal);
+ WebsocketClient.ReconnectionHappened.Subscribe(ReconnectHandler);
+ WebsocketClient.MessageReceived.Subscribe(MessageReceivedHandler);
+ WebsocketClient.Start();
+ }
+
+ private void DisconnectHandlerInternal(DisconnectionInfo info){
+ StopHeartbeat();
+ DisconnectHandler(info);
+ }
+
+ protected abstract void DisconnectHandler(DisconnectionInfo info);
+
+ protected abstract void ReconnectHandler(ReconnectionInfo info);
+
+ protected abstract void MessageReceivedHandler(ResponseMessage message);
+
+ #region Heartbeat
+
+ private CancellationTokenSource? HeartbeatCts;
+ private CancellationTokenSource? InstantHeartbeatCts;
+ private DateTime? HeartbeatSent;
+ protected int HeartbeatPing {get; private set;}
+ private bool HeartbeatAcked;
+
+ ///
+ /// Sends a heartbeat packet to the remote
+ ///
+ protected abstract Task SendHeartbeat();
+ ///
+ /// Handler for when the heartbeat ack was not received
+ ///
+ /// Indicates if the heartbeat loop should exit
+ protected abstract Task MissingHeartbeatAckHandler();
+
+ private async Task HeartbeatTask(int heartbeat_interval){
+ CancellationToken ct = HeartbeatCts!.Token;
+ InstantHeartbeatCts = new();
+ HeartbeatSent = null;
+ await Task.Delay(
+ TimeSpan.FromMilliseconds(Random.Shared.Next(1, heartbeat_interval)),
+ TimeProvider,
+ InstantHeartbeatCts.Token);
+
+ if(ct.IsCancellationRequested) return; // Check if the task was cancelled when we were waiting
+
+ using PeriodicTimer pd = new(TimeSpan.FromMilliseconds(heartbeat_interval), TimeProvider);
+ HeartbeatAcked = true;
+ do{
+ if(!HeartbeatAcked){
+ if(await MissingHeartbeatAckHandler()){
+ break;
+ }
+ }
+ await SendHeartbeat();
+ HeartbeatSent = TimeProvider.GetUtcNow().DateTime;
+ HeartbeatAcked = false;
+
+ // Create new InstantHeartbeatCts if it was used
+ if(InstantHeartbeatCts.IsCancellationRequested) InstantHeartbeatCts = new();
+ } while(await pd.WaitForNextTickAsync(InstantHeartbeatCts.Token) && !ct.IsCancellationRequested);
+ }
+
+ protected int HeartbeatAckReceived(){
+ HeartbeatAcked = true;
+ if(HeartbeatSent == null){
+ Log.Warning("GATEWAY(abstract): HeartbeatAck received before heartbeat was sent");
+ return -1;
+ }
+ HeartbeatPing = (int)(TimeProvider.GetUtcNow().DateTime - HeartbeatSent.Value).TotalMilliseconds;
+ return HeartbeatPing;
+ }
+
+ protected void StartHeartbeat(int heartbeat_interval){
+ StopHeartbeat();
+ HeartbeatCts = new();
+ Task.Run(() => HeartbeatTask(heartbeat_interval), HeartbeatCts.Token);
+ }
+
+ protected void StopHeartbeat(){
+ HeartbeatCts?.Cancel();
+ InstantHeartbeatCts?.Cancel();
+ }
+
+ #endregion
+}
\ No newline at end of file
diff --git a/Discord.API/Gateway/GatewayClient.cs b/Discord.API/Gateway/GatewayClient.cs
index 0aaac96..fdf76c6 100644
--- a/Discord.API/Gateway/GatewayClient.cs
+++ b/Discord.API/Gateway/GatewayClient.cs
@@ -1,51 +1,43 @@
using System.Collections.Specialized;
-using System.Security.Cryptography;
using System.Text.Json;
using Serilog;
using Websocket.Client;
namespace Discord.API;
-public class GatewayClient {
+public class GatewayClient : AbstractGateway {
private const int ApiVersion = 10;
private string ApiKey;
- private IWebsocketClient Websocket;
- private CancellationTokenSource? HeartbeatCts;
private ulong? Sequence = null;
- private DateTime? HeartbeatAckReceived = null;
- private TimeProvider timeProvider;
public GatewayClient(string url, string api_key, TimeProvider time_provider)
: this(url, api_key, time_provider, uri => new WebsocketClient(uri))
{ }
- internal GatewayClient(string url, string api_key, TimeProvider time_provider, Func websocket_client_factory){
- this.timeProvider=time_provider;
+ internal GatewayClient(string url, string api_key, TimeProvider time_provider, Func websocket_client_factory)
+ : base(websocket_client_factory.Invoke(BuildUrl(url)), time_provider){
this.ApiKey=api_key;
+ Log.Debug("GATEWAY: Created new gateway, with url: {url}", WebsocketClient.Url);
+ }
+
+ private static Uri BuildUrl(string url){
UriBuilder uriBuilder = new(url);
NameValueCollection query = System.Web.HttpUtility.ParseQueryString("");
query.Add("v", ApiVersion.ToString());
query.Add("encoding", "json");
uriBuilder.Query=query.ToString();
- Websocket = websocket_client_factory.Invoke(uriBuilder.Uri);
- Websocket.DisconnectionHappened.Subscribe(DisconnectionHandler);
- Websocket.ReconnectionHappened.Subscribe(ReconnectionHandler);
- Websocket.MessageReceived.Subscribe(MessageReceivedHandler);
- Websocket.Start();
-
- Log.Debug("GATEWAY: Created new gateway, with url: {url}", uriBuilder.ToString());
+ return uriBuilder.Uri;
}
- private void DisconnectionHandler(DisconnectionInfo info){
- StopHeartbeat();
+ protected override void DisconnectHandler(DisconnectionInfo info){
Log.Information("GATEWAY: Disconnected. Type: {DisconnectionType}", info.Type);
}
- private void ReconnectionHandler(ReconnectionInfo info){
- Log.Information("GATEWAY: (Re)Connected to server. Url: {url}, Type: {Type}", Websocket.Url, info.Type);
+ protected override void ReconnectHandler(ReconnectionInfo info){
+ Log.Information("GATEWAY: (Re)Connected to server. Url: {url}, Type: {Type}", WebsocketClient.Url, info.Type);
}
- private void MessageReceivedHandler(ResponseMessage msg){
+ protected override void MessageReceivedHandler(ResponseMessage msg){
if(msg.MessageType != System.Net.WebSockets.WebSocketMessageType.Text) return;
try{
GatewayPacket packet = JsonSerializer.Deserialize(msg.Text!, SourceGenerationContext.Default.GatewayPacket)
@@ -75,41 +67,24 @@ public class GatewayClient {
}
private void HeartbeatAckHandler(){
- HeartbeatAckReceived = timeProvider.GetUtcNow().DateTime;
Log.Debug("GATEWAY: Heartbeat ACK received");
+ Log.Information("GATEWAY: Heartbeat ping: {ping_ms} ms", HeartbeatAckReceived());
}
- private void StartHeartbeat(int heartbeat_interval){
- HeartbeatCts?.Cancel();
- HeartbeatCts = new CancellationTokenSource();
- CancellationToken ct = HeartbeatCts.Token;
- Task.Run(async ()=>{
- await Task.Delay(Random.Shared.Next(1, heartbeat_interval));
- using PeriodicTimer pd = new(TimeSpan.FromMilliseconds(heartbeat_interval), timeProvider);
- HeartbeatAckReceived = timeProvider.GetUtcNow().DateTime;
- DateTime HeartbeatSent = timeProvider.GetUtcNow().DateTime;
- do{
- if(HeartbeatAckReceived == null){
- Log.Debug("GATEWAY: Heartbeat ack not received. Reconnecting.");
- _ = Websocket.Reconnect();
- break;
- }
- Log.Information("GATEWAY: Heartbeat ping time is {time_ms} ms", (HeartbeatAckReceived.Value - HeartbeatSent).TotalMilliseconds);
- HeartbeatPacket packet = new HeartbeatPacket(){
- Sequence=this.Sequence
- };
- if(Websocket.IsRunning)
- if(!Websocket.Send(JsonSerializer.Serialize(packet, SourceGenerationContext.Default.HeartbeatPacket))){
- Log.Warning("GATEWAY: Failed to queue heartbeat message");
- }
- HeartbeatSent = timeProvider.GetUtcNow().DateTime;
- HeartbeatAckReceived = null;
- Log.Debug("GATEWAY: Heartbeat sent");
- }while(await pd.WaitForNextTickAsync(ct) && !ct.IsCancellationRequested);
- });
+ protected override Task SendHeartbeat()
+ {
+ HeartbeatPacket packet = new(){
+ Sequence = Sequence
+ };
+ WebsocketClient.Send(JsonSerializer.Serialize(packet, SourceGenerationContext.Default.HeartbeatPacket));
+ Log.Debug("GATEWAY: Heartbeat sent");
+ return Task.CompletedTask;
}
- private void StopHeartbeat(){
- HeartbeatCts?.Cancel();
+ protected override Task MissingHeartbeatAckHandler()
+ {
+ _ = WebsocketClient.Reconnect();
+ Log.Debug("GATEWAY: Heartbeat ack missed. Reconnecting");
+ return Task.FromResult(true);
}
}