using System; using System.IO; using System.Net; #if !VNEXT using System.Reflection; #endif using System.Text; #if VNEXT using System.Threading.Tasks; #endif using System.Xml; using log4net; using winsw.Util; namespace winsw { /// /// Specify the download activities prior to the launch. /// This enables self-updating services. /// public class Download { public enum AuthType { none = 0, sspi, basic } private static readonly ILog Logger = LogManager.GetLogger(typeof(Download)); public readonly string From; public readonly string To; public readonly AuthType Auth; public readonly string? Username; public readonly string? Password; public readonly bool UnsecureAuth; public readonly bool FailOnError; public readonly string? Proxy; public string ShortId => $"(download from {From})"; static Download() { #if NET461 // If your app runs on .NET Framework 4.7 or later versions, but targets an earlier version AppContext.SetSwitch("Switch.System.Net.DontEnableSystemDefaultTlsVersions", false); #elif !VNEXT // If your app runs on .NET Framework 4.6, but targets an earlier version Type.GetType("System.AppContext")?.InvokeMember("SetSwitch", BindingFlags.InvokeMethod | BindingFlags.Public | BindingFlags.Static, null, null, new object[] { "Switch.System.Net.DontEnableSchUseStrongCrypto", false }); const SecurityProtocolType Tls12 = (SecurityProtocolType)0x00000C00; const SecurityProtocolType Tls11 = (SecurityProtocolType)0x00000300; // Windows 7 and Windows Server 2008 R2 if (Environment.OSVersion.Version.Major == 6 && Environment.OSVersion.Version.Minor == 1) { try { ServicePointManager.SecurityProtocol |= Tls11 | Tls12; Logger.Info("TLS 1.1/1.2 enabled"); } catch (NotSupportedException) { Logger.Info("TLS 1.1/1.2 disabled"); } } #endif } // internal public Download( string from, string to, bool failOnError = false, AuthType auth = AuthType.none, string? username = null, string? password = null, bool unsecureAuth = false, string? proxy = null) { From = from; To = to; FailOnError = failOnError; Proxy = proxy; Auth = auth; Username = username; Password = password; UnsecureAuth = unsecureAuth; } /// /// Constructs the download setting sfrom the XML entry /// /// XML element /// The required attribute is missing or the configuration is invalid internal Download(XmlElement n) { From = XmlHelper.SingleAttribute(n, "from"); To = XmlHelper.SingleAttribute(n, "to"); // All arguments below are optional FailOnError = XmlHelper.SingleAttribute(n, "failOnError", false); Proxy = XmlHelper.SingleAttribute(n, "proxy", null); Auth = XmlHelper.EnumAttribute(n, "auth", AuthType.none); Username = XmlHelper.SingleAttribute(n, "user", null); Password = XmlHelper.SingleAttribute(n, "password", null); UnsecureAuth = XmlHelper.SingleAttribute(n, "unsecureAuth", false); if (Auth == AuthType.basic) { // Allow it only for HTTPS or for UnsecureAuth if (!From.StartsWith("https:") && !UnsecureAuth) { throw new InvalidDataException("Warning: you're sending your credentials in clear text to the server " + ShortId + "If you really want this you must enable 'unsecureAuth' in the configuration"); } // Also fail if there is no user/password if (Username is null) { throw new InvalidDataException("Basic Auth is enabled, but username is not specified " + ShortId); } if (Password is null) { throw new InvalidDataException("Basic Auth is enabled, but password is not specified " + ShortId); } } } // Source: http://stackoverflow.com/questions/2764577/forcing-basic-authentication-in-webrequest private void SetBasicAuthHeader(WebRequest request, string username, string password) { string authInfo = username + ":" + password; authInfo = Convert.ToBase64String(Encoding.GetEncoding("ISO-8859-1").GetBytes(authInfo)); request.Headers["Authorization"] = "Basic " + authInfo; } /// /// Downloads the requested file and puts it to the specified target. /// /// /// Download failure. FailOnError flag should be processed outside. /// #if VNEXT public async Task PerformAsync() #else public void Perform() #endif { WebRequest request = WebRequest.Create(From); if (!string.IsNullOrEmpty(Proxy)) { CustomProxyInformation proxyInformation = new CustomProxyInformation(Proxy); if (proxyInformation.Credentials != null) { request.Proxy = new WebProxy(proxyInformation.ServerAddress, false, null, proxyInformation.Credentials); } else { request.Proxy = new WebProxy(proxyInformation.ServerAddress); } } switch (Auth) { case AuthType.none: // Do nothing break; case AuthType.sspi: request.UseDefaultCredentials = true; request.PreAuthenticate = true; request.Credentials = CredentialCache.DefaultCredentials; break; case AuthType.basic: SetBasicAuthHeader(request, Username!, Password!); break; default: throw new WebException("Code defect. Unsupported authentication type: " + Auth); } bool supportsIfModifiedSince = false; if (request is HttpWebRequest httpRequest && File.Exists(To)) { supportsIfModifiedSince = true; httpRequest.IfModifiedSince = File.GetLastWriteTime(To); } DateTime lastModified = default; string tmpFilePath = To + ".tmp"; try { #if VNEXT using (WebResponse response = await request.GetResponseAsync()) #else using (WebResponse response = request.GetResponse()) #endif using (Stream responseStream = response.GetResponseStream()) using (FileStream tmpStream = new FileStream(tmpFilePath, FileMode.Create)) { if (supportsIfModifiedSince) { lastModified = ((HttpWebResponse)response).LastModified; } #if VNEXT await responseStream.CopyToAsync(tmpStream); #elif NET20 CopyStream(responseStream, tmpStream); #else responseStream.CopyTo(tmpStream); #endif } FileHelper.MoveOrReplaceFile(To + ".tmp", To); if (supportsIfModifiedSince) { File.SetLastWriteTime(To, lastModified); } } catch (WebException e) { if (supportsIfModifiedSince && ((HttpWebResponse)e.Response).StatusCode == HttpStatusCode.NotModified) { Logger.Info($"Skipped downloading unmodified resource '{From}'"); } else { throw; } } } #if NET20 private static void CopyStream(Stream source, Stream destination) { byte[] buffer = new byte[8192]; int read; while ((read = source.Read(buffer, 0, buffer.Length)) != 0) { destination.Write(buffer, 0, read); } } #endif } public class CustomProxyInformation { public string ServerAddress { get; set; } public NetworkCredential? Credentials { get; set; } public CustomProxyInformation(string proxy) { if (proxy.Contains("@")) { // Extract proxy credentials int credsFrom = proxy.IndexOf("://") + 3; int credsTo = proxy.LastIndexOf("@"); string completeCredsStr = proxy.Substring(credsFrom, credsTo - credsFrom); int credsSeparator = completeCredsStr.IndexOf(":"); string username = completeCredsStr.Substring(0, credsSeparator); string password = completeCredsStr.Substring(credsSeparator + 1); Credentials = new NetworkCredential(username, password); ServerAddress = proxy.Replace(completeCredsStr + "@", ""); } else { ServerAddress = proxy; } } } }