waitgroup
WaitGroup,由 Go 標准庫提供,它的功能就是用於等待一組協程運行完畢。
package main
import (
"fmt"
"sync"
)
func main() {
var wg sync.WaitGroup
for i := range 10 {
wg.Add(1)
go func() {
defer wg.Done()
fmt.Println(i)
}()
}
wg.Wait()
}這是一段非常簡單的代碼,它的功能就是開啟 10 個協程打印 0-9,並等待它們運行完畢。它的用法不再贅述,接下來我們來了解下它的基本工作原理,一點也不復雜。
結構
它的類型定義位於sync/waitgroup.go文件中
type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}字段釋義如下:
state,表示 WaitGroup 的狀態,高 32 位用於統計被等待協程的數量,低 32 位用於統計等待 wg 完成的協程數量。sema,信號量,在sync標准庫裡它幾乎無處不在。
它的核心就在於Add()和Wait()這兩個方法,基本工作原理就是信號量,Wait()方法嘗試獲取信號量,Add()方法釋放信號量,來實現 M 個協程等待一組 N 個協程運行完畢。
Add
Add 方法就是增加需要等待協程的數量。
func (wg *WaitGroup) Add(delta int) {
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}流程如下:
它首先會對
wg.state進行移位操作,分別獲取高 32 位和低 32 位,對應變量v和wgostate := wg.state.Add(uint64(delta) << 32) v := int32(state >> 32) w := uint32(state)然後開始判斷,
v代表的是 wg 計數,w代表的等待 wg 完成的協程數量如果
v小於 0,直接panic,負數沒有任何意義goif v < 0 { panic("sync: negative WaitGroup counter") }w不為 0,且delta與v相等,表示Wait()方法與Add()方法被並發地調用,這是錯誤的使用方式goif w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") }如果
v大於 0,或者w等於 0,表示現在沒有等待 wg 完成的協程,可以直接返回goif v > 0 || w == 0 { return }
走到這一步說明
v等於 0,且w大於 0,即當前沒有協程運行,但是有協程正在等待 wg 完成,所以就需要釋放信號量,喚醒這些協程。goif wg.state.Load() != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } wg.state.Store(0) for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false, 0) }
Done()方法其實就是Add(-1),沒有什麼要講的。
Wait
如果當前有其它協程需要等待運行完成,Wait方法的調用會使當前協程陷入阻塞。
func (wg *WaitGroup) Wait() {
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
return
}
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}它的流程就是一個 for 循環
讀取高 32 位和低 32 位,得到需要被等待協程的數量,和等待協程的數量,如果沒有協程需要等待,就直接返回
gostate := wg.state.Load() v := int32(state >> 32) w := uint32(state) if v == 0 { return }否則就通過 CAS 操作將等待協程數量加一,然後嘗試獲取信號量,進入阻塞等待隊列
go// Increment waiters count. if wg.state.CompareAndSwap(state, state+1) { runtime_Semacquire(&wg.sema) ... }當等待協程被喚醒後(因為所有被等待的協程都運行完畢了,釋放了信號量),檢查
state,如果不為 0,表示在Wait()和Add()又被並發的使用了goif wg.state.Load() != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return如果 CAS 沒有更新成功,則繼續循環
小結
最後要提醒下,在使用WaitGroup時,Add和Wait不要並發的調用。
