mirror of https://github.com/k3s-io/k3s
799 lines
17 KiB
Go
799 lines
17 KiB
Go
|
/*
|
||
|
Copyright (c) 2014-2018 VMware, Inc. All Rights Reserved.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package soap
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/sha1"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/http/cookiejar"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/vmware/govmomi/vim25/progress"
|
||
|
"github.com/vmware/govmomi/vim25/types"
|
||
|
"github.com/vmware/govmomi/vim25/xml"
|
||
|
)
|
||
|
|
||
|
type HasFault interface {
|
||
|
Fault() *Fault
|
||
|
}
|
||
|
|
||
|
type RoundTripper interface {
|
||
|
RoundTrip(ctx context.Context, req, res HasFault) error
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
SessionCookieName = "vmware_soap_session"
|
||
|
)
|
||
|
|
||
|
type Client struct {
|
||
|
http.Client
|
||
|
|
||
|
u *url.URL
|
||
|
k bool // Named after curl's -k flag
|
||
|
d *debugContainer
|
||
|
t *http.Transport
|
||
|
|
||
|
hostsMu sync.Mutex
|
||
|
hosts map[string]string
|
||
|
|
||
|
Namespace string // Vim namespace
|
||
|
Version string // Vim version
|
||
|
UserAgent string
|
||
|
|
||
|
cookie string
|
||
|
}
|
||
|
|
||
|
var schemeMatch = regexp.MustCompile(`^\w+://`)
|
||
|
|
||
|
type errInvalidCACertificate struct {
|
||
|
File string
|
||
|
}
|
||
|
|
||
|
func (e errInvalidCACertificate) Error() string {
|
||
|
return fmt.Sprintf(
|
||
|
"invalid certificate '%s', cannot be used as a trusted CA certificate",
|
||
|
e.File,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// ParseURL is wrapper around url.Parse, where Scheme defaults to "https" and Path defaults to "/sdk"
|
||
|
func ParseURL(s string) (*url.URL, error) {
|
||
|
var err error
|
||
|
var u *url.URL
|
||
|
|
||
|
if s != "" {
|
||
|
// Default the scheme to https
|
||
|
if !schemeMatch.MatchString(s) {
|
||
|
s = "https://" + s
|
||
|
}
|
||
|
|
||
|
u, err = url.Parse(s)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Default the path to /sdk
|
||
|
if u.Path == "" {
|
||
|
u.Path = "/sdk"
|
||
|
}
|
||
|
|
||
|
if u.User == nil {
|
||
|
u.User = url.UserPassword("", "")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return u, nil
|
||
|
}
|
||
|
|
||
|
func NewClient(u *url.URL, insecure bool) *Client {
|
||
|
c := Client{
|
||
|
u: u,
|
||
|
k: insecure,
|
||
|
d: newDebug(),
|
||
|
}
|
||
|
|
||
|
// Initialize http.RoundTripper on client, so we can customize it below
|
||
|
if t, ok := http.DefaultTransport.(*http.Transport); ok {
|
||
|
c.t = &http.Transport{
|
||
|
Proxy: t.Proxy,
|
||
|
DialContext: t.DialContext,
|
||
|
MaxIdleConns: t.MaxIdleConns,
|
||
|
IdleConnTimeout: t.IdleConnTimeout,
|
||
|
TLSHandshakeTimeout: t.TLSHandshakeTimeout,
|
||
|
ExpectContinueTimeout: t.ExpectContinueTimeout,
|
||
|
}
|
||
|
} else {
|
||
|
c.t = new(http.Transport)
|
||
|
}
|
||
|
|
||
|
c.hosts = make(map[string]string)
|
||
|
c.t.TLSClientConfig = &tls.Config{InsecureSkipVerify: c.k}
|
||
|
// Don't bother setting DialTLS if InsecureSkipVerify=true
|
||
|
if !c.k {
|
||
|
c.t.DialTLS = c.dialTLS
|
||
|
}
|
||
|
|
||
|
c.Client.Transport = c.t
|
||
|
c.Client.Jar, _ = cookiejar.New(nil)
|
||
|
|
||
|
// Remove user information from a copy of the URL
|
||
|
c.u = c.URL()
|
||
|
c.u.User = nil
|
||
|
|
||
|
return &c
|
||
|
}
|
||
|
|
||
|
// NewServiceClient creates a NewClient with the given URL.Path and namespace.
|
||
|
func (c *Client) NewServiceClient(path string, namespace string) *Client {
|
||
|
vc := c.URL()
|
||
|
u, err := url.Parse(path)
|
||
|
if err != nil {
|
||
|
log.Panicf("url.Parse(%q): %s", path, err)
|
||
|
}
|
||
|
if u.Host == "" {
|
||
|
u.Scheme = vc.Scheme
|
||
|
u.Host = vc.Host
|
||
|
}
|
||
|
|
||
|
client := NewClient(u, c.k)
|
||
|
client.Namespace = "urn:" + namespace
|
||
|
if cert := c.Certificate(); cert != nil {
|
||
|
client.SetCertificate(*cert)
|
||
|
}
|
||
|
|
||
|
// Copy the trusted thumbprints
|
||
|
c.hostsMu.Lock()
|
||
|
for k, v := range c.hosts {
|
||
|
client.hosts[k] = v
|
||
|
}
|
||
|
c.hostsMu.Unlock()
|
||
|
|
||
|
// Copy the cookies
|
||
|
client.Client.Jar.SetCookies(u, c.Client.Jar.Cookies(u))
|
||
|
|
||
|
// Set SOAP Header cookie
|
||
|
for _, cookie := range client.Jar.Cookies(u) {
|
||
|
if cookie.Name == SessionCookieName {
|
||
|
client.cookie = cookie.Value
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Copy any query params (e.g. GOVMOMI_TUNNEL_PROXY_PORT used in testing)
|
||
|
client.u.RawQuery = vc.RawQuery
|
||
|
|
||
|
return client
|
||
|
}
|
||
|
|
||
|
// SetRootCAs defines the set of root certificate authorities
|
||
|
// that clients use when verifying server certificates.
|
||
|
// By default TLS uses the host's root CA set.
|
||
|
//
|
||
|
// See: http.Client.Transport.TLSClientConfig.RootCAs
|
||
|
func (c *Client) SetRootCAs(file string) error {
|
||
|
pool := x509.NewCertPool()
|
||
|
|
||
|
for _, name := range filepath.SplitList(file) {
|
||
|
pem, err := ioutil.ReadFile(name)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if ok := pool.AppendCertsFromPEM(pem); !ok {
|
||
|
return errInvalidCACertificate{
|
||
|
File: name,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
c.t.TLSClientConfig.RootCAs = pool
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Add default https port if missing
|
||
|
func hostAddr(addr string) string {
|
||
|
_, port := splitHostPort(addr)
|
||
|
if port == "" {
|
||
|
return addr + ":443"
|
||
|
}
|
||
|
return addr
|
||
|
}
|
||
|
|
||
|
// SetThumbprint sets the known certificate thumbprint for the given host.
|
||
|
// A custom DialTLS function is used to support thumbprint based verification.
|
||
|
// We first try tls.Dial with the default tls.Config, only falling back to thumbprint verification
|
||
|
// if it fails with an x509.UnknownAuthorityError or x509.HostnameError
|
||
|
//
|
||
|
// See: http.Client.Transport.DialTLS
|
||
|
func (c *Client) SetThumbprint(host string, thumbprint string) {
|
||
|
host = hostAddr(host)
|
||
|
|
||
|
c.hostsMu.Lock()
|
||
|
if thumbprint == "" {
|
||
|
delete(c.hosts, host)
|
||
|
} else {
|
||
|
c.hosts[host] = thumbprint
|
||
|
}
|
||
|
c.hostsMu.Unlock()
|
||
|
}
|
||
|
|
||
|
// Thumbprint returns the certificate thumbprint for the given host if known to this client.
|
||
|
func (c *Client) Thumbprint(host string) string {
|
||
|
host = hostAddr(host)
|
||
|
c.hostsMu.Lock()
|
||
|
defer c.hostsMu.Unlock()
|
||
|
return c.hosts[host]
|
||
|
}
|
||
|
|
||
|
// LoadThumbprints from file with the give name.
|
||
|
// If name is empty or name does not exist this function will return nil.
|
||
|
func (c *Client) LoadThumbprints(file string) error {
|
||
|
if file == "" {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
for _, name := range filepath.SplitList(file) {
|
||
|
err := c.loadThumbprints(name)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Client) loadThumbprints(name string) error {
|
||
|
f, err := os.Open(name)
|
||
|
if err != nil {
|
||
|
if os.IsNotExist(err) {
|
||
|
return nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
scanner := bufio.NewScanner(f)
|
||
|
|
||
|
for scanner.Scan() {
|
||
|
e := strings.SplitN(scanner.Text(), " ", 2)
|
||
|
if len(e) != 2 {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
c.SetThumbprint(e[0], e[1])
|
||
|
}
|
||
|
|
||
|
_ = f.Close()
|
||
|
|
||
|
return scanner.Err()
|
||
|
}
|
||
|
|
||
|
// ThumbprintSHA1 returns the thumbprint of the given cert in the same format used by the SDK and Client.SetThumbprint.
|
||
|
//
|
||
|
// See: SSLVerifyFault.Thumbprint, SessionManagerGenericServiceTicket.Thumbprint, HostConnectSpec.SslThumbprint
|
||
|
func ThumbprintSHA1(cert *x509.Certificate) string {
|
||
|
sum := sha1.Sum(cert.Raw)
|
||
|
hex := make([]string, len(sum))
|
||
|
for i, b := range sum {
|
||
|
hex[i] = fmt.Sprintf("%02X", b)
|
||
|
}
|
||
|
return strings.Join(hex, ":")
|
||
|
}
|
||
|
|
||
|
func (c *Client) dialTLS(network string, addr string) (net.Conn, error) {
|
||
|
// Would be nice if there was a tls.Config.Verify func,
|
||
|
// see tls.clientHandshakeState.doFullHandshake
|
||
|
|
||
|
conn, err := tls.Dial(network, addr, c.t.TLSClientConfig)
|
||
|
|
||
|
if err == nil {
|
||
|
return conn, nil
|
||
|
}
|
||
|
|
||
|
switch err.(type) {
|
||
|
case x509.UnknownAuthorityError:
|
||
|
case x509.HostnameError:
|
||
|
default:
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
thumbprint := c.Thumbprint(addr)
|
||
|
if thumbprint == "" {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
config := &tls.Config{InsecureSkipVerify: true}
|
||
|
conn, err = tls.Dial(network, addr, config)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
cert := conn.ConnectionState().PeerCertificates[0]
|
||
|
peer := ThumbprintSHA1(cert)
|
||
|
if thumbprint != peer {
|
||
|
_ = conn.Close()
|
||
|
|
||
|
return nil, fmt.Errorf("Host %q thumbprint does not match %q", addr, thumbprint)
|
||
|
}
|
||
|
|
||
|
return conn, nil
|
||
|
}
|
||
|
|
||
|
// splitHostPort is similar to net.SplitHostPort,
|
||
|
// but rather than return error if there isn't a ':port',
|
||
|
// return an empty string for the port.
|
||
|
func splitHostPort(host string) (string, string) {
|
||
|
ix := strings.LastIndex(host, ":")
|
||
|
|
||
|
if ix <= strings.LastIndex(host, "]") {
|
||
|
return host, ""
|
||
|
}
|
||
|
|
||
|
name := host[:ix]
|
||
|
port := host[ix+1:]
|
||
|
|
||
|
return name, port
|
||
|
}
|
||
|
|
||
|
const sdkTunnel = "sdkTunnel:8089"
|
||
|
|
||
|
func (c *Client) Certificate() *tls.Certificate {
|
||
|
certs := c.t.TLSClientConfig.Certificates
|
||
|
if len(certs) == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
return &certs[0]
|
||
|
}
|
||
|
|
||
|
func (c *Client) SetCertificate(cert tls.Certificate) {
|
||
|
t := c.Client.Transport.(*http.Transport)
|
||
|
|
||
|
// Extension or HoK certificate
|
||
|
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||
|
}
|
||
|
|
||
|
// Tunnel returns a Client configured to proxy requests through vCenter's http port 80,
|
||
|
// to the SDK tunnel virtual host. Use of the SDK tunnel is required by LoginExtensionByCertificate()
|
||
|
// and optional for other methods.
|
||
|
func (c *Client) Tunnel() *Client {
|
||
|
tunnel := c.NewServiceClient(c.u.Path, c.Namespace)
|
||
|
t := tunnel.Client.Transport.(*http.Transport)
|
||
|
// Proxy to vCenter host on port 80
|
||
|
host := tunnel.u.Hostname()
|
||
|
// Should be no reason to change the default port other than testing
|
||
|
key := "GOVMOMI_TUNNEL_PROXY_PORT"
|
||
|
|
||
|
port := tunnel.URL().Query().Get(key)
|
||
|
if port == "" {
|
||
|
port = os.Getenv(key)
|
||
|
}
|
||
|
|
||
|
if port != "" {
|
||
|
host += ":" + port
|
||
|
}
|
||
|
|
||
|
t.Proxy = http.ProxyURL(&url.URL{
|
||
|
Scheme: "http",
|
||
|
Host: host,
|
||
|
})
|
||
|
|
||
|
// Rewrite url Host to use the sdk tunnel, required for a certificate request.
|
||
|
tunnel.u.Host = sdkTunnel
|
||
|
return tunnel
|
||
|
}
|
||
|
|
||
|
func (c *Client) URL() *url.URL {
|
||
|
urlCopy := *c.u
|
||
|
return &urlCopy
|
||
|
}
|
||
|
|
||
|
type marshaledClient struct {
|
||
|
Cookies []*http.Cookie
|
||
|
URL *url.URL
|
||
|
Insecure bool
|
||
|
}
|
||
|
|
||
|
func (c *Client) MarshalJSON() ([]byte, error) {
|
||
|
m := marshaledClient{
|
||
|
Cookies: c.Jar.Cookies(c.u),
|
||
|
URL: c.u,
|
||
|
Insecure: c.k,
|
||
|
}
|
||
|
|
||
|
return json.Marshal(m)
|
||
|
}
|
||
|
|
||
|
func (c *Client) UnmarshalJSON(b []byte) error {
|
||
|
var m marshaledClient
|
||
|
|
||
|
err := json.Unmarshal(b, &m)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
*c = *NewClient(m.URL, m.Insecure)
|
||
|
c.Jar.SetCookies(m.URL, m.Cookies)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type kindContext struct{}
|
||
|
|
||
|
func (c *Client) Do(ctx context.Context, req *http.Request, f func(*http.Response) error) error {
|
||
|
if ctx == nil {
|
||
|
ctx = context.Background()
|
||
|
}
|
||
|
// Create debugging context for this round trip
|
||
|
d := c.d.newRoundTrip()
|
||
|
if d.enabled() {
|
||
|
defer d.done()
|
||
|
}
|
||
|
|
||
|
if c.UserAgent != "" {
|
||
|
req.Header.Set(`User-Agent`, c.UserAgent)
|
||
|
}
|
||
|
|
||
|
if d.enabled() {
|
||
|
d.debugRequest(req)
|
||
|
}
|
||
|
|
||
|
tstart := time.Now()
|
||
|
res, err := c.Client.Do(req.WithContext(ctx))
|
||
|
tstop := time.Now()
|
||
|
|
||
|
if d.enabled() {
|
||
|
var name string
|
||
|
if kind, ok := ctx.Value(kindContext{}).(HasFault); ok {
|
||
|
name = fmt.Sprintf("%T", kind)
|
||
|
} else {
|
||
|
name = fmt.Sprintf("%s %s", req.Method, req.URL)
|
||
|
}
|
||
|
d.logf("%6dms (%s)", tstop.Sub(tstart)/time.Millisecond, name)
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer res.Body.Close()
|
||
|
|
||
|
if d.enabled() {
|
||
|
d.debugResponse(res)
|
||
|
}
|
||
|
|
||
|
return f(res)
|
||
|
}
|
||
|
|
||
|
// Signer can be implemented by soap.Header.Security to sign requests.
|
||
|
// If the soap.Header.Security field is set to an implementation of Signer via WithHeader(),
|
||
|
// then Client.RoundTrip will call Sign() to marshal the SOAP request.
|
||
|
type Signer interface {
|
||
|
Sign(Envelope) ([]byte, error)
|
||
|
}
|
||
|
|
||
|
type headerContext struct{}
|
||
|
|
||
|
// WithHeader can be used to modify the outgoing request soap.Header fields.
|
||
|
func (c *Client) WithHeader(ctx context.Context, header Header) context.Context {
|
||
|
return context.WithValue(ctx, headerContext{}, header)
|
||
|
}
|
||
|
|
||
|
func (c *Client) RoundTrip(ctx context.Context, reqBody, resBody HasFault) error {
|
||
|
var err error
|
||
|
var b []byte
|
||
|
|
||
|
reqEnv := Envelope{Body: reqBody}
|
||
|
resEnv := Envelope{Body: resBody}
|
||
|
|
||
|
h, ok := ctx.Value(headerContext{}).(Header)
|
||
|
if !ok {
|
||
|
h = Header{}
|
||
|
}
|
||
|
|
||
|
// We added support for OperationID before soap.Header was exported.
|
||
|
if id, ok := ctx.Value(types.ID{}).(string); ok {
|
||
|
h.ID = id
|
||
|
}
|
||
|
|
||
|
h.Cookie = c.cookie
|
||
|
if h.Cookie != "" || h.ID != "" || h.Security != nil {
|
||
|
reqEnv.Header = &h // XML marshal header only if a field is set
|
||
|
}
|
||
|
|
||
|
if signer, ok := h.Security.(Signer); ok {
|
||
|
b, err = signer.Sign(reqEnv)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
b, err = xml.Marshal(reqEnv)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
rawReqBody := io.MultiReader(strings.NewReader(xml.Header), bytes.NewReader(b))
|
||
|
req, err := http.NewRequest("POST", c.u.String(), rawReqBody)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
|
||
|
req.Header.Set(`Content-Type`, `text/xml; charset="utf-8"`)
|
||
|
|
||
|
action := h.Action
|
||
|
if action == "" {
|
||
|
action = fmt.Sprintf("%s/%s", c.Namespace, c.Version)
|
||
|
}
|
||
|
req.Header.Set(`SOAPAction`, action)
|
||
|
|
||
|
return c.Do(context.WithValue(ctx, kindContext{}, resBody), req, func(res *http.Response) error {
|
||
|
switch res.StatusCode {
|
||
|
case http.StatusOK:
|
||
|
// OK
|
||
|
case http.StatusInternalServerError:
|
||
|
// Error, but typically includes a body explaining the error
|
||
|
default:
|
||
|
return errors.New(res.Status)
|
||
|
}
|
||
|
|
||
|
dec := xml.NewDecoder(res.Body)
|
||
|
dec.TypeFunc = types.TypeFunc()
|
||
|
err = dec.Decode(&resEnv)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if f := resBody.Fault(); f != nil {
|
||
|
return WrapSoapFault(f)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (c *Client) CloseIdleConnections() {
|
||
|
c.t.CloseIdleConnections()
|
||
|
}
|
||
|
|
||
|
// ParseURL wraps url.Parse to rewrite the URL.Host field
|
||
|
// In the case of VM guest uploads or NFC lease URLs, a Host
|
||
|
// field with a value of "*" is rewritten to the Client's URL.Host.
|
||
|
func (c *Client) ParseURL(urlStr string) (*url.URL, error) {
|
||
|
u, err := url.Parse(urlStr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
host, _ := splitHostPort(u.Host)
|
||
|
if host == "*" {
|
||
|
// Also use Client's port, to support port forwarding
|
||
|
u.Host = c.URL().Host
|
||
|
}
|
||
|
|
||
|
return u, nil
|
||
|
}
|
||
|
|
||
|
type Upload struct {
|
||
|
Type string
|
||
|
Method string
|
||
|
ContentLength int64
|
||
|
Headers map[string]string
|
||
|
Ticket *http.Cookie
|
||
|
Progress progress.Sinker
|
||
|
}
|
||
|
|
||
|
var DefaultUpload = Upload{
|
||
|
Type: "application/octet-stream",
|
||
|
Method: "PUT",
|
||
|
}
|
||
|
|
||
|
// Upload PUTs the local file to the given URL
|
||
|
func (c *Client) Upload(ctx context.Context, f io.Reader, u *url.URL, param *Upload) error {
|
||
|
var err error
|
||
|
|
||
|
if param.Progress != nil {
|
||
|
pr := progress.NewReader(ctx, param.Progress, f, param.ContentLength)
|
||
|
f = pr
|
||
|
|
||
|
// Mark progress reader as done when returning from this function.
|
||
|
defer func() {
|
||
|
pr.Done(err)
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
req, err := http.NewRequest(param.Method, u.String(), f)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
req = req.WithContext(ctx)
|
||
|
|
||
|
req.ContentLength = param.ContentLength
|
||
|
req.Header.Set("Content-Type", param.Type)
|
||
|
|
||
|
for k, v := range param.Headers {
|
||
|
req.Header.Add(k, v)
|
||
|
}
|
||
|
|
||
|
if param.Ticket != nil {
|
||
|
req.AddCookie(param.Ticket)
|
||
|
}
|
||
|
|
||
|
res, err := c.Client.Do(req)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer res.Body.Close()
|
||
|
|
||
|
switch res.StatusCode {
|
||
|
case http.StatusOK:
|
||
|
case http.StatusCreated:
|
||
|
default:
|
||
|
err = errors.New(res.Status)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// UploadFile PUTs the local file to the given URL
|
||
|
func (c *Client) UploadFile(ctx context.Context, file string, u *url.URL, param *Upload) error {
|
||
|
if param == nil {
|
||
|
p := DefaultUpload // Copy since we set ContentLength
|
||
|
param = &p
|
||
|
}
|
||
|
|
||
|
s, err := os.Stat(file)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
f, err := os.Open(file)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer f.Close()
|
||
|
|
||
|
param.ContentLength = s.Size()
|
||
|
|
||
|
return c.Upload(ctx, f, u, param)
|
||
|
}
|
||
|
|
||
|
type Download struct {
|
||
|
Method string
|
||
|
Headers map[string]string
|
||
|
Ticket *http.Cookie
|
||
|
Progress progress.Sinker
|
||
|
Writer io.Writer
|
||
|
}
|
||
|
|
||
|
var DefaultDownload = Download{
|
||
|
Method: "GET",
|
||
|
}
|
||
|
|
||
|
// DownloadRequest wraps http.Client.Do, returning the http.Response without checking its StatusCode
|
||
|
func (c *Client) DownloadRequest(ctx context.Context, u *url.URL, param *Download) (*http.Response, error) {
|
||
|
req, err := http.NewRequest(param.Method, u.String(), nil)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
req = req.WithContext(ctx)
|
||
|
|
||
|
for k, v := range param.Headers {
|
||
|
req.Header.Add(k, v)
|
||
|
}
|
||
|
|
||
|
if param.Ticket != nil {
|
||
|
req.AddCookie(param.Ticket)
|
||
|
}
|
||
|
|
||
|
return c.Client.Do(req)
|
||
|
}
|
||
|
|
||
|
// Download GETs the remote file from the given URL
|
||
|
func (c *Client) Download(ctx context.Context, u *url.URL, param *Download) (io.ReadCloser, int64, error) {
|
||
|
res, err := c.DownloadRequest(ctx, u, param)
|
||
|
if err != nil {
|
||
|
return nil, 0, err
|
||
|
}
|
||
|
|
||
|
switch res.StatusCode {
|
||
|
case http.StatusOK:
|
||
|
default:
|
||
|
err = errors.New(res.Status)
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, 0, err
|
||
|
}
|
||
|
|
||
|
r := res.Body
|
||
|
|
||
|
return r, res.ContentLength, nil
|
||
|
}
|
||
|
|
||
|
func (c *Client) WriteFile(ctx context.Context, file string, src io.Reader, size int64, s progress.Sinker, w io.Writer) error {
|
||
|
var err error
|
||
|
|
||
|
r := src
|
||
|
|
||
|
fh, err := os.Create(file)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if s != nil {
|
||
|
pr := progress.NewReader(ctx, s, src, size)
|
||
|
src = pr
|
||
|
|
||
|
// Mark progress reader as done when returning from this function.
|
||
|
defer func() {
|
||
|
pr.Done(err)
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
if w == nil {
|
||
|
w = fh
|
||
|
} else {
|
||
|
w = io.MultiWriter(w, fh)
|
||
|
}
|
||
|
|
||
|
_, err = io.Copy(w, r)
|
||
|
|
||
|
cerr := fh.Close()
|
||
|
|
||
|
if err == nil {
|
||
|
err = cerr
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// DownloadFile GETs the given URL to a local file
|
||
|
func (c *Client) DownloadFile(ctx context.Context, file string, u *url.URL, param *Download) error {
|
||
|
var err error
|
||
|
if param == nil {
|
||
|
param = &DefaultDownload
|
||
|
}
|
||
|
|
||
|
rc, contentLength, err := c.Download(ctx, u, param)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return c.WriteFile(ctx, file, rc, contentLength, param.Progress, param.Writer)
|
||
|
}
|