1 |
|
2 |
|
3 | import EventEmitter from 'events'
|
4 | import autoBind from 'auto-bind'
|
5 |
|
6 | import AsyncQueue, { QueueClosedError } from './AsyncQueue'
|
7 | import { OneTimeBroadcastEvent, Semaphore } from './asyncSyncUtils'
|
8 | import { ExtendedError } from './errorUtils'
|
9 |
|
10 |
|
11 | type ThreadPoolOptions<T, R> = {
|
12 | name?: ?string,
|
13 | threads?: ?number,
|
14 | items?: ?Array<T>,
|
15 | task: (item: T) => Promise<R>,
|
16 | errorHandler?: ?(err: Error) => void,
|
17 | queueMaxSize?: ?number,
|
18 | }
|
19 |
|
20 | type Task<R> = {
|
21 | func: () => Promise<R>,
|
22 | index: number,
|
23 | }
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 | export default class ThreadPool<T, R> extends EventEmitter {
|
33 |
|
34 | queuedCount: number = 0
|
35 | startedCount: number = 0
|
36 | endedCount: number = 0
|
37 |
|
38 | _options: ThreadPoolOptions<T, R>
|
39 | _errorHandler: (err: Error) => void
|
40 | _uncaughtErrors: Array<Error> = []
|
41 | _closed: boolean = false
|
42 | _queuedTasks: AsyncQueue<{ func: () => Promise<R>, index: number }>
|
43 | _threadsSemaphore: Semaphore
|
44 | _completeEvent: OneTimeBroadcastEvent
|
45 | _allTasksCompleteOrSomeFailedEvent: OneTimeBroadcastEvent
|
46 | _res: Array<R> = []
|
47 |
|
48 |
|
49 | static async run(options: ThreadPoolOptions<T, R>): Promise<Array<R>> {
|
50 | options.errorHandler = options.errorHandler || ((err) => { throw err })
|
51 | const tp = new ThreadPool(options)
|
52 | if (options.items) {
|
53 | await tp.queueItems(options.items)
|
54 | }
|
55 | return await tp.runAllQueued()
|
56 | }
|
57 |
|
58 | static async all(items: Array<T>, task: (item: T) => Promise<R>, threads: ?number): Promise<Array<R>> {
|
59 | const tp = new ThreadPool({
|
60 | task: task,
|
61 | threads: threads,
|
62 | items: items,
|
63 | })
|
64 | if (items) {
|
65 | await tp.queueItems(items)
|
66 | }
|
67 | return await tp.runAllQueued()
|
68 | }
|
69 |
|
70 |
|
71 | |
72 |
|
73 |
|
74 | constructor(options: ThreadPoolOptions<T, R>) {
|
75 | super()
|
76 |
|
77 | this._options = options
|
78 | this._options.threads = this._options.threads || Infinity
|
79 |
|
80 | this._errorHandler = this._options.errorHandler || ((err) => { throw err })
|
81 |
|
82 | this._queuedTasks = new AsyncQueue({
|
83 | name: options.name,
|
84 | maxSize: options.queueMaxSize,
|
85 | })
|
86 | this._threadsSemaphore = new Semaphore(this._options.threads)
|
87 | this._completeEvent = new OneTimeBroadcastEvent(false)
|
88 | this._allTasksCompleteOrSomeFailedEvent = new OneTimeBroadcastEvent(false)
|
89 |
|
90 | autoBind(this)
|
91 | }
|
92 |
|
93 | |
94 |
|
95 |
|
96 | async queueItem(item: T): Promise<void> {
|
97 | if (this._closed)
|
98 | throw new Error(`Trying to queue a job to a closed ThreadPool`)
|
99 |
|
100 | const index = this.queuedCount++
|
101 | await this._queuedTasks.enqueue({
|
102 | func: async () => this._res[index] = await this._options.task(item),
|
103 | index: index,
|
104 | })
|
105 | }
|
106 |
|
107 | |
108 |
|
109 |
|
110 | async queueItems(queueItem: Array<T>): Promise<void> {
|
111 | for (const item of queueItem) {
|
112 | await this.queueItem(item)
|
113 | }
|
114 | }
|
115 |
|
116 | |
117 |
|
118 |
|
119 |
|
120 |
|
121 | async run() {
|
122 | try {
|
123 | while (true) {
|
124 |
|
125 |
|
126 |
|
127 | let task
|
128 | try {
|
129 | task = await this._queuedTasks.dequeue()
|
130 | } catch (err) {
|
131 | if (err instanceof QueueClosedError)
|
132 | break
|
133 | throw err
|
134 | }
|
135 |
|
136 |
|
137 | await this._threadsSemaphore.enter()
|
138 | if (this._uncaughtErrors.length) {
|
139 | break
|
140 | }
|
141 |
|
142 | this._allTasksCompleteOrSomeFailedEvent.reset()
|
143 |
|
144 |
|
145 | this._runTask(task)
|
146 | }
|
147 |
|
148 |
|
149 | if (this.startedCount > 0) {
|
150 | await this._allTasksCompleteOrSomeFailedEvent.wait()
|
151 | }
|
152 |
|
153 | this._throwUncaughtErrors()
|
154 | } finally {
|
155 | this._completeEvent.signal()
|
156 | }
|
157 | }
|
158 |
|
159 | |
160 |
|
161 |
|
162 |
|
163 |
|
164 | startRun() {
|
165 | (async () => {
|
166 | try {
|
167 | await this.run()
|
168 | } catch (err) {}
|
169 | })()
|
170 | }
|
171 |
|
172 | |
173 |
|
174 |
|
175 | close() {
|
176 | this._closed = true
|
177 | this._queuedTasks.close()
|
178 |
|
179 | }
|
180 |
|
181 | |
182 |
|
183 |
|
184 | async runAllQueued(): Promise<Array<R>> {
|
185 | this.close()
|
186 | await this.run()
|
187 | return this._res
|
188 | }
|
189 |
|
190 | |
191 |
|
192 |
|
193 | async waitComplete(): Promise<Array<R>> {
|
194 | await this._completeEvent.wait()
|
195 | this._throwUncaughtErrors()
|
196 | return this._res
|
197 | }
|
198 |
|
199 | |
200 |
|
201 |
|
202 | async closeAndWaitComplete(): Promise<Array<R>> {
|
203 | this.close()
|
204 | return await this.waitComplete()
|
205 | }
|
206 |
|
207 |
|
208 |
|
209 | _throwUncaughtErrors() {
|
210 | if (this._uncaughtErrors.length > 0) {
|
211 | throw new ExtendedError(`Errors were thrown during execution of ThreadPool`, {
|
212 | threadPoolName: this._options.name,
|
213 | errorCount: this._uncaughtErrors.length,
|
214 | errorMessages: this._uncaughtErrors.map(e => e.message),
|
215 | uncaughtErrors: this._uncaughtErrors,
|
216 | })
|
217 | }
|
218 | }
|
219 |
|
220 | async _runTask(task: Task<R>) {
|
221 | try {
|
222 | this.startedCount++
|
223 | await task.func()
|
224 | } catch (err) {
|
225 | try {
|
226 | this._errorHandler(err)
|
227 | } catch (err2) {
|
228 | this._uncaughtErrors.push(err2)
|
229 | await this._allTasksCompleteOrSomeFailedEvent.signal()
|
230 | }
|
231 | } finally {
|
232 | this.endedCount++
|
233 |
|
234 | this.emit('progress', { endedCount: this.endedCount })
|
235 | this._threadsSemaphore.exit()
|
236 | if (this._threadsSemaphore.takenCount === 0) {
|
237 | this._allTasksCompleteOrSomeFailedEvent.signal()
|
238 | }
|
239 | }
|
240 | }
|
241 | }
|