1 /* 2 * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com> 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * * Redistributions of source code must retain the above copyright notice, 9 * this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * * Neither the name of Redis nor the names of its contributors may be used 14 * to endorse or promote products derived from this software without 15 * specific prior written permission. 16 * 17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 * POSSIBILITY OF SUCH DAMAGE. 28 */ 29 30 #include "server.h" 31 32 /*----------------------------------------------------------------------------- 33 * Pubsub low level API 34 *----------------------------------------------------------------------------*/ 35 36 void freePubsubPattern(void *p) { 37 pubsubPattern *pat = p; 38 39 decrRefCount(pat->pattern); 40 zfree(pat); 41 } 42 43 int listMatchPubsubPattern(void *a, void *b) { 44 pubsubPattern *pa = a, *pb = b; 45 46 return (pa->client == pb->client) && 47 (equalStringObjects(pa->pattern,pb->pattern)); 48 } 49 50 /* Return the number of channels + patterns a client is subscribed to. */ 51 int clientSubscriptionsCount(client *c) { 52 return dictSize(c->pubsub_channels)+ 53 listLength(c->pubsub_patterns); 54 } 55 56 /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or 57 * 0 if the client was already subscribed to that channel. */ 58 int pubsubSubscribeChannel(client *c, robj *channel) { 59 dictEntry *de; 60 list *clients = NULL; 61 int retval = 0; 62 63 /* Add the channel to the client -> channels hash table */ 64 if (dictAdd(c->pubsub_channels,channel,NULL) == DICT_OK) { 65 retval = 1; 66 incrRefCount(channel); 67 /* Add the client to the channel -> list of clients hash table */ 68 de = dictFind(server.pubsub_channels,channel); 69 if (de == NULL) { 70 clients = listCreate(); 71 dictAdd(server.pubsub_channels,channel,clients); 72 incrRefCount(channel); 73 } else { 74 clients = dictGetVal(de); 75 } 76 listAddNodeTail(clients,c); 77 } 78 /* Notify the client */ 79 addReply(c,shared.mbulkhdr[3]); 80 addReply(c,shared.subscribebulk); 81 addReplyBulk(c,channel); 82 addReplyLongLong(c,clientSubscriptionsCount(c)); 83 return retval; 84 } 85 86 /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or 87 * 0 if the client was not subscribed to the specified channel. */ 88 int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { 89 dictEntry *de; 90 list *clients; 91 listNode *ln; 92 int retval = 0; 93 94 /* Remove the channel from the client -> channels hash table */ 95 incrRefCount(channel); /* channel may be just a pointer to the same object 96 we have in the hash tables. Protect it... */ 97 if (dictDelete(c->pubsub_channels,channel) == DICT_OK) { 98 retval = 1; 99 /* Remove the client from the channel -> clients list hash table */ 100 de = dictFind(server.pubsub_channels,channel); 101 serverAssertWithInfo(c,NULL,de != NULL); 102 clients = dictGetVal(de); 103 ln = listSearchKey(clients,c); 104 serverAssertWithInfo(c,NULL,ln != NULL); 105 listDelNode(clients,ln); 106 if (listLength(clients) == 0) { 107 /* Free the list and associated hash entry at all if this was 108 * the latest client, so that it will be possible to abuse 109 * Redis PUBSUB creating millions of channels. */ 110 dictDelete(server.pubsub_channels,channel); 111 } 112 } 113 /* Notify the client */ 114 if (notify) { 115 addReply(c,shared.mbulkhdr[3]); 116 addReply(c,shared.unsubscribebulk); 117 addReplyBulk(c,channel); 118 addReplyLongLong(c,dictSize(c->pubsub_channels)+ 119 listLength(c->pubsub_patterns)); 120 121 } 122 decrRefCount(channel); /* it is finally safe to release it */ 123 return retval; 124 } 125 126 /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ 127 int pubsubSubscribePattern(client *c, robj *pattern) { 128 int retval = 0; 129 130 if (listSearchKey(c->pubsub_patterns,pattern) == NULL) { 131 retval = 1; 132 pubsubPattern *pat; 133 listAddNodeTail(c->pubsub_patterns,pattern); 134 incrRefCount(pattern); 135 pat = zmalloc(sizeof(*pat)); 136 pat->pattern = getDecodedObject(pattern); 137 pat->client = c; 138 listAddNodeTail(server.pubsub_patterns,pat); 139 } 140 /* Notify the client */ 141 addReply(c,shared.mbulkhdr[3]); 142 addReply(c,shared.psubscribebulk); 143 addReplyBulk(c,pattern); 144 addReplyLongLong(c,clientSubscriptionsCount(c)); 145 return retval; 146 } 147 148 /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or 149 * 0 if the client was not subscribed to the specified channel. */ 150 int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { 151 listNode *ln; 152 pubsubPattern pat; 153 int retval = 0; 154 155 incrRefCount(pattern); /* Protect the object. May be the same we remove */ 156 if ((ln = listSearchKey(c->pubsub_patterns,pattern)) != NULL) { 157 retval = 1; 158 listDelNode(c->pubsub_patterns,ln); 159 pat.client = c; 160 pat.pattern = pattern; 161 ln = listSearchKey(server.pubsub_patterns,&pat); 162 listDelNode(server.pubsub_patterns,ln); 163 } 164 /* Notify the client */ 165 if (notify) { 166 addReply(c,shared.mbulkhdr[3]); 167 addReply(c,shared.punsubscribebulk); 168 addReplyBulk(c,pattern); 169 addReplyLongLong(c,dictSize(c->pubsub_channels)+ 170 listLength(c->pubsub_patterns)); 171 } 172 decrRefCount(pattern); 173 return retval; 174 } 175 176 /* Unsubscribe from all the channels. Return the number of channels the 177 * client was subscribed to. */ 178 int pubsubUnsubscribeAllChannels(client *c, int notify) { 179 dictIterator *di = dictGetSafeIterator(c->pubsub_channels); 180 dictEntry *de; 181 int count = 0; 182 183 while((de = dictNext(di)) != NULL) { 184 robj *channel = dictGetKey(de); 185 186 count += pubsubUnsubscribeChannel(c,channel,notify); 187 } 188 /* We were subscribed to nothing? Still reply to the client. */ 189 if (notify && count == 0) { 190 addReply(c,shared.mbulkhdr[3]); 191 addReply(c,shared.unsubscribebulk); 192 addReply(c,shared.nullbulk); 193 addReplyLongLong(c,dictSize(c->pubsub_channels)+ 194 listLength(c->pubsub_patterns)); 195 } 196 dictReleaseIterator(di); 197 return count; 198 } 199 200 /* Unsubscribe from all the patterns. Return the number of patterns the 201 * client was subscribed from. */ 202 int pubsubUnsubscribeAllPatterns(client *c, int notify) { 203 listNode *ln; 204 listIter li; 205 int count = 0; 206 207 listRewind(c->pubsub_patterns,&li); 208 while ((ln = listNext(&li)) != NULL) { 209 robj *pattern = ln->value; 210 211 count += pubsubUnsubscribePattern(c,pattern,notify); 212 } 213 if (notify && count == 0) { 214 /* We were subscribed to nothing? Still reply to the client. */ 215 addReply(c,shared.mbulkhdr[3]); 216 addReply(c,shared.punsubscribebulk); 217 addReply(c,shared.nullbulk); 218 addReplyLongLong(c,dictSize(c->pubsub_channels)+ 219 listLength(c->pubsub_patterns)); 220 } 221 return count; 222 } 223 224 /* Publish a message */ 225 int pubsubPublishMessage(robj *channel, robj *message) { 226 int receivers = 0; 227 dictEntry *de; 228 listNode *ln; 229 listIter li; 230 231 /* Send to clients listening for that channel */ 232 de = dictFind(server.pubsub_channels,channel); 233 if (de) { 234 list *list = dictGetVal(de); 235 listNode *ln; 236 listIter li; 237 238 listRewind(list,&li); 239 while ((ln = listNext(&li)) != NULL) { 240 client *c = ln->value; 241 242 addReply(c,shared.mbulkhdr[3]); 243 addReply(c,shared.messagebulk); 244 addReplyBulk(c,channel); 245 addReplyBulk(c,message); 246 receivers++; 247 } 248 } 249 /* Send to clients listening to matching channels */ 250 if (listLength(server.pubsub_patterns)) { 251 listRewind(server.pubsub_patterns,&li); 252 channel = getDecodedObject(channel); 253 while ((ln = listNext(&li)) != NULL) { 254 pubsubPattern *pat = ln->value; 255 256 if (stringmatchlen((char*)pat->pattern->ptr, 257 sdslen(pat->pattern->ptr), 258 (char*)channel->ptr, 259 sdslen(channel->ptr),0)) { 260 addReply(pat->client,shared.mbulkhdr[4]); 261 addReply(pat->client,shared.pmessagebulk); 262 addReplyBulk(pat->client,pat->pattern); 263 addReplyBulk(pat->client,channel); 264 addReplyBulk(pat->client,message); 265 receivers++; 266 } 267 } 268 decrRefCount(channel); 269 } 270 return receivers; 271 } 272 273 /*----------------------------------------------------------------------------- 274 * Pubsub commands implementation 275 *----------------------------------------------------------------------------*/ 276 277 void subscribeCommand(client *c) { 278 int j; 279 280 for (j = 1; j < c->argc; j++) 281 pubsubSubscribeChannel(c,c->argv[j]); 282 c->flags |= CLIENT_PUBSUB; 283 } 284 285 void unsubscribeCommand(client *c) { 286 if (c->argc == 1) { 287 pubsubUnsubscribeAllChannels(c,1); 288 } else { 289 int j; 290 291 for (j = 1; j < c->argc; j++) 292 pubsubUnsubscribeChannel(c,c->argv[j],1); 293 } 294 if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; 295 } 296 297 void psubscribeCommand(client *c) { 298 int j; 299 300 for (j = 1; j < c->argc; j++) 301 pubsubSubscribePattern(c,c->argv[j]); 302 c->flags |= CLIENT_PUBSUB; 303 } 304 305 void punsubscribeCommand(client *c) { 306 if (c->argc == 1) { 307 pubsubUnsubscribeAllPatterns(c,1); 308 } else { 309 int j; 310 311 for (j = 1; j < c->argc; j++) 312 pubsubUnsubscribePattern(c,c->argv[j],1); 313 } 314 if (clientSubscriptionsCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; 315 } 316 317 void publishCommand(client *c) { 318 int receivers = pubsubPublishMessage(c->argv[1],c->argv[2]); 319 if (server.cluster_enabled) 320 clusterPropagatePublish(c->argv[1],c->argv[2]); 321 else 322 forceCommandPropagation(c,PROPAGATE_REPL); 323 addReplyLongLong(c,receivers); 324 } 325 326 /* PUBSUB command for Pub/Sub introspection. */ 327 void pubsubCommand(client *c) { 328 if (c->argc == 2 && !strcasecmp(c->argv[1]->ptr,"help")) { 329 const char *help[] = { 330 "CHANNELS [<pattern>] -- Return the currently active channels matching a pattern (default: all).", 331 "NUMPAT -- Return number of subscriptions to patterns.", 332 "NUMSUB [channel-1 .. channel-N] -- Returns the number of subscribers for the specified channels (excluding patterns, default: none).", 333 NULL 334 }; 335 addReplyHelp(c, help); 336 } else if (!strcasecmp(c->argv[1]->ptr,"channels") && 337 (c->argc == 2 || c->argc == 3)) 338 { 339 /* PUBSUB CHANNELS [<pattern>] */ 340 sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; 341 dictIterator *di = dictGetIterator(server.pubsub_channels); 342 dictEntry *de; 343 long mblen = 0; 344 void *replylen; 345 346 replylen = addDeferredMultiBulkLength(c); 347 while((de = dictNext(di)) != NULL) { 348 robj *cobj = dictGetKey(de); 349 sds channel = cobj->ptr; 350 351 if (!pat || stringmatchlen(pat, sdslen(pat), 352 channel, sdslen(channel),0)) 353 { 354 addReplyBulk(c,cobj); 355 mblen++; 356 } 357 } 358 dictReleaseIterator(di); 359 setDeferredMultiBulkLength(c,replylen,mblen); 360 } else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) { 361 /* PUBSUB NUMSUB [Channel_1 ... Channel_N] */ 362 int j; 363 364 addReplyMultiBulkLen(c,(c->argc-2)*2); 365 for (j = 2; j < c->argc; j++) { 366 list *l = dictFetchValue(server.pubsub_channels,c->argv[j]); 367 368 addReplyBulk(c,c->argv[j]); 369 addReplyLongLong(c,l ? listLength(l) : 0); 370 } 371 } else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) { 372 /* PUBSUB NUMPAT */ 373 addReplyLongLong(c,listLength(server.pubsub_patterns)); 374 } else { 375 addReplySubcommandSyntaxError(c); 376 } 377 } 378