diff --git a/packages/ws/README.md b/packages/ws/README.md index c1d3b70a1..61be4e935 100644 --- a/packages/ws/README.md +++ b/packages/ws/README.md @@ -50,7 +50,10 @@ const manager = new WebSocketManager({ intents: 0, // for no intents rest, // uncomment if you have zlib-sync installed and want to use compression - // compression: CompressionMethod.ZlibStream, + // compression: CompressionMethod.ZlibSync, + + // alternatively, we support compression using node's native `node:zlib` module: + // compression: CompressionMethod.ZlibNative, }); manager.on(WebSocketShardEvents.Dispatch, (event) => { diff --git a/packages/ws/src/utils/constants.ts b/packages/ws/src/utils/constants.ts index c901c23b2..2917090ea 100644 --- a/packages/ws/src/utils/constants.ts +++ b/packages/ws/src/utils/constants.ts @@ -18,13 +18,19 @@ export enum Encoding { * Valid compression methods */ export enum CompressionMethod { - ZlibStream = 'zlib-stream', + ZlibNative, + ZlibSync, } export const DefaultDeviceProperty = `@discordjs/ws [VI]{{inject}}[/VI]` as `@discordjs/ws ${string}`; const getDefaultSessionStore = lazy(() => new Collection()); +export const CompressionParameterMap = { + [CompressionMethod.ZlibNative]: 'zlib-stream', + [CompressionMethod.ZlibSync]: 'zlib-stream', +} as const satisfies Record; + /** * Default options used by the manager */ @@ -46,6 +52,7 @@ export const DefaultWebSocketManagerOptions = { version: APIVersion, encoding: Encoding.JSON, compression: null, + useIdentifyCompression: false, retrieveSessionInfo(shardId) { const store = getDefaultSessionStore(); return store.get(shardId) ?? null; diff --git a/packages/ws/src/ws/WebSocketManager.ts b/packages/ws/src/ws/WebSocketManager.ts index f4a80bec7..2bc8a601c 100644 --- a/packages/ws/src/ws/WebSocketManager.ts +++ b/packages/ws/src/ws/WebSocketManager.ts @@ -96,9 +96,9 @@ export interface OptionalWebSocketManagerOptions { */ buildStrategy(manager: WebSocketManager): IShardingStrategy; /** - * The compression method to use + * The transport compression method to use - mutually exclusive with `useIdentifyCompression` * - * @defaultValue `null` (no compression) + * @defaultValue `null` (no transport compression) */ compression: CompressionMethod | null; /** @@ -176,6 +176,12 @@ export interface OptionalWebSocketManagerOptions { * Function used to store session information for a given shard */ updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable; + /** + * Whether to use the `compress` option when identifying + * + * @defaultValue `false` + */ + useIdentifyCompression: boolean; /** * The gateway version to use * diff --git a/packages/ws/src/ws/WebSocketShard.ts b/packages/ws/src/ws/WebSocketShard.ts index 5af2d2597..0d183b90c 100644 --- a/packages/ws/src/ws/WebSocketShard.ts +++ b/packages/ws/src/ws/WebSocketShard.ts @@ -1,11 +1,10 @@ -/* eslint-disable id-length */ import { Buffer } from 'node:buffer'; import { once } from 'node:events'; import { clearInterval, clearTimeout, setInterval, setTimeout } from 'node:timers'; import { setTimeout as sleep } from 'node:timers/promises'; import { URLSearchParams } from 'node:url'; import { TextDecoder } from 'node:util'; -import { inflate } from 'node:zlib'; +import type * as nativeZlib from 'node:zlib'; import { Collection } from '@discordjs/collection'; import { lazy, shouldUseGlobalFetchAndWebSocket } from '@discordjs/util'; import { AsyncQueue } from '@sapphire/async-queue'; @@ -21,13 +20,20 @@ import { type GatewaySendPayload, } from 'discord-api-types/v10'; import { WebSocket, type Data } from 'ws'; -import type { Inflate } from 'zlib-sync'; -import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy.js'; -import { ImportantGatewayOpcodes, getInitialSendRateLimitState } from '../utils/constants.js'; +import type * as ZlibSync from 'zlib-sync'; +import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy'; +import { + CompressionMethod, + CompressionParameterMap, + ImportantGatewayOpcodes, + getInitialSendRateLimitState, +} from '../utils/constants.js'; import type { SessionInfo } from './WebSocketManager.js'; -// eslint-disable-next-line promise/prefer-await-to-then +/* eslint-disable promise/prefer-await-to-then */ const getZlibSync = lazy(async () => import('zlib-sync').then((mod) => mod.default).catch(() => null)); +const getNativeZlib = lazy(async () => import('node:zlib').then((mod) => mod).catch(() => null)); +/* eslint-enable promise/prefer-await-to-then */ export enum WebSocketShardEvents { Closed = 'closed', @@ -86,9 +92,9 @@ const WebSocketConstructor: typeof WebSocket = shouldUseGlobalFetchAndWebSocket( export class WebSocketShard extends AsyncEventEmitter { private connection: WebSocket | null = null; - private useIdentifyCompress = false; + private nativeInflate: nativeZlib.Inflate | null = null; - private inflate: Inflate | null = null; + private zLibSyncInflate: ZlibSync.Inflate | null = null; private readonly textDecoder = new TextDecoder(); @@ -120,6 +126,18 @@ export class WebSocketShard extends AsyncEventEmitter { #status: WebSocketShardStatus = WebSocketShardStatus.Idle; + private identifyCompressionEnabled = false; + + /** + * @privateRemarks + * + * This is needed because `this.strategy.options.compression` is not an actual reflection of the compression method + * used, but rather the compression method that the user wants to use. This is because the libraries could just be missing. + */ + private get transportCompressionEnabled() { + return this.strategy.options.compression !== null && (this.nativeInflate ?? this.zLibSyncInflate) !== null; + } + public get status(): WebSocketShardStatus { return this.#status; } @@ -161,21 +179,63 @@ export class WebSocketShard extends AsyncEventEmitter { throw new Error("Tried to connect a shard that wasn't idle"); } - const { version, encoding, compression } = this.strategy.options; + const { version, encoding, compression, useIdentifyCompression } = this.strategy.options; + this.identifyCompressionEnabled = useIdentifyCompression; + + // eslint-disable-next-line id-length const params = new URLSearchParams({ v: version, encoding }); - if (compression) { - const zlib = await getZlibSync(); - if (zlib) { - params.append('compress', compression); - this.inflate = new zlib.Inflate({ - chunkSize: 65_535, - to: 'string', - }); - } else if (!this.useIdentifyCompress) { - this.useIdentifyCompress = true; - console.warn( - 'WebSocketShard: Compression is enabled but zlib-sync is not installed, falling back to identify compress', - ); + if (compression !== null) { + if (useIdentifyCompression) { + console.warn('WebSocketShard: transport compression is enabled, disabling identify compression'); + this.identifyCompressionEnabled = false; + } + + params.append('compress', CompressionParameterMap[compression]); + + switch (compression) { + case CompressionMethod.ZlibNative: { + const zlib = await getNativeZlib(); + if (zlib) { + const inflate = zlib.createInflate({ + chunkSize: 65_535, + flush: zlib.constants.Z_SYNC_FLUSH, + }); + + inflate.on('error', (error) => { + this.emit(WebSocketShardEvents.Error, { error }); + }); + + this.nativeInflate = inflate; + } else { + console.warn('WebSocketShard: Compression is set to native but node:zlib is not available.'); + params.delete('compress'); + } + + break; + } + + case CompressionMethod.ZlibSync: { + const zlib = await getZlibSync(); + if (zlib) { + this.zLibSyncInflate = new zlib.Inflate({ + chunkSize: 65_535, + to: 'string', + }); + } else { + console.warn('WebSocketShard: Compression is set to zlib-sync, but it is not installed.'); + params.delete('compress'); + } + + break; + } + } + } + + if (this.identifyCompressionEnabled) { + const zlib = await getNativeZlib(); + if (!zlib) { + console.warn('WebSocketShard: Identify compression is enabled, but node:zlib is not available.'); + this.identifyCompressionEnabled = false; } } @@ -451,28 +511,29 @@ export class WebSocketShard extends AsyncEventEmitter { `shard id: ${this.id.toString()}`, `shard count: ${this.strategy.options.shardCount}`, `intents: ${this.strategy.options.intents}`, - `compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`, + `compression: ${this.transportCompressionEnabled ? CompressionParameterMap[this.strategy.options.compression!] : this.identifyCompressionEnabled ? 'identify' : 'none'}`, ]); - const d: GatewayIdentifyData = { + const data: GatewayIdentifyData = { token: this.strategy.options.token, properties: this.strategy.options.identifyProperties, intents: this.strategy.options.intents, - compress: this.useIdentifyCompress, + compress: this.identifyCompressionEnabled, shard: [this.id, this.strategy.options.shardCount], }; if (this.strategy.options.largeThreshold) { - d.large_threshold = this.strategy.options.largeThreshold; + data.large_threshold = this.strategy.options.largeThreshold; } if (this.strategy.options.initialPresence) { - d.presence = this.strategy.options.initialPresence; + data.presence = this.strategy.options.initialPresence; } await this.send({ op: GatewayOpcodes.Identify, - d, + // eslint-disable-next-line id-length + d: data, }); await this.waitForEvent(WebSocketShardEvents.Ready, this.strategy.options.readyTimeout); @@ -490,6 +551,7 @@ export class WebSocketShard extends AsyncEventEmitter { this.replayedEvents = 0; return this.send({ op: GatewayOpcodes.Resume, + // eslint-disable-next-line id-length d: { token: this.strategy.options.token, seq: session.sequence, @@ -507,6 +569,7 @@ export class WebSocketShard extends AsyncEventEmitter { await this.send({ op: GatewayOpcodes.Heartbeat, + // eslint-disable-next-line id-length d: session?.sequence ?? null, }); @@ -514,6 +577,14 @@ export class WebSocketShard extends AsyncEventEmitter { this.isAck = false; } + private parseInflateResult(result: any): GatewayReceivePayload | null { + if (!result) { + return null; + } + + return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload; + } + private async unpackMessage(data: Data, isBinary: boolean): Promise { // Deal with no compression if (!isBinary) { @@ -528,10 +599,12 @@ export class WebSocketShard extends AsyncEventEmitter { const decompressable = new Uint8Array(data as ArrayBuffer); // Deal with identify compress - if (this.useIdentifyCompress) { - return new Promise((resolve, reject) => { + if (this.identifyCompressionEnabled) { + // eslint-disable-next-line no-async-promise-executor + return new Promise(async (resolve, reject) => { + const zlib = (await getNativeZlib())!; // eslint-disable-next-line promise/prefer-await-to-callbacks - inflate(decompressable, { chunkSize: 65_535 }, (err, result) => { + zlib.inflate(decompressable, { chunkSize: 65_535 }, (err, result) => { if (err) { reject(err); return; @@ -542,42 +615,50 @@ export class WebSocketShard extends AsyncEventEmitter { }); } - // Deal with gw wide zlib-stream compression - if (this.inflate) { - const l = decompressable.length; + // Deal with transport compression + if (this.transportCompressionEnabled) { const flush = - l >= 4 && - decompressable[l - 4] === 0x00 && - decompressable[l - 3] === 0x00 && - decompressable[l - 2] === 0xff && - decompressable[l - 1] === 0xff; + decompressable.length >= 4 && + decompressable.at(-4) === 0x00 && + decompressable.at(-3) === 0x00 && + decompressable.at(-2) === 0xff && + decompressable.at(-1) === 0xff; - const zlib = (await getZlibSync())!; - this.inflate.push(Buffer.from(decompressable), flush ? zlib.Z_SYNC_FLUSH : zlib.Z_NO_FLUSH); + if (this.nativeInflate) { + this.nativeInflate.write(decompressable, 'binary'); - if (this.inflate.err) { - this.emit(WebSocketShardEvents.Error, { - error: new Error(`${this.inflate.err}${this.inflate.msg ? `: ${this.inflate.msg}` : ''}`), - }); + if (!flush) { + return null; + } + + const [result] = await once(this.nativeInflate, 'data'); + return this.parseInflateResult(result); + } else if (this.zLibSyncInflate) { + const zLibSync = (await getZlibSync())!; + this.zLibSyncInflate.push(Buffer.from(decompressable), flush ? zLibSync.Z_SYNC_FLUSH : zLibSync.Z_NO_FLUSH); + + if (this.zLibSyncInflate.err) { + this.emit(WebSocketShardEvents.Error, { + error: new Error( + `${this.zLibSyncInflate.err}${this.zLibSyncInflate.msg ? `: ${this.zLibSyncInflate.msg}` : ''}`, + ), + }); + } + + if (!flush) { + return null; + } + + const { result } = this.zLibSyncInflate; + return this.parseInflateResult(result); } - - if (!flush) { - return null; - } - - const { result } = this.inflate; - if (!result) { - return null; - } - - return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload; } this.debug([ 'Received a message we were unable to decompress', `isBinary: ${isBinary.toString()}`, - `useIdentifyCompress: ${this.useIdentifyCompress.toString()}`, - `inflate: ${Boolean(this.inflate).toString()}`, + `identifyCompressionEnabled: ${this.identifyCompressionEnabled.toString()}`, + `inflate: ${this.transportCompressionEnabled ? CompressionMethod[this.strategy.options.compression!] : 'none'}`, ]); return null; @@ -838,7 +919,7 @@ export class WebSocketShard extends AsyncEventEmitter { messages.length > 1 ? `\n${messages .slice(1) - .map((m) => ` ${m}`) + .map((message) => ` ${message}`) .join('\n')}` : '' }`;