fix: 解决 UFW 防火墙并发导致的批量操作失败的问题 (#2483)

pull/2485/head
ssongliu 2023-10-09 19:38:31 +08:00 committed by GitHub
parent a4bd9362ed
commit 150b1a2590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 50 additions and 75 deletions

View File

@ -212,7 +212,6 @@ func (u *FirewallService) OperatePortRule(req dto.PortRuleOperate, reload bool)
protos := strings.Split(req.Protocol, "/")
itemAddress := strings.Split(strings.TrimSuffix(req.Address, ","), ",")
var wg sync.WaitGroup
if client.Name() == "ufw" {
if strings.Contains(req.Port, ",") || strings.Contains(req.Port, "-") {
for _, proto := range protos {
@ -223,17 +222,12 @@ func (u *FirewallService) OperatePortRule(req dto.PortRuleOperate, reload bool)
req.Address = addr
req.Port = strings.ReplaceAll(req.Port, "-", ":")
req.Protocol = proto
wg.Add(1)
go func(req dto.PortRuleOperate) {
defer wg.Done()
if err := u.operatePort(client, req); err != nil {
global.LOG.Errorf("%s port %s/%s failed (strategy: %s, address: %s), err: %v", req.Operation, req.Port, req.Protocol, req.Strategy, req.Address, err)
}
_ = u.addPortRecord(req)
}(req)
}
}
wg.Wait()
return nil
}
if req.Protocol == "tcp/udp" {
@ -244,16 +238,11 @@ func (u *FirewallService) OperatePortRule(req dto.PortRuleOperate, reload bool)
addr = "Anywhere"
}
req.Address = addr
wg.Add(1)
go func(req dto.PortRuleOperate) {
defer wg.Done()
if err := u.operatePort(client, req); err != nil {
global.LOG.Errorf("%s port %s/%s failed (strategy: %s, address: %s), err: %v", req.Operation, req.Port, req.Protocol, req.Strategy, req.Address, err)
}
_ = u.addPortRecord(req)
}(req)
}
wg.Wait()
return nil
}
@ -263,14 +252,10 @@ func (u *FirewallService) OperatePortRule(req dto.PortRuleOperate, reload bool)
for _, addr := range itemAddress {
req.Protocol = proto
req.Address = addr
wg.Add(1)
go func(req dto.PortRuleOperate) {
defer wg.Done()
if err := u.operatePort(client, req); err != nil {
global.LOG.Errorf("%s port %s/%s failed (strategy: %s, address: %s), err: %v", req.Operation, req.Port, req.Protocol, req.Strategy, req.Address, err)
}
_ = u.addPortRecord(req)
}(req)
}
} else {
ports := strings.Split(itemPorts, ",")
@ -282,20 +267,15 @@ func (u *FirewallService) OperatePortRule(req dto.PortRuleOperate, reload bool)
req.Address = addr
req.Port = port
req.Protocol = proto
wg.Add(1)
go func(req dto.PortRuleOperate) {
defer wg.Done()
if err := u.operatePort(client, req); err != nil {
global.LOG.Errorf("%s port %s/%s failed (strategy: %s, address: %s), err: %v", req.Operation, req.Port, req.Protocol, req.Strategy, req.Address, err)
}
_ = u.addPortRecord(req)
}(req)
}
}
}
}
wg.Wait()
if reload {
return client.Reload()
}
@ -313,24 +293,18 @@ func (u *FirewallService) OperateAddressRule(req dto.AddrRuleOperate, reload boo
return err
}
var wg sync.WaitGroup
addressList := strings.Split(req.Address, ",")
for i := 0; i < len(addressList); i++ {
if len(addressList[i]) == 0 {
continue
}
wg.Add(1)
go func(addr string) {
defer wg.Done()
fireInfo.Address = addr
fireInfo.Address = addressList[i]
if err := client.RichRules(fireInfo, req.Operation); err != nil {
global.LOG.Errorf("%s address %s failed (strategy: %s), err: %v", req.Operation, addr, req.Strategy, err)
global.LOG.Errorf("%s address %s failed (strategy: %s), err: %v", req.Operation, addressList[i], req.Strategy, err)
}
req.Address = addr
req.Address = addressList[i]
_ = u.addAddressRecord(req)
}(addressList[i])
}
wg.Wait()
if reload {
return client.Reload()
}
@ -378,27 +352,16 @@ func (u *FirewallService) BatchOperateRule(req dto.BatchRuleOperate) error {
if err != nil {
return err
}
var wgBatch sync.WaitGroup
if req.Type == "port" {
for _, rule := range req.Rules {
wgBatch.Add(1)
go func(item dto.PortRuleOperate) {
defer wgBatch.Done()
_ = u.OperatePortRule(item, false)
}(rule)
_ = u.OperatePortRule(rule, false)
}
wgBatch.Wait()
return client.Reload()
}
for _, rule := range req.Rules {
itemRule := dto.AddrRuleOperate{Operation: rule.Operation, Address: rule.Address, Strategy: rule.Strategy}
wgBatch.Add(1)
go func(item dto.AddrRuleOperate) {
defer wgBatch.Done()
_ = u.OperateAddressRule(item, false)
}(itemRule)
_ = u.OperateAddressRule(itemRule, false)
}
wgBatch.Wait()
return client.Reload()
}
@ -615,6 +578,8 @@ func listIpRules(strategy string) ([]string, error) {
}
func checkPortUsed(ports, proto string, apps []portOfApp) string {
var portList []int
if strings.Contains(ports, "-") || strings.Contains(ports, ",") {
if strings.Contains(ports, "-") {
port1, err := strconv.Atoi(strings.Split(ports, "-")[0])
if err != nil {
@ -626,10 +591,20 @@ func checkPortUsed(ports, proto string, apps []portOfApp) string {
global.LOG.Errorf(" convert string %s to int failed, err: %v", strings.Split(ports, "-")[1], err)
return ""
}
for i := port1; i <= port2; i++ {
portList = append(portList, i)
}
} else {
portLists := strings.Split(ports, ",")
for _, item := range portLists {
portItem, _ := strconv.Atoi(item)
portList = append(portList, portItem)
}
}
var usedPorts []string
for i := port1; i <= port2; i++ {
portItem := fmt.Sprintf("%v", i)
for _, port := range portList {
portItem := fmt.Sprintf("%v", port)
isUsedByApp := false
for _, app := range apps {
if app.HttpPort == portItem || app.HttpsPort == portItem {
@ -638,8 +613,8 @@ func checkPortUsed(ports, proto string, apps []portOfApp) string {
break
}
}
if !isUsedByApp && common.ScanPortWithProto(i, proto) {
usedPorts = append(usedPorts, fmt.Sprintf("%v", i))
if !isUsedByApp && common.ScanPortWithProto(port, proto) {
usedPorts = append(usedPorts, fmt.Sprintf("%v", port))
}
}
return strings.Join(usedPorts, ",")