Skip to content

waitgroup

WaitGroup,由 Go 標准庫提供,它的功能就是用於等待一組協程運行完畢。

go
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	for i := range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			fmt.Println(i)
		}()
	}
	wg.Wait()
}

這是一段非常簡單的代碼,它的功能就是開啟 10 個協程打印 0-9,並等待它們運行完畢。它的用法不再贅述,接下來我們來了解下它的基本工作原理,一點也不復雜。

結構

它的類型定義位於sync/waitgroup.go文件中

go
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

字段釋義如下:

  • state,表示 WaitGroup 的狀態,高 32 位用於統計被等待協程的數量,低 32 位用於統計等待 wg 完成的協程數量。
  • sema,信號量,在sync標准庫裡它幾乎無處不在。

它的核心就在於Add()Wait()這兩個方法,基本工作原理就是信號量,Wait()方法嘗試獲取信號量,Add()方法釋放信號量,來實現 M 個協程等待一組 N 個協程運行完畢。

Add

Add 方法就是增加需要等待協程的數量。

go
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

流程如下:

  1. 它首先會對wg.state進行移位操作,分別獲取高 32 位和低 32 位,對應變量vw

    go
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
  2. 然後開始判斷,v代表的是 wg 計數,w代表的等待 wg 完成的協程數量

    1. 如果v小於 0,直接panic,負數沒有任何意義

      go
      if v < 0 {
          panic("sync: negative WaitGroup counter")
      }
    2. w不為 0,且deltav相等,表示Wait()方法與Add()方法被並發地調用,這是錯誤的使用方式

      go
      if w != 0 && delta > 0 && v == int32(delta) {
      	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
      }
    3. 如果v大於 0,或者w等於 0,表示現在沒有等待 wg 完成的協程,可以直接返回

      go
      if v > 0 || w == 0 {
      	return
      }
  3. 走到這一步說明v等於 0,且w大於 0,即當前沒有協程運行,但是有協程正在等待 wg 完成,所以就需要釋放信號量,喚醒這些協程。

    go
    if wg.state.Load() != state {
    	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
    	runtime_Semrelease(&wg.sema, false, 0)
    }

Done()方法其實就是Add(-1),沒有什麼要講的。

Wait

如果當前有其它協程需要等待運行完成,Wait方法的調用會使當前協程陷入阻塞。

go
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          return
       }
       // Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          return
       }
    }
}

它的流程就是一個 for 循環

  1. 讀取高 32 位和低 32 位,得到需要被等待協程的數量,和等待協程的數量,如果沒有協程需要等待,就直接返回

    go
    state := wg.state.Load()
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
    	return
    }
  2. 否則就通過 CAS 操作將等待協程數量加一,然後嘗試獲取信號量,進入阻塞等待隊列

    go
    // Increment waiters count.
    if wg.state.CompareAndSwap(state, state+1) {
    	runtime_Semacquire(&wg.sema)
    	...
    }
  3. 當等待協程被喚醒後(因為所有被等待的協程都運行完畢了,釋放了信號量),檢查state ,如果不為 0,表示在Wait()Add() 又被並發的使用了

    go
    if wg.state.Load() != 0 {
    	panic("sync: WaitGroup is reused before previous Wait has returned")
    }
    return
  4. 如果 CAS 沒有更新成功,則繼續循環

小結

最後要提醒下,在使用WaitGroup時,AddWait不要並發的調用。

Golang學習網由www.golangdev.cn整理維護