partition.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 
00003 // Copyright (C) 2007, 2008, 2009, 2010 Free Software Foundation, Inc.
00004 //
00005 // This file is part of the GNU ISO C++ Library.  This library is free
00006 // software; you can redistribute it and/or modify it under the terms
00007 // of the GNU General Public License as published by the Free Software
00008 // Foundation; either version 3, or (at your option) any later
00009 // version.
00010 
00011 // This library is distributed in the hope that it will be useful, but
00012 // WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00014 // General Public License for more details.
00015 
00016 // Under Section 7 of GPL version 3, you are granted additional
00017 // permissions described in the GCC Runtime Library Exception, version
00018 // 3.1, as published by the Free Software Foundation.
00019 
00020 // You should have received a copy of the GNU General Public License and
00021 // a copy of the GCC Runtime Library Exception along with this program;
00022 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
00023 // <http://www.gnu.org/licenses/>.
00024 
00025 /** @file parallel/partition.h
00026  *  @brief Parallel implementation of std::partition(),
00027  *  std::nth_element(), and std::partial_sort().
00028  *  This file is a GNU parallel extension to the Standard C++ Library.
00029  */
00030 
00031 // Written by Johannes Singler and Felix Putze.
00032 
00033 #ifndef _GLIBCXX_PARALLEL_PARTITION_H
00034 #define _GLIBCXX_PARALLEL_PARTITION_H 1
00035 
00036 #include <parallel/basic_iterator.h>
00037 #include <parallel/sort.h>
00038 #include <parallel/random_number.h>
00039 #include <bits/stl_algo.h>
00040 #include <parallel/parallel.h>
00041 
00042 /** @brief Decide whether to declare certain variables volatile. */
00043 #define _GLIBCXX_VOLATILE volatile
00044 
00045 namespace __gnu_parallel
00046 {
00047   /** @brief Parallel implementation of std::partition.
00048     *  @param __begin Begin iterator of input sequence to split.
00049     *  @param __end End iterator of input sequence to split.
00050     *  @param __pred Partition predicate, possibly including some kind
00051     *         of pivot.
00052     *  @param __num_threads Maximum number of threads to use for this task.
00053     *  @return Number of elements not fulfilling the predicate. */
00054   template<typename _RAIter, typename _Predicate>
00055     typename std::iterator_traits<_RAIter>::difference_type
00056     __parallel_partition(_RAIter __begin, _RAIter __end,
00057              _Predicate __pred, _ThreadIndex __num_threads)
00058     {
00059       typedef std::iterator_traits<_RAIter> _TraitsType;
00060       typedef typename _TraitsType::value_type _ValueType;
00061       typedef typename _TraitsType::difference_type _DifferenceType;
00062 
00063       _DifferenceType __n = __end - __begin;
00064 
00065       _GLIBCXX_CALL(__n)
00066 
00067       const _Settings& __s = _Settings::get();
00068 
00069       // Shared.
00070       _GLIBCXX_VOLATILE _DifferenceType __left = 0, __right = __n - 1;
00071       _GLIBCXX_VOLATILE _DifferenceType __leftover_left, __leftover_right;
00072       _GLIBCXX_VOLATILE _DifferenceType __leftnew, __rightnew;
00073 
00074       bool* __reserved_left = NULL, * __reserved_right = NULL;
00075 
00076       _DifferenceType __chunk_size = __s.partition_chunk_size;
00077 
00078       omp_lock_t __result_lock;
00079       omp_init_lock(&__result_lock);
00080 
00081       //at least two chunks per thread
00082       if (__right - __left + 1 >= 2 * __num_threads * __chunk_size)
00083 #       pragma omp parallel num_threads(__num_threads)
00084     {
00085 #         pragma omp single
00086       {
00087         __num_threads = omp_get_num_threads();
00088         __reserved_left = new bool[__num_threads];
00089         __reserved_right = new bool[__num_threads];
00090 
00091         if (__s.partition_chunk_share > 0.0)
00092           __chunk_size = std::max<_DifferenceType>
00093         (__s.partition_chunk_size, (double)__n 
00094          * __s.partition_chunk_share / (double)__num_threads);
00095         else
00096           __chunk_size = __s.partition_chunk_size;
00097       }
00098 
00099       while (__right - __left + 1 >= 2 * __num_threads * __chunk_size)
00100         {
00101 #             pragma omp single
00102           {
00103         _DifferenceType __num_chunks = ((__right - __left + 1) 
00104                         / __chunk_size);
00105 
00106         for (_ThreadIndex __r = 0; __r < __num_threads; ++__r)
00107           {
00108             __reserved_left[__r] = false;
00109             __reserved_right[__r] = false;
00110           }
00111         __leftover_left = 0;
00112         __leftover_right = 0;
00113           } //implicit barrier
00114 
00115           // Private.
00116           _DifferenceType __thread_left, __thread_left_border,
00117                       __thread_right, __thread_right_border;
00118           __thread_left = __left + 1;
00119 
00120           // Just to satisfy the condition below.
00121           __thread_left_border = __thread_left - 1;
00122           __thread_right = __n - 1;
00123           __thread_right_border = __thread_right + 1;
00124 
00125           bool __iam_finished = false;
00126           while (!__iam_finished)
00127         {
00128           if (__thread_left > __thread_left_border)
00129             {
00130               omp_set_lock(&__result_lock);
00131               if (__left + (__chunk_size - 1) > __right)
00132             __iam_finished = true;
00133               else
00134             {
00135               __thread_left = __left;
00136               __thread_left_border = __left + (__chunk_size - 1);
00137               __left += __chunk_size;
00138             }
00139               omp_unset_lock(&__result_lock);
00140             }
00141 
00142           if (__thread_right < __thread_right_border)
00143             {
00144               omp_set_lock(&__result_lock);
00145               if (__left > __right - (__chunk_size - 1))
00146             __iam_finished = true;
00147               else
00148             {
00149               __thread_right = __right;
00150               __thread_right_border = __right - (__chunk_size - 1);
00151               __right -= __chunk_size;
00152             }
00153               omp_unset_lock(&__result_lock);
00154             }
00155 
00156           if (__iam_finished)
00157             break;
00158 
00159           // Swap as usual.
00160           while (__thread_left < __thread_right)
00161             {
00162               while (__pred(__begin[__thread_left])
00163                  && __thread_left <= __thread_left_border)
00164             ++__thread_left;
00165               while (!__pred(__begin[__thread_right])
00166                  && __thread_right >= __thread_right_border)
00167             --__thread_right;
00168 
00169               if (__thread_left > __thread_left_border
00170               || __thread_right < __thread_right_border)
00171             // Fetch new chunk(__s).
00172             break;
00173 
00174               std::swap(__begin[__thread_left],
00175                 __begin[__thread_right]);
00176               ++__thread_left;
00177               --__thread_right;
00178             }
00179         }
00180 
00181           // Now swap the leftover chunks to the right places.
00182           if (__thread_left <= __thread_left_border)
00183 #               pragma omp atomic
00184         ++__leftover_left;
00185           if (__thread_right >= __thread_right_border)
00186 #               pragma omp atomic
00187         ++__leftover_right;
00188 
00189 #             pragma omp barrier
00190 
00191 #             pragma omp single
00192           {
00193         __leftnew = __left - __leftover_left * __chunk_size;
00194         __rightnew = __right + __leftover_right * __chunk_size;
00195           }
00196 
00197 #             pragma omp barrier
00198 
00199           // <=> __thread_left_border + (__chunk_size - 1) >= __leftnew
00200           if (__thread_left <= __thread_left_border
00201           && __thread_left_border >= __leftnew)
00202         {
00203           // Chunk already in place, reserve spot.
00204         __reserved_left[(__left - (__thread_left_border + 1))
00205                 / __chunk_size] = true;
00206         }
00207 
00208           // <=> __thread_right_border - (__chunk_size - 1) <= __rightnew
00209           if (__thread_right >= __thread_right_border
00210           && __thread_right_border <= __rightnew)
00211         {
00212           // Chunk already in place, reserve spot.
00213           __reserved_right[((__thread_right_border - 1) - __right)
00214                    / __chunk_size] = true;
00215         }
00216 
00217 #             pragma omp barrier
00218 
00219           if (__thread_left <= __thread_left_border
00220           && __thread_left_border < __leftnew)
00221         {
00222           // Find spot and swap.
00223           _DifferenceType __swapstart = -1;
00224           omp_set_lock(&__result_lock);
00225           for (_DifferenceType __r = 0; __r < __leftover_left; ++__r)
00226             if (!__reserved_left[__r])
00227               {
00228             __reserved_left[__r] = true;
00229             __swapstart = __left - (__r + 1) * __chunk_size;
00230             break;
00231               }
00232           omp_unset_lock(&__result_lock);
00233 
00234 #if _GLIBCXX_ASSERTIONS
00235           _GLIBCXX_PARALLEL_ASSERT(__swapstart != -1);
00236 #endif
00237 
00238           std::swap_ranges(__begin + __thread_left_border
00239                    - (__chunk_size - 1),
00240                    __begin + __thread_left_border + 1,
00241                    __begin + __swapstart);
00242         }
00243 
00244           if (__thread_right >= __thread_right_border
00245           && __thread_right_border > __rightnew)
00246         {
00247           // Find spot and swap
00248           _DifferenceType __swapstart = -1;
00249           omp_set_lock(&__result_lock);
00250           for (_DifferenceType __r = 0; __r < __leftover_right; ++__r)
00251             if (!__reserved_right[__r])
00252               {
00253             __reserved_right[__r] = true;
00254             __swapstart = __right + __r * __chunk_size + 1;
00255             break;
00256               }
00257           omp_unset_lock(&__result_lock);
00258 
00259 #if _GLIBCXX_ASSERTIONS
00260           _GLIBCXX_PARALLEL_ASSERT(__swapstart != -1);
00261 #endif
00262 
00263           std::swap_ranges(__begin + __thread_right_border,
00264                    __begin + __thread_right_border
00265                    + __chunk_size, __begin + __swapstart);
00266           }
00267 #if _GLIBCXX_ASSERTIONS
00268 #             pragma omp barrier
00269 
00270 #             pragma omp single
00271           {
00272         for (_DifferenceType __r = 0; __r < __leftover_left; ++__r)
00273           _GLIBCXX_PARALLEL_ASSERT(__reserved_left[__r]);
00274         for (_DifferenceType __r = 0; __r < __leftover_right; ++__r)
00275           _GLIBCXX_PARALLEL_ASSERT(__reserved_right[__r]);
00276           }
00277 
00278 #             pragma omp barrier
00279 #endif
00280 
00281 #             pragma omp barrier
00282 
00283           __left = __leftnew;
00284           __right = __rightnew;
00285         }
00286 
00287 #           pragma omp flush(__left, __right)
00288     } // end "recursion" //parallel
00289 
00290         _DifferenceType __final_left = __left, __final_right = __right;
00291 
00292     while (__final_left < __final_right)
00293       {
00294         // Go right until key is geq than pivot.
00295         while (__pred(__begin[__final_left])
00296            && __final_left < __final_right)
00297           ++__final_left;
00298 
00299         // Go left until key is less than pivot.
00300         while (!__pred(__begin[__final_right])
00301            && __final_left < __final_right)
00302           --__final_right;
00303 
00304         if (__final_left == __final_right)
00305           break;
00306         std::swap(__begin[__final_left], __begin[__final_right]);
00307         ++__final_left;
00308         --__final_right;
00309       }
00310 
00311     // All elements on the left side are < piv, all elements on the
00312     // right are >= piv
00313     delete[] __reserved_left;
00314     delete[] __reserved_right;
00315 
00316     omp_destroy_lock(&__result_lock);
00317 
00318     // Element "between" __final_left and __final_right might not have
00319     // been regarded yet
00320     if (__final_left < __n && !__pred(__begin[__final_left]))
00321       // Really swapped.
00322       return __final_left;
00323     else
00324       return __final_left + 1;
00325     }
00326 
00327   /**
00328     *  @brief Parallel implementation of std::nth_element().
00329     *  @param __begin Begin iterator of input sequence.
00330     *  @param __nth _Iterator of element that must be in position afterwards.
00331     *  @param __end End iterator of input sequence.
00332     *  @param __comp Comparator.
00333     */
00334   template<typename _RAIter, typename _Compare>
00335     void 
00336     __parallel_nth_element(_RAIter __begin, _RAIter __nth, 
00337                _RAIter __end, _Compare __comp)
00338     {
00339       typedef std::iterator_traits<_RAIter> _TraitsType;
00340       typedef typename _TraitsType::value_type _ValueType;
00341       typedef typename _TraitsType::difference_type _DifferenceType;
00342 
00343       _GLIBCXX_CALL(__end - __begin)
00344 
00345       _RAIter __split;
00346       _RandomNumber __rng;
00347 
00348       const _Settings& __s = _Settings::get();
00349       _DifferenceType __minimum_length = std::max<_DifferenceType>(2,
00350         std::max(__s.nth_element_minimal_n, __s.partition_minimal_n));
00351 
00352       // Break if input range to small.
00353       while (static_cast<_SequenceIndex>(__end - __begin) >= __minimum_length)
00354     {
00355           _DifferenceType __n = __end - __begin;
00356 
00357           _RAIter __pivot_pos = __begin + __rng(__n);
00358 
00359           // Swap __pivot_pos value to end.
00360           if (__pivot_pos != (__end - 1))
00361             std::swap(*__pivot_pos, *(__end - 1));
00362           __pivot_pos = __end - 1;
00363 
00364           // _Compare must have first_value_type, second_value_type,
00365           // result_type
00366           // _Compare ==
00367           // __gnu_parallel::_Lexicographic<S, int,
00368       //                                __gnu_parallel::_Less<S, S> >
00369           // __pivot_pos == std::pair<S, int>*
00370           __gnu_parallel::__binder2nd<_Compare, _ValueType, _ValueType, bool>
00371             __pred(__comp, *__pivot_pos);
00372 
00373           // Divide, leave pivot unchanged in last place.
00374           _RAIter __split_pos1, __split_pos2;
00375           __split_pos1 = __begin + __parallel_partition(__begin, __end - 1,
00376                             __pred,
00377                             __get_max_threads());
00378 
00379           // Left side: < __pivot_pos; __right side: >= __pivot_pos
00380 
00381           // Swap pivot back to middle.
00382           if (__split_pos1 != __pivot_pos)
00383             std::swap(*__split_pos1, *__pivot_pos);
00384           __pivot_pos = __split_pos1;
00385 
00386           // In case all elements are equal, __split_pos1 == 0
00387           if ((__split_pos1 + 1 - __begin) < (__n >> 7)
00388               || (__end - __split_pos1) < (__n >> 7))
00389             {
00390               // Very unequal split, one part smaller than one 128th
00391               // elements not strictly larger than the pivot.
00392               __gnu_parallel::__unary_negate<__gnu_parallel::
00393         	__binder1st<_Compare, _ValueType,
00394                     _ValueType, bool>, _ValueType>
00395             __pred(__gnu_parallel::__binder1st<_Compare, _ValueType,
00396                _ValueType, bool>(__comp, *__pivot_pos));
00397 
00398               // Find other end of pivot-equal range.
00399               __split_pos2 = __gnu_sequential::partition(__split_pos1 + 1,
00400                              __end, __pred);
00401             }
00402           else
00403             // Only skip the pivot.
00404             __split_pos2 = __split_pos1 + 1;
00405 
00406           // Compare iterators.
00407           if (__split_pos2 <= __nth)
00408             __begin = __split_pos2;
00409           else if (__nth < __split_pos1)
00410             __end = __split_pos1;
00411           else
00412             break;
00413     }
00414 
00415       // Only at most _Settings::partition_minimal_n __elements __left.
00416       __gnu_sequential::nth_element(__begin, __nth, __end, __comp);
00417     }
00418 
00419   /** @brief Parallel implementation of std::partial_sort().
00420   *  @param __begin Begin iterator of input sequence.
00421   *  @param __middle Sort until this position.
00422   *  @param __end End iterator of input sequence.
00423   *  @param __comp Comparator. */
00424   template<typename _RAIter, typename _Compare>
00425     void
00426     __parallel_partial_sort(_RAIter __begin,
00427                 _RAIter __middle,
00428                 _RAIter __end, _Compare __comp)
00429     {
00430       __parallel_nth_element(__begin, __middle, __end, __comp);
00431       std::sort(__begin, __middle, __comp);
00432     }
00433 
00434 } //namespace __gnu_parallel
00435 
00436 #undef _GLIBCXX_VOLATILE
00437 
00438 #endif /* _GLIBCXX_PARALLEL_PARTITION_H */