diff --git a/README.md b/README.md index dfe669a..79933d8 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,7 @@ The following options are allowed: - `key`: the name of the key to pub/sub events on as prefix (`socket.io`) - `requestsTimeout`: optional, after this timeout the adapter will stop waiting from responses to request (`5000ms`) +- `parser`: optional, parser to use for encoding and decoding messages passed through Redis ([`notepack.io`](https://www.npmjs.com/package/notepack.io)) ### RedisAdapter @@ -205,6 +206,7 @@ that a regular `Adapter` does not - `pubClient` - `subClient` - `requestsTimeout` +- `parser` ### RedisAdapter#sockets(rooms: Set<String>) diff --git a/lib/index.ts b/lib/index.ts index 413d69b..60da9b9 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -37,6 +37,11 @@ interface AckRequest { ack: (...args: any[]) => void; } +interface Parser { + decode: (msg: any) => any; + encode: (msg: any) => any; +} + const isNumeric = (str) => !isNaN(str) && !isNaN(parseFloat(str)); export interface RedisAdapterOptions { @@ -62,6 +67,11 @@ export interface RedisAdapterOptions { * @default false */ publishOnSpecificResponseChannel: boolean; + /** + * The parser to use for encoding and decoding messages sent to Redis. + * This option defaults to using `notepack.io`, a MessagePack implementation. + */ + parser: Parser; } /** @@ -87,6 +97,7 @@ export class RedisAdapter extends Adapter { public readonly uid; public readonly requestsTimeout: number; public readonly publishOnSpecificResponseChannel: boolean; + public readonly parser: Parser; private readonly channel: string; private readonly requestChannel: string; @@ -115,6 +126,7 @@ export class RedisAdapter extends Adapter { this.uid = uid2(6); this.requestsTimeout = opts.requestsTimeout || 5000; this.publishOnSpecificResponseChannel = !!opts.publishOnSpecificResponseChannel; + this.parser = opts.parser || msgpack; const prefix = opts.key || "socket.io"; @@ -181,7 +193,7 @@ export class RedisAdapter extends Adapter { return debug("ignore unknown room %s", room); } - const args = msgpack.decode(msg); + const args = this.parser.decode(msg); const [uid, packet, opts] = args; if (this.uid === uid) return debug("ignore same uid"); @@ -226,7 +238,7 @@ export class RedisAdapter extends Adapter { if (msg[0] === 0x7b) { request = JSON.parse(msg.toString()); } else { - request = msgpack.decode(msg); + request = this.parser.decode(msg); } } catch (err) { debug("ignoring malformed request"); @@ -424,7 +436,7 @@ export class RedisAdapter extends Adapter { this.publishResponse( request, - msgpack.encode({ + this.parser.encode({ type: RequestType.BROADCAST_ACK, requestId: request.requestId, packet: arg, @@ -467,7 +479,7 @@ export class RedisAdapter extends Adapter { if (msg[0] === 0x7b) { response = JSON.parse(msg.toString()); } else { - response = msgpack.decode(msg); + response = this.parser.decode(msg); } } catch (err) { debug("ignoring malformed response"); @@ -596,7 +608,7 @@ export class RedisAdapter extends Adapter { except: [...new Set(opts.except)], flags: opts.flags, }; - const msg = msgpack.encode([this.uid, packet, rawOpts]); + const msg = this.parser.encode([this.uid, packet, rawOpts]); let channel = this.channel; if (opts.rooms && opts.rooms.size === 1) { channel += opts.rooms.keys().next().value + "#"; @@ -626,7 +638,7 @@ export class RedisAdapter extends Adapter { flags: opts.flags, }; - const request = msgpack.encode({ + const request = this.parser.encode({ uid: this.uid, requestId, type: RequestType.BROADCAST, diff --git a/test/custom-parser.ts b/test/custom-parser.ts new file mode 100644 index 0000000..fb7077f --- /dev/null +++ b/test/custom-parser.ts @@ -0,0 +1,85 @@ +import { createServer } from "http"; +import { Server, Socket as ServerSocket } from "socket.io"; +import { io as ioc, Socket as ClientSocket } from "socket.io-client"; +import { createAdapter } from "../lib"; +import { createClient } from "redis"; +import { AddressInfo } from "net"; +import { times } from "./util"; +import expect = require("expect.js"); + +const NODES_COUNT = 3; + +describe("custom parser", () => { + let servers: Server[] = [], + serverSockets: ServerSocket[] = [], + clientSockets: ClientSocket[] = [], + redisClients: any[] = []; + + beforeEach(async () => { + for (let i = 1; i <= NODES_COUNT; i++) { + const httpServer = createServer(); + const pubClient = createClient(); + const subClient = createClient(); + + await Promise.all([pubClient.connect(), subClient.connect()]); + + redisClients.push(pubClient, subClient); + + const io = new Server(httpServer, { + adapter: createAdapter(pubClient, subClient, { + parser: { + decode(msg) { + return JSON.parse(msg); + }, + encode(msg) { + return JSON.stringify(msg); + }, + }, + }), + }); + + await new Promise((resolve) => { + httpServer.listen(() => { + const port = (httpServer.address() as AddressInfo).port; + const clientSocket = ioc(`http://localhost:${port}`); + + io.on("connection", async (socket) => { + clientSockets.push(clientSocket); + serverSockets.push(socket); + servers.push(io); + resolve(); + }); + }); + }); + } + }); + + afterEach(() => { + servers.forEach((server) => { + // @ts-ignore + server.httpServer.close(); + server.of("/").adapter.close(); + }); + clientSockets.forEach((socket) => { + socket.disconnect(); + }); + redisClients.forEach((redisClient) => { + redisClient.disconnect(); + }); + }); + + it("broadcasts", (done) => { + const partialDone = times(3, done); + + clientSockets.forEach((clientSocket) => { + clientSocket.on("test", (arg1, arg2, arg3) => { + expect(arg1).to.eql(1); + expect(arg2).to.eql("2"); + expect(arg3).to.eql([3]); + partialDone(); + }); + }); + + servers[0].emit("test", 1, "2", [3]); + }); +}); diff --git a/test/index.ts b/test/index.ts index 9c39496..78d3cfd 100644 --- a/test/index.ts +++ b/test/index.ts @@ -562,3 +562,5 @@ function cleanup(done) { } done(); } + +require("./custom-parser"); diff --git a/test/util.ts b/test/util.ts index f3844af..5446b2f 100644 --- a/test/util.ts +++ b/test/util.ts @@ -20,3 +20,13 @@ Assertion.prototype.contain = function (...args) { } return contain.apply(this, args); }; + +export function times(count: number, fn: () => void) { + let i = 0; + return () => { + i++; + if (i === count) { + fn(); + } + }; +}