当前位置:网站首页>go源码之sync.Waitgroup

go源码之sync.Waitgroup

2022-08-02 11:31:00 Tomyang_

本文基于Go版本:1.17.8

go version go1.17.8 darwin/amd64 

什么是sync.WaitGroup

官方文档对于sync.WaitGroup的描述是:

  • WaitGroup等待一组 goroutines完成
  • 主goroutine调用 Add来设置
  • goroutines组等待,然后是每一个 goroutine在完成时运行并调用 Done。与此同时,等待可以用来阻塞,直到所有 goroutines完成。
sync.WaitGroup
package main

import (
 "fmt"
 "sync"
)

type Http interface {
 Get(string)
}

type httpPkg struct {}

func (h *httpPkg) Get(url string) {}

var (
 _ Http = (*httpPkg)(nil)
)

func main() {
 var (
  wg   sync.WaitGroup
  urls = []string{
   "http://www.golang.org/",
   "http://www.google.com/",
   "http://www.somestupidname.com/",
  }
 )
 for k := range urls {
  url := urls[k]
  wg.Add(1)
  go func(url string, k int) {
   defer wg.Done()
   pkg := new(httpPkg)
   pkg.Get(url)
   fmt.Println(k)
  }(url, k)
 }
 wg.Wait()
}

首先我们需要声明一个sync.WaitGroup对象,在主goroutine调用Add()方法设置要等待的goroutine数量,每一个goroutine在运行结束时调用Done()方法,同时使用Wait()方法进行阻塞直到所有goroutine完成。

为什么要用sync.waitGroup

在日常开发过程中提高接口响应时间,有一些场景需要在多个goruotine中做到互不影响的业务,这样可以节省出时间,但是需要协调多个goruotine,没有sync.waitGroup的时候,可以使用channel来解决这个问题。 案列:

package main

import "fmt"

func main() {
 exampleWaitGroup()
}

func exampleWaitGroup() {
 done := make(chan struct{})
 count := 10
 for i := 0; i < count; i++ {
  go func(i int) {
   defer func() {
    done <- struct{}{}
   }()
   fmt.Printf("小弟%d收取保护费\n", i)
  }(i)
 }
 for i := 0; i < count; i++ {
  select {
  case <-done:
   fmt.Printf("小弟%d号已经收取完保护费\n", i)
  }
 }
 fmt.Println("所有小弟已经干完活了,开房去了~")
}

虽然这样可以实现,但是每次使用都要保证主goruotine在最后从通道接收的次数需要与之前其它goruotine发送元素的次数相同。在这种场景下我们就可以选用sync.WaitGroup来帮助实现同步。

源码解析

看看sync.WaitGroup的结构

type WaitGroup struct {
 noCopy noCopy
 //state1 分配12个字节, 被设计了三种状态
 // 其中对齐的8个字节作为状态,高32位为计数的数量,低32位为等待goroutine数量
 // 其中的4个字节作为信号量存储
 state1 [3]uint32
}
  • noCopy为了保证该结构体不会被进行拷贝的一种保护机制。
  • state1 主要存储着状态和信号量。 它这里被分配了 12字节:
func main() {
 var (
  state1 [3]uint32
 )
  //内存长度为12个字节
 fmt.Printf("state1:%T,内存长度:%d", state1, unsafe.Sizeof(state1))
  
}
  • 其中对齐的 8个字节作为状态,高 32位为计数的数量,低 32位为等待的 goruotine数量。
  • 其中的 4个字节作为信号量的存储。 源码包提供函数可以 state1字段中取出它的 状态信号量
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
    return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
   } else {
    return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
   }
}

为什么这么设计呢? 因为64位原子操作是需要64位对齐,但32位编辑器不能保证这一点,所以为了保证waitGroup32位平台上使用的话,就必须保证在任何时候,64位操作不会报错。考虑到字段顺序不同,平台不同,内存对齐也就不同,因此这里采用动态识别当前操作的64位数到底是不是在8字节对齐的位置上。 数组的首地址是处于一个8字节对齐的位置上,那么就将这个数组的前8个字节作为64位值使用表示状态,后4个字节作为32位值表示信号(signal),如果首地址没有处于8字节对齐的位置上,那么将前4个字节作为信号(signal),后8个字节作为64位数值。

Add()

func (wg *WaitGroup) Add(delta int) {
  // 获取状态 和信号量
 statep, semap := wg.state()
 if race.Enabled {
  _ = *statep // trigger nil deref early
  if delta < 0 {
   // 同步等待
   race.ReleaseMerge(unsafe.Pointer(wg))
  }
  race.Disable()
  defer race.Enable()
 }
  // 原子操作, goroutine count 累加
 state := atomic.AddUint64(statep, uint64(delta)<<32)
  // 当前 goroutine count的值(高32位)
 v := int32(state >> 32)
  // 当前wait count的值(低32位)
 w := uint32(state)
 if race.Enabled && delta > 0 && v == int32(delta) {
//第一个增量必须同步等。, 
//需要模型作为一个阅读,因为可以有, 
//几个并发工作组。对抗转换从0。
  race.Read(unsafe.Pointer(semap))
 }
  // goroutine count 是不允许为负数。
 if v < 0 {
  panic("sync: negative WaitGroup counter")
 }
  // 当wait的 goroutine不为0时,累加后的值与delta 相等, 说明Add()和Wait()同时调用,
  // 会触发panic 正确的调用方法,先Add()后Wait(),也就是已经调用Wait()就不允许再添加任务了。
 if w != 0 && delta > 0 && v == int32(delta) {
  panic("sync: WaitGroup misuse: Add called concurrently with Wait")
 }
  // 正常Add()方法后, goroutine 计数大于0或者 Wait 计数等于0时,这时是不需要释放信号量。
 if v > 0 || w == 0 {
  return
 }
  // 当前 goroutine 计数为0 Wait 计数大于0,就会触发panic
 if *statep != state {
  panic("sync: WaitGroup misuse: Add called concurrently with Wait")
 }
 // 重置 状态, 并发出信号量告诉Wait所有任务已经完成。
 *statep = 0
 for ; w != 0; w-- {
  runtime_Semrelease(semap, false0)
 }
}

Wait()

func (wg *WaitGroup) Wait() {
 // 获取状态 和信号量
 statep, semap := wg.state()
 if race.Enabled {
  _ = *statep // trigger nil deref early
  race.Disable()
 }
 for {
  // 使用原子操作读取state, 保证Add中写入操作已完成
  state := atomic.LoadUint64(statep)
  // 获取当前goroutine counter的值(高32位)
  v := int32(state >> 32)
  // 获取当前Wait counter的值(低32位)
  w := uint32(state)
  // 如果没有任务或者任务已经调用Wait方法前已经执行完成,就不用阻塞操作。
  if v == 0 {
   if race.Enabled {
    race.Enable()
    race.Acquire(unsafe.Pointer(wg))
   }
   return
  }
  // 使用CAS(比较与交换)操作 Wait Counter 计数器进行+1操作,外面有for循环保证这里可以进行重试操作
  if atomic.CompareAndSwapUint64(statep, state, state+1) {
   if race.Enabled && w == 0 {
    //等待必须与第一个添加同步
    //需要模型这是写与读比赛添加
    //结果,可以只编写第一个服务
    //否则并发等待会互相竞争
    race.Write(unsafe.Pointer(semap))
   }
   // 使用信号量,是协程进入睡眠状态,与Add()中最后的增加信号量相对应,也就是当最后一个任务调用Done方法
   // 后会调用Add方法对goroutine Counter的值减到0,就会走到最后的增加信号量

   runtime_Semacquire(semap)
   // 在Add方法中增加信号量时已经将statep的值设为0了,如果这里不0,说明Wait之后又调用了Add方法, 使用错误就会触发panic
   if *statep != 0 {
    panic("sync: WaitGroup is reused before previous Wait has returned")
   }
   if race.Enabled {
    race.Enable()
    race.Acquire(unsafe.Pointer(wg))
   }
   return
  }
 }
}

源码总结:

  • Add方法与 Wait方法不可并发调用,Add方法要在Wait方法之前调用
  • Add设置的值必须与实际等待 goroutine个数一致,否则会 panic
  • Done只是对 Add方法的简单封装,可以向 Add方法传入任意负值 (保证计数器非负),可以将计数器归零以唤醒等待的 goroutine
  • WaitGroup结构只能有一份,不可以拷贝給其它变量。 WaitGroup结构有一个 nocopy字段。

nocopy 字段

WaitGroup结构中,有一个nocopy字段,为什么要有nocopy?可以先看看结构体拷贝的:

type User1 struct {
 Name string
 Info *Info
}

type User2 struct {
 Name string
 Info Info
}

type Info struct {
 Age    int
 Number int
}

func main() {
 u := User1{
  Name: "Tom",
  Info: &Info{
   Age:    10,
   Number: 24,
  },
 }

 u1 := u
 u1.Info.Age = 100
 fmt.Printf("user类型:%v  %+v %s", unsafe.Pointer(&u), u.Info, "\n")
 fmt.Printf("user1类型:%v %+v %s", unsafe.Pointer(&u1), u1.Info, "\n")
  //user类型:0xc00000c030  &{Age:100 Number:24} 
  //user1类型:0xc00000c048 &{Age:100 Number:24} 
  
  //无指针结构进行拷贝
 u2 := User2{
  Name: "TomYang",
  Info: Info{
   Age:    10,
   Number: 24,
  },
 }
 u3 := u2
 u3.Info.Age = 110
 fmt.Printf("user2类型:%v %+v %s", unsafe.Pointer(&u2), u2.Info, "\n")
 fmt.Printf("user3类型:%v %+v %s", unsafe.Pointer(&u2), u3.Info, "\n")
  //user2类型:0xc000060020 {Age:10 Number:24} 
  //user3类型:0xc000060020 {Age:110 Number:24} 
}

结构体User1中有两个字段NameInfo结构体,Name是string,Info是指向结构体Info的指针类型,代码中先声明变量u变量,针对它进行复制拷贝到变量u1,在针对u1中两个字段进行改变,可以看到Info.Age字段发生更改。这就是引发了安全问题,如果结构体对象包含指针字段,当该对结构体拷贝时,会使用两个结构体重的指针字段变得不再安全

原网站

版权声明
本文为[Tomyang_]所创,转载请带上原文链接,感谢
https://mdnice.com/writing/64e486d73a7e49f6bd1b3216bdcdc830