xref: /f-stack/app/redis-5.0.5/src/t_set.c (revision 572c4311)
1 /*
2  * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  *   * Redistributions of source code must retain the above copyright notice,
9  *     this list of conditions and the following disclaimer.
10  *   * Redistributions in binary form must reproduce the above copyright
11  *     notice, this list of conditions and the following disclaimer in the
12  *     documentation and/or other materials provided with the distribution.
13  *   * Neither the name of Redis nor the names of its contributors may be used
14  *     to endorse or promote products derived from this software without
15  *     specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27  * POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 #include "server.h"
31 
32 /*-----------------------------------------------------------------------------
33  * Set Commands
34  *----------------------------------------------------------------------------*/
35 
36 void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
37                               robj *dstkey, int op);
38 
39 /* Factory method to return a set that *can* hold "value". When the object has
40  * an integer-encodable value, an intset will be returned. Otherwise a regular
41  * hash table. */
setTypeCreate(sds value)42 robj *setTypeCreate(sds value) {
43     if (isSdsRepresentableAsLongLong(value,NULL) == C_OK)
44         return createIntsetObject();
45     return createSetObject();
46 }
47 
48 /* Add the specified value into a set.
49  *
50  * If the value was already member of the set, nothing is done and 0 is
51  * returned, otherwise the new element is added and 1 is returned. */
setTypeAdd(robj * subject,sds value)52 int setTypeAdd(robj *subject, sds value) {
53     long long llval;
54     if (subject->encoding == OBJ_ENCODING_HT) {
55         dict *ht = subject->ptr;
56         dictEntry *de = dictAddRaw(ht,value,NULL);
57         if (de) {
58             dictSetKey(ht,de,sdsdup(value));
59             dictSetVal(ht,de,NULL);
60             return 1;
61         }
62     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
63         if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
64             uint8_t success = 0;
65             subject->ptr = intsetAdd(subject->ptr,llval,&success);
66             if (success) {
67                 /* Convert to regular set when the intset contains
68                  * too many entries. */
69                 if (intsetLen(subject->ptr) > server.set_max_intset_entries)
70                     setTypeConvert(subject,OBJ_ENCODING_HT);
71                 return 1;
72             }
73         } else {
74             /* Failed to get integer from object, convert to regular set. */
75             setTypeConvert(subject,OBJ_ENCODING_HT);
76 
77             /* The set *was* an intset and this value is not integer
78              * encodable, so dictAdd should always work. */
79             serverAssert(dictAdd(subject->ptr,sdsdup(value),NULL) == DICT_OK);
80             return 1;
81         }
82     } else {
83         serverPanic("Unknown set encoding");
84     }
85     return 0;
86 }
87 
setTypeRemove(robj * setobj,sds value)88 int setTypeRemove(robj *setobj, sds value) {
89     long long llval;
90     if (setobj->encoding == OBJ_ENCODING_HT) {
91         if (dictDelete(setobj->ptr,value) == DICT_OK) {
92             if (htNeedsResize(setobj->ptr)) dictResize(setobj->ptr);
93             return 1;
94         }
95     } else if (setobj->encoding == OBJ_ENCODING_INTSET) {
96         if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
97             int success;
98             setobj->ptr = intsetRemove(setobj->ptr,llval,&success);
99             if (success) return 1;
100         }
101     } else {
102         serverPanic("Unknown set encoding");
103     }
104     return 0;
105 }
106 
setTypeIsMember(robj * subject,sds value)107 int setTypeIsMember(robj *subject, sds value) {
108     long long llval;
109     if (subject->encoding == OBJ_ENCODING_HT) {
110         return dictFind((dict*)subject->ptr,value) != NULL;
111     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
112         if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
113             return intsetFind((intset*)subject->ptr,llval);
114         }
115     } else {
116         serverPanic("Unknown set encoding");
117     }
118     return 0;
119 }
120 
setTypeInitIterator(robj * subject)121 setTypeIterator *setTypeInitIterator(robj *subject) {
122     setTypeIterator *si = zmalloc(sizeof(setTypeIterator));
123     si->subject = subject;
124     si->encoding = subject->encoding;
125     if (si->encoding == OBJ_ENCODING_HT) {
126         si->di = dictGetIterator(subject->ptr);
127     } else if (si->encoding == OBJ_ENCODING_INTSET) {
128         si->ii = 0;
129     } else {
130         serverPanic("Unknown set encoding");
131     }
132     return si;
133 }
134 
setTypeReleaseIterator(setTypeIterator * si)135 void setTypeReleaseIterator(setTypeIterator *si) {
136     if (si->encoding == OBJ_ENCODING_HT)
137         dictReleaseIterator(si->di);
138     zfree(si);
139 }
140 
141 /* Move to the next entry in the set. Returns the object at the current
142  * position.
143  *
144  * Since set elements can be internally be stored as SDS strings or
145  * simple arrays of integers, setTypeNext returns the encoding of the
146  * set object you are iterating, and will populate the appropriate pointer
147  * (sdsele) or (llele) accordingly.
148  *
149  * Note that both the sdsele and llele pointers should be passed and cannot
150  * be NULL since the function will try to defensively populate the non
151  * used field with values which are easy to trap if misused.
152  *
153  * When there are no longer elements -1 is returned. */
setTypeNext(setTypeIterator * si,sds * sdsele,int64_t * llele)154 int setTypeNext(setTypeIterator *si, sds *sdsele, int64_t *llele) {
155     if (si->encoding == OBJ_ENCODING_HT) {
156         dictEntry *de = dictNext(si->di);
157         if (de == NULL) return -1;
158         *sdsele = dictGetKey(de);
159         *llele = -123456789; /* Not needed. Defensive. */
160     } else if (si->encoding == OBJ_ENCODING_INTSET) {
161         if (!intsetGet(si->subject->ptr,si->ii++,llele))
162             return -1;
163         *sdsele = NULL; /* Not needed. Defensive. */
164     } else {
165         serverPanic("Wrong set encoding in setTypeNext");
166     }
167     return si->encoding;
168 }
169 
170 /* The not copy on write friendly version but easy to use version
171  * of setTypeNext() is setTypeNextObject(), returning new SDS
172  * strings. So if you don't retain a pointer to this object you should call
173  * sdsfree() against it.
174  *
175  * This function is the way to go for write operations where COW is not
176  * an issue. */
setTypeNextObject(setTypeIterator * si)177 sds setTypeNextObject(setTypeIterator *si) {
178     int64_t intele;
179     sds sdsele;
180     int encoding;
181 
182     encoding = setTypeNext(si,&sdsele,&intele);
183     switch(encoding) {
184         case -1:    return NULL;
185         case OBJ_ENCODING_INTSET:
186             return sdsfromlonglong(intele);
187         case OBJ_ENCODING_HT:
188             return sdsdup(sdsele);
189         default:
190             serverPanic("Unsupported encoding");
191     }
192     return NULL; /* just to suppress warnings */
193 }
194 
195 /* Return random element from a non empty set.
196  * The returned element can be a int64_t value if the set is encoded
197  * as an "intset" blob of integers, or an SDS string if the set
198  * is a regular set.
199  *
200  * The caller provides both pointers to be populated with the right
201  * object. The return value of the function is the object->encoding
202  * field of the object and is used by the caller to check if the
203  * int64_t pointer or the redis object pointer was populated.
204  *
205  * Note that both the sdsele and llele pointers should be passed and cannot
206  * be NULL since the function will try to defensively populate the non
207  * used field with values which are easy to trap if misused. */
setTypeRandomElement(robj * setobj,sds * sdsele,int64_t * llele)208 int setTypeRandomElement(robj *setobj, sds *sdsele, int64_t *llele) {
209     if (setobj->encoding == OBJ_ENCODING_HT) {
210         dictEntry *de = dictGetRandomKey(setobj->ptr);
211         *sdsele = dictGetKey(de);
212         *llele = -123456789; /* Not needed. Defensive. */
213     } else if (setobj->encoding == OBJ_ENCODING_INTSET) {
214         *llele = intsetRandom(setobj->ptr);
215         *sdsele = NULL; /* Not needed. Defensive. */
216     } else {
217         serverPanic("Unknown set encoding");
218     }
219     return setobj->encoding;
220 }
221 
setTypeSize(const robj * subject)222 unsigned long setTypeSize(const robj *subject) {
223     if (subject->encoding == OBJ_ENCODING_HT) {
224         return dictSize((const dict*)subject->ptr);
225     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
226         return intsetLen((const intset*)subject->ptr);
227     } else {
228         serverPanic("Unknown set encoding");
229     }
230 }
231 
232 /* Convert the set to specified encoding. The resulting dict (when converting
233  * to a hash table) is presized to hold the number of elements in the original
234  * set. */
setTypeConvert(robj * setobj,int enc)235 void setTypeConvert(robj *setobj, int enc) {
236     setTypeIterator *si;
237     serverAssertWithInfo(NULL,setobj,setobj->type == OBJ_SET &&
238                              setobj->encoding == OBJ_ENCODING_INTSET);
239 
240     if (enc == OBJ_ENCODING_HT) {
241         int64_t intele;
242         dict *d = dictCreate(&setDictType,NULL);
243         sds element;
244 
245         /* Presize the dict to avoid rehashing */
246         dictExpand(d,intsetLen(setobj->ptr));
247 
248         /* To add the elements we extract integers and create redis objects */
249         si = setTypeInitIterator(setobj);
250         while (setTypeNext(si,&element,&intele) != -1) {
251             element = sdsfromlonglong(intele);
252             serverAssert(dictAdd(d,element,NULL) == DICT_OK);
253         }
254         setTypeReleaseIterator(si);
255 
256         setobj->encoding = OBJ_ENCODING_HT;
257         zfree(setobj->ptr);
258         setobj->ptr = d;
259     } else {
260         serverPanic("Unsupported set conversion");
261     }
262 }
263 
saddCommand(client * c)264 void saddCommand(client *c) {
265     robj *set;
266     int j, added = 0;
267 
268     set = lookupKeyWrite(c->db,c->argv[1]);
269     if (set == NULL) {
270         set = setTypeCreate(c->argv[2]->ptr);
271         dbAdd(c->db,c->argv[1],set);
272     } else {
273         if (set->type != OBJ_SET) {
274             addReply(c,shared.wrongtypeerr);
275             return;
276         }
277     }
278 
279     for (j = 2; j < c->argc; j++) {
280         if (setTypeAdd(set,c->argv[j]->ptr)) added++;
281     }
282     if (added) {
283         signalModifiedKey(c->db,c->argv[1]);
284         notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[1],c->db->id);
285     }
286     server.dirty += added;
287     addReplyLongLong(c,added);
288 }
289 
sremCommand(client * c)290 void sremCommand(client *c) {
291     robj *set;
292     int j, deleted = 0, keyremoved = 0;
293 
294     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
295         checkType(c,set,OBJ_SET)) return;
296 
297     for (j = 2; j < c->argc; j++) {
298         if (setTypeRemove(set,c->argv[j]->ptr)) {
299             deleted++;
300             if (setTypeSize(set) == 0) {
301                 dbDelete(c->db,c->argv[1]);
302                 keyremoved = 1;
303                 break;
304             }
305         }
306     }
307     if (deleted) {
308         signalModifiedKey(c->db,c->argv[1]);
309         notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
310         if (keyremoved)
311             notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],
312                                 c->db->id);
313         server.dirty += deleted;
314     }
315     addReplyLongLong(c,deleted);
316 }
317 
smoveCommand(client * c)318 void smoveCommand(client *c) {
319     robj *srcset, *dstset, *ele;
320     srcset = lookupKeyWrite(c->db,c->argv[1]);
321     dstset = lookupKeyWrite(c->db,c->argv[2]);
322     ele = c->argv[3];
323 
324     /* If the source key does not exist return 0 */
325     if (srcset == NULL) {
326         addReply(c,shared.czero);
327         return;
328     }
329 
330     /* If the source key has the wrong type, or the destination key
331      * is set and has the wrong type, return with an error. */
332     if (checkType(c,srcset,OBJ_SET) ||
333         (dstset && checkType(c,dstset,OBJ_SET))) return;
334 
335     /* If srcset and dstset are equal, SMOVE is a no-op */
336     if (srcset == dstset) {
337         addReply(c,setTypeIsMember(srcset,ele->ptr) ?
338             shared.cone : shared.czero);
339         return;
340     }
341 
342     /* If the element cannot be removed from the src set, return 0. */
343     if (!setTypeRemove(srcset,ele->ptr)) {
344         addReply(c,shared.czero);
345         return;
346     }
347     notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
348 
349     /* Remove the src set from the database when empty */
350     if (setTypeSize(srcset) == 0) {
351         dbDelete(c->db,c->argv[1]);
352         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
353     }
354 
355     /* Create the destination set when it doesn't exist */
356     if (!dstset) {
357         dstset = setTypeCreate(ele->ptr);
358         dbAdd(c->db,c->argv[2],dstset);
359     }
360 
361     signalModifiedKey(c->db,c->argv[1]);
362     signalModifiedKey(c->db,c->argv[2]);
363     server.dirty++;
364 
365     /* An extra key has changed when ele was successfully added to dstset */
366     if (setTypeAdd(dstset,ele->ptr)) {
367         server.dirty++;
368         notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[2],c->db->id);
369     }
370     addReply(c,shared.cone);
371 }
372 
sismemberCommand(client * c)373 void sismemberCommand(client *c) {
374     robj *set;
375 
376     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
377         checkType(c,set,OBJ_SET)) return;
378 
379     if (setTypeIsMember(set,c->argv[2]->ptr))
380         addReply(c,shared.cone);
381     else
382         addReply(c,shared.czero);
383 }
384 
scardCommand(client * c)385 void scardCommand(client *c) {
386     robj *o;
387 
388     if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
389         checkType(c,o,OBJ_SET)) return;
390 
391     addReplyLongLong(c,setTypeSize(o));
392 }
393 
394 /* Handle the "SPOP key <count>" variant. The normal version of the
395  * command is handled by the spopCommand() function itself. */
396 
397 /* How many times bigger should be the set compared to the remaining size
398  * for us to use the "create new set" strategy? Read later in the
399  * implementation for more info. */
400 #define SPOP_MOVE_STRATEGY_MUL 5
401 
spopWithCountCommand(client * c)402 void spopWithCountCommand(client *c) {
403     long l;
404     unsigned long count, size;
405     robj *set;
406 
407     /* Get the count argument */
408     if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
409     if (l >= 0) {
410         count = (unsigned long) l;
411     } else {
412         addReply(c,shared.outofrangeerr);
413         return;
414     }
415 
416     /* Make sure a key with the name inputted exists, and that it's type is
417      * indeed a set. Otherwise, return nil */
418     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
419         == NULL || checkType(c,set,OBJ_SET)) return;
420 
421     /* If count is zero, serve an empty multibulk ASAP to avoid special
422      * cases later. */
423     if (count == 0) {
424         addReply(c,shared.emptymultibulk);
425         return;
426     }
427 
428     size = setTypeSize(set);
429 
430     /* Generate an SPOP keyspace notification */
431     notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
432     server.dirty += count;
433 
434     /* CASE 1:
435      * The number of requested elements is greater than or equal to
436      * the number of elements inside the set: simply return the whole set. */
437     if (count >= size) {
438         /* We just return the entire set */
439         sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
440 
441         /* Delete the set as it is now empty */
442         dbDelete(c->db,c->argv[1]);
443         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
444 
445         /* Propagate this command as an DEL operation */
446         rewriteClientCommandVector(c,2,shared.del,c->argv[1]);
447         signalModifiedKey(c->db,c->argv[1]);
448         server.dirty++;
449         return;
450     }
451 
452     /* Case 2 and 3 require to replicate SPOP as a set of SREM commands.
453      * Prepare our replication argument vector. Also send the array length
454      * which is common to both the code paths. */
455     robj *propargv[3];
456     propargv[0] = createStringObject("SREM",4);
457     propargv[1] = c->argv[1];
458     addReplyMultiBulkLen(c,count);
459 
460     /* Common iteration vars. */
461     sds sdsele;
462     robj *objele;
463     int encoding;
464     int64_t llele;
465     unsigned long remaining = size-count; /* Elements left after SPOP. */
466 
467     /* If we are here, the number of requested elements is less than the
468      * number of elements inside the set. Also we are sure that count < size.
469      * Use two different strategies.
470      *
471      * CASE 2: The number of elements to return is small compared to the
472      * set size. We can just extract random elements and return them to
473      * the set. */
474     if (remaining*SPOP_MOVE_STRATEGY_MUL > count) {
475         while(count--) {
476             /* Emit and remove. */
477             encoding = setTypeRandomElement(set,&sdsele,&llele);
478             if (encoding == OBJ_ENCODING_INTSET) {
479                 addReplyBulkLongLong(c,llele);
480                 objele = createStringObjectFromLongLong(llele);
481                 set->ptr = intsetRemove(set->ptr,llele,NULL);
482             } else {
483                 addReplyBulkCBuffer(c,sdsele,sdslen(sdsele));
484                 objele = createStringObject(sdsele,sdslen(sdsele));
485                 setTypeRemove(set,sdsele);
486             }
487 
488             /* Replicate/AOF this command as an SREM operation */
489             propargv[2] = objele;
490             alsoPropagate(server.sremCommand,c->db->id,propargv,3,
491                 PROPAGATE_AOF|PROPAGATE_REPL);
492             decrRefCount(objele);
493         }
494     } else {
495     /* CASE 3: The number of elements to return is very big, approaching
496      * the size of the set itself. After some time extracting random elements
497      * from such a set becomes computationally expensive, so we use
498      * a different strategy, we extract random elements that we don't
499      * want to return (the elements that will remain part of the set),
500      * creating a new set as we do this (that will be stored as the original
501      * set). Then we return the elements left in the original set and
502      * release it. */
503         robj *newset = NULL;
504 
505         /* Create a new set with just the remaining elements. */
506         while(remaining--) {
507             encoding = setTypeRandomElement(set,&sdsele,&llele);
508             if (encoding == OBJ_ENCODING_INTSET) {
509                 sdsele = sdsfromlonglong(llele);
510             } else {
511                 sdsele = sdsdup(sdsele);
512             }
513             if (!newset) newset = setTypeCreate(sdsele);
514             setTypeAdd(newset,sdsele);
515             setTypeRemove(set,sdsele);
516             sdsfree(sdsele);
517         }
518 
519         /* Transfer the old set to the client. */
520         setTypeIterator *si;
521         si = setTypeInitIterator(set);
522         while((encoding = setTypeNext(si,&sdsele,&llele)) != -1) {
523             if (encoding == OBJ_ENCODING_INTSET) {
524                 addReplyBulkLongLong(c,llele);
525                 objele = createStringObjectFromLongLong(llele);
526             } else {
527                 addReplyBulkCBuffer(c,sdsele,sdslen(sdsele));
528                 objele = createStringObject(sdsele,sdslen(sdsele));
529             }
530 
531             /* Replicate/AOF this command as an SREM operation */
532             propargv[2] = objele;
533             alsoPropagate(server.sremCommand,c->db->id,propargv,3,
534                 PROPAGATE_AOF|PROPAGATE_REPL);
535             decrRefCount(objele);
536         }
537         setTypeReleaseIterator(si);
538 
539         /* Assign the new set as the key value. */
540         dbOverwrite(c->db,c->argv[1],newset);
541     }
542 
543     /* Don't propagate the command itself even if we incremented the
544      * dirty counter. We don't want to propagate an SPOP command since
545      * we propagated the command as a set of SREMs operations using
546      * the alsoPropagate() API. */
547     decrRefCount(propargv[0]);
548     preventCommandPropagation(c);
549     signalModifiedKey(c->db,c->argv[1]);
550     server.dirty++;
551 }
552 
spopCommand(client * c)553 void spopCommand(client *c) {
554     robj *set, *ele, *aux;
555     sds sdsele;
556     int64_t llele;
557     int encoding;
558 
559     if (c->argc == 3) {
560         spopWithCountCommand(c);
561         return;
562     } else if (c->argc > 3) {
563         addReply(c,shared.syntaxerr);
564         return;
565     }
566 
567     /* Make sure a key with the name inputted exists, and that it's type is
568      * indeed a set */
569     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
570         checkType(c,set,OBJ_SET)) return;
571 
572     /* Get a random element from the set */
573     encoding = setTypeRandomElement(set,&sdsele,&llele);
574 
575     /* Remove the element from the set */
576     if (encoding == OBJ_ENCODING_INTSET) {
577         ele = createStringObjectFromLongLong(llele);
578         set->ptr = intsetRemove(set->ptr,llele,NULL);
579     } else {
580         ele = createStringObject(sdsele,sdslen(sdsele));
581         setTypeRemove(set,ele->ptr);
582     }
583 
584     notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
585 
586     /* Replicate/AOF this command as an SREM operation */
587     aux = createStringObject("SREM",4);
588     rewriteClientCommandVector(c,3,aux,c->argv[1],ele);
589     decrRefCount(aux);
590 
591     /* Add the element to the reply */
592     addReplyBulk(c,ele);
593     decrRefCount(ele);
594 
595     /* Delete the set if it's empty */
596     if (setTypeSize(set) == 0) {
597         dbDelete(c->db,c->argv[1]);
598         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
599     }
600 
601     /* Set has been modified */
602     signalModifiedKey(c->db,c->argv[1]);
603     server.dirty++;
604 }
605 
606 /* handle the "SRANDMEMBER key <count>" variant. The normal version of the
607  * command is handled by the srandmemberCommand() function itself. */
608 
609 /* How many times bigger should be the set compared to the requested size
610  * for us to don't use the "remove elements" strategy? Read later in the
611  * implementation for more info. */
612 #define SRANDMEMBER_SUB_STRATEGY_MUL 3
613 
srandmemberWithCountCommand(client * c)614 void srandmemberWithCountCommand(client *c) {
615     long l;
616     unsigned long count, size;
617     int uniq = 1;
618     robj *set;
619     sds ele;
620     int64_t llele;
621     int encoding;
622 
623     dict *d;
624 
625     if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
626     if (l >= 0) {
627         count = (unsigned long) l;
628     } else {
629         /* A negative count means: return the same elements multiple times
630          * (i.e. don't remove the extracted element after every extraction). */
631         count = -l;
632         uniq = 0;
633     }
634 
635     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
636         == NULL || checkType(c,set,OBJ_SET)) return;
637     size = setTypeSize(set);
638 
639     /* If count is zero, serve it ASAP to avoid special cases later. */
640     if (count == 0) {
641         addReply(c,shared.emptymultibulk);
642         return;
643     }
644 
645     /* CASE 1: The count was negative, so the extraction method is just:
646      * "return N random elements" sampling the whole set every time.
647      * This case is trivial and can be served without auxiliary data
648      * structures. */
649     if (!uniq) {
650         addReplyMultiBulkLen(c,count);
651         while(count--) {
652             encoding = setTypeRandomElement(set,&ele,&llele);
653             if (encoding == OBJ_ENCODING_INTSET) {
654                 addReplyBulkLongLong(c,llele);
655             } else {
656                 addReplyBulkCBuffer(c,ele,sdslen(ele));
657             }
658         }
659         return;
660     }
661 
662     /* CASE 2:
663      * The number of requested elements is greater than the number of
664      * elements inside the set: simply return the whole set. */
665     if (count >= size) {
666         sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
667         return;
668     }
669 
670     /* For CASE 3 and CASE 4 we need an auxiliary dictionary. */
671     d = dictCreate(&objectKeyPointerValueDictType,NULL);
672 
673     /* CASE 3:
674      * The number of elements inside the set is not greater than
675      * SRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements.
676      * In this case we create a set from scratch with all the elements, and
677      * subtract random elements to reach the requested number of elements.
678      *
679      * This is done because if the number of requsted elements is just
680      * a bit less than the number of elements in the set, the natural approach
681      * used into CASE 3 is highly inefficient. */
682     if (count*SRANDMEMBER_SUB_STRATEGY_MUL > size) {
683         setTypeIterator *si;
684 
685         /* Add all the elements into the temporary dictionary. */
686         si = setTypeInitIterator(set);
687         while((encoding = setTypeNext(si,&ele,&llele)) != -1) {
688             int retval = DICT_ERR;
689 
690             if (encoding == OBJ_ENCODING_INTSET) {
691                 retval = dictAdd(d,createStringObjectFromLongLong(llele),NULL);
692             } else {
693                 retval = dictAdd(d,createStringObject(ele,sdslen(ele)),NULL);
694             }
695             serverAssert(retval == DICT_OK);
696         }
697         setTypeReleaseIterator(si);
698         serverAssert(dictSize(d) == size);
699 
700         /* Remove random elements to reach the right count. */
701         while(size > count) {
702             dictEntry *de;
703 
704             de = dictGetRandomKey(d);
705             dictDelete(d,dictGetKey(de));
706             size--;
707         }
708     }
709 
710     /* CASE 4: We have a big set compared to the requested number of elements.
711      * In this case we can simply get random elements from the set and add
712      * to the temporary set, trying to eventually get enough unique elements
713      * to reach the specified count. */
714     else {
715         unsigned long added = 0;
716         robj *objele;
717 
718         while(added < count) {
719             encoding = setTypeRandomElement(set,&ele,&llele);
720             if (encoding == OBJ_ENCODING_INTSET) {
721                 objele = createStringObjectFromLongLong(llele);
722             } else {
723                 objele = createStringObject(ele,sdslen(ele));
724             }
725             /* Try to add the object to the dictionary. If it already exists
726              * free it, otherwise increment the number of objects we have
727              * in the result dictionary. */
728             if (dictAdd(d,objele,NULL) == DICT_OK)
729                 added++;
730             else
731                 decrRefCount(objele);
732         }
733     }
734 
735     /* CASE 3 & 4: send the result to the user. */
736     {
737         dictIterator *di;
738         dictEntry *de;
739 
740         addReplyMultiBulkLen(c,count);
741         di = dictGetIterator(d);
742         while((de = dictNext(di)) != NULL)
743             addReplyBulk(c,dictGetKey(de));
744         dictReleaseIterator(di);
745         dictRelease(d);
746     }
747 }
748 
srandmemberCommand(client * c)749 void srandmemberCommand(client *c) {
750     robj *set;
751     sds ele;
752     int64_t llele;
753     int encoding;
754 
755     if (c->argc == 3) {
756         srandmemberWithCountCommand(c);
757         return;
758     } else if (c->argc > 3) {
759         addReply(c,shared.syntaxerr);
760         return;
761     }
762 
763     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
764         checkType(c,set,OBJ_SET)) return;
765 
766     encoding = setTypeRandomElement(set,&ele,&llele);
767     if (encoding == OBJ_ENCODING_INTSET) {
768         addReplyBulkLongLong(c,llele);
769     } else {
770         addReplyBulkCBuffer(c,ele,sdslen(ele));
771     }
772 }
773 
qsortCompareSetsByCardinality(const void * s1,const void * s2)774 int qsortCompareSetsByCardinality(const void *s1, const void *s2) {
775     if (setTypeSize(*(robj**)s1) > setTypeSize(*(robj**)s2)) return 1;
776     if (setTypeSize(*(robj**)s1) < setTypeSize(*(robj**)s2)) return -1;
777     return 0;
778 }
779 
780 /* This is used by SDIFF and in this case we can receive NULL that should
781  * be handled as empty sets. */
qsortCompareSetsByRevCardinality(const void * s1,const void * s2)782 int qsortCompareSetsByRevCardinality(const void *s1, const void *s2) {
783     robj *o1 = *(robj**)s1, *o2 = *(robj**)s2;
784     unsigned long first = o1 ? setTypeSize(o1) : 0;
785     unsigned long second = o2 ? setTypeSize(o2) : 0;
786 
787     if (first < second) return 1;
788     if (first > second) return -1;
789     return 0;
790 }
791 
sinterGenericCommand(client * c,robj ** setkeys,unsigned long setnum,robj * dstkey)792 void sinterGenericCommand(client *c, robj **setkeys,
793                           unsigned long setnum, robj *dstkey) {
794     robj **sets = zmalloc(sizeof(robj*)*setnum);
795     setTypeIterator *si;
796     robj *dstset = NULL;
797     sds elesds;
798     int64_t intobj;
799     void *replylen = NULL;
800     unsigned long j, cardinality = 0;
801     int encoding;
802 
803     for (j = 0; j < setnum; j++) {
804         robj *setobj = dstkey ?
805             lookupKeyWrite(c->db,setkeys[j]) :
806             lookupKeyRead(c->db,setkeys[j]);
807         if (!setobj) {
808             zfree(sets);
809             if (dstkey) {
810                 if (dbDelete(c->db,dstkey)) {
811                     signalModifiedKey(c->db,dstkey);
812                     server.dirty++;
813                 }
814                 addReply(c,shared.czero);
815             } else {
816                 addReply(c,shared.emptymultibulk);
817             }
818             return;
819         }
820         if (checkType(c,setobj,OBJ_SET)) {
821             zfree(sets);
822             return;
823         }
824         sets[j] = setobj;
825     }
826     /* Sort sets from the smallest to largest, this will improve our
827      * algorithm's performance */
828     qsort(sets,setnum,sizeof(robj*),qsortCompareSetsByCardinality);
829 
830     /* The first thing we should output is the total number of elements...
831      * since this is a multi-bulk write, but at this stage we don't know
832      * the intersection set size, so we use a trick, append an empty object
833      * to the output list and save the pointer to later modify it with the
834      * right length */
835     if (!dstkey) {
836         replylen = addDeferredMultiBulkLength(c);
837     } else {
838         /* If we have a target key where to store the resulting set
839          * create this key with an empty set inside */
840         dstset = createIntsetObject();
841     }
842 
843     /* Iterate all the elements of the first (smallest) set, and test
844      * the element against all the other sets, if at least one set does
845      * not include the element it is discarded */
846     si = setTypeInitIterator(sets[0]);
847     while((encoding = setTypeNext(si,&elesds,&intobj)) != -1) {
848         for (j = 1; j < setnum; j++) {
849             if (sets[j] == sets[0]) continue;
850             if (encoding == OBJ_ENCODING_INTSET) {
851                 /* intset with intset is simple... and fast */
852                 if (sets[j]->encoding == OBJ_ENCODING_INTSET &&
853                     !intsetFind((intset*)sets[j]->ptr,intobj))
854                 {
855                     break;
856                 /* in order to compare an integer with an object we
857                  * have to use the generic function, creating an object
858                  * for this */
859                 } else if (sets[j]->encoding == OBJ_ENCODING_HT) {
860                     elesds = sdsfromlonglong(intobj);
861                     if (!setTypeIsMember(sets[j],elesds)) {
862                         sdsfree(elesds);
863                         break;
864                     }
865                     sdsfree(elesds);
866                 }
867             } else if (encoding == OBJ_ENCODING_HT) {
868                 if (!setTypeIsMember(sets[j],elesds)) {
869                     break;
870                 }
871             }
872         }
873 
874         /* Only take action when all sets contain the member */
875         if (j == setnum) {
876             if (!dstkey) {
877                 if (encoding == OBJ_ENCODING_HT)
878                     addReplyBulkCBuffer(c,elesds,sdslen(elesds));
879                 else
880                     addReplyBulkLongLong(c,intobj);
881                 cardinality++;
882             } else {
883                 if (encoding == OBJ_ENCODING_INTSET) {
884                     elesds = sdsfromlonglong(intobj);
885                     setTypeAdd(dstset,elesds);
886                     sdsfree(elesds);
887                 } else {
888                     setTypeAdd(dstset,elesds);
889                 }
890             }
891         }
892     }
893     setTypeReleaseIterator(si);
894 
895     if (dstkey) {
896         /* Store the resulting set into the target, if the intersection
897          * is not an empty set. */
898         int deleted = dbDelete(c->db,dstkey);
899         if (setTypeSize(dstset) > 0) {
900             dbAdd(c->db,dstkey,dstset);
901             addReplyLongLong(c,setTypeSize(dstset));
902             notifyKeyspaceEvent(NOTIFY_SET,"sinterstore",
903                 dstkey,c->db->id);
904         } else {
905             decrRefCount(dstset);
906             addReply(c,shared.czero);
907             if (deleted)
908                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
909                     dstkey,c->db->id);
910         }
911         signalModifiedKey(c->db,dstkey);
912         server.dirty++;
913     } else {
914         setDeferredMultiBulkLength(c,replylen,cardinality);
915     }
916     zfree(sets);
917 }
918 
sinterCommand(client * c)919 void sinterCommand(client *c) {
920     sinterGenericCommand(c,c->argv+1,c->argc-1,NULL);
921 }
922 
sinterstoreCommand(client * c)923 void sinterstoreCommand(client *c) {
924     sinterGenericCommand(c,c->argv+2,c->argc-2,c->argv[1]);
925 }
926 
927 #define SET_OP_UNION 0
928 #define SET_OP_DIFF 1
929 #define SET_OP_INTER 2
930 
sunionDiffGenericCommand(client * c,robj ** setkeys,int setnum,robj * dstkey,int op)931 void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
932                               robj *dstkey, int op) {
933     robj **sets = zmalloc(sizeof(robj*)*setnum);
934     setTypeIterator *si;
935     robj *dstset = NULL;
936     sds ele;
937     int j, cardinality = 0;
938     int diff_algo = 1;
939 
940     for (j = 0; j < setnum; j++) {
941         robj *setobj = dstkey ?
942             lookupKeyWrite(c->db,setkeys[j]) :
943             lookupKeyRead(c->db,setkeys[j]);
944         if (!setobj) {
945             sets[j] = NULL;
946             continue;
947         }
948         if (checkType(c,setobj,OBJ_SET)) {
949             zfree(sets);
950             return;
951         }
952         sets[j] = setobj;
953     }
954 
955     /* Select what DIFF algorithm to use.
956      *
957      * Algorithm 1 is O(N*M) where N is the size of the element first set
958      * and M the total number of sets.
959      *
960      * Algorithm 2 is O(N) where N is the total number of elements in all
961      * the sets.
962      *
963      * We compute what is the best bet with the current input here. */
964     if (op == SET_OP_DIFF && sets[0]) {
965         long long algo_one_work = 0, algo_two_work = 0;
966 
967         for (j = 0; j < setnum; j++) {
968             if (sets[j] == NULL) continue;
969 
970             algo_one_work += setTypeSize(sets[0]);
971             algo_two_work += setTypeSize(sets[j]);
972         }
973 
974         /* Algorithm 1 has better constant times and performs less operations
975          * if there are elements in common. Give it some advantage. */
976         algo_one_work /= 2;
977         diff_algo = (algo_one_work <= algo_two_work) ? 1 : 2;
978 
979         if (diff_algo == 1 && setnum > 1) {
980             /* With algorithm 1 it is better to order the sets to subtract
981              * by decreasing size, so that we are more likely to find
982              * duplicated elements ASAP. */
983             qsort(sets+1,setnum-1,sizeof(robj*),
984                 qsortCompareSetsByRevCardinality);
985         }
986     }
987 
988     /* We need a temp set object to store our union. If the dstkey
989      * is not NULL (that is, we are inside an SUNIONSTORE operation) then
990      * this set object will be the resulting object to set into the target key*/
991     dstset = createIntsetObject();
992 
993     if (op == SET_OP_UNION) {
994         /* Union is trivial, just add every element of every set to the
995          * temporary set. */
996         for (j = 0; j < setnum; j++) {
997             if (!sets[j]) continue; /* non existing keys are like empty sets */
998 
999             si = setTypeInitIterator(sets[j]);
1000             while((ele = setTypeNextObject(si)) != NULL) {
1001                 if (setTypeAdd(dstset,ele)) cardinality++;
1002                 sdsfree(ele);
1003             }
1004             setTypeReleaseIterator(si);
1005         }
1006     } else if (op == SET_OP_DIFF && sets[0] && diff_algo == 1) {
1007         /* DIFF Algorithm 1:
1008          *
1009          * We perform the diff by iterating all the elements of the first set,
1010          * and only adding it to the target set if the element does not exist
1011          * into all the other sets.
1012          *
1013          * This way we perform at max N*M operations, where N is the size of
1014          * the first set, and M the number of sets. */
1015         si = setTypeInitIterator(sets[0]);
1016         while((ele = setTypeNextObject(si)) != NULL) {
1017             for (j = 1; j < setnum; j++) {
1018                 if (!sets[j]) continue; /* no key is an empty set. */
1019                 if (sets[j] == sets[0]) break; /* same set! */
1020                 if (setTypeIsMember(sets[j],ele)) break;
1021             }
1022             if (j == setnum) {
1023                 /* There is no other set with this element. Add it. */
1024                 setTypeAdd(dstset,ele);
1025                 cardinality++;
1026             }
1027             sdsfree(ele);
1028         }
1029         setTypeReleaseIterator(si);
1030     } else if (op == SET_OP_DIFF && sets[0] && diff_algo == 2) {
1031         /* DIFF Algorithm 2:
1032          *
1033          * Add all the elements of the first set to the auxiliary set.
1034          * Then remove all the elements of all the next sets from it.
1035          *
1036          * This is O(N) where N is the sum of all the elements in every
1037          * set. */
1038         for (j = 0; j < setnum; j++) {
1039             if (!sets[j]) continue; /* non existing keys are like empty sets */
1040 
1041             si = setTypeInitIterator(sets[j]);
1042             while((ele = setTypeNextObject(si)) != NULL) {
1043                 if (j == 0) {
1044                     if (setTypeAdd(dstset,ele)) cardinality++;
1045                 } else {
1046                     if (setTypeRemove(dstset,ele)) cardinality--;
1047                 }
1048                 sdsfree(ele);
1049             }
1050             setTypeReleaseIterator(si);
1051 
1052             /* Exit if result set is empty as any additional removal
1053              * of elements will have no effect. */
1054             if (cardinality == 0) break;
1055         }
1056     }
1057 
1058     /* Output the content of the resulting set, if not in STORE mode */
1059     if (!dstkey) {
1060         addReplyMultiBulkLen(c,cardinality);
1061         si = setTypeInitIterator(dstset);
1062         while((ele = setTypeNextObject(si)) != NULL) {
1063             addReplyBulkCBuffer(c,ele,sdslen(ele));
1064             sdsfree(ele);
1065         }
1066         setTypeReleaseIterator(si);
1067         decrRefCount(dstset);
1068     } else {
1069         /* If we have a target key where to store the resulting set
1070          * create this key with the result set inside */
1071         int deleted = dbDelete(c->db,dstkey);
1072         if (setTypeSize(dstset) > 0) {
1073             dbAdd(c->db,dstkey,dstset);
1074             addReplyLongLong(c,setTypeSize(dstset));
1075             notifyKeyspaceEvent(NOTIFY_SET,
1076                 op == SET_OP_UNION ? "sunionstore" : "sdiffstore",
1077                 dstkey,c->db->id);
1078         } else {
1079             decrRefCount(dstset);
1080             addReply(c,shared.czero);
1081             if (deleted)
1082                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
1083                     dstkey,c->db->id);
1084         }
1085         signalModifiedKey(c->db,dstkey);
1086         server.dirty++;
1087     }
1088     zfree(sets);
1089 }
1090 
sunionCommand(client * c)1091 void sunionCommand(client *c) {
1092     sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_UNION);
1093 }
1094 
sunionstoreCommand(client * c)1095 void sunionstoreCommand(client *c) {
1096     sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_UNION);
1097 }
1098 
sdiffCommand(client * c)1099 void sdiffCommand(client *c) {
1100     sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_DIFF);
1101 }
1102 
sdiffstoreCommand(client * c)1103 void sdiffstoreCommand(client *c) {
1104     sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_DIFF);
1105 }
1106 
sscanCommand(client * c)1107 void sscanCommand(client *c) {
1108     robj *set;
1109     unsigned long cursor;
1110 
1111     if (parseScanCursorOrReply(c,c->argv[2],&cursor) == C_ERR) return;
1112     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptyscan)) == NULL ||
1113         checkType(c,set,OBJ_SET)) return;
1114     scanGenericCommand(c,set,cursor);
1115 }
1116