diff --git a/src/Core/ServiceWrapper/Main.cs b/src/Core/ServiceWrapper/Main.cs index 4602453..931b285 100644 --- a/src/Core/ServiceWrapper/Main.cs +++ b/src/Core/ServiceWrapper/Main.cs @@ -8,6 +8,9 @@ using System.Runtime.InteropServices; using System.ServiceProcess; using System.Text; using System.Threading; +#if VNEXT +using System.Threading.Tasks; +#endif using log4net; using log4net.Appender; using log4net.Config; @@ -199,23 +202,54 @@ namespace winsw HandleFileCopies(); // handle downloads - foreach (Download d in _descriptor.Downloads) +#if VNEXT + List downloads = _descriptor.Downloads; + Task[] tasks = new Task[downloads.Count]; + for (int i = 0; i < downloads.Count; i++) { - string downloadMsg = "Downloading: " + d.From + " to " + d.To + ". failOnError=" + d.FailOnError; - LogEvent(downloadMsg); - Log.Info(downloadMsg); + Download download = downloads[i]; + string downloadMessage = $"Downloading: {download.From} to {download.To}. failOnError={download.FailOnError.ToString()}"; + LogEvent(downloadMessage); + Log.Info(downloadMessage); + tasks[i] = download.PerformAsync(); + } + + Task.WhenAll(tasks); + for (int i = 0; i < tasks.Length; i++) + { + if (tasks[i].IsFaulted) + { + Download download = downloads[i]; + string errorMessage = $"Failed to download {download.From} to {download.To}"; + AggregateException exception = tasks[i].Exception!; + LogEvent($"{errorMessage}. {exception.Message}"); + Log.Error(errorMessage, exception); + + // TODO: move this code into the download logic + if (download.FailOnError) + { + throw new IOException(errorMessage, exception); + } + } + } +#else + foreach (Download download in _descriptor.Downloads) + { + string downloadMessage = $"Downloading: {download.From} to {download.To}. failOnError={download.FailOnError.ToString()}"; + LogEvent(downloadMessage); + Log.Info(downloadMessage); try { - d.Perform(); + download.Perform(); } catch (Exception e) { - string errorMessage = "Failed to download " + d.From + " to " + d.To; - LogEvent(errorMessage + ". " + e.Message); + string errorMessage = $"Failed to download {download.From} to {download.To}"; + LogEvent($"{errorMessage}. {e.Message}"); Log.Error(errorMessage, e); // TODO: move this code into the download logic - if (d.FailOnError) + if (download.FailOnError) { throw new IOException(errorMessage, e); } @@ -223,6 +257,7 @@ namespace winsw // Else just keep going } } +#endif string? startarguments = _descriptor.Startarguments; diff --git a/src/Core/WinSWCore/Download.cs b/src/Core/WinSWCore/Download.cs index 9829995..197e7a3 100755 --- a/src/Core/WinSWCore/Download.cs +++ b/src/Core/WinSWCore/Download.cs @@ -2,6 +2,9 @@ using System; using System.IO; using System.Net; using System.Text; +#if VNEXT +using System.Threading.Tasks; +#endif using System.Xml; using winsw.Util; @@ -103,9 +106,13 @@ namespace winsw /// /// Download failure. FailOnError flag should be processed outside. /// +#if VNEXT + public async Task PerformAsync() +#else public void Perform() +#endif { - WebRequest req = WebRequest.Create(From); + WebRequest request = WebRequest.Create(From); switch (Auth) { @@ -114,43 +121,57 @@ namespace winsw break; case AuthType.sspi: - req.UseDefaultCredentials = true; - req.PreAuthenticate = true; - req.Credentials = CredentialCache.DefaultCredentials; + request.UseDefaultCredentials = true; + request.PreAuthenticate = true; + request.Credentials = CredentialCache.DefaultCredentials; break; case AuthType.basic: - SetBasicAuthHeader(req, Username!, Password!); + SetBasicAuthHeader(request, Username!, Password!); break; default: throw new WebException("Code defect. Unsupported authentication type: " + Auth); } - WebResponse rsp = req.GetResponse(); - FileStream tmpstream = new FileStream(To + ".tmp", FileMode.Create); - CopyStream(rsp.GetResponseStream(), tmpstream); - // only after we successfully downloaded a file, overwrite the existing one + string tmpFilePath = To + ".tmp"; +#if VNEXT + using (WebResponse response = await request.GetResponseAsync()) +#else + using (WebResponse response = request.GetResponse()) +#endif + using (Stream responseStream = response.GetResponseStream()) + using (FileStream tmpStream = new FileStream(tmpFilePath, FileMode.Create)) + { +#if VNEXT + await responseStream.CopyToAsync(tmpStream); +#elif NET20 + CopyStream(responseStream, tmpStream); +#else + responseStream.CopyTo(tmpStream); +#endif + } + +#if NETCOREAPP + File.Move(tmpFilePath, To, true); +#else if (File.Exists(To)) File.Delete(To); - File.Move(To + ".tmp", To); + File.Move(tmpFilePath, To); +#endif } +#if NET20 - private static void CopyStream(Stream i, Stream o) + private static void CopyStream(Stream source, Stream destination) { - byte[] buf = new byte[8192]; - while (true) + byte[] buffer = new byte[8192]; + int read; + while ((read = source.Read(buffer, 0, buffer.Length)) != 0) { - int len = i.Read(buf, 0, buf.Length); - if (len <= 0) - break; - - o.Write(buf, 0, len); + destination.Write(buffer, 0, read); } - - i.Close(); - o.Close(); } +#endif } }