diff --git a/modules/rpc/client/client.go b/modules/rpc/client/client.go index a9c2b6e..86ca5ec 100644 --- a/modules/rpc/client/client.go +++ b/modules/rpc/client/client.go @@ -7,16 +7,21 @@ import ( "time" "errors" "github.com/ouqiang/gocron/modules/rpc/grpcpool" + "google.golang.org/grpc/codes" + "google.golang.org/grpc" ) func Exec(ip string, port int, taskReq *pb.TaskRequest) (string, error) { - addr := fmt.Sprintf("%s:%d", ip, port); + addr := fmt.Sprintf("%s:%d", ip, port) conn, err := grpcpool.Pool.Get(addr) if err != nil { return "", err } + isConnClosed := false defer func() { - grpcpool.Pool.Put(addr, conn) + if !isConnClosed { + grpcpool.Pool.Put(addr, conn) + } }() c := pb.NewTaskClient(conn) if taskReq.Timeout <= 0 || taskReq.Timeout > 86400 { @@ -27,7 +32,7 @@ func Exec(ip string, port int, taskReq *pb.TaskRequest) (string, error) { defer cancel() resp, err := c.Run(ctx, taskReq) if err != nil { - return "", err + return parseGRPCError(err, conn, &isConnClosed) } if resp.Error == "" { @@ -35,4 +40,16 @@ func Exec(ip string, port int, taskReq *pb.TaskRequest) (string, error) { } return resp.Output, errors.New(resp.Error) -} \ No newline at end of file +} + +func parseGRPCError(err error, conn *grpc.ClientConn, connClosed *bool) (string, error) { + switch grpc.Code(err) { + case codes.Unavailable: + conn.Close() + *connClosed = true + return "", errors.New("无法连接远程服务器") + case codes.DeadlineExceeded: + return "", errors.New("执行超时, 强制结束") + } + return "", err +}