1 /*
2 * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * * Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 * * Redistributions in binary form must reproduce the above copyright
11 * notice, this list of conditions and the following disclaimer in the
12 * documentation and/or other materials provided with the distribution.
13 * * Neither the name of Redis nor the names of its contributors may be used
14 * to endorse or promote products derived from this software without
15 * specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 * POSSIBILITY OF SUCH DAMAGE.
28 */
29
30 #include "server.h"
31
32 /*-----------------------------------------------------------------------------
33 * Pubsub low level API
34 *----------------------------------------------------------------------------*/
35
freePubsubPattern(void * p)36 void freePubsubPattern(void *p) {
37 pubsubPattern *pat = p;
38
39 decrRefCount(pat->pattern);
40 zfree(pat);
41 }
42
listMatchPubsubPattern(void * a,void * b)43 int listMatchPubsubPattern(void *a, void *b) {
44 pubsubPattern *pa = a, *pb = b;
45
46 return (pa->client == pb->client) &&
47 (equalStringObjects(pa->pattern,pb->pattern));
48 }
49
50 /* Return the number of channels + patterns a client is subscribed to. */
clientSubscriptionsCount(client * c)51 int clientSubscriptionsCount(client *c) {
52 return dictSize(c->pubsub_channels)+
53 listLength(c->pubsub_patterns);
54 }
55
56 /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or
57 * 0 if the client was already subscribed to that channel. */
pubsubSubscribeChannel(client * c,robj * channel)58 int pubsubSubscribeChannel(client *c, robj *channel) {
59 dictEntry *de;
60 list *clients = NULL;
61 int retval = 0;
62
63 /* Add the channel to the client -> channels hash table */
64 if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) {
65 retval = 1;
66 incrRefCount(channel);
67 /* Add the client to the channel -> list of clients hash table */
68 de = dictFind(server.pubsub_channels,channel);
69 if (de == NULL) {
70 clients = listCreate();
71 dictAdd(server.pubsub_channels,channel,clients);
72 incrRefCount(channel);
73 } else {
74 clients = dictGetVal(de);
75 }
76 listAddNodeTail(clients,c);
77 }
78 /* Notify the client */
79 addReply(c,shared.mbulkhdr[3]);
80 addReply(c,shared.subscribebulk);
81 addReplyBulk(c,channel);
82 addReplyLongLong(c,clientSubscriptionsCount(c));
83 return retval;
84 }
85
86 /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or
87 * 0 if the client was not subscribed to the specified channel. */
pubsubUnsubscribeChannel(client * c,robj * channel,int notify)88 int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) {
89 dictEntry *de;
90 list *clients;
91 listNode *ln;
92 int retval = 0;
93
94 /* Remove the channel from the client -> channels hash table */
95 incrRefCount(channel); /* channel may be just a pointer to the same object
96 we have in the hash tables. Protect it... */
97 if (dictDelete(c->pubsub_channels,channel) == DICT_OK) {
98 retval = 1;
99 /* Remove the client from the channel -> clients list hash table */
100 de = dictFind(server.pubsub_channels,channel);
101 serverAssertWithInfo(c,NULL,de != NULL);
102 clients = dictGetVal(de);
103 ln = listSearchKey(clients,c);
104 serverAssertWithInfo(c,NULL,ln != NULL);
105 listDelNode(clients,ln);
106 if (listLength(clients) == 0) {
107 /* Free the list and associated hash entry at all if this was
108 * the latest client, so that it will be possible to abuse
109 * Redis PUBSUB creating millions of channels. */
110 dictDelete(server.pubsub_channels,channel);
111 }
112 }
113 /* Notify the client */
114 if (notify) {
115 addReply(c,shared.mbulkhdr[3]);
116 addReply(c,shared.unsubscribebulk);
117 addReplyBulk(c,channel);
118 addReplyLongLong(c,dictSize(c->pubsub_channels)+
119 listLength(c->pubsub_patterns));
120
121 }
122 decrRefCount(channel); /* it is finally safe to release it */
123 return retval;
124 }
125
126 /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */
pubsubSubscribePattern(client * c,robj * pattern)127 int pubsubSubscribePattern(client *c, robj *pattern) {
128 int retval = 0;
129
130 if (listSearchKey(c->pubsub_patterns,pattern) == NULL) {
131 retval = 1;
132 pubsubPattern *pat;
133 listAddNodeTail(c->pubsub_patterns,pattern);
134 incrRefCount(pattern);
135 pat = zmalloc(sizeof(*pat));
136 pat->pattern = getDecodedObject(pattern);
137 pat->client = c;
138 listAddNodeTail(server.pubsub_patterns,pat);
139 }
140 /* Notify the client */
141 addReply(c,shared.mbulkhdr[3]);
142 addReply(c,shared.psubscribebulk);
143 addReplyBulk(c,pattern);
144 addReplyLongLong(c,clientSubscriptionsCount(c));
145 return retval;
146 }
147
148 /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or
149 * 0 if the client was not subscribed to the specified channel. */
pubsubUnsubscribePattern(client * c,robj * pattern,int notify)150 int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) {
151 listNode *ln;
152 pubsubPattern pat;
153 int retval = 0;
154
155 incrRefCount(pattern); /* Protect the object. May be the same we remove */
156 if ((ln = listSearchKey(c->pubsub_patterns,pattern)) != NULL) {
157 retval = 1;
158 listDelNode(c->pubsub_patterns,ln);
159 pat.client = c;
160 pat.pattern = pattern;
161 ln = listSearchKey(server.pubsub_patterns,&pat);
162 listDelNode(server.pubsub_patterns,ln);
163 }
164 /* Notify the client */
165 if (notify) {
166 addReply(c,shared.mbulkhdr[3]);
167 addReply(c,shared.punsubscribebulk);
168 addReplyBulk(c,pattern);
169 addReplyLongLong(c,dictSize(c->pubsub_channels)+
170 listLength(c->pubsub_patterns));
171 }
172 decrRefCount(pattern);
173 return retval;
174 }
175
176 /* Unsubscribe from all the channels. Return the number of channels the
177 * client was subscribed to. */
pubsubUnsubscribeAllChannels(client * c,int notify)178 int pubsubUnsubscribeAllChannels(client *c, int notify) {
179 dictIterator *di = dictGetSafeIterator(c->pubsub_channels);
180 dictEntry *de;
181 int count = 0;
182
183 while((de = dictNext(di)) != NULL) {
184 robj *channel = dictGetKey(de);
185
186 count += pubsubUnsubscribeChannel(c,channel,notify);
187 }
188 /* We were subscribed to nothing? Still reply to the client. */
189 if (notify && count == 0) {
190 addReply(c,shared.mbulkhdr[3]);
191 addReply(c,shared.unsubscribebulk);
192 addReply(c,shared.nullbulk);
193 addReplyLongLong(c,dictSize(c->pubsub_channels)+
194 listLength(c->pubsub_patterns));
195 }
196 dictReleaseIterator(di);
197 return count;
198 }
199
200 /* Unsubscribe from all the patterns. Return the number of patterns the
201 * client was subscribed from. */
pubsubUnsubscribeAllPatterns(client * c,int notify)202 int pubsubUnsubscribeAllPatterns(client *c, int notify) {
203 listNode *ln;
204 listIter li;
205 int count = 0;
206
207 listRewind(c->pubsub_patterns,&li);
208 while ((ln = listNext(&li)) != NULL) {
209 robj *pattern = ln->value;
210
211 count += pubsubUnsubscribePattern(c,pattern,notify);
212 }
213 if (notify && count == 0) {
214 /* We were subscribed to nothing? Still reply to the client. */
215 addReply(c,shared.mbulkhdr[3]);
216 addReply(c,shared.punsubscribebulk);
217 addReply(c,shared.nullbulk);
218 addReplyLongLong(c,dictSize(c->pubsub_channels)+
219 listLength(c->pubsub_patterns));
220 }
221 return count;
222 }
223
224 /* Publish a message */
pubsubPublishMessage(robj * channel,robj * message)225 int pubsubPublishMessage(robj *channel, robj *message) {
226 int receivers = 0;
227 dictEntry *de;
228 listNode *ln;
229 listIter li;
230
231 /* Send to clients listening for that channel */
232 de = dictFind(server.pubsub_channels,channel);
233 if (de) {
234 list *list = dictGetVal(de);
235 listNode *ln;
236 listIter li;
237
238 listRewind(list,&li);
239 while ((ln = listNext(&li)) != NULL) {
240 client *c = ln->value;
241
242 addReply(c,shared.mbulkhdr[3]);
243 addReply(c,shared.messagebulk);
244 addReplyBulk(c,channel);
245 addReplyBulk(c,message);
246 receivers++;
247 }
248 }
249 /* Send to clients listening to matching channels */
250 if (listLength(server.pubsub_patterns)) {
251 listRewind(server.pubsub_patterns,&li);
252 channel = getDecodedObject(channel);
253 while ((ln = listNext(&li)) != NULL) {
254 pubsubPattern *pat = ln->value;
255
256 if (stringmatchlen((char*)pat->pattern->ptr,
257 sdslen(pat->pattern->ptr),
258 (char*)channel->ptr,
259 sdslen(channel->ptr),0)) {
260 addReply(pat->client,shared.mbulkhdr[4]);
261 addReply(pat->client,shared.pmessagebulk);
262 addReplyBulk(pat->client,pat->pattern);
263 addReplyBulk(pat->client,channel);
264 addReplyBulk(pat->client,message);
265 receivers++;
266 }
267 }
268 decrRefCount(channel);
269 }
270 return receivers;
271 }
272
273 /*-----------------------------------------------------------------------------
274 * Pubsub commands implementation
275 *----------------------------------------------------------------------------*/
276
subscribeCommand(client * c)277 void subscribeCommand(client *c) {
278 int j;
279
280 for (j = 1; j < c->argc; j++)
281 pubsubSubscribeChannel(c,c->argv[j]);
282 c->flags |= CLIENT_PUBSUB;
283 }
284
unsubscribeCommand(client * c)285 void unsubscribeCommand(client *c) {
286 if (c->argc == 1) {
287 pubsubUnsubscribeAllChannels(c,1);
288 } else {
289 int j;
290
291 for (j = 1; j < c->argc; j++)
292 pubsubUnsubscribeChannel(c,c->argv[j],1);
293 }
294 if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
295 }
296
psubscribeCommand(client * c)297 void psubscribeCommand(client *c) {
298 int j;
299
300 for (j = 1; j < c->argc; j++)
301 pubsubSubscribePattern(c,c->argv[j]);
302 c->flags |= CLIENT_PUBSUB;
303 }
304
punsubscribeCommand(client * c)305 void punsubscribeCommand(client *c) {
306 if (c->argc == 1) {
307 pubsubUnsubscribeAllPatterns(c,1);
308 } else {
309 int j;
310
311 for (j = 1; j < c->argc; j++)
312 pubsubUnsubscribePattern(c,c->argv[j],1);
313 }
314 if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB;
315 }
316
publishCommand(client * c)317 void publishCommand(client *c) {
318 int receivers = pubsubPublishMessage(c->argv[1],c->argv[2]);
319 if (server.cluster_enabled)
320 clusterPropagatePublish(c->argv[1],c->argv[2]);
321 else
322 forceCommandPropagation(c,PROPAGATE_REPL);
323 addReplyLongLong(c,receivers);
324 }
325
326 /* PUBSUB command for Pub/Sub introspection. */
pubsubCommand(client * c)327 void pubsubCommand(client *c) {
328 if (c->argc == 2 && !strcasecmp(c->argv[1]->ptr,"help")) {
329 const char *help[] = {
330 "CHANNELS [<pattern>] -- Return the currently active channels matching a pattern (default: all).",
331 "NUMPAT -- Return number of subscriptions to patterns.",
332 "NUMSUB [channel-1 .. channel-N] -- Returns the number of subscribers for the specified channels (excluding patterns, default: none).",
333 NULL
334 };
335 addReplyHelp(c, help);
336 } else if (!strcasecmp(c->argv[1]->ptr,"channels") &&
337 (c->argc == 2 || c->argc == 3))
338 {
339 /* PUBSUB CHANNELS [<pattern>] */
340 sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr;
341 dictIterator *di = dictGetIterator(server.pubsub_channels);
342 dictEntry *de;
343 long mblen = 0;
344 void *replylen;
345
346 replylen = addDeferredMultiBulkLength(c);
347 while((de = dictNext(di)) != NULL) {
348 robj *cobj = dictGetKey(de);
349 sds channel = cobj->ptr;
350
351 if (!pat || stringmatchlen(pat, sdslen(pat),
352 channel, sdslen(channel),0))
353 {
354 addReplyBulk(c,cobj);
355 mblen++;
356 }
357 }
358 dictReleaseIterator(di);
359 setDeferredMultiBulkLength(c,replylen,mblen);
360 } else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) {
361 /* PUBSUB NUMSUB [Channel_1 ... Channel_N] */
362 int j;
363
364 addReplyMultiBulkLen(c,(c->argc-2)*2);
365 for (j = 2; j < c->argc; j++) {
366 list *l = dictFetchValue(server.pubsub_channels,c->argv[j]);
367
368 addReplyBulk(c,c->argv[j]);
369 addReplyLongLong(c,l ? listLength(l) : 0);
370 }
371 } else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) {
372 /* PUBSUB NUMPAT */
373 addReplyLongLong(c,listLength(server.pubsub_patterns));
374 } else {
375 addReplySubcommandSyntaxError(c);
376 }
377 }
378