xref: /oneTBB/python/tbb/pool.py (revision 4bb8938d)
1# Copyright (c) 2016-2023 Intel Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# Based on the software developed by:
16# Copyright (c) 2008,2016 david decotigny (Pool of threads)
17# Copyright (c) 2006-2008, R Oudkerk (multiprocessing.Pool)
18# All rights reserved.
19#
20# Redistribution and use in source and binary forms, with or without
21# modification, are permitted provided that the following conditions
22# are met:
23#
24# 1. Redistributions of source code must retain the above copyright
25#    notice, this list of conditions and the following disclaimer.
26# 2. Redistributions in binary form must reproduce the above copyright
27#    notice, this list of conditions and the following disclaimer in the
28#    documentation and/or other materials provided with the distribution.
29# 3. Neither the name of author nor the names of any contributors may be
30#    used to endorse or promote products derived from this software
31#    without specific prior written permission.
32#
33# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
34# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
35# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
36# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
37# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
38# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
39# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
40# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
42# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
43# SUCH DAMAGE.
44#
45
46# @brief Python Pool implementation based on TBB with monkey-patching
47#
48# See http://docs.python.org/dev/library/multiprocessing.html
49# Differences: added imap_async and imap_unordered_async, and terminate()
50# has to be called explicitly (it's not registered by atexit).
51#
52# The general idea is that we submit works to a workqueue, either as
53# single Jobs (one function to call), or JobSequences (batch of
54# Jobs). Each Job is associated with an ApplyResult object which has 2
55# states: waiting for the Job to complete, or Ready. Instead of
56# waiting for the jobs to finish, we wait for their ApplyResult object
57# to become ready: an event mechanism is used for that.
58# When we apply a function to several arguments in "parallel", we need
59# a way to wait for all/part of the Jobs to be processed: that's what
60# "collectors" are for; they group and wait for a set of ApplyResult
61# objects. Once a collector is ready to be used, we can use a
62# CollectorIterator to iterate over the result values it's collecting.
63#
64# The methods of a Pool object use all these concepts and expose
65# them to their caller in a very simple way.
66
67import sys
68import threading
69import traceback
70from .api import *
71
72__all__ = ["Pool", "TimeoutError"]
73__doc__ = """
74Standard Python Pool implementation based on Python API
75for Intel(R) oneAPI Threading Building Blocks (oneTBB)
76"""
77
78
79class TimeoutError(Exception):
80    """Raised when a result is not available within the given timeout"""
81    pass
82
83
84class Pool(object):
85    """
86    The Pool class provides standard multiprocessing.Pool interface
87    which is mapped onto oneTBB tasks executing in its thread pool
88    """
89
90    def __init__(self, nworkers=0, name="Pool"):
91        """
92        \param nworkers (integer) number of worker threads to start
93        \param name (string) prefix for the worker threads' name
94        """
95        self._closed = False
96        self._tasks = task_group()
97        self._pool = [None,]*default_num_threads()  # Dask asks for len(_pool)
98
99    def apply(self, func, args=(), kwds=dict()):
100        """Equivalent of the apply() builtin function. It blocks till
101        the result is ready."""
102        return self.apply_async(func, args, kwds).get()
103
104    def map(self, func, iterable, chunksize=None):
105        """A parallel equivalent of the map() builtin function. It
106        blocks till the result is ready.
107
108        This method chops the iterable into a number of chunks which
109        it submits to the process pool as separate tasks. The
110        (approximate) size of these chunks can be specified by setting
111        chunksize to a positive integer."""
112        return self.map_async(func, iterable, chunksize).get()
113
114    def imap(self, func, iterable, chunksize=1):
115        """
116        An equivalent of itertools.imap().
117
118        The chunksize argument is the same as the one used by the
119        map() method. For very long iterables using a large value for
120        chunksize can make the job complete much faster than
121        using the default value of 1.
122
123        Also if chunksize is 1 then the next() method of the iterator
124        returned by the imap() method has an optional timeout
125        parameter: next(timeout) will raise processing.TimeoutError if
126        the result cannot be returned within timeout seconds.
127        """
128        collector = OrderedResultCollector(as_iterator=True)
129        self._create_sequences(func, iterable, chunksize, collector)
130        return iter(collector)
131
132    def imap_unordered(self, func, iterable, chunksize=1):
133        """The same as imap() except that the ordering of the results
134        from the returned iterator should be considered
135        arbitrary. (Only when there is only one worker process is the
136        order guaranteed to be "correct".)"""
137        collector = UnorderedResultCollector()
138        self._create_sequences(func, iterable, chunksize, collector)
139        return iter(collector)
140
141    def apply_async(self, func, args=(), kwds=dict(), callback=None):
142        """A variant of the apply() method which returns an
143        ApplyResult object.
144
145        If callback is specified then it should be a callable which
146        accepts a single argument. When the result becomes ready,
147        callback is applied to it (unless the call failed). callback
148        should complete immediately since otherwise the thread which
149        handles the results will get blocked."""
150        assert not self._closed  # No lock here. We assume it's atomic...
151        apply_result = ApplyResult(callback=callback)
152        job = Job(func, args, kwds, apply_result)
153        self._tasks.run(job)
154        return apply_result
155
156    def map_async(self, func, iterable, chunksize=None, callback=None):
157        """A variant of the map() method which returns a ApplyResult
158        object.
159
160        If callback is specified then it should be a callable which
161        accepts a single argument. When the result becomes ready
162        callback is applied to it (unless the call failed). callback
163        should complete immediately since otherwise the thread which
164        handles the results will get blocked."""
165        apply_result = ApplyResult(callback=callback)
166        collector    = OrderedResultCollector(apply_result, as_iterator=False)
167        if not self._create_sequences(func, iterable, chunksize, collector):
168          apply_result._set_value([])
169        return apply_result
170
171    def imap_async(self, func, iterable, chunksize=None, callback=None):
172        """A variant of the imap() method which returns an ApplyResult
173        object that provides an iterator (next method(timeout)
174        available).
175
176        If callback is specified then it should be a callable which
177        accepts a single argument. When the resulting iterator becomes
178        ready, callback is applied to it (unless the call
179        failed). callback should complete immediately since otherwise
180        the thread which handles the results will get blocked."""
181        apply_result = ApplyResult(callback=callback)
182        collector    = OrderedResultCollector(apply_result, as_iterator=True)
183        if not self._create_sequences(func, iterable, chunksize, collector):
184          apply_result._set_value(iter([]))
185        return apply_result
186
187    def imap_unordered_async(self, func, iterable, chunksize=None,
188                             callback=None):
189        """A variant of the imap_unordered() method which returns an
190        ApplyResult object that provides an iterator (next
191        method(timeout) available).
192
193        If callback is specified then it should be a callable which
194        accepts a single argument. When the resulting iterator becomes
195        ready, callback is applied to it (unless the call
196        failed). callback should complete immediately since otherwise
197        the thread which handles the results will get blocked."""
198        apply_result = ApplyResult(callback=callback)
199        collector    = UnorderedResultCollector(apply_result)
200        if not self._create_sequences(func, iterable, chunksize, collector):
201          apply_result._set_value(iter([]))
202        return apply_result
203
204    def close(self):
205        """Prevents any more tasks from being submitted to the
206        pool. Once all the tasks have been completed the worker
207        processes will exit."""
208        # No lock here. We assume it's sufficiently atomic...
209        self._closed = True
210
211    def terminate(self):
212        """Stops the worker processes immediately without completing
213        outstanding work. When the pool object is garbage collected
214        terminate() will be called immediately."""
215        self.close()
216        self._tasks.cancel()
217
218    def join(self):
219        """Wait for the worker processes to exit. One must call
220        close() or terminate() before using join()."""
221        self._tasks.wait()
222
223    def __enter__(self):
224        return self
225
226    def __exit__(self, exc_type, exc_value, traceback):
227        self.join()
228
229    def __del__(self):
230        self.terminate()
231        self.join()
232
233    def _create_sequences(self, func, iterable, chunksize, collector):
234        """
235        Create callable objects to process and pushes them on the
236        work queue. Each work unit is meant to process a slice of
237        iterable of size chunksize. If collector is specified, then
238        the ApplyResult objects associated with the jobs will notify
239        collector when their result becomes ready.
240
241        \return the list callable objects (basically: JobSequences)
242        pushed onto the work queue
243        """
244        assert not self._closed  # No lock here. We assume it's atomic...
245        it_ = iter(iterable)
246        exit_loop = False
247        sequences = []
248        while not exit_loop:
249            seq = []
250            for _ in range(chunksize or 1):
251                try:
252                    arg = next(it_)
253                except StopIteration:
254                    exit_loop = True
255                    break
256                apply_result = ApplyResult(collector)
257                job = Job(func, (arg,), {}, apply_result)
258                seq.append(job)
259            if seq:
260                sequences.append(JobSequence(seq))
261        for t in sequences:
262            self._tasks.run(t)
263        return sequences
264
265
266class Job:
267    """A work unit that corresponds to the execution of a single function"""
268
269    def __init__(self, func, args, kwds, apply_result):
270        """
271        \param func/args/kwds used to call the function
272        \param apply_result ApplyResult object that holds the result
273        of the function call
274        """
275        self._func = func
276        self._args = args
277        self._kwds = kwds
278        self._result = apply_result
279
280    def __call__(self):
281        """
282        Call the function with the args/kwds and tell the ApplyResult
283        that its result is ready. Correctly handles the exceptions
284        happening during the execution of the function
285        """
286        try:
287            result = self._func(*self._args, **self._kwds)
288        except:
289            self._result._set_exception()
290        else:
291            self._result._set_value(result)
292
293
294class JobSequence:
295    """A work unit that corresponds to the processing of a continuous
296    sequence of Job objects"""
297
298    def __init__(self, jobs):
299        self._jobs = jobs
300
301    def __call__(self):
302        """
303        Call all the Job objects that have been specified
304        """
305        for job in self._jobs:
306            job()
307
308
309class ApplyResult(object):
310    """An object associated with a Job object that holds its result:
311    it's available during the whole life the Job and after, even when
312    the Job didn't process yet. It's possible to use this object to
313    wait for the result/exception of the job to be available.
314
315    The result objects returns by the Pool::*_async() methods are of
316    this type"""
317
318    def __init__(self, collector=None, callback=None):
319        """
320        \param collector when not None, the notify_ready() method of
321        the collector will be called when the result from the Job is
322        ready
323        \param callback when not None, function to call when the
324        result becomes available (this is the parameter passed to the
325        Pool::*_async() methods.
326        """
327        self._success = False
328        self._event = threading.Event()
329        self._data = None
330        self._collector = None
331        self._callback = callback
332
333        if collector is not None:
334            collector.register_result(self)
335            self._collector = collector
336
337    def get(self, timeout=None):
338        """
339        Returns the result when it arrives. If timeout is not None and
340        the result does not arrive within timeout seconds then
341        TimeoutError is raised. If the remote call raised an exception
342        then that exception will be reraised by get().
343        """
344        if not self.wait(timeout):
345            raise TimeoutError("Result not available within %fs" % timeout)
346        if self._success:
347            return self._data
348        raise self._data[0](self._data[1]).with_traceback(self._data[2])
349
350    def wait(self, timeout=None):
351        """Waits until the result is available or until timeout
352        seconds pass."""
353        self._event.wait(timeout)
354        return self._event.isSet()
355
356    def ready(self):
357        """Returns whether the call has completed."""
358        return self._event.isSet()
359
360    def successful(self):
361        """Returns whether the call completed without raising an
362        exception. Will raise AssertionError if the result is not
363        ready."""
364        assert self.ready()
365        return self._success
366
367    def _set_value(self, value):
368        """Called by a Job object to tell the result is ready, and
369        provides the value of this result. The object will become
370        ready and successful. The collector's notify_ready() method
371        will be called, and the callback method too"""
372        assert not self.ready()
373        self._data = value
374        self._success = True
375        self._event.set()
376        if self._collector is not None:
377            self._collector.notify_ready(self)
378        if self._callback is not None:
379            try:
380                self._callback(value)
381            except:
382                traceback.print_exc()
383
384    def _set_exception(self):
385        """Called by a Job object to tell that an exception occurred
386        during the processing of the function. The object will become
387        ready but not successful. The collector's notify_ready()
388        method will be called, but NOT the callback method"""
389        # traceback.print_exc()
390        assert not self.ready()
391        self._data = sys.exc_info()
392        self._success = False
393        self._event.set()
394        if self._collector is not None:
395            self._collector.notify_ready(self)
396
397
398class AbstractResultCollector(object):
399    """ABC to define the interface of a ResultCollector object. It is
400    basically an object which knows whuich results it's waiting for,
401    and which is able to get notify when they get available. It is
402    also able to provide an iterator over the results when they are
403    available"""
404
405    def __init__(self, to_notify):
406        """
407        \param to_notify ApplyResult object to notify when all the
408        results we're waiting for become available. Can be None.
409        """
410        self._to_notify = to_notify
411
412    def register_result(self, apply_result):
413        """Used to identify which results we're waiting for. Will
414        always be called BEFORE the Jobs get submitted to the work
415        queue, and BEFORE the __iter__ and _get_result() methods can
416        be called
417        \param apply_result ApplyResult object to add in our collection
418        """
419        raise NotImplementedError("Children classes must implement it")
420
421    def notify_ready(self, apply_result):
422        """Called by the ApplyResult object (already registered via
423        register_result()) that it is now ready (ie. the Job's result
424        is available or an exception has been raised).
425        \param apply_result ApplyResult object telling us that the job
426        has been processed
427        """
428        raise NotImplementedError("Children classes must implement it")
429
430    def _get_result(self, idx, timeout=None):
431        """Called by the CollectorIterator object to retrieve the
432        result's values one after another (order defined by the
433        implementation)
434        \param idx The index of the result we want, wrt collector's order
435        \param timeout integer telling how long to wait (in seconds)
436        for the result at index idx to be available, or None (wait
437        forever)
438        """
439        raise NotImplementedError("Children classes must implement it")
440
441    def __iter__(self):
442        """Return a new CollectorIterator object for this collector"""
443        return CollectorIterator(self)
444
445
446class CollectorIterator(object):
447    """An iterator that allows to iterate over the result values
448    available in the given collector object. Equipped with an extended
449    next() method accepting a timeout argument. Created by the
450    AbstractResultCollector::__iter__() method"""
451
452    def __init__(self, collector):
453        """\param AbstractResultCollector instance"""
454        self._collector = collector
455        self._idx = 0
456
457    def __iter__(self):
458        return self
459
460    def next(self, timeout=None):
461        """Return the next result value in the sequence. Raise
462        StopIteration at the end. Can raise the exception raised by
463        the Job"""
464        try:
465            apply_result = self._collector._get_result(self._idx, timeout)
466        except IndexError:
467            # Reset for next time
468            self._idx = 0
469            raise StopIteration
470        except:
471            self._idx = 0
472            raise
473        self._idx += 1
474        assert apply_result.ready()
475        return apply_result.get(0)
476
477    def __next__(self):
478        return self.next()
479
480
481class UnorderedResultCollector(AbstractResultCollector):
482    """An AbstractResultCollector implementation that collects the
483    values of the ApplyResult objects in the order they become ready. The
484    CollectorIterator object returned by __iter__() will iterate over
485    them in the order they become ready"""
486
487    def __init__(self, to_notify=None):
488        """
489        \param to_notify ApplyResult object to notify when all the
490        results we're waiting for become available. Can be None.
491        """
492        AbstractResultCollector.__init__(self, to_notify)
493        self._cond = threading.Condition()
494        self._collection = []
495        self._expected = 0
496
497    def register_result(self, apply_result):
498        """Used to identify which results we're waiting for. Will
499        always be called BEFORE the Jobs get submitted to the work
500        queue, and BEFORE the __iter__ and _get_result() methods can
501        be called
502        \param apply_result ApplyResult object to add in our collection
503        """
504        self._expected += 1
505
506    def _get_result(self, idx, timeout=None):
507        """Called by the CollectorIterator object to retrieve the
508        result's values one after another, in the order the results have
509        become available.
510        \param idx The index of the result we want, wrt collector's order
511        \param timeout integer telling how long to wait (in seconds)
512        for the result at index idx to be available, or None (wait
513        forever)
514        """
515        self._cond.acquire()
516        try:
517            if idx >= self._expected:
518                raise IndexError
519            elif idx < len(self._collection):
520                return self._collection[idx]
521            elif idx != len(self._collection):
522                # Violation of the sequence protocol
523                raise IndexError()
524            else:
525                self._cond.wait(timeout=timeout)
526                try:
527                    return self._collection[idx]
528                except IndexError:
529                    # Still not added !
530                    raise TimeoutError("Timeout while waiting for results")
531        finally:
532            self._cond.release()
533
534    def notify_ready(self, apply_result=None):
535        """Called by the ApplyResult object (already registered via
536        register_result()) that it is now ready (ie. the Job's result
537        is available or an exception has been raised).
538        \param apply_result ApplyResult object telling us that the job
539        has been processed
540        """
541        first_item = False
542        self._cond.acquire()
543        try:
544            self._collection.append(apply_result)
545            first_item = (len(self._collection) == 1)
546
547            self._cond.notifyAll()
548        finally:
549            self._cond.release()
550
551        if first_item and self._to_notify is not None:
552            self._to_notify._set_value(iter(self))
553
554
555class OrderedResultCollector(AbstractResultCollector):
556    """An AbstractResultCollector implementation that collects the
557    values of the ApplyResult objects in the order they have been
558    submitted. The CollectorIterator object returned by __iter__()
559    will iterate over them in the order they have been submitted"""
560
561    def __init__(self, to_notify=None, as_iterator=True):
562        """
563        \param to_notify ApplyResult object to notify when all the
564        results we're waiting for become available. Can be None.
565        \param as_iterator boolean telling whether the result value
566        set on to_notify should be an iterator (available as soon as 1
567        result arrived) or a list (available only after the last
568        result arrived)
569        """
570        AbstractResultCollector.__init__(self, to_notify)
571        self._results = []
572        self._lock = threading.Lock()
573        self._remaining = 0
574        self._as_iterator = as_iterator
575
576    def register_result(self, apply_result):
577        """Used to identify which results we're waiting for. Will
578        always be called BEFORE the Jobs get submitted to the work
579        queue, and BEFORE the __iter__ and _get_result() methods can
580        be called
581        \param apply_result ApplyResult object to add in our collection
582        """
583        self._results.append(apply_result)
584        self._remaining += 1
585
586    def _get_result(self, idx, timeout=None):
587        """Called by the CollectorIterator object to retrieve the
588        result's values one after another (order defined by the
589        implementation)
590        \param idx The index of the result we want, wrt collector's order
591        \param timeout integer telling how long to wait (in seconds)
592        for the result at index idx to be available, or None (wait
593        forever)
594        """
595        res = self._results[idx]
596        res.wait(timeout)
597        return res
598
599    def notify_ready(self, apply_result):
600        """Called by the ApplyResult object (already registered via
601        register_result()) that it is now ready (ie. the Job's result
602        is available or an exception has been raised).
603        \param apply_result ApplyResult object telling us that the job
604        has been processed
605        """
606        got_first = False
607        got_last = False
608        self._lock.acquire()
609        try:
610            assert self._remaining > 0
611            got_first = (len(self._results) == self._remaining)
612            self._remaining -= 1
613            got_last = (self._remaining == 0)
614        finally:
615            self._lock.release()
616
617        if self._to_notify is not None:
618            if self._as_iterator and got_first:
619                self._to_notify._set_value(iter(self))
620            elif not self._as_iterator and got_last:
621                try:
622                    lst = [r.get(0) for r in self._results]
623                except:
624                    self._to_notify._set_exception()
625                else:
626                    self._to_notify._set_value(lst)
627