xref: /oneTBB/python/tbb/pool.py (revision 6caecf96)
1#!/usr/bin/env python3
2#
3# Copyright (c) 2016-2021 Intel Corporation
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# Based on the software developed by:
18# Copyright (c) 2008,2016 david decotigny (Pool of threads)
19# Copyright (c) 2006-2008, R Oudkerk (multiprocessing.Pool)
20# All rights reserved.
21#
22# Redistribution and use in source and binary forms, with or without
23# modification, are permitted provided that the following conditions
24# are met:
25#
26# 1. Redistributions of source code must retain the above copyright
27#    notice, this list of conditions and the following disclaimer.
28# 2. Redistributions in binary form must reproduce the above copyright
29#    notice, this list of conditions and the following disclaimer in the
30#    documentation and/or other materials provided with the distribution.
31# 3. Neither the name of author nor the names of any contributors may be
32#    used to endorse or promote products derived from this software
33#    without specific prior written permission.
34#
35# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
36# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
37# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
38# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
39# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
40# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
41# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
42# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
43# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
44# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
45# SUCH DAMAGE.
46#
47
48# @brief Python Pool implementation based on TBB with monkey-patching
49#
50# See http://docs.python.org/dev/library/multiprocessing.html
51# Differences: added imap_async and imap_unordered_async, and terminate()
52# has to be called explicitly (it's not registered by atexit).
53#
54# The general idea is that we submit works to a workqueue, either as
55# single Jobs (one function to call), or JobSequences (batch of
56# Jobs). Each Job is associated with an ApplyResult object which has 2
57# states: waiting for the Job to complete, or Ready. Instead of
58# waiting for the jobs to finish, we wait for their ApplyResult object
59# to become ready: an event mechanism is used for that.
60# When we apply a function to several arguments in "parallel", we need
61# a way to wait for all/part of the Jobs to be processed: that's what
62# "collectors" are for; they group and wait for a set of ApplyResult
63# objects. Once a collector is ready to be used, we can use a
64# CollectorIterator to iterate over the result values it's collecting.
65#
66# The methods of a Pool object use all these concepts and expose
67# them to their caller in a very simple way.
68
69import sys
70import threading
71import traceback
72from .api import *
73
74__all__ = ["Pool", "TimeoutError"]
75__doc__ = """
76Standard Python Pool implementation based on Python API
77for Intel(R) oneAPI Threading Building Blocks (oneTBB)
78"""
79
80
81class TimeoutError(Exception):
82    """Raised when a result is not available within the given timeout"""
83    pass
84
85
86class Pool(object):
87    """
88    The Pool class provides standard multiprocessing.Pool interface
89    which is mapped onto oneTBB tasks executing in its thread pool
90    """
91
92    def __init__(self, nworkers=0, name="Pool"):
93        """
94        \param nworkers (integer) number of worker threads to start
95        \param name (string) prefix for the worker threads' name
96        """
97        self._closed = False
98        self._tasks = task_group()
99        self._pool = [None,]*default_num_threads()  # Dask asks for len(_pool)
100
101    def apply(self, func, args=(), kwds=dict()):
102        """Equivalent of the apply() builtin function. It blocks till
103        the result is ready."""
104        return self.apply_async(func, args, kwds).get()
105
106    def map(self, func, iterable, chunksize=None):
107        """A parallel equivalent of the map() builtin function. It
108        blocks till the result is ready.
109
110        This method chops the iterable into a number of chunks which
111        it submits to the process pool as separate tasks. The
112        (approximate) size of these chunks can be specified by setting
113        chunksize to a positive integer."""
114        return self.map_async(func, iterable, chunksize).get()
115
116    def imap(self, func, iterable, chunksize=1):
117        """
118        An equivalent of itertools.imap().
119
120        The chunksize argument is the same as the one used by the
121        map() method. For very long iterables using a large value for
122        chunksize can make the job complete much faster than
123        using the default value of 1.
124
125        Also if chunksize is 1 then the next() method of the iterator
126        returned by the imap() method has an optional timeout
127        parameter: next(timeout) will raise processing.TimeoutError if
128        the result cannot be returned within timeout seconds.
129        """
130        collector = OrderedResultCollector(as_iterator=True)
131        self._create_sequences(func, iterable, chunksize, collector)
132        return iter(collector)
133
134    def imap_unordered(self, func, iterable, chunksize=1):
135        """The same as imap() except that the ordering of the results
136        from the returned iterator should be considered
137        arbitrary. (Only when there is only one worker process is the
138        order guaranteed to be "correct".)"""
139        collector = UnorderedResultCollector()
140        self._create_sequences(func, iterable, chunksize, collector)
141        return iter(collector)
142
143    def apply_async(self, func, args=(), kwds=dict(), callback=None):
144        """A variant of the apply() method which returns an
145        ApplyResult object.
146
147        If callback is specified then it should be a callable which
148        accepts a single argument. When the result becomes ready,
149        callback is applied to it (unless the call failed). callback
150        should complete immediately since otherwise the thread which
151        handles the results will get blocked."""
152        assert not self._closed  # No lock here. We assume it's atomic...
153        apply_result = ApplyResult(callback=callback)
154        job = Job(func, args, kwds, apply_result)
155        self._tasks.run(job)
156        return apply_result
157
158    def map_async(self, func, iterable, chunksize=None, callback=None):
159        """A variant of the map() method which returns a ApplyResult
160        object.
161
162        If callback is specified then it should be a callable which
163        accepts a single argument. When the result becomes ready
164        callback is applied to it (unless the call failed). callback
165        should complete immediately since otherwise the thread which
166        handles the results will get blocked."""
167        apply_result = ApplyResult(callback=callback)
168        collector    = OrderedResultCollector(apply_result, as_iterator=False)
169        if not self._create_sequences(func, iterable, chunksize, collector):
170          apply_result._set_value([])
171        return apply_result
172
173    def imap_async(self, func, iterable, chunksize=None, callback=None):
174        """A variant of the imap() method which returns an ApplyResult
175        object that provides an iterator (next method(timeout)
176        available).
177
178        If callback is specified then it should be a callable which
179        accepts a single argument. When the resulting iterator becomes
180        ready, callback is applied to it (unless the call
181        failed). callback should complete immediately since otherwise
182        the thread which handles the results will get blocked."""
183        apply_result = ApplyResult(callback=callback)
184        collector    = OrderedResultCollector(apply_result, as_iterator=True)
185        if not self._create_sequences(func, iterable, chunksize, collector):
186          apply_result._set_value(iter([]))
187        return apply_result
188
189    def imap_unordered_async(self, func, iterable, chunksize=None,
190                             callback=None):
191        """A variant of the imap_unordered() method which returns an
192        ApplyResult object that provides an iterator (next
193        method(timeout) available).
194
195        If callback is specified then it should be a callable which
196        accepts a single argument. When the resulting iterator becomes
197        ready, callback is applied to it (unless the call
198        failed). callback should complete immediately since otherwise
199        the thread which handles the results will get blocked."""
200        apply_result = ApplyResult(callback=callback)
201        collector    = UnorderedResultCollector(apply_result)
202        if not self._create_sequences(func, iterable, chunksize, collector):
203          apply_result._set_value(iter([]))
204        return apply_result
205
206    def close(self):
207        """Prevents any more tasks from being submitted to the
208        pool. Once all the tasks have been completed the worker
209        processes will exit."""
210        # No lock here. We assume it's sufficiently atomic...
211        self._closed = True
212
213    def terminate(self):
214        """Stops the worker processes immediately without completing
215        outstanding work. When the pool object is garbage collected
216        terminate() will be called immediately."""
217        self.close()
218        self._tasks.cancel()
219
220    def join(self):
221        """Wait for the worker processes to exit. One must call
222        close() or terminate() before using join()."""
223        self._tasks.wait()
224
225    def __enter__(self):
226        return self
227
228    def __exit__(self, exc_type, exc_value, traceback):
229        self.join()
230
231    def __del__(self):
232        self.terminate()
233        self.join()
234
235    def _create_sequences(self, func, iterable, chunksize, collector):
236        """
237        Create callable objects to process and pushes them on the
238        work queue. Each work unit is meant to process a slice of
239        iterable of size chunksize. If collector is specified, then
240        the ApplyResult objects associated with the jobs will notify
241        collector when their result becomes ready.
242
243        \return the list callable objects (basically: JobSequences)
244        pushed onto the work queue
245        """
246        assert not self._closed  # No lock here. We assume it's atomic...
247        it_ = iter(iterable)
248        exit_loop = False
249        sequences = []
250        while not exit_loop:
251            seq = []
252            for _ in range(chunksize or 1):
253                try:
254                    arg = next(it_)
255                except StopIteration:
256                    exit_loop = True
257                    break
258                apply_result = ApplyResult(collector)
259                job = Job(func, (arg,), {}, apply_result)
260                seq.append(job)
261            if seq:
262                sequences.append(JobSequence(seq))
263        for t in sequences:
264            self._tasks.run(t)
265        return sequences
266
267
268class Job:
269    """A work unit that corresponds to the execution of a single function"""
270
271    def __init__(self, func, args, kwds, apply_result):
272        """
273        \param func/args/kwds used to call the function
274        \param apply_result ApplyResult object that holds the result
275        of the function call
276        """
277        self._func = func
278        self._args = args
279        self._kwds = kwds
280        self._result = apply_result
281
282    def __call__(self):
283        """
284        Call the function with the args/kwds and tell the ApplyResult
285        that its result is ready. Correctly handles the exceptions
286        happening during the execution of the function
287        """
288        try:
289            result = self._func(*self._args, **self._kwds)
290        except:
291            self._result._set_exception()
292        else:
293            self._result._set_value(result)
294
295
296class JobSequence:
297    """A work unit that corresponds to the processing of a continuous
298    sequence of Job objects"""
299
300    def __init__(self, jobs):
301        self._jobs = jobs
302
303    def __call__(self):
304        """
305        Call all the Job objects that have been specified
306        """
307        for job in self._jobs:
308            job()
309
310
311class ApplyResult(object):
312    """An object associated with a Job object that holds its result:
313    it's available during the whole life the Job and after, even when
314    the Job didn't process yet. It's possible to use this object to
315    wait for the result/exception of the job to be available.
316
317    The result objects returns by the Pool::*_async() methods are of
318    this type"""
319
320    def __init__(self, collector=None, callback=None):
321        """
322        \param collector when not None, the notify_ready() method of
323        the collector will be called when the result from the Job is
324        ready
325        \param callback when not None, function to call when the
326        result becomes available (this is the parameter passed to the
327        Pool::*_async() methods.
328        """
329        self._success = False
330        self._event = threading.Event()
331        self._data = None
332        self._collector = None
333        self._callback = callback
334
335        if collector is not None:
336            collector.register_result(self)
337            self._collector = collector
338
339    def get(self, timeout=None):
340        """
341        Returns the result when it arrives. If timeout is not None and
342        the result does not arrive within timeout seconds then
343        TimeoutError is raised. If the remote call raised an exception
344        then that exception will be reraised by get().
345        """
346        if not self.wait(timeout):
347            raise TimeoutError("Result not available within %fs" % timeout)
348        if self._success:
349            return self._data
350        raise self._data[0](self._data[1]).with_traceback(self._data[2])
351
352    def wait(self, timeout=None):
353        """Waits until the result is available or until timeout
354        seconds pass."""
355        self._event.wait(timeout)
356        return self._event.isSet()
357
358    def ready(self):
359        """Returns whether the call has completed."""
360        return self._event.isSet()
361
362    def successful(self):
363        """Returns whether the call completed without raising an
364        exception. Will raise AssertionError if the result is not
365        ready."""
366        assert self.ready()
367        return self._success
368
369    def _set_value(self, value):
370        """Called by a Job object to tell the result is ready, and
371        provides the value of this result. The object will become
372        ready and successful. The collector's notify_ready() method
373        will be called, and the callback method too"""
374        assert not self.ready()
375        self._data = value
376        self._success = True
377        self._event.set()
378        if self._collector is not None:
379            self._collector.notify_ready(self)
380        if self._callback is not None:
381            try:
382                self._callback(value)
383            except:
384                traceback.print_exc()
385
386    def _set_exception(self):
387        """Called by a Job object to tell that an exception occurred
388        during the processing of the function. The object will become
389        ready but not successful. The collector's notify_ready()
390        method will be called, but NOT the callback method"""
391        # traceback.print_exc()
392        assert not self.ready()
393        self._data = sys.exc_info()
394        self._success = False
395        self._event.set()
396        if self._collector is not None:
397            self._collector.notify_ready(self)
398
399
400class AbstractResultCollector(object):
401    """ABC to define the interface of a ResultCollector object. It is
402    basically an object which knows whuich results it's waiting for,
403    and which is able to get notify when they get available. It is
404    also able to provide an iterator over the results when they are
405    available"""
406
407    def __init__(self, to_notify):
408        """
409        \param to_notify ApplyResult object to notify when all the
410        results we're waiting for become available. Can be None.
411        """
412        self._to_notify = to_notify
413
414    def register_result(self, apply_result):
415        """Used to identify which results we're waiting for. Will
416        always be called BEFORE the Jobs get submitted to the work
417        queue, and BEFORE the __iter__ and _get_result() methods can
418        be called
419        \param apply_result ApplyResult object to add in our collection
420        """
421        raise NotImplementedError("Children classes must implement it")
422
423    def notify_ready(self, apply_result):
424        """Called by the ApplyResult object (already registered via
425        register_result()) that it is now ready (ie. the Job's result
426        is available or an exception has been raised).
427        \param apply_result ApplyResult object telling us that the job
428        has been processed
429        """
430        raise NotImplementedError("Children classes must implement it")
431
432    def _get_result(self, idx, timeout=None):
433        """Called by the CollectorIterator object to retrieve the
434        result's values one after another (order defined by the
435        implementation)
436        \param idx The index of the result we want, wrt collector's order
437        \param timeout integer telling how long to wait (in seconds)
438        for the result at index idx to be available, or None (wait
439        forever)
440        """
441        raise NotImplementedError("Children classes must implement it")
442
443    def __iter__(self):
444        """Return a new CollectorIterator object for this collector"""
445        return CollectorIterator(self)
446
447
448class CollectorIterator(object):
449    """An iterator that allows to iterate over the result values
450    available in the given collector object. Equipped with an extended
451    next() method accepting a timeout argument. Created by the
452    AbstractResultCollector::__iter__() method"""
453
454    def __init__(self, collector):
455        """\param AbstractResultCollector instance"""
456        self._collector = collector
457        self._idx = 0
458
459    def __iter__(self):
460        return self
461
462    def next(self, timeout=None):
463        """Return the next result value in the sequence. Raise
464        StopIteration at the end. Can raise the exception raised by
465        the Job"""
466        try:
467            apply_result = self._collector._get_result(self._idx, timeout)
468        except IndexError:
469            # Reset for next time
470            self._idx = 0
471            raise StopIteration
472        except:
473            self._idx = 0
474            raise
475        self._idx += 1
476        assert apply_result.ready()
477        return apply_result.get(0)
478
479    def __next__(self):
480        return self.next()
481
482
483class UnorderedResultCollector(AbstractResultCollector):
484    """An AbstractResultCollector implementation that collects the
485    values of the ApplyResult objects in the order they become ready. The
486    CollectorIterator object returned by __iter__() will iterate over
487    them in the order they become ready"""
488
489    def __init__(self, to_notify=None):
490        """
491        \param to_notify ApplyResult object to notify when all the
492        results we're waiting for become available. Can be None.
493        """
494        AbstractResultCollector.__init__(self, to_notify)
495        self._cond = threading.Condition()
496        self._collection = []
497        self._expected = 0
498
499    def register_result(self, apply_result):
500        """Used to identify which results we're waiting for. Will
501        always be called BEFORE the Jobs get submitted to the work
502        queue, and BEFORE the __iter__ and _get_result() methods can
503        be called
504        \param apply_result ApplyResult object to add in our collection
505        """
506        self._expected += 1
507
508    def _get_result(self, idx, timeout=None):
509        """Called by the CollectorIterator object to retrieve the
510        result's values one after another, in the order the results have
511        become available.
512        \param idx The index of the result we want, wrt collector's order
513        \param timeout integer telling how long to wait (in seconds)
514        for the result at index idx to be available, or None (wait
515        forever)
516        """
517        self._cond.acquire()
518        try:
519            if idx >= self._expected:
520                raise IndexError
521            elif idx < len(self._collection):
522                return self._collection[idx]
523            elif idx != len(self._collection):
524                # Violation of the sequence protocol
525                raise IndexError()
526            else:
527                self._cond.wait(timeout=timeout)
528                try:
529                    return self._collection[idx]
530                except IndexError:
531                    # Still not added !
532                    raise TimeoutError("Timeout while waiting for results")
533        finally:
534            self._cond.release()
535
536    def notify_ready(self, apply_result=None):
537        """Called by the ApplyResult object (already registered via
538        register_result()) that it is now ready (ie. the Job's result
539        is available or an exception has been raised).
540        \param apply_result ApplyResult object telling us that the job
541        has been processed
542        """
543        first_item = False
544        self._cond.acquire()
545        try:
546            self._collection.append(apply_result)
547            first_item = (len(self._collection) == 1)
548
549            self._cond.notifyAll()
550        finally:
551            self._cond.release()
552
553        if first_item and self._to_notify is not None:
554            self._to_notify._set_value(iter(self))
555
556
557class OrderedResultCollector(AbstractResultCollector):
558    """An AbstractResultCollector implementation that collects the
559    values of the ApplyResult objects in the order they have been
560    submitted. The CollectorIterator object returned by __iter__()
561    will iterate over them in the order they have been submitted"""
562
563    def __init__(self, to_notify=None, as_iterator=True):
564        """
565        \param to_notify ApplyResult object to notify when all the
566        results we're waiting for become available. Can be None.
567        \param as_iterator boolean telling whether the result value
568        set on to_notify should be an iterator (available as soon as 1
569        result arrived) or a list (available only after the last
570        result arrived)
571        """
572        AbstractResultCollector.__init__(self, to_notify)
573        self._results = []
574        self._lock = threading.Lock()
575        self._remaining = 0
576        self._as_iterator = as_iterator
577
578    def register_result(self, apply_result):
579        """Used to identify which results we're waiting for. Will
580        always be called BEFORE the Jobs get submitted to the work
581        queue, and BEFORE the __iter__ and _get_result() methods can
582        be called
583        \param apply_result ApplyResult object to add in our collection
584        """
585        self._results.append(apply_result)
586        self._remaining += 1
587
588    def _get_result(self, idx, timeout=None):
589        """Called by the CollectorIterator object to retrieve the
590        result's values one after another (order defined by the
591        implementation)
592        \param idx The index of the result we want, wrt collector's order
593        \param timeout integer telling how long to wait (in seconds)
594        for the result at index idx to be available, or None (wait
595        forever)
596        """
597        res = self._results[idx]
598        res.wait(timeout)
599        return res
600
601    def notify_ready(self, apply_result):
602        """Called by the ApplyResult object (already registered via
603        register_result()) that it is now ready (ie. the Job's result
604        is available or an exception has been raised).
605        \param apply_result ApplyResult object telling us that the job
606        has been processed
607        """
608        got_first = False
609        got_last = False
610        self._lock.acquire()
611        try:
612            assert self._remaining > 0
613            got_first = (len(self._results) == self._remaining)
614            self._remaining -= 1
615            got_last = (self._remaining == 0)
616        finally:
617            self._lock.release()
618
619        if self._to_notify is not None:
620            if self._as_iterator and got_first:
621                self._to_notify._set_value(iter(self))
622            elif not self._as_iterator and got_last:
623                try:
624                    lst = [r.get(0) for r in self._results]
625                except:
626                    self._to_notify._set_exception()
627                else:
628                    self._to_notify._set_value(lst)
629