Skip to content

waitgroup

WaitGroup, provided by Go's standard library, is used to wait for a group of goroutines to complete.

go
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	for i := range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			fmt.Println(i)
		}()
	}
	wg.Wait()
}

This is a very simple piece of code. Its function is to start 10 goroutines to print 0-9 and wait for them to complete. Its usage won't be elaborated further here. Next, let's understand its basic working principle, which is not complex at all.

Structure

Its type definition is located in the sync/waitgroup.go file:

go
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

Field definitions are as follows:

  • state, represents WaitGroup's state. High 32 bits are used to count the number of goroutines being waited for, and low 32 bits are used to count the number of goroutines waiting for wg to complete.
  • sema, semaphore, which is almost ubiquitous in the sync standard library.

Its core lies in the Add() and Wait() methods. The basic working principle is semaphore. The Wait() method tries to acquire the semaphore, and the Add() method releases the semaphore to implement M goroutines waiting for a group of N goroutines to complete.

Add

The Add method increases the number of goroutines that need to be waited for.

go
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

The flow is as follows:

  1. It first performs shift operations on wg.state to get the high 32 bits and low 32 bits, corresponding to variables v and w:

    go
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
  2. Then it starts making judgments. v represents the wg counter, and w represents the number of goroutines waiting for wg to complete:

    1. If v is less than 0, directly panic. Negative numbers have no meaning:

      go
      if v < 0 {
          panic("sync: negative WaitGroup counter")
      }
    2. If w is not 0, and delta equals v, it means the Wait() method and Add() method are called concurrently, which is incorrect usage:

      go
      if w != 0 && delta > 0 && v == int32(delta) {
      	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
      }
    3. If v is greater than 0, or w equals 0, it means there are no goroutines waiting for wg to complete, so directly return:

      go
      if v > 0 || w == 0 {
      	return
      }
  3. Reaching this step means v equals 0 and w is greater than 0, i.e., no goroutines are currently running, but there are goroutines waiting for wg to complete. So it needs to release the semaphore and wake up these goroutines:

    go
    if wg.state.Load() != state {
    	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
    	runtime_Semrelease(&wg.sema, false, 0)
    }

The Done() method is actually just Add(-1), nothing more to explain.

Wait

If there are other goroutines that need to wait for completion, calling the Wait method will cause the current goroutine to become blocked.

go
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          return
       }
       // Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          return
       }
    }
}

Its flow is just a for loop:

  1. Read the high 32 bits and low 32 bits to get the number of goroutines to be waited for and the number of waiting goroutines. If no goroutines need to be waited for, return directly:

    go
    state := wg.state.Load()
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
    	return
    }
  2. Otherwise, use CAS operation to increment the waiting goroutine count by one, then try to acquire the semaphore and enter the blocking wait queue:

    go
    // Increment waiters count.
    if wg.state.CompareAndSwap(state, state+1) {
    	runtime_Semacquire(&wg.sema)
    	...
    }
  3. When the waiting goroutine is awakened (because all waited goroutines have completed and released the semaphore), check state. If it's not 0, it means Wait() and Add() are being used concurrently:

    go
    if wg.state.Load() != 0 {
    	panic("sync: WaitGroup is reused before previous Wait has returned")
    }
    return
  4. If the CAS update was not successful, continue looping.

Summary

Finally, a reminder: when using WaitGroup, do not call Add and Wait concurrently.

Golang by www.golangdev.cn edit