前言

sync.waitgroup 是一种用于同步的工具,它允许一组goroutine在继续执行之前等待其他goroutine完成。WaitGroup.Done方法和Wait方法共同工作来实现这种同步。

  • WaitGroup.Add方法用于初始化计数器,表示需要等待的goroutine的数量。
  • 每个goroutine完成时调用WaitGroup.Done方法,减少计数器。
  • WaitGroup.Wait方法阻塞调用它的goroutine,直到WaitGroup维护的计数器为零。(state字段)
    官方文档中有这样一段注释:
    In the terminology of the Go memory model, a call to WaitGroup.Done “synchronizes before” the return of any Wait call that it unblocks.
    synchronizes before这个术语是Go内存模型(Go Memory Model, GMM)的一部分,它描述了程序中不同部分之间的内存可见性。
    在Go内存模型中,如果一个操作synchronizes before另一个操作,那么第一个操作对内存的修改对第二个操作是可见的。这意味着,如果一个goroutine执行了Done方法,那么任何被这个Done调用解除阻塞的Wait调用都能看到Done之前的所有内存修改。
    也即: 当你在WaitGroup上调用Done时,WaitGroup计数器减一。如果WaitGroup的计数器因此变为零,那么所有因为WaitGroup计数器不为零而阻塞在Wait调用上的goroutine将被解除阻塞。根据Go内存模型,这些goroutineWait返回后能够看到Done调用之前的所有内存状态,包括对共享变量的修改。

使用

源码解析

sync.WaitGroup的定义如下,提供了Add|Wait|Done三个接口。

1
2
3
4
5
6
type WaitGroup struct {
noCopy noCopy

state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}

WaitGroup中包含三个字段:

  • noCopy用以标识WaitGroup被使用后不允许被copy
  • state是64位无符号整形数,其中高位32表示计数器,低32位表示等待的个数。后续代码中可以看到这样的操作v := int32(state >> 32)w := uint32(state);其中>> 表示算数右移。以此获取到state的高32位,也即计数器值。而低w则获取其低32位表示阻塞的协程数量。
    当counter计数器为0时,wait操作不会增加等待的协程数量。
  • sema: 32位无符号数

WaitGroup.Add

WaitGroup.Add用来增加或者减少计数器,delta可以是整数也可以是负数。如果调用Add后,计数器变为零,那么所有因为 Wait 方法而阻塞的 goroutine 将被释放.
如果Add方法使得计数器为0,则panic
调用时机:

  • 当计数器为零时,如果 delta是正数,那么这个 Add调用必须在调用Wait方法之前发生。因为如果计数器为0时, Wait方法会立即返回。
  • 当计数器大于0,或者delta为负数,则可以在任意时刻调用Add
    所以,通常而言,Add方法的调用应该在创建goroutine或等待其他事件之前执行。这样可以确保WaitGroup正确地跟踪并发操作的数量.

WaitGroup可以被重用,但是前提是所有wait返回后才能被新的Add调用。

如下是WaitGroup.Add的源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
func (wg *WaitGroup) Add(delta int) {
if race.Enabled {
if delta < 0 {
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(&wg.sema))
}
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}

1-8行,判断当前程序是否允许竞争,如果允许且当前传入的delta小于0,则调用race.ReleaseMerge传入当前WaitGroup的指针。
ReleaseMerge的作用是告知race detector,当前的减少计数器操作(Release)与 Wait 方法中的等待操作是相关的,需要同步。这样可以确保race detector能够正确地检测到可能的数据竞争
第6行,关闭数据竞争。
第7行,在函数返回时重新启用数据竞争

第9行, 使用 wg.state.Add(uint64(delta) << 32)增加计数器的值
第10-11行,分别用v和w描述计数器和等待的协程数。
第12-14行,如果开启了数据竞争,且delta大于0,且当前计数器的值等于delta,则调用race.Read并传入WaitGroup实例sema字段的指针。sema是用于同步的信号量。
race.Read函数用于告诉race detector,当前的增加计数器操作(Add 方法)需要与Wait方法中的等待操作同步。这是因为可能有多个goroutine同时尝试将WaitGroup 的计数器从0增加到正数,这种情况下需要确保内存的可见性。

第15-21行,如果 计数器小于0,则panic;如果等待协程数量不为0,且delta>0并且计数器的值和delta相同,则panic。
第22行,如果计数器大于0,但是等待协程数为0,则返回。
第23行 检查waitGroup.state的状态是否符合预期,如果不符合预期,说明有其他goroutine正在并发地调用Add方法,这会导致数据竞争将会panic; ;
第24行将WaitGroup的状态重置为 0。Store方法用于安全地更新atomic.Value类型的变量。此时所有等待的goroutine都将被释放。
低25-27行 使用 runtime_Semrelease函数释放等待的goroutineruntime_Semreleaseruntime包中用于释放信号量的函数。&wg.semaWaitGroup的信号量字段的地址,false 表示这是一个正常的释放操作,0 是一个标志,表示释放的数量。循环会一直执行,直到 w 减到 0 为止。

WaitGroup.Done

通常在使用的时候都是调用Done()减少计数器,其源码如下,实际上就是调用Add方法,传入delta=-1

1
2
3
func (wg *WaitGroup) Done() {
wg.Add(-1)
}

WaitGroup.Wait

此方法的调用者将会一直阻塞直至WaitGroup的计数器为0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
func (wg *WaitGroup) Wait() {
if race.Enabled {
race.Disable()
}
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(&wg.sema))
}
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}

源代码和Add非常类似,使用位运算操作,获取WaitGroup的计数器和等待的协程个数。如果当前计数器为0,且开启数据竞争,则开启当前线程的数据竞争检测
使用race.Acquire标记WaitGroup对象的获取。如果此时计数器的值为0,则直接返回。
否则,尝试将等待计数器的个数加一,如果开启了数据竞争检测,且当前协程是第一个等待协程吗,则使用race.Write标记信号量sema的写入,以同步第一个添加操作.
调用runtime包的Semacquire函数,阻塞当前goroutine直到信号量sema可用.
检查WaitGroup的计数器个数是否为0,如果不为0,则表示当前WaitGroup被重用,抛出panic
最后,如果启用了数据竞争检测,则重新开启数据竞争,再次使用race.Acquire标记WaitGroup对象的获取。