Skip to content

waitgroup

WaitGroup, do thư viện chuẩn Go cung cấp, chức năng của nó là dùng để chờ một nhóm coroutine chạy xong.

go
package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	for i := range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			fmt.Println(i)
		}()
	}
	wg.Wait()
}

Đây là một đoạn code rất đơn giản, chức năng của nó là khởi tạo 10 coroutine in 0-9, và chờ chúng chạy xong. Cách sử dụng của nó sẽ không được nhắc lại ở đây, tiếp theo chúng ta sẽ tìm hiểu nguyên lý làm việc cơ bản của nó, hoàn toàn không phức tạp.

Cấu trúc

Định nghĩa loại của nó nằm trong file sync/waitgroup.go

go
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

Giải thích các trường như sau:

  • state, biểu thị trạng thái của WaitGroup, 32 bit cao dùng để thống kê số lượng coroutine được chờ, 32 bit thấp dùng để thống kê số lượng coroutine chờ wg hoàn thành.
  • sema, semaphore, trong thư viện chuẩn sync nó hầu như xuất hiện ở khắp mọi nơi.

Cốt lõi của nó nằm ở hai phương thức Add()Wait(), nguyên lý làm việc cơ bản là semaphore, phương thức Wait() thử lấy semaphore, phương thức Add() giải phóng semaphore, để thực hiện M coroutine chờ một nhóm N coroutine chạy xong.

Add

Phương thức Add là tăng số lượng coroutine cần chờ.

go
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
       return
    }
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

Quy trình như sau:

  1. Trước tiên nó sẽ thực hiện thao tác shift trên wg.state, lần lượt lấy 32 bit cao và 32 bit thấp, tương ứng với biến vw

    go
    state := wg.state.Add(uint64(delta) << 32)
    v := int32(state >> 32)
    w := uint32(state)
  2. Sau đó bắt đầu phán đoán, v đại biểu cho wg count, w đại biểu cho số lượng coroutine chờ wg hoàn thành

    1. Nếu v nhỏ hơn 0, trực tiếp panic, số âm không có ý nghĩa gì

      go
      if v < 0 {
          panic("sync: negative WaitGroup counter")
      }
    2. w khác 0, và delta bằng với v, biểu thị phương thức Wait() và phương thức Add() được gọi đồng thời, đây là cách sử dụng sai

      go
      if w != 0 && delta > 0 && v == int32(delta) {
      	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
      }
    3. Nếu v lớn hơn 0, hoặc w bằng 0, biểu thị hiện tại không có coroutine chờ wg hoàn thành, có thể trực tiếp trả về

      go
      if v > 0 || w == 0 {
      	return
      }
  3. Đi đến bước này biểu thị v bằng 0, và w lớn hơn 0, tức hiện tại không có coroutine chạy, nhưng có coroutine đang chờ wg hoàn thành, nên cần giải phóng semaphore, đánh thức các coroutine này.

    go
    if wg.state.Load() != state {
    	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    wg.state.Store(0)
    for ; w != 0; w-- {
    	runtime_Semrelease(&wg.sema, false, 0)
    }

Phương thức Done() thực ra chính là Add(-1), không có gì để nói.

Wait

Nếu hiện tại có coroutine khác cần chờ chạy xong, việc gọi phương thức Wait sẽ khiến coroutine hiện tại rơi vào trạng thái chặn.

go
func (wg *WaitGroup) Wait() {
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          return
       }
       // Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          return
       }
    }
}

Quy trình của nó là một vòng lặp for

  1. Đọc 32 bit cao và 32 bit thấp, lấy số lượng coroutine cần được chờ, và số lượng coroutine chờ, nếu không có coroutine cần chờ, thì trực tiếp trả về

    go
    state := wg.state.Load()
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
    	return
    }
  2. Nếu không thì thông qua thao tác CAS tăng số lượng coroutine chờ lên một, sau đó thử lấy semaphore, đi vào blocking wait queue

    go
    // Increment waiters count.
    if wg.state.CompareAndSwap(state, state+1) {
    	runtime_Semacquire(&wg.sema)
    	...
    }
  3. Khi coroutine chờ được đánh thức (vì tất cả coroutine được chờ đã chạy xong, giải phóng semaphore), kiểm tra state, nếu khác 0, biểu thị trong Wait()Add() lại được sử dụng đồng thời

    go
    if wg.state.Load() != 0 {
    	panic("sync: WaitGroup is reused before previous Wait has returned")
    }
    return
  4. Nếu CAS không cập nhật thành công, thì tiếp tục vòng lặp

Tóm tắt

Cuối cùng cần nhắc nhở, khi sử dụng WaitGroup, AddWait không nên được gọi đồng thời.

Golang by www.golangdev.cn edit