Skip to content

waitgroup

WaitGroup, предоставляемый стандартной библиотекой Go, используется для ожидания завершения группы goroutine.

go
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	for i := range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			fmt.Println(i)
		}()
	}
	wg.Wait()
}

Это очень простой код. Его функция — запустить 10 goroutine для печати 0-9 и ждать их завершения. Его использование не будет здесь подробно рассматриваться. Далее давайте разберёмся с его базовым принципом работы, который совсем не сложен.

Структура

Определение его типа расположено в файле sync/waitgroup.go:

go
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // высокие 32 бита — счётчик, низкие 32 бита — счётчик ожидающих.
	sema  uint32
}

Определения полей следующие:

  • state, представляет состояние WaitGroup. Высокие 32 бита используются для подсчёта количества goroutine, которых ожидают, а низкие 32 бита используются для подсчёта количества goroutine, ожидающих завершения wg.
  • sema, семафор, который почти повсеместен в стандартной библиотеке sync.

Его ядро лежит в методах Add() и Wait(). Базовый принцип работы — семафор. Метод Wait() пытается приобрести семафор, а метод Add() освобождает семафор для реализации ожидания M goroutine завершения группы N goroutine.

Add

Метод Add увеличивает количество goroutine, которые должны быть ожиданы.

go
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

Поток следующий:

  1. Сначала выполняет операции сдвига на wg.state для получения высоких 32 бит и низких 32 бит, соответствующих переменным v и w:

    go
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
  2. Затем начинает делать проверки. v представляет счётчик wg, а w представляет количество goroutine, ожидающих завершения wg:

    1. Если v меньше 0, напрямую panic. Отрицательные числа не имеют смысла:

      go
      if v < 0 {
          panic("sync: negative WaitGroup counter")
      }
    2. Если w не равно 0, и delta равно v, это означает, что метод Wait() и метод Add() вызываются одновременно, что является неправильным использованием:

      go
      if w != 0 && delta > 0 && v == int32(delta) {
      	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
      }
    3. Если v больше 0, или w равно 0, это означает, что нет goroutine, ожидающих завершения wg, поэтому напрямую возвращается:

      go
      if v > 0 || w == 0 {
      	return
      }
  3. Достижение этого шага означает, что v равно 0 и w больше 0, т.е. в настоящее время нет работающих goroutine, но есть goroutine, ожидающие завершения wg. Поэтому нужно освободить семафор и пробудить эти goroutine:

    go
    if wg.state.Load() != state {
    	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
    	runtime_Semrelease(&wg.sema, false, 0)
    }

Метод Done() на самом деле просто Add(-1), больше нечего объяснять.

Wait

Если есть другие goroutine, которые должны ждать завершения, вызов метода Wait заставит текущую goroutine стать заблокированной.

go
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          return
       }
       // Увеличить счётчик ожидающих.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          return
       }
    }
}

Его поток — это просто цикл for:

  1. Читает высокие 32 бита и низкие 32 бита для получения количества goroutine, которых нужно ожидать, и количества ожидающих goroutine. Если нет goroutine, которых нужно ожидать, напрямую возвращается:

    go
    state := wg.state.Load()
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
    	return
    }
  2. В противном случае использует операцию CAS для инкрементирования счётчика ожидающих goroutine на один, затем пытается приобрести семафор и входит в очередь блокирующего ожидания:

    go
    // Увеличить счётчик ожидающих.
    if wg.state.CompareAndSwap(state, state+1) {
    	runtime_Semacquire(&wg.sema)
    	...
    }
  3. Когда ожидающая goroutine пробуждена (потому что все ожидаемые goroutine завершены и освободили семафор), проверяет state. Если он не равен 0, это означает, что Wait() и Add() используются одновременно:

    go
    if wg.state.Load() != 0 {
    	panic("sync: WaitGroup is reused before previous Wait has returned")
    }
    return
  4. Если обновление CAS не удалось, продолжает циклически выполнять.

Итоги

Наконец, напоминание: при использовании WaitGroup не вызывайте Add и Wait одновременно.

Golang by www.golangdev.cn edit