Refactor socket.io controller

This commit is contained in:
Calvin Montgomery 2017-08-01 19:29:11 -07:00
parent 107155a661
commit 0118a6fb15
10 changed files with 480 additions and 253 deletions

View file

@ -1,17 +1,12 @@
var sio = require("socket.io");
var db = require("../database");
var User = require("../user");
var Server = require("../server");
var Config = require("../config");
var cookieParser = require("cookie-parser")(Config.get("http.cookie-secret"));
var $util = require("../utilities");
var Flags = require("../flags");
var typecheck = require("json-typecheck");
var net = require("net");
var util = require("../utilities");
var crypto = require("crypto");
var isTorExit = require("../tor").isTorExit;
var session = require("../session");
import sio from 'socket.io';
import db from '../database';
import User from '../user';
import Server from '../server';
import Config from '../config';
const cookieParser = require("cookie-parser")(Config.get("http.cookie-secret"));
import typecheck from 'json-typecheck';
import { isTorExit } from '../tor';
import session from '../session';
import counters from '../counters';
import { verifyIPSessionCookie } from '../web/middleware/ipsessioncookie';
import Promise from 'bluebird';
@ -21,118 +16,181 @@ import { CachingGlobalBanlist } from './globalban';
import proxyaddr from 'proxy-addr';
import { Counter, Gauge } from 'prom-client';
import Socket from 'socket.io/lib/socket';
import { TokenBucket } from '../util/token-bucket';
import http from 'http';
const LOGGER = require('@calzoneman/jsli')('ioserver');
var CONNECT_RATE = {
burst: 5,
sustained: 0.1
};
// WIP, not in use yet
class IOServer {
constructor(options = {
proxyTrustFn: proxyaddr.compile('127.0.0.1')
}) {
({
proxyTrustFn: this.proxyTrustFn
} = options);
var ipThrottle = {};
// Keep track of number of connections per IP
var ipCount = {};
function parseCookies(socket, accept) {
var req = socket.request;
if (req.headers.cookie) {
cookieParser(req, null, () => {
accept(null, true);
});
} else {
req.cookies = {};
req.signedCookies = {};
accept(null, true);
}
}
/**
* Called before an incoming socket.io connection is accepted.
*/
function handleAuth(socket, accept) {
socket.user = null;
socket.aliases = [];
const promises = [];
const auth = socket.request.signedCookies.auth;
if (auth) {
promises.push(verifySession(auth).then(user => {
socket.user = Object.assign({}, user);
}).catch(error => {
// Do nothing
}));
this.ipThrottle = new Map();
this.ipCount = new Map();
}
promises.push(getAliases(socket._realip).then(aliases => {
socket.aliases = aliases;
}).catch(error => {
// Do nothing
}));
Promise.all(promises).then(() => {
accept(null, true);
});
}
function handleIPSessionCookie(socket, accept) {
var cookie = socket.request.signedCookies['ip-session'];
if (!cookie) {
socket.ipSessionFirstSeen = new Date();
return accept(null, true);
}
var sessionMatch = verifyIPSessionCookie(socket._realip, cookie);
if (sessionMatch) {
socket.ipSessionFirstSeen = sessionMatch.date;
} else {
socket.ipSessionFirstSeen = new Date();
}
accept(null, true);
}
function throttleIP(sock) {
var ip = sock._realip;
if (!(ip in ipThrottle)) {
ipThrottle[ip] = $util.newRateLimiter();
}
if (ipThrottle[ip].throttle(CONNECT_RATE)) {
LOGGER.warn("IP throttled: " + ip);
sock.emit("kick", {
reason: "Your IP address is connecting too quickly. Please "+
"wait 10 seconds before joining again."
});
sock.disconnect();
return true;
}
return false;
}
function ipLimitReached(sock) {
var ip = sock._realip;
sock.on("disconnect", function () {
counters.add("socket.io:disconnect", 1);
ipCount[ip]--;
if (ipCount[ip] === 0) {
/* Clear out unnecessary counters to save memory */
delete ipCount[ip];
// Map proxied sockets to the real IP address via X-Forwarded-For
// If the resulting address is a known Tor exit, flag it as such
ipProxyMiddleware(socket, next) {
if (!socket.context) socket.context = {};
socket.context.ipAddress = proxyaddr(socket.client.request, this.proxyTrustFn);
if (isTorExit(socket.context.ipAddress)) {
socket.context.torConnection = true;
}
});
if (!(ip in ipCount)) {
ipCount[ip] = 0;
next();
}
ipCount[ip]++;
if (ipCount[ip] > Config.get("io.ip-connection-limit")) {
sock.emit("kick", {
reason: "Too many connections from your IP address"
// Reject global banned IP addresses
ipBanMiddleware(socket, next) {
if (isIPGlobalBanned(socket.context.ipAddress)) {
LOGGER.info('Rejecting %s - banned',
socket.context.ipAddress);
next(new Error('You are banned from the server'));
return;
}
next();
}
// Rate limit connection attempts by IP address
ipThrottleMiddleware(socket, next) {
if (!this.ipThrottle.has(socket.context.ipAddress)) {
this.ipThrottle.set(socket.context.ipAddress, new TokenBucket(5, 0.1));
}
const bucket = this.ipThrottle.get(socket.context.ipAddress);
if (bucket.throttle()) {
LOGGER.info('Rejecting %s - exceeded connection rate limit',
socket.context.ipAddress);
next(new Error('Rate limit exceeded'));
return;
}
next();
}
ipConnectionLimitMiddleware(socket, next) {
const ip = socket.context.ipAddress;
const count = this.ipCount.get(ip) || 0;
if (count >= Config.get('io.ip-connection-limit')) {
// TODO: better error message would be nice
next(new Error('Too many connections from your IP address'));
return;
}
this.ipCount.set(ip, count + 1);
socket.once('disconnect', () => {
this.ipCount.set(ip, this.ipCount.get(ip) - 1);
});
next();
}
// Parse cookies
cookieParsingMiddleware(socket, next) {
const req = socket.request;
if (req.headers.cookie) {
cookieParser(req, null, () => next());
} else {
req.cookies = {};
req.signedCookies = {};
next();
}
}
// Determine session age from ip-session cookie
// (Used for restricting chat)
ipSessionCookieMiddleware(socket, next) {
const cookie = socket.request.signedCookies['ip-session'];
if (!cookie) {
socket.context.ipSessionFirstSeen = new Date();
next();
return;
}
const sessionMatch = verifyIPSessionCookie(socket.context.ipAddress, cookie);
if (sessionMatch) {
socket.context.ipSessionFirstSeen = sessionMatch.date;
} else {
socket.context.ipSessionFirstSeen = new Date();
}
next();
}
// Match login cookie against the DB, look up aliases
authUserMiddleware(socket, next) {
socket.context.aliases = [];
const promises = [];
const auth = socket.request.signedCookies.auth;
if (auth) {
promises.push(verifySession(auth).then(user => {
socket.context.user = Object.assign({}, user);
}).catch(error => {
LOGGER.warn('Unable to verify session for %s - ignoring auth',
socket.context.ipAddress);
}));
}
promises.push(getAliases(socket.context.ipAddress).then(aliases => {
socket.context.aliases = aliases;
}).catch(error => {
LOGGER.warn('Unable to load aliases for %s',
socket.context.ipAddress);
}));
Promise.all(promises).then(() => next());
}
metricsEmittingMiddleware(socket, next) {
emitMetrics(socket);
next();
}
handleConnection(socket) {
LOGGER.info('Accepted socket from %s', socket.context.ipAddress);
counters.add('socket.io:accept', 1);
socket.once('disconnect', () => counters.add('socket.io:disconnect', 1));
const user = new User(socket, socket.context.ipAddress, socket.context.user);
if (socket.context.user) {
db.recordVisit(socket.context.ipAddress, user.getName());
}
const announcement = Server.getServer().announcement;
if (announcement !== null) {
socket.emit('announcement', announcement);
}
}
initSocketIO() {
patchTypecheckedFunctions();
const io = this.io = sio.instance = sio();
io.use(this.ipProxyMiddleware.bind(this));
io.use(this.ipBanMiddleware.bind(this));
io.use(this.ipThrottleMiddleware.bind(this));
io.use(this.ipConnectionLimitMiddleware.bind(this));
io.use(this.cookieParsingMiddleware.bind(this));
io.use(this.ipSessionCookieMiddleware.bind(this));
io.use(this.authUserMiddleware.bind(this));
io.use(this.metricsEmittingMiddleware.bind(this));
io.on('connection', this.handleConnection.bind(this));
}
bindTo(servers) {
if (!this.io) {
throw new Error('Cannot bind: socket.io has not been initialized yet');
}
servers.forEach(server => {
this.io.attach(server);
});
sock.disconnect();
return;
}
}
@ -167,18 +225,6 @@ function patchTypecheckedFunctions() {
};
}
function ipForwardingMiddleware(webConfig) {
const trustFn = proxyaddr.compile(webConfig.getTrustedProxies());
return function (socket, accept) {
LOGGER.debug('ip = %s', socket.client.request.connection.remoteAddress);
//socket.client.request.ip = socket.client.conn.remoteAddress;
socket._realip = proxyaddr(socket.client.request, trustFn);
LOGGER.debug('socket._realip: %s', socket._realip);
accept(null, true);
}
}
let globalIPBanlist = null;
function isIPGlobalBanned(ip) {
if (globalIPBanlist === null) {
@ -219,7 +265,7 @@ function emitMetrics(sock) {
}
} catch (error) {
LOGGER.error('Error emitting transport upgrade metrics for socket (ip=%s): %s',
sock._realip, error.stack);
sock.context.ipAddress, error.stack);
}
});
@ -229,130 +275,74 @@ function emitMetrics(sock) {
promSocketDisconnect.inc(1, new Date());
} catch (error) {
LOGGER.error('Error emitting disconnect metrics for socket (ip=%s): %s',
sock._realip, error.stack);
sock.context.ipAddress, error.stack);
}
});
} catch (error) {
LOGGER.error('Error emitting metrics for socket (ip=%s): %s',
sock._realip, error.stack);
sock.context.ipAddress, error.stack);
}
}
/**
* Called after a connection is accepted
*/
function handleConnection(sock) {
var ip = sock._realip;
if (!ip) {
sock.emit("kick", {
reason: "Your IP address could not be determined from the socket connection. See https://github.com/Automattic/socket.io/issues/1737 for details"
});
return;
}
if (net.isIPv6(ip)) {
ip = util.expandIPv6(ip);
sock._realip = ip;
}
if (isTorExit(ip)) {
sock._isUsingTor = true;
}
var srv = Server.getServer();
if (throttleIP(sock)) {
return;
}
// Check for global ban on the IP
if (isIPGlobalBanned(ip)) {
LOGGER.info("Rejecting " + ip + " - global banned");
sock.emit("kick", { reason: "Your IP is globally banned." });
sock.disconnect();
return;
}
if (ipLimitReached(sock)) {
return;
}
emitMetrics(sock);
LOGGER.info("Accepted socket from " + ip);
counters.add("socket.io:accept", 1);
const user = new User(sock, ip, sock.user);
if (sock.user) {
db.recordVisit(ip, user.getName());
}
const announcement = srv.announcement;
if (announcement != null) {
sock.emit("announcement", announcement);
}
}
let instance = null;
module.exports = {
init: function (srv, webConfig) {
patchTypecheckedFunctions();
var bound = {};
const ioOptions = {
perMessageDeflate: Config.get("io.per-message-deflate")
};
var io = sio.instance = sio();
if (instance !== null) {
throw new Error('ioserver.init: already initialized');
}
io.use(ipForwardingMiddleware(webConfig));
io.use(parseCookies);
io.use(handleIPSessionCookie);
io.use(handleAuth);
io.on("connection", handleConnection);
const ioServer = instance = new IOServer({
proxyTrustFn: proxyaddr.compile(webConfig.getTrustedProxies())
});
ioServer.initSocketIO();
const uniqueListenAddresses = new Set();
const servers = [];
Config.get("listen").forEach(function (bind) {
if (!bind.io) {
return;
}
var id = bind.ip + ":" + bind.port;
if (id in bound) {
LOGGER.warn("Ignoring duplicate listen address " + id);
const id = bind.ip + ":" + bind.port;
if (uniqueListenAddresses.has(id)) {
LOGGER.warn("Ignoring duplicate listen address %s", id);
return;
}
if (id in srv.servers) {
io.attach(srv.servers[id], ioOptions);
if (srv.servers.hasOwnProperty(id)) {
servers.push(srv.servers[id]);
} else {
var server = require("http").createServer().listen(bind.port, bind.ip);
server.on("clientError", function (err, socket) {
try {
socket.destroy();
} catch (e) {
}
});
io.attach(server, ioOptions);
const server = http.createServer().listen(bind.port, bind.ip);
servers.push(server);
}
bound[id] = null;
uniqueListenAddresses.add(id);
});
ioServer.bindTo(servers);
},
handleConnection: handleConnection
IOServer: IOServer
};
/* Clean out old rate limiters */
setInterval(function () {
for (var ip in ipThrottle) {
if (ipThrottle[ip].lastTime < Date.now() - 60 * 1000) {
var obj = ipThrottle[ip];
/* Not strictly necessary, but seems to help the GC out a bit */
for (var key in obj) {
delete obj[key];
}
delete ipThrottle[ip];
if (instance == null) return;
let cleaned = 0;
const keys = instance.ipThrottle.keys();
for (const key of keys) {
if (instance.ipThrottle.get(key).lastRefill < Date.now() - 60000) {
const bucket = instance.ipThrottle.delete(key);
for (const k in bucket) delete bucket[k];
cleaned++;
}
}
if (Config.get("aggressive-gc") && global && global.gc) {
global.gc();
if (cleaned > 0) {
LOGGER.info('Cleaned up %d stale IP throttle token buckets', cleaned);
}
}, 5 * 60 * 1000);