Skip to content

waitgroup

WaitGroup,由 Go 标准库提供,它的功能就是用于等待一组协程运行完毕。

go
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	for i := range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			fmt.Println(i)
		}()
	}
	wg.Wait()
}

这是一段非常简单的代码,它的功能就是开启 10 个协程打印 0-9,并等待它们运行完毕。它的用法不再赘述,接下来我们来了解下它的基本工作原理,一点也不复杂。

结构

它的类型定义位于sync/waitgroup.go文件中

go
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

字段释义如下:

  • state,表示 WaitGroup 的状态,高 32 位用于统计被等待协程的数量,低 32 位用于统计等待 wg 完成的协程数量。
  • sema,信号量,在sync标准库里它几乎无处不在。

它的核心就在于Add()Wait()这两个方法,基本工作原理就是信号量,Wait()方法尝试获取信号量,Add()方法释放信号量,来实现 M 个协程等待一组 N 个协程运行完毕。

Add

Add 方法就是增加需要等待协程的数量。

go
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

流程如下:

  1. 它首先会对wg.state进行移位操作,分别获取高 32 位和低 32 位,对应变量vw

    go
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
  2. 然后开始判断,v代表的是 wg 计数,w代表的等待 wg 完成的协程数量

    1. 如果v小于 0,直接panic,负数没有任何意义

      go
      if v < 0 {
          panic("sync: negative WaitGroup counter")
      }
    2. w不为 0,且deltav相等,表示Wait()方法与Add()方法被并发地调用,这是错误的使用方式

      go
      if w != 0 && delta > 0 && v == int32(delta) {
      	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
      }
    3. 如果v大于 0,或者w等于 0,表示现在没有等待 wg 完成的协程,可以直接返回

      go
      if v > 0 || w == 0 {
      	return
      }
  3. 走到这一步说明v等于 0,且w大于 0,即当前没有协程运行,但是有协程正在等待 wg 完成,所以就需要释放信号量,唤醒这些协程。

    go
    if wg.state.Load() != state {
    	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
    	runtime_Semrelease(&wg.sema, false, 0)
    }

Done()方法其实就是Add(-1),没有什么要讲的。

Wait

如果当前有其它协程需要等待运行完成,Wait方法的调用会使当前协程陷入阻塞。

go
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          return
       }
       // Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          return
       }
    }
}

它的流程就是一个 for 循环

  1. 读取高 32 位和低 32 位,得到需要被等待协程的数量,和等待协程的数量,如果没有协程需要等待,就直接返回

    go
    state := wg.state.Load()
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
    	return
    }
  2. 否则就通过 CAS 操作将等待协程数量加一,然后尝试获取信号量,进入阻塞等待队列

    go
    // Increment waiters count.
    if wg.state.CompareAndSwap(state, state+1) {
    	runtime_Semacquire(&wg.sema)
    	...
    }
  3. 当等待协程被唤醒后(因为所有被等待的协程都运行完毕了,释放了信号量),检查state ,如果不为 0,表示在Wait()Add() 又被并发的使用了

    go
    if wg.state.Load() != 0 {
    	panic("sync: WaitGroup is reused before previous Wait has returned")
    }
    return
  4. 如果 CAS 没有更新成功,则继续循环

小结

最后要提醒下,在使用WaitGroup时,AddWait不要并发的调用。

Golang学习网由www.golangdev.cn整理维护