xref: /redis-3.2.3/src/t_set.c (revision 60323407)
1 /*
2  * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  *   * Redistributions of source code must retain the above copyright notice,
9  *     this list of conditions and the following disclaimer.
10  *   * Redistributions in binary form must reproduce the above copyright
11  *     notice, this list of conditions and the following disclaimer in the
12  *     documentation and/or other materials provided with the distribution.
13  *   * Neither the name of Redis nor the names of its contributors may be used
14  *     to endorse or promote products derived from this software without
15  *     specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27  * POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 #include "server.h"
31 
32 /*-----------------------------------------------------------------------------
33  * Set Commands
34  *----------------------------------------------------------------------------*/
35 
36 void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
37                               robj *dstkey, int op);
38 
39 /* Factory method to return a set that *can* hold "value". When the object has
40  * an integer-encodable value, an intset will be returned. Otherwise a regular
41  * hash table. */
setTypeCreate(robj * value)42 robj *setTypeCreate(robj *value) {
43     if (isObjectRepresentableAsLongLong(value,NULL) == C_OK)
44         return createIntsetObject();
45     return createSetObject();
46 }
47 
48 /* Add the specified value into a set. The function takes care of incrementing
49  * the reference count of the object if needed in order to retain a copy.
50  *
51  * If the value was already member of the set, nothing is done and 0 is
52  * returned, otherwise the new element is added and 1 is returned. */
setTypeAdd(robj * subject,robj * value)53 int setTypeAdd(robj *subject, robj *value) {
54     long long llval;
55     if (subject->encoding == OBJ_ENCODING_HT) {
56         if (dictAdd(subject->ptr,value,NULL) == DICT_OK) {
57             incrRefCount(value);
58             return 1;
59         }
60     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
61         if (isObjectRepresentableAsLongLong(value,&llval) == C_OK) {
62             uint8_t success = 0;
63             subject->ptr = intsetAdd(subject->ptr,llval,&success);
64             if (success) {
65                 /* Convert to regular set when the intset contains
66                  * too many entries. */
67                 if (intsetLen(subject->ptr) > server.set_max_intset_entries)
68                     setTypeConvert(subject,OBJ_ENCODING_HT);
69                 return 1;
70             }
71         } else {
72             /* Failed to get integer from object, convert to regular set. */
73             setTypeConvert(subject,OBJ_ENCODING_HT);
74 
75             /* The set *was* an intset and this value is not integer
76              * encodable, so dictAdd should always work. */
77             serverAssertWithInfo(NULL,value,
78                                 dictAdd(subject->ptr,value,NULL) == DICT_OK);
79             incrRefCount(value);
80             return 1;
81         }
82     } else {
83         serverPanic("Unknown set encoding");
84     }
85     return 0;
86 }
87 
setTypeRemove(robj * setobj,robj * value)88 int setTypeRemove(robj *setobj, robj *value) {
89     long long llval;
90     if (setobj->encoding == OBJ_ENCODING_HT) {
91         if (dictDelete(setobj->ptr,value) == DICT_OK) {
92             if (htNeedsResize(setobj->ptr)) dictResize(setobj->ptr);
93             return 1;
94         }
95     } else if (setobj->encoding == OBJ_ENCODING_INTSET) {
96         if (isObjectRepresentableAsLongLong(value,&llval) == C_OK) {
97             int success;
98             setobj->ptr = intsetRemove(setobj->ptr,llval,&success);
99             if (success) return 1;
100         }
101     } else {
102         serverPanic("Unknown set encoding");
103     }
104     return 0;
105 }
106 
setTypeIsMember(robj * subject,robj * value)107 int setTypeIsMember(robj *subject, robj *value) {
108     long long llval;
109     if (subject->encoding == OBJ_ENCODING_HT) {
110         return dictFind((dict*)subject->ptr,value) != NULL;
111     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
112         if (isObjectRepresentableAsLongLong(value,&llval) == C_OK) {
113             return intsetFind((intset*)subject->ptr,llval);
114         }
115     } else {
116         serverPanic("Unknown set encoding");
117     }
118     return 0;
119 }
120 
setTypeInitIterator(robj * subject)121 setTypeIterator *setTypeInitIterator(robj *subject) {
122     setTypeIterator *si = zmalloc(sizeof(setTypeIterator));
123     si->subject = subject;
124     si->encoding = subject->encoding;
125     if (si->encoding == OBJ_ENCODING_HT) {
126         si->di = dictGetIterator(subject->ptr);
127     } else if (si->encoding == OBJ_ENCODING_INTSET) {
128         si->ii = 0;
129     } else {
130         serverPanic("Unknown set encoding");
131     }
132     return si;
133 }
134 
setTypeReleaseIterator(setTypeIterator * si)135 void setTypeReleaseIterator(setTypeIterator *si) {
136     if (si->encoding == OBJ_ENCODING_HT)
137         dictReleaseIterator(si->di);
138     zfree(si);
139 }
140 
141 /* Move to the next entry in the set. Returns the object at the current
142  * position.
143  *
144  * Since set elements can be internally be stored as redis objects or
145  * simple arrays of integers, setTypeNext returns the encoding of the
146  * set object you are iterating, and will populate the appropriate pointer
147  * (objele) or (llele) accordingly.
148  *
149  * Note that both the objele and llele pointers should be passed and cannot
150  * be NULL since the function will try to defensively populate the non
151  * used field with values which are easy to trap if misused.
152  *
153  * When there are no longer elements -1 is returned.
154  * Returned objects ref count is not incremented, so this function is
155  * copy on write friendly. */
setTypeNext(setTypeIterator * si,robj ** objele,int64_t * llele)156 int setTypeNext(setTypeIterator *si, robj **objele, int64_t *llele) {
157     if (si->encoding == OBJ_ENCODING_HT) {
158         dictEntry *de = dictNext(si->di);
159         if (de == NULL) return -1;
160         *objele = dictGetKey(de);
161         *llele = -123456789; /* Not needed. Defensive. */
162     } else if (si->encoding == OBJ_ENCODING_INTSET) {
163         if (!intsetGet(si->subject->ptr,si->ii++,llele))
164             return -1;
165         *objele = NULL; /* Not needed. Defensive. */
166     } else {
167         serverPanic("Wrong set encoding in setTypeNext");
168     }
169     return si->encoding;
170 }
171 
172 /* The not copy on write friendly version but easy to use version
173  * of setTypeNext() is setTypeNextObject(), returning new objects
174  * or incrementing the ref count of returned objects. So if you don't
175  * retain a pointer to this object you should call decrRefCount() against it.
176  *
177  * This function is the way to go for write operations where COW is not
178  * an issue as the result will be anyway of incrementing the ref count. */
setTypeNextObject(setTypeIterator * si)179 robj *setTypeNextObject(setTypeIterator *si) {
180     int64_t intele;
181     robj *objele;
182     int encoding;
183 
184     encoding = setTypeNext(si,&objele,&intele);
185     switch(encoding) {
186         case -1:    return NULL;
187         case OBJ_ENCODING_INTSET:
188             return createStringObjectFromLongLong(intele);
189         case OBJ_ENCODING_HT:
190             incrRefCount(objele);
191             return objele;
192         default:
193             serverPanic("Unsupported encoding");
194     }
195     return NULL; /* just to suppress warnings */
196 }
197 
198 /* Return random element from a non empty set.
199  * The returned element can be a int64_t value if the set is encoded
200  * as an "intset" blob of integers, or a redis object if the set
201  * is a regular set.
202  *
203  * The caller provides both pointers to be populated with the right
204  * object. The return value of the function is the object->encoding
205  * field of the object and is used by the caller to check if the
206  * int64_t pointer or the redis object pointer was populated.
207  *
208  * Note that both the objele and llele pointers should be passed and cannot
209  * be NULL since the function will try to defensively populate the non
210  * used field with values which are easy to trap if misused.
211  *
212  * When an object is returned (the set was a real set) the ref count
213  * of the object is not incremented so this function can be considered
214  * copy on write friendly. */
setTypeRandomElement(robj * setobj,robj ** objele,int64_t * llele)215 int setTypeRandomElement(robj *setobj, robj **objele, int64_t *llele) {
216     if (setobj->encoding == OBJ_ENCODING_HT) {
217         dictEntry *de = dictGetRandomKey(setobj->ptr);
218         *objele = dictGetKey(de);
219         *llele = -123456789; /* Not needed. Defensive. */
220     } else if (setobj->encoding == OBJ_ENCODING_INTSET) {
221         *llele = intsetRandom(setobj->ptr);
222         *objele = NULL; /* Not needed. Defensive. */
223     } else {
224         serverPanic("Unknown set encoding");
225     }
226     return setobj->encoding;
227 }
228 
setTypeSize(robj * subject)229 unsigned long setTypeSize(robj *subject) {
230     if (subject->encoding == OBJ_ENCODING_HT) {
231         return dictSize((dict*)subject->ptr);
232     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
233         return intsetLen((intset*)subject->ptr);
234     } else {
235         serverPanic("Unknown set encoding");
236     }
237 }
238 
239 /* Convert the set to specified encoding. The resulting dict (when converting
240  * to a hash table) is presized to hold the number of elements in the original
241  * set. */
setTypeConvert(robj * setobj,int enc)242 void setTypeConvert(robj *setobj, int enc) {
243     setTypeIterator *si;
244     serverAssertWithInfo(NULL,setobj,setobj->type == OBJ_SET &&
245                              setobj->encoding == OBJ_ENCODING_INTSET);
246 
247     if (enc == OBJ_ENCODING_HT) {
248         int64_t intele;
249         dict *d = dictCreate(&setDictType,NULL);
250         robj *element;
251 
252         /* Presize the dict to avoid rehashing */
253         dictExpand(d,intsetLen(setobj->ptr));
254 
255         /* To add the elements we extract integers and create redis objects */
256         si = setTypeInitIterator(setobj);
257         while (setTypeNext(si,&element,&intele) != -1) {
258             element = createStringObjectFromLongLong(intele);
259             serverAssertWithInfo(NULL,element,
260                                 dictAdd(d,element,NULL) == DICT_OK);
261         }
262         setTypeReleaseIterator(si);
263 
264         setobj->encoding = OBJ_ENCODING_HT;
265         zfree(setobj->ptr);
266         setobj->ptr = d;
267     } else {
268         serverPanic("Unsupported set conversion");
269     }
270 }
271 
saddCommand(client * c)272 void saddCommand(client *c) {
273     robj *set;
274     int j, added = 0;
275 
276     set = lookupKeyWrite(c->db,c->argv[1]);
277     if (set == NULL) {
278         set = setTypeCreate(c->argv[2]);
279         dbAdd(c->db,c->argv[1],set);
280     } else {
281         if (set->type != OBJ_SET) {
282             addReply(c,shared.wrongtypeerr);
283             return;
284         }
285     }
286 
287     for (j = 2; j < c->argc; j++) {
288         c->argv[j] = tryObjectEncoding(c->argv[j]);
289         if (setTypeAdd(set,c->argv[j])) added++;
290     }
291     if (added) {
292         signalModifiedKey(c->db,c->argv[1]);
293         notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[1],c->db->id);
294     }
295     server.dirty += added;
296     addReplyLongLong(c,added);
297 }
298 
sremCommand(client * c)299 void sremCommand(client *c) {
300     robj *set;
301     int j, deleted = 0, keyremoved = 0;
302 
303     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
304         checkType(c,set,OBJ_SET)) return;
305 
306     for (j = 2; j < c->argc; j++) {
307         if (setTypeRemove(set,c->argv[j])) {
308             deleted++;
309             if (setTypeSize(set) == 0) {
310                 dbDelete(c->db,c->argv[1]);
311                 keyremoved = 1;
312                 break;
313             }
314         }
315     }
316     if (deleted) {
317         signalModifiedKey(c->db,c->argv[1]);
318         notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
319         if (keyremoved)
320             notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],
321                                 c->db->id);
322         server.dirty += deleted;
323     }
324     addReplyLongLong(c,deleted);
325 }
326 
smoveCommand(client * c)327 void smoveCommand(client *c) {
328     robj *srcset, *dstset, *ele;
329     srcset = lookupKeyWrite(c->db,c->argv[1]);
330     dstset = lookupKeyWrite(c->db,c->argv[2]);
331     ele = c->argv[3] = tryObjectEncoding(c->argv[3]);
332 
333     /* If the source key does not exist return 0 */
334     if (srcset == NULL) {
335         addReply(c,shared.czero);
336         return;
337     }
338 
339     /* If the source key has the wrong type, or the destination key
340      * is set and has the wrong type, return with an error. */
341     if (checkType(c,srcset,OBJ_SET) ||
342         (dstset && checkType(c,dstset,OBJ_SET))) return;
343 
344     /* If srcset and dstset are equal, SMOVE is a no-op */
345     if (srcset == dstset) {
346         addReply(c,setTypeIsMember(srcset,ele) ? shared.cone : shared.czero);
347         return;
348     }
349 
350     /* If the element cannot be removed from the src set, return 0. */
351     if (!setTypeRemove(srcset,ele)) {
352         addReply(c,shared.czero);
353         return;
354     }
355     notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
356 
357     /* Remove the src set from the database when empty */
358     if (setTypeSize(srcset) == 0) {
359         dbDelete(c->db,c->argv[1]);
360         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
361     }
362 
363     /* Create the destination set when it doesn't exist */
364     if (!dstset) {
365         dstset = setTypeCreate(ele);
366         dbAdd(c->db,c->argv[2],dstset);
367     }
368 
369     signalModifiedKey(c->db,c->argv[1]);
370     signalModifiedKey(c->db,c->argv[2]);
371     server.dirty++;
372 
373     /* An extra key has changed when ele was successfully added to dstset */
374     if (setTypeAdd(dstset,ele)) {
375         server.dirty++;
376         notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[2],c->db->id);
377     }
378     addReply(c,shared.cone);
379 }
380 
sismemberCommand(client * c)381 void sismemberCommand(client *c) {
382     robj *set;
383 
384     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
385         checkType(c,set,OBJ_SET)) return;
386 
387     c->argv[2] = tryObjectEncoding(c->argv[2]);
388     if (setTypeIsMember(set,c->argv[2]))
389         addReply(c,shared.cone);
390     else
391         addReply(c,shared.czero);
392 }
393 
scardCommand(client * c)394 void scardCommand(client *c) {
395     robj *o;
396 
397     if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
398         checkType(c,o,OBJ_SET)) return;
399 
400     addReplyLongLong(c,setTypeSize(o));
401 }
402 
403 /* Handle the "SPOP key <count>" variant. The normal version of the
404  * command is handled by the spopCommand() function itself. */
405 
406 /* How many times bigger should be the set compared to the remaining size
407  * for us to use the "create new set" strategy? Read later in the
408  * implementation for more info. */
409 #define SPOP_MOVE_STRATEGY_MUL 5
410 
spopWithCountCommand(client * c)411 void spopWithCountCommand(client *c) {
412     long l;
413     unsigned long count, size;
414     robj *set;
415 
416     /* Get the count argument */
417     if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
418     if (l >= 0) {
419         count = (unsigned) l;
420     } else {
421         addReply(c,shared.outofrangeerr);
422         return;
423     }
424 
425     /* Make sure a key with the name inputted exists, and that it's type is
426      * indeed a set. Otherwise, return nil */
427     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
428         == NULL || checkType(c,set,OBJ_SET)) return;
429 
430     /* If count is zero, serve an empty multibulk ASAP to avoid special
431      * cases later. */
432     if (count == 0) {
433         addReply(c,shared.emptymultibulk);
434         return;
435     }
436 
437     size = setTypeSize(set);
438 
439     /* Generate an SPOP keyspace notification */
440     notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
441     server.dirty += count;
442 
443     /* CASE 1:
444      * The number of requested elements is greater than or equal to
445      * the number of elements inside the set: simply return the whole set. */
446     if (count >= size) {
447         /* We just return the entire set */
448         sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
449 
450         /* Delete the set as it is now empty */
451         dbDelete(c->db,c->argv[1]);
452         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
453 
454         /* Propagate this command as an DEL operation */
455         rewriteClientCommandVector(c,2,shared.del,c->argv[1]);
456         signalModifiedKey(c->db,c->argv[1]);
457         server.dirty++;
458         return;
459     }
460 
461     /* Case 2 and 3 require to replicate SPOP as a set of SERM commands.
462      * Prepare our replication argument vector. Also send the array length
463      * which is common to both the code paths. */
464     robj *propargv[3];
465     propargv[0] = createStringObject("SREM",4);
466     propargv[1] = c->argv[1];
467     addReplyMultiBulkLen(c,count);
468 
469     /* Common iteration vars. */
470     robj *objele;
471     int encoding;
472     int64_t llele;
473     unsigned long remaining = size-count; /* Elements left after SPOP. */
474 
475     /* If we are here, the number of requested elements is less than the
476      * number of elements inside the set. Also we are sure that count < size.
477      * Use two different strategies.
478      *
479      * CASE 2: The number of elements to return is small compared to the
480      * set size. We can just extract random elements and return them to
481      * the set. */
482     if (remaining*SPOP_MOVE_STRATEGY_MUL > count) {
483         while(count--) {
484             encoding = setTypeRandomElement(set,&objele,&llele);
485             if (encoding == OBJ_ENCODING_INTSET) {
486                 objele = createStringObjectFromLongLong(llele);
487             } else {
488                 incrRefCount(objele);
489             }
490 
491             /* Return the element to the client and remove from the set. */
492             addReplyBulk(c,objele);
493             setTypeRemove(set,objele);
494 
495             /* Replicate/AOF this command as an SREM operation */
496             propargv[2] = objele;
497             alsoPropagate(server.sremCommand,c->db->id,propargv,3,
498                 PROPAGATE_AOF|PROPAGATE_REPL);
499             decrRefCount(objele);
500         }
501     } else {
502     /* CASE 3: The number of elements to return is very big, approaching
503      * the size of the set itself. After some time extracting random elements
504      * from such a set becomes computationally expensive, so we use
505      * a different strategy, we extract random elements that we don't
506      * want to return (the elements that will remain part of the set),
507      * creating a new set as we do this (that will be stored as the original
508      * set). Then we return the elements left in the original set and
509      * release it. */
510         robj *newset = NULL;
511 
512         /* Create a new set with just the remaining elements. */
513         while(remaining--) {
514             encoding = setTypeRandomElement(set,&objele,&llele);
515             if (encoding == OBJ_ENCODING_INTSET) {
516                 objele = createStringObjectFromLongLong(llele);
517             } else {
518                 incrRefCount(objele);
519             }
520             if (!newset) newset = setTypeCreate(objele);
521             setTypeAdd(newset,objele);
522             setTypeRemove(set,objele);
523             decrRefCount(objele);
524         }
525 
526         /* Assign the new set as the key value. */
527         incrRefCount(set); /* Protect the old set value. */
528         dbOverwrite(c->db,c->argv[1],newset);
529 
530         /* Tranfer the old set to the client and release it. */
531         setTypeIterator *si;
532         si = setTypeInitIterator(set);
533         while((encoding = setTypeNext(si,&objele,&llele)) != -1) {
534             if (encoding == OBJ_ENCODING_INTSET) {
535                 objele = createStringObjectFromLongLong(llele);
536             } else {
537                 incrRefCount(objele);
538             }
539             addReplyBulk(c,objele);
540 
541             /* Replicate/AOF this command as an SREM operation */
542             propargv[2] = objele;
543             alsoPropagate(server.sremCommand,c->db->id,propargv,3,
544                 PROPAGATE_AOF|PROPAGATE_REPL);
545 
546             decrRefCount(objele);
547         }
548         setTypeReleaseIterator(si);
549         decrRefCount(set);
550     }
551 
552     /* Don't propagate the command itself even if we incremented the
553      * dirty counter. We don't want to propagate an SPOP command since
554      * we propagated the command as a set of SREMs operations using
555      * the alsoPropagate() API. */
556     decrRefCount(propargv[0]);
557     preventCommandPropagation(c);
558     signalModifiedKey(c->db,c->argv[1]);
559     server.dirty++;
560 }
561 
spopCommand(client * c)562 void spopCommand(client *c) {
563     robj *set, *ele, *aux;
564     int64_t llele;
565     int encoding;
566 
567     if (c->argc == 3) {
568         spopWithCountCommand(c);
569         return;
570     } else if (c->argc > 3) {
571         addReply(c,shared.syntaxerr);
572         return;
573     }
574 
575     /* Make sure a key with the name inputted exists, and that it's type is
576      * indeed a set */
577     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
578         checkType(c,set,OBJ_SET)) return;
579 
580     /* Get a random element from the set */
581     encoding = setTypeRandomElement(set,&ele,&llele);
582 
583     /* Remove the element from the set */
584     if (encoding == OBJ_ENCODING_INTSET) {
585         ele = createStringObjectFromLongLong(llele);
586         set->ptr = intsetRemove(set->ptr,llele,NULL);
587     } else {
588         incrRefCount(ele);
589         setTypeRemove(set,ele);
590     }
591 
592     notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
593 
594     /* Replicate/AOF this command as an SREM operation */
595     aux = createStringObject("SREM",4);
596     rewriteClientCommandVector(c,3,aux,c->argv[1],ele);
597     decrRefCount(ele);
598     decrRefCount(aux);
599 
600     /* Add the element to the reply */
601     addReplyBulk(c,ele);
602 
603     /* Delete the set if it's empty */
604     if (setTypeSize(set) == 0) {
605         dbDelete(c->db,c->argv[1]);
606         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
607     }
608 
609     /* Set has been modified */
610     signalModifiedKey(c->db,c->argv[1]);
611     server.dirty++;
612 }
613 
614 /* handle the "SRANDMEMBER key <count>" variant. The normal version of the
615  * command is handled by the srandmemberCommand() function itself. */
616 
617 /* How many times bigger should be the set compared to the requested size
618  * for us to don't use the "remove elements" strategy? Read later in the
619  * implementation for more info. */
620 #define SRANDMEMBER_SUB_STRATEGY_MUL 3
621 
srandmemberWithCountCommand(client * c)622 void srandmemberWithCountCommand(client *c) {
623     long l;
624     unsigned long count, size;
625     int uniq = 1;
626     robj *set, *ele;
627     int64_t llele;
628     int encoding;
629 
630     dict *d;
631 
632     if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
633     if (l >= 0) {
634         count = (unsigned) l;
635     } else {
636         /* A negative count means: return the same elements multiple times
637          * (i.e. don't remove the extracted element after every extraction). */
638         count = -l;
639         uniq = 0;
640     }
641 
642     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
643         == NULL || checkType(c,set,OBJ_SET)) return;
644     size = setTypeSize(set);
645 
646     /* If count is zero, serve it ASAP to avoid special cases later. */
647     if (count == 0) {
648         addReply(c,shared.emptymultibulk);
649         return;
650     }
651 
652     /* CASE 1: The count was negative, so the extraction method is just:
653      * "return N random elements" sampling the whole set every time.
654      * This case is trivial and can be served without auxiliary data
655      * structures. */
656     if (!uniq) {
657         addReplyMultiBulkLen(c,count);
658         while(count--) {
659             encoding = setTypeRandomElement(set,&ele,&llele);
660             if (encoding == OBJ_ENCODING_INTSET) {
661                 addReplyBulkLongLong(c,llele);
662             } else {
663                 addReplyBulk(c,ele);
664             }
665         }
666         return;
667     }
668 
669     /* CASE 2:
670      * The number of requested elements is greater than the number of
671      * elements inside the set: simply return the whole set. */
672     if (count >= size) {
673         sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
674         return;
675     }
676 
677     /* For CASE 3 and CASE 4 we need an auxiliary dictionary. */
678     d = dictCreate(&setDictType,NULL);
679 
680     /* CASE 3:
681      * The number of elements inside the set is not greater than
682      * SRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements.
683      * In this case we create a set from scratch with all the elements, and
684      * subtract random elements to reach the requested number of elements.
685      *
686      * This is done because if the number of requsted elements is just
687      * a bit less than the number of elements in the set, the natural approach
688      * used into CASE 3 is highly inefficient. */
689     if (count*SRANDMEMBER_SUB_STRATEGY_MUL > size) {
690         setTypeIterator *si;
691 
692         /* Add all the elements into the temporary dictionary. */
693         si = setTypeInitIterator(set);
694         while((encoding = setTypeNext(si,&ele,&llele)) != -1) {
695             int retval = DICT_ERR;
696 
697             if (encoding == OBJ_ENCODING_INTSET) {
698                 retval = dictAdd(d,createStringObjectFromLongLong(llele),NULL);
699             } else {
700                 retval = dictAdd(d,dupStringObject(ele),NULL);
701             }
702             serverAssert(retval == DICT_OK);
703         }
704         setTypeReleaseIterator(si);
705         serverAssert(dictSize(d) == size);
706 
707         /* Remove random elements to reach the right count. */
708         while(size > count) {
709             dictEntry *de;
710 
711             de = dictGetRandomKey(d);
712             dictDelete(d,dictGetKey(de));
713             size--;
714         }
715     }
716 
717     /* CASE 4: We have a big set compared to the requested number of elements.
718      * In this case we can simply get random elements from the set and add
719      * to the temporary set, trying to eventually get enough unique elements
720      * to reach the specified count. */
721     else {
722         unsigned long added = 0;
723 
724         while(added < count) {
725             encoding = setTypeRandomElement(set,&ele,&llele);
726             if (encoding == OBJ_ENCODING_INTSET) {
727                 ele = createStringObjectFromLongLong(llele);
728             } else {
729                 ele = dupStringObject(ele);
730             }
731             /* Try to add the object to the dictionary. If it already exists
732              * free it, otherwise increment the number of objects we have
733              * in the result dictionary. */
734             if (dictAdd(d,ele,NULL) == DICT_OK)
735                 added++;
736             else
737                 decrRefCount(ele);
738         }
739     }
740 
741     /* CASE 3 & 4: send the result to the user. */
742     {
743         dictIterator *di;
744         dictEntry *de;
745 
746         addReplyMultiBulkLen(c,count);
747         di = dictGetIterator(d);
748         while((de = dictNext(di)) != NULL)
749             addReplyBulk(c,dictGetKey(de));
750         dictReleaseIterator(di);
751         dictRelease(d);
752     }
753 }
754 
srandmemberCommand(client * c)755 void srandmemberCommand(client *c) {
756     robj *set, *ele;
757     int64_t llele;
758     int encoding;
759 
760     if (c->argc == 3) {
761         srandmemberWithCountCommand(c);
762         return;
763     } else if (c->argc > 3) {
764         addReply(c,shared.syntaxerr);
765         return;
766     }
767 
768     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
769         checkType(c,set,OBJ_SET)) return;
770 
771     encoding = setTypeRandomElement(set,&ele,&llele);
772     if (encoding == OBJ_ENCODING_INTSET) {
773         addReplyBulkLongLong(c,llele);
774     } else {
775         addReplyBulk(c,ele);
776     }
777 }
778 
qsortCompareSetsByCardinality(const void * s1,const void * s2)779 int qsortCompareSetsByCardinality(const void *s1, const void *s2) {
780     return setTypeSize(*(robj**)s1)-setTypeSize(*(robj**)s2);
781 }
782 
783 /* This is used by SDIFF and in this case we can receive NULL that should
784  * be handled as empty sets. */
qsortCompareSetsByRevCardinality(const void * s1,const void * s2)785 int qsortCompareSetsByRevCardinality(const void *s1, const void *s2) {
786     robj *o1 = *(robj**)s1, *o2 = *(robj**)s2;
787 
788     return  (o2 ? setTypeSize(o2) : 0) - (o1 ? setTypeSize(o1) : 0);
789 }
790 
sinterGenericCommand(client * c,robj ** setkeys,unsigned long setnum,robj * dstkey)791 void sinterGenericCommand(client *c, robj **setkeys,
792                           unsigned long setnum, robj *dstkey) {
793     robj **sets = zmalloc(sizeof(robj*)*setnum);
794     setTypeIterator *si;
795     robj *eleobj, *dstset = NULL;
796     int64_t intobj;
797     void *replylen = NULL;
798     unsigned long j, cardinality = 0;
799     int encoding;
800 
801     for (j = 0; j < setnum; j++) {
802         robj *setobj = dstkey ?
803             lookupKeyWrite(c->db,setkeys[j]) :
804             lookupKeyRead(c->db,setkeys[j]);
805         if (!setobj) {
806             zfree(sets);
807             if (dstkey) {
808                 if (dbDelete(c->db,dstkey)) {
809                     signalModifiedKey(c->db,dstkey);
810                     server.dirty++;
811                 }
812                 addReply(c,shared.czero);
813             } else {
814                 addReply(c,shared.emptymultibulk);
815             }
816             return;
817         }
818         if (checkType(c,setobj,OBJ_SET)) {
819             zfree(sets);
820             return;
821         }
822         sets[j] = setobj;
823     }
824     /* Sort sets from the smallest to largest, this will improve our
825      * algorithm's performance */
826     qsort(sets,setnum,sizeof(robj*),qsortCompareSetsByCardinality);
827 
828     /* The first thing we should output is the total number of elements...
829      * since this is a multi-bulk write, but at this stage we don't know
830      * the intersection set size, so we use a trick, append an empty object
831      * to the output list and save the pointer to later modify it with the
832      * right length */
833     if (!dstkey) {
834         replylen = addDeferredMultiBulkLength(c);
835     } else {
836         /* If we have a target key where to store the resulting set
837          * create this key with an empty set inside */
838         dstset = createIntsetObject();
839     }
840 
841     /* Iterate all the elements of the first (smallest) set, and test
842      * the element against all the other sets, if at least one set does
843      * not include the element it is discarded */
844     si = setTypeInitIterator(sets[0]);
845     while((encoding = setTypeNext(si,&eleobj,&intobj)) != -1) {
846         for (j = 1; j < setnum; j++) {
847             if (sets[j] == sets[0]) continue;
848             if (encoding == OBJ_ENCODING_INTSET) {
849                 /* intset with intset is simple... and fast */
850                 if (sets[j]->encoding == OBJ_ENCODING_INTSET &&
851                     !intsetFind((intset*)sets[j]->ptr,intobj))
852                 {
853                     break;
854                 /* in order to compare an integer with an object we
855                  * have to use the generic function, creating an object
856                  * for this */
857                 } else if (sets[j]->encoding == OBJ_ENCODING_HT) {
858                     eleobj = createStringObjectFromLongLong(intobj);
859                     if (!setTypeIsMember(sets[j],eleobj)) {
860                         decrRefCount(eleobj);
861                         break;
862                     }
863                     decrRefCount(eleobj);
864                 }
865             } else if (encoding == OBJ_ENCODING_HT) {
866                 /* Optimization... if the source object is integer
867                  * encoded AND the target set is an intset, we can get
868                  * a much faster path. */
869                 if (eleobj->encoding == OBJ_ENCODING_INT &&
870                     sets[j]->encoding == OBJ_ENCODING_INTSET &&
871                     !intsetFind((intset*)sets[j]->ptr,(long)eleobj->ptr))
872                 {
873                     break;
874                 /* else... object to object check is easy as we use the
875                  * type agnostic API here. */
876                 } else if (!setTypeIsMember(sets[j],eleobj)) {
877                     break;
878                 }
879             }
880         }
881 
882         /* Only take action when all sets contain the member */
883         if (j == setnum) {
884             if (!dstkey) {
885                 if (encoding == OBJ_ENCODING_HT)
886                     addReplyBulk(c,eleobj);
887                 else
888                     addReplyBulkLongLong(c,intobj);
889                 cardinality++;
890             } else {
891                 if (encoding == OBJ_ENCODING_INTSET) {
892                     eleobj = createStringObjectFromLongLong(intobj);
893                     setTypeAdd(dstset,eleobj);
894                     decrRefCount(eleobj);
895                 } else {
896                     setTypeAdd(dstset,eleobj);
897                 }
898             }
899         }
900     }
901     setTypeReleaseIterator(si);
902 
903     if (dstkey) {
904         /* Store the resulting set into the target, if the intersection
905          * is not an empty set. */
906         int deleted = dbDelete(c->db,dstkey);
907         if (setTypeSize(dstset) > 0) {
908             dbAdd(c->db,dstkey,dstset);
909             addReplyLongLong(c,setTypeSize(dstset));
910             notifyKeyspaceEvent(NOTIFY_SET,"sinterstore",
911                 dstkey,c->db->id);
912         } else {
913             decrRefCount(dstset);
914             addReply(c,shared.czero);
915             if (deleted)
916                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
917                     dstkey,c->db->id);
918         }
919         signalModifiedKey(c->db,dstkey);
920         server.dirty++;
921     } else {
922         setDeferredMultiBulkLength(c,replylen,cardinality);
923     }
924     zfree(sets);
925 }
926 
sinterCommand(client * c)927 void sinterCommand(client *c) {
928     sinterGenericCommand(c,c->argv+1,c->argc-1,NULL);
929 }
930 
sinterstoreCommand(client * c)931 void sinterstoreCommand(client *c) {
932     sinterGenericCommand(c,c->argv+2,c->argc-2,c->argv[1]);
933 }
934 
935 #define SET_OP_UNION 0
936 #define SET_OP_DIFF 1
937 #define SET_OP_INTER 2
938 
sunionDiffGenericCommand(client * c,robj ** setkeys,int setnum,robj * dstkey,int op)939 void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
940                               robj *dstkey, int op) {
941     robj **sets = zmalloc(sizeof(robj*)*setnum);
942     setTypeIterator *si;
943     robj *ele, *dstset = NULL;
944     int j, cardinality = 0;
945     int diff_algo = 1;
946 
947     for (j = 0; j < setnum; j++) {
948         robj *setobj = dstkey ?
949             lookupKeyWrite(c->db,setkeys[j]) :
950             lookupKeyRead(c->db,setkeys[j]);
951         if (!setobj) {
952             sets[j] = NULL;
953             continue;
954         }
955         if (checkType(c,setobj,OBJ_SET)) {
956             zfree(sets);
957             return;
958         }
959         sets[j] = setobj;
960     }
961 
962     /* Select what DIFF algorithm to use.
963      *
964      * Algorithm 1 is O(N*M) where N is the size of the element first set
965      * and M the total number of sets.
966      *
967      * Algorithm 2 is O(N) where N is the total number of elements in all
968      * the sets.
969      *
970      * We compute what is the best bet with the current input here. */
971     if (op == SET_OP_DIFF && sets[0]) {
972         long long algo_one_work = 0, algo_two_work = 0;
973 
974         for (j = 0; j < setnum; j++) {
975             if (sets[j] == NULL) continue;
976 
977             algo_one_work += setTypeSize(sets[0]);
978             algo_two_work += setTypeSize(sets[j]);
979         }
980 
981         /* Algorithm 1 has better constant times and performs less operations
982          * if there are elements in common. Give it some advantage. */
983         algo_one_work /= 2;
984         diff_algo = (algo_one_work <= algo_two_work) ? 1 : 2;
985 
986         if (diff_algo == 1 && setnum > 1) {
987             /* With algorithm 1 it is better to order the sets to subtract
988              * by decreasing size, so that we are more likely to find
989              * duplicated elements ASAP. */
990             qsort(sets+1,setnum-1,sizeof(robj*),
991                 qsortCompareSetsByRevCardinality);
992         }
993     }
994 
995     /* We need a temp set object to store our union. If the dstkey
996      * is not NULL (that is, we are inside an SUNIONSTORE operation) then
997      * this set object will be the resulting object to set into the target key*/
998     dstset = createIntsetObject();
999 
1000     if (op == SET_OP_UNION) {
1001         /* Union is trivial, just add every element of every set to the
1002          * temporary set. */
1003         for (j = 0; j < setnum; j++) {
1004             if (!sets[j]) continue; /* non existing keys are like empty sets */
1005 
1006             si = setTypeInitIterator(sets[j]);
1007             while((ele = setTypeNextObject(si)) != NULL) {
1008                 if (setTypeAdd(dstset,ele)) cardinality++;
1009                 decrRefCount(ele);
1010             }
1011             setTypeReleaseIterator(si);
1012         }
1013     } else if (op == SET_OP_DIFF && sets[0] && diff_algo == 1) {
1014         /* DIFF Algorithm 1:
1015          *
1016          * We perform the diff by iterating all the elements of the first set,
1017          * and only adding it to the target set if the element does not exist
1018          * into all the other sets.
1019          *
1020          * This way we perform at max N*M operations, where N is the size of
1021          * the first set, and M the number of sets. */
1022         si = setTypeInitIterator(sets[0]);
1023         while((ele = setTypeNextObject(si)) != NULL) {
1024             for (j = 1; j < setnum; j++) {
1025                 if (!sets[j]) continue; /* no key is an empty set. */
1026                 if (sets[j] == sets[0]) break; /* same set! */
1027                 if (setTypeIsMember(sets[j],ele)) break;
1028             }
1029             if (j == setnum) {
1030                 /* There is no other set with this element. Add it. */
1031                 setTypeAdd(dstset,ele);
1032                 cardinality++;
1033             }
1034             decrRefCount(ele);
1035         }
1036         setTypeReleaseIterator(si);
1037     } else if (op == SET_OP_DIFF && sets[0] && diff_algo == 2) {
1038         /* DIFF Algorithm 2:
1039          *
1040          * Add all the elements of the first set to the auxiliary set.
1041          * Then remove all the elements of all the next sets from it.
1042          *
1043          * This is O(N) where N is the sum of all the elements in every
1044          * set. */
1045         for (j = 0; j < setnum; j++) {
1046             if (!sets[j]) continue; /* non existing keys are like empty sets */
1047 
1048             si = setTypeInitIterator(sets[j]);
1049             while((ele = setTypeNextObject(si)) != NULL) {
1050                 if (j == 0) {
1051                     if (setTypeAdd(dstset,ele)) cardinality++;
1052                 } else {
1053                     if (setTypeRemove(dstset,ele)) cardinality--;
1054                 }
1055                 decrRefCount(ele);
1056             }
1057             setTypeReleaseIterator(si);
1058 
1059             /* Exit if result set is empty as any additional removal
1060              * of elements will have no effect. */
1061             if (cardinality == 0) break;
1062         }
1063     }
1064 
1065     /* Output the content of the resulting set, if not in STORE mode */
1066     if (!dstkey) {
1067         addReplyMultiBulkLen(c,cardinality);
1068         si = setTypeInitIterator(dstset);
1069         while((ele = setTypeNextObject(si)) != NULL) {
1070             addReplyBulk(c,ele);
1071             decrRefCount(ele);
1072         }
1073         setTypeReleaseIterator(si);
1074         decrRefCount(dstset);
1075     } else {
1076         /* If we have a target key where to store the resulting set
1077          * create this key with the result set inside */
1078         int deleted = dbDelete(c->db,dstkey);
1079         if (setTypeSize(dstset) > 0) {
1080             dbAdd(c->db,dstkey,dstset);
1081             addReplyLongLong(c,setTypeSize(dstset));
1082             notifyKeyspaceEvent(NOTIFY_SET,
1083                 op == SET_OP_UNION ? "sunionstore" : "sdiffstore",
1084                 dstkey,c->db->id);
1085         } else {
1086             decrRefCount(dstset);
1087             addReply(c,shared.czero);
1088             if (deleted)
1089                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
1090                     dstkey,c->db->id);
1091         }
1092         signalModifiedKey(c->db,dstkey);
1093         server.dirty++;
1094     }
1095     zfree(sets);
1096 }
1097 
sunionCommand(client * c)1098 void sunionCommand(client *c) {
1099     sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_UNION);
1100 }
1101 
sunionstoreCommand(client * c)1102 void sunionstoreCommand(client *c) {
1103     sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_UNION);
1104 }
1105 
sdiffCommand(client * c)1106 void sdiffCommand(client *c) {
1107     sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_DIFF);
1108 }
1109 
sdiffstoreCommand(client * c)1110 void sdiffstoreCommand(client *c) {
1111     sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_DIFF);
1112 }
1113 
sscanCommand(client * c)1114 void sscanCommand(client *c) {
1115     robj *set;
1116     unsigned long cursor;
1117 
1118     if (parseScanCursorOrReply(c,c->argv[2],&cursor) == C_ERR) return;
1119     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptyscan)) == NULL ||
1120         checkType(c,set,OBJ_SET)) return;
1121     scanGenericCommand(c,set,cursor);
1122 }
1123