diff --git a/src/WinSW.Core/AssemblyInfo.cs b/src/WinSW.Core/AssemblyInfo.cs index 3e8e4d2..3b0a125 100644 --- a/src/WinSW.Core/AssemblyInfo.cs +++ b/src/WinSW.Core/AssemblyInfo.cs @@ -1,3 +1,4 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("WinSW")] +[assembly: InternalsVisibleTo("WinSW.Tests")] diff --git a/src/WinSW.Core/Native/Credentials.cs b/src/WinSW.Core/Native/Credentials.cs new file mode 100644 index 0000000..025ac0e --- /dev/null +++ b/src/WinSW.Core/Native/Credentials.cs @@ -0,0 +1,103 @@ +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; +using static WinSW.Native.CredentialApis; + +namespace WinSW.Native +{ + internal static class Credentials + { + internal static void PropmtForCredentialsDialog(ref string? userName, ref string? password, string caption, string message) + { + userName ??= string.Empty; + password ??= string.Empty; + + int inBufferSize = 0; + _ = CredPackAuthenticationBuffer( + 0, + userName, + password, + IntPtr.Zero, + ref inBufferSize); + + IntPtr inBuffer = Marshal.AllocCoTaskMem(inBufferSize); + try + { + if (!CredPackAuthenticationBuffer( + 0, + userName, + password, + inBuffer, + ref inBufferSize)) + { + Throw.Command.Win32Exception("Failed to pack auth buffer."); + } + + CREDUI_INFO info = new CREDUI_INFO + { + Size = Marshal.SizeOf(typeof(CREDUI_INFO)), + CaptionText = caption, + MessageText = message, + }; + uint authPackage = 0; + bool save = false; + int error = CredUIPromptForWindowsCredentials( + info, + 0, + ref authPackage, + inBuffer, + inBufferSize, + out IntPtr outBuffer, + out uint outBufferSize, + ref save, + CREDUIWIN_GENERIC); + + if (error != Errors.ERROR_SUCCESS) + { + throw new Win32Exception(error); + } + + try + { + int userNameLength = 0; + int passwordLength = 0; + _ = CredUnPackAuthenticationBuffer( + 0, + outBuffer, + outBufferSize, + null, + ref userNameLength, + default, + default, + null, + ref passwordLength); + + userName = userNameLength == 0 ? null : new string('\0', userNameLength - 1); + password = passwordLength == 0 ? null : new string('\0', passwordLength - 1); + + if (!CredUnPackAuthenticationBuffer( + 0, + outBuffer, + outBufferSize, + userName, + ref userNameLength, + default, + default, + password, + ref passwordLength)) + { + Throw.Command.Win32Exception("Failed to unpack auth buffer."); + } + } + finally + { + Marshal.FreeCoTaskMem(outBuffer); + } + } + finally + { + Marshal.FreeCoTaskMem(inBuffer); + } + } + } +} diff --git a/src/WinSW.Core/Native/Service.cs b/src/WinSW.Core/Native/Service.cs index 41ccf4b..6c4d895 100644 --- a/src/WinSW.Core/Native/Service.cs +++ b/src/WinSW.Core/Native/Service.cs @@ -76,24 +76,6 @@ namespace WinSW.Native string? username, string? password) { - int arrayLength = 1; - for (int i = 0; i < dependencies.Length; i++) - { - arrayLength += dependencies[i].Length + 1; - } - - StringBuilder? array = null; - if (dependencies.Length != 0) - { - array = new StringBuilder(arrayLength); - for (int i = 0; i < dependencies.Length; i++) - { - _ = array.Append(dependencies[i]).Append('\0'); - } - - _ = array.Append('\0'); - } - IntPtr handle = ServiceApis.CreateService( this.handle, serviceName, @@ -105,7 +87,7 @@ namespace WinSW.Native executablePath, default, default, - array, + Service.GetNativeDependencies(dependencies), username, password); if (handle == IntPtr.Zero) @@ -171,6 +153,29 @@ namespace WinSW.Native } } + internal static StringBuilder? GetNativeDependencies(string[] dependencies) + { + int arrayLength = 1; + for (int i = 0; i < dependencies.Length; i++) + { + arrayLength += dependencies[i].Length + 1; + } + + StringBuilder? array = null; + if (dependencies.Length != 0) + { + array = new StringBuilder(arrayLength); + for (int i = 0; i < dependencies.Length; i++) + { + _ = array.Append(dependencies[i]).Append('\0'); + } + + _ = array.Append('\0'); + } + + return array; + } + /// internal void SetStatus(IntPtr statusHandle, ServiceControllerStatus state) { @@ -192,7 +197,8 @@ namespace WinSW.Native /// internal void ChangeConfig( string displayName, - ServiceStartMode startMode) + ServiceStartMode startMode, + string[] dependencies) { if (!ChangeServiceConfig( this.handle, @@ -202,7 +208,7 @@ namespace WinSW.Native null, null, IntPtr.Zero, - null, + GetNativeDependencies(dependencies), null, null, displayName)) diff --git a/src/WinSW.Core/ServiceDescriptor.cs b/src/WinSW.Core/ServiceDescriptor.cs index 99b0e33..14ca121 100644 --- a/src/WinSW.Core/ServiceDescriptor.cs +++ b/src/WinSW.Core/ServiceDescriptor.cs @@ -21,6 +21,8 @@ namespace WinSW private readonly Dictionary environmentVariables; + internal static ServiceDescriptor? TestDescriptor; + public static DefaultWinSWSettings Defaults { get; } = new DefaultWinSWSettings(); /// @@ -42,34 +44,17 @@ namespace WinSW public ServiceDescriptor() { - // find co-located configuration xml. We search up to the ancestor directories to simplify debugging, - // as well as trimming off ".vshost" suffix (which is used during debugging) - // Get the first parent to go into the recursive loop - string p = this.ExecutablePath; - string baseName = Path.GetFileNameWithoutExtension(p); - if (baseName.EndsWith(".vshost")) + string path = this.ExecutablePath; + string baseName = Path.GetFileNameWithoutExtension(path); + string baseDir = Path.GetDirectoryName(path)!; + + if (!File.Exists(Path.Combine(baseDir, baseName + ".xml"))) { - baseName = baseName.Substring(0, baseName.Length - 7); - } - - DirectoryInfo d = new DirectoryInfo(Path.GetDirectoryName(p)); - while (true) - { - if (File.Exists(Path.Combine(d.FullName, baseName + ".xml"))) - { - break; - } - - if (d.Parent is null) - { - throw new FileNotFoundException("Unable to locate " + baseName + ".xml file within executable directory or any parents"); - } - - d = d.Parent; + throw new FileNotFoundException("Unable to locate " + baseName + ".xml file within executable directory"); } this.BaseName = baseName; - this.BasePath = Path.Combine(d.FullName, this.BaseName); + this.BasePath = Path.Combine(baseDir, baseName); try { @@ -81,7 +66,45 @@ namespace WinSW } // register the base directory as environment variable so that future expansions can refer to this. - Environment.SetEnvironmentVariable("BASE", d.FullName); + Environment.SetEnvironmentVariable("BASE", baseDir); + + // ditto for ID + Environment.SetEnvironmentVariable("SERVICE_ID", this.Id); + + // New name + Environment.SetEnvironmentVariable(WinSWSystem.EnvVarNameExecutablePath, this.ExecutablePath); + + // Also inject system environment variables + Environment.SetEnvironmentVariable(WinSWSystem.EnvVarNameServiceId, this.Id); + + this.environmentVariables = this.LoadEnvironmentVariables(); + } + + /// + public ServiceDescriptor(string path) + { + if (!File.Exists(path)) + { + throw new FileNotFoundException(null, path); + } + + string baseName = Path.GetFileNameWithoutExtension(path); + string baseDir = Path.GetDirectoryName(Path.GetFullPath(path))!; + + this.BaseName = baseName; + this.BasePath = Path.Combine(baseDir, baseName); + + try + { + this.dom.Load(path); + } + catch (XmlException e) + { + throw new InvalidDataException(e.Message, e); + } + + // register the base directory as environment variable so that future expansions can refer to this. + Environment.SetEnvironmentVariable("BASE", baseDir); // ditto for ID Environment.SetEnvironmentVariable("SERVICE_ID", this.Id); @@ -107,6 +130,11 @@ namespace WinSW this.environmentVariables = this.LoadEnvironmentVariables(); } + internal static ServiceDescriptor Create(string? path) + { + return path != null ? new ServiceDescriptor(path) : TestDescriptor ?? new ServiceDescriptor(); + } + public static ServiceDescriptor FromXml(string xml) { var dom = new XmlDocument(); diff --git a/src/WinSW.Tests/Attributes/ElevatedFactAttribute.cs b/src/WinSW.Tests/Attributes/ElevatedFactAttribute.cs index ab4cbe8..41aef3a 100644 --- a/src/WinSW.Tests/Attributes/ElevatedFactAttribute.cs +++ b/src/WinSW.Tests/Attributes/ElevatedFactAttribute.cs @@ -4,9 +4,9 @@ using Xunit; namespace WinSW.Tests { [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] - internal sealed class ElevatedFactAttribute : FactAttribute + public sealed class ElevatedFactAttribute : FactAttribute { - internal ElevatedFactAttribute() + public ElevatedFactAttribute() { if (!Program.IsProcessElevated()) { diff --git a/src/WinSW.Tests/MainTest.cs b/src/WinSW.Tests/MainTest.cs index 1acc901..95add34 100644 --- a/src/WinSW.Tests/MainTest.cs +++ b/src/WinSW.Tests/MainTest.cs @@ -12,10 +12,10 @@ namespace WinSW.Tests { try { - _ = CLITestHelper.CLITest(new[] { "install" }); + _ = CommandLineTestHelper.Test(new[] { "install" }); - using ServiceController controller = new ServiceController(CLITestHelper.Id); - Assert.Equal(CLITestHelper.Name, controller.DisplayName); + using ServiceController controller = new ServiceController(CommandLineTestHelper.Id); + Assert.Equal(CommandLineTestHelper.Name, controller.DisplayName); Assert.False(controller.CanStop); Assert.False(controller.CanShutdown); Assert.False(controller.CanPauseAndContinue); @@ -24,42 +24,18 @@ namespace WinSW.Tests } finally { - _ = CLITestHelper.CLITest(new[] { "uninstall" }); + _ = CommandLineTestHelper.Test(new[] { "uninstall" }); } } [Fact] - public void PrintVersion() + public void FailOnUnknownCommand() { - string expectedVersion = WrapperService.Version.ToString(); - string cliOut = CLITestHelper.CLITest(new[] { "version" }); - Assert.Contains(expectedVersion, cliOut); - } + const string commandName = "unknown"; - [Fact] - public void PrintHelp() - { - string expectedVersion = WrapperService.Version.ToString(); - string cliOut = CLITestHelper.CLITest(new[] { "help" }); + CommandLineTestResult result = CommandLineTestHelper.ErrorTest(new[] { commandName }); - Assert.Contains(expectedVersion, cliOut); - Assert.Contains("start", cliOut); - Assert.Contains("help", cliOut); - Assert.Contains("version", cliOut); - - // TODO: check all commands after the migration of ccommands to enum - } - - [Fact] - public void FailOnUnsupportedCommand() - { - const string commandName = "nonExistentCommand"; - string expectedMessage = "Unknown command: " + commandName; - CLITestResult result = CLITestHelper.CLIErrorTest(new[] { commandName }); - - Assert.True(result.HasException); - Assert.Contains(expectedMessage, result.Out); - Assert.Contains(expectedMessage, result.Exception.Message); + Assert.Equal($"Unrecognized command or argument '{commandName}'\r\n\r\n", result.Error); } /// @@ -68,7 +44,7 @@ namespace WinSW.Tests [Fact] public void ShouldNotPrintLogsForStatusCommand() { - string cliOut = CLITestHelper.CLITest(new[] { "status" }); + string cliOut = CommandLineTestHelper.Test(new[] { "status" }); Assert.Equal("NonExistent" + Environment.NewLine, cliOut); } } diff --git a/src/WinSW.Tests/Util/CLITestHelper.cs b/src/WinSW.Tests/Util/CommandLineTestHelper.cs similarity index 62% rename from src/WinSW.Tests/Util/CLITestHelper.cs rename to src/WinSW.Tests/Util/CommandLineTestHelper.cs index 65cc55b..8a91fdd 100644 --- a/src/WinSW.Tests/Util/CLITestHelper.cs +++ b/src/WinSW.Tests/Util/CommandLineTestHelper.cs @@ -7,7 +7,7 @@ namespace WinSW.Tests.Util /// /// Helper for WinSW CLI testing /// - public static class CLITestHelper + public static class CommandLineTestHelper { public const string Id = "WinSW.Tests"; public const string Name = "WinSW Test Service"; @@ -33,28 +33,29 @@ $@" /// Optional Service descriptor (will be used for initializationpurposes) /// STDOUT if there's no exceptions /// Command failure - public static string CLITest(string[] arguments, ServiceDescriptor descriptor = null) + public static string Test(string[] arguments, ServiceDescriptor descriptor = null) { TextWriter tmpOut = Console.Out; - TextWriter tmpErr = Console.Error; + TextWriter tmpError = Console.Error; using StringWriter swOut = new StringWriter(); - using StringWriter swErr = new StringWriter(); + using StringWriter swError = new StringWriter(); Console.SetOut(swOut); - Console.SetError(swErr); + Console.SetError(swError); + ServiceDescriptor.TestDescriptor = descriptor ?? DefaultServiceDescriptor; try { - Program.Run(arguments, descriptor ?? DefaultServiceDescriptor); + _ = Program.Run(arguments); } finally { Console.SetOut(tmpOut); - Console.SetError(tmpErr); + Console.SetError(tmpError); + ServiceDescriptor.TestDescriptor = null; } - Assert.Equal(0, swErr.GetStringBuilder().Length); - Console.Write(swOut.ToString()); + Assert.Equal(string.Empty, swError.ToString()); return swOut.ToString(); } @@ -64,49 +65,44 @@ $@" /// CLI arguments to be passed /// Optional Service descriptor (will be used for initializationpurposes) /// Test results - public static CLITestResult CLIErrorTest(string[] arguments, ServiceDescriptor descriptor = null) + public static CommandLineTestResult ErrorTest(string[] arguments, ServiceDescriptor descriptor = null) { - Exception testEx = null; + Exception exception = null; + TextWriter tmpOut = Console.Out; - TextWriter tmpErr = Console.Error; + TextWriter tmpError = Console.Error; using StringWriter swOut = new StringWriter(); - using StringWriter swErr = new StringWriter(); + using StringWriter swError = new StringWriter(); Console.SetOut(swOut); - Console.SetError(swErr); + Console.SetError(swError); + ServiceDescriptor.TestDescriptor = descriptor ?? DefaultServiceDescriptor; + Program.TestExceptionHandler = (e, _) => exception = e; try { - Program.Run(arguments, descriptor ?? DefaultServiceDescriptor); + _ = Program.Run(arguments); } - catch (Exception ex) + catch (Exception e) { - testEx = ex; + exception = e; } finally { Console.SetOut(tmpOut); - Console.SetError(tmpErr); + Console.SetError(tmpError); + ServiceDescriptor.TestDescriptor = null; + Program.TestExceptionHandler = null; } - Console.WriteLine("\n>>> Output: "); - Console.Write(swOut.ToString()); - Console.WriteLine("\n>>> Error: "); - Console.Write(swErr.ToString()); - if (testEx != null) - { - Console.WriteLine("\n>>> Exception: "); - Console.WriteLine(testEx); - } - - return new CLITestResult(swOut.ToString(), swErr.ToString(), testEx); + return new CommandLineTestResult(swOut.ToString(), swError.ToString(), exception); } } /// /// Aggregated test report /// - public class CLITestResult + public class CommandLineTestResult { public string Out { get; } @@ -116,7 +112,7 @@ $@" public bool HasException => this.Exception != null; - public CLITestResult(string output, string error, Exception exception = null) + public CommandLineTestResult(string output, string error, Exception exception = null) { this.Out = output; this.Error = error; diff --git a/src/WinSW/Program.cs b/src/WinSW/Program.cs index 22e05b5..1e20936 100644 --- a/src/WinSW/Program.cs +++ b/src/WinSW/Program.cs @@ -1,9 +1,16 @@ using System; using System.Collections.Generic; +using System.CommandLine; +using System.CommandLine.Builder; +using System.CommandLine.Invocation; +using System.CommandLine.IO; +using System.CommandLine.Parsing; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Linq; +using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Security.AccessControl; @@ -11,6 +18,7 @@ using System.Security.Principal; using System.ServiceProcess; using System.Text; using System.Threading; +using System.Threading.Tasks; using log4net; using log4net.Appender; using log4net.Config; @@ -18,94 +26,36 @@ using log4net.Core; using log4net.Layout; using WinSW.Logging; using WinSW.Native; +using Process = System.Diagnostics.Process; +using TimeoutException = System.ServiceProcess.TimeoutException; namespace WinSW { + // NOTE: Keep description strings in sync with docs. public static class Program { private static readonly ILog Log = LogManager.GetLogger(typeof(Program)); - public static int Main(string[] args) + internal static Action? TestExceptionHandler; + + private static int Main(string[] args) { - try - { - Run(args); - Log.Debug("Completed. Exit code is 0"); - return 0; - } - catch (InvalidDataException e) - { - string message = "The configuration file cound not be loaded. " + e.Message; - Log.Fatal(message, e); - Console.Error.WriteLine(message); - return -1; - } - catch (CommandException e) - { - string message = e.Message; - Log.Fatal(message); - Console.Error.WriteLine(message); - return e.InnerException is Win32Exception inner ? inner.NativeErrorCode : -1; - } - catch (InvalidOperationException e) when (e.InnerException is Win32Exception inner) - { - string message = e.Message; - Log.Fatal(message, e); - Console.Error.WriteLine(message); - return inner.NativeErrorCode; - } - catch (Win32Exception e) - { - string message = e.Message; - Log.Fatal(message, e); - Console.Error.WriteLine(message); - return e.NativeErrorCode; - } - catch (Exception e) - { - Log.Fatal("Unhandled exception", e); - Console.Error.WriteLine(e); - return -1; - } + int exitCode = Run(args); + Log.Debug("Completed. Exit code is " + exitCode); + return exitCode; } - public static void Run(string[] argsArray, ServiceDescriptor? descriptor = null) + internal static int Run(string[] args) { - bool inConsoleMode = argsArray.Length > 0; - - // If descriptor is not specified, initialize the new one (and load configs from there) - descriptor ??= new ServiceDescriptor(); - - // Configure the wrapper-internal logging. - // STDOUT and STDERR of the child process will be handled independently. - InitLoggers(descriptor, inConsoleMode); - - if (!inConsoleMode) - { - Log.Debug("Starting WinSW in service mode"); - ServiceBase.Run(new WrapperService(descriptor)); - return; - } - - Log.Debug("Starting WinSW in console mode"); - - if (argsArray.Length == 0) - { - PrintHelp(); - return; - } - - var args = new List(argsArray); - bool elevated; - if (args[0] == "/elevated") + if (args[0] == "--elevated") { elevated = true; _ = ConsoleApis.FreeConsole(); _ = ConsoleApis.AttachConsole(ConsoleApis.ATTACH_PARENT_PROCESS); - args = args.GetRange(1, args.Count - 1); + args = new List(args).GetRange(1, args.Length - 1).ToArray(); } else if (Environment.OSVersion.Version.Major == 5) { @@ -117,75 +67,265 @@ namespace WinSW elevated = IsProcessElevated(); } - switch (args[0].ToLower()) + var root = new RootCommand("A wrapper binary that can be used to host executables as Windows services. https://github.com/winsw/winsw") { - case "install": - Install(); - return; + Handler = CommandHandler.Create((string? pathToConfig) => + { + ServiceDescriptor descriptor; + try + { + descriptor = ServiceDescriptor.Create(pathToConfig); + } + catch (FileNotFoundException) + { + throw new CommandException("The specified command or file was not found."); + } - case "uninstall": - Uninstall(); - return; + InitLoggers(descriptor, enableConsoleLogging: false); - case "start": - Start(); - return; + Log.Debug("Starting WinSW in service mode"); + ServiceBase.Run(new WrapperService(descriptor)); + }), + }; - case "stop": - Stop(); - return; - - case "stopwait": - StopWait(); - return; - - case "restart": - Restart(); - return; - - case "restart!": - RestartSelf(); - return; - - case "status": - Status(); - return; - - case "test": - Test(); - return; - - case "testwait": - TestWait(); - return; - - case "refresh": - Refresh(); - return; - - case "help": - case "--help": - case "-h": - case "-?": - case "/?": - PrintHelp(); - return; - - case "version": - PrintVersion(); - return; - - default: - Console.WriteLine("Unknown command: " + args[0]); - PrintAvailableCommands(); - throw new Exception("Unknown command: " + args[0]); + using (WindowsIdentity identity = WindowsIdentity.GetCurrent()) + { + WindowsPrincipal principal = new WindowsPrincipal(identity); + if (principal.IsInRole(new SecurityIdentifier(WellKnownSidType.ServiceSid, null)) || + principal.IsInRole(new SecurityIdentifier(WellKnownSidType.LocalSystemSid, null)) || + principal.IsInRole(new SecurityIdentifier(WellKnownSidType.LocalServiceSid, null)) || + principal.IsInRole(new SecurityIdentifier(WellKnownSidType.NetworkServiceSid, null))) + { + root.Add(new Argument("path-to-config") + { + Arity = ArgumentArity.ZeroOrOne, + IsHidden = true, + }); + } } - void Install() + var config = new Argument("path-to-config", "The path to the configuration file.") { + Arity = ArgumentArity.ZeroOrOne, + }; + + var noElevate = new Option("--no-elevate", "Doesn't automatically trigger a UAC prompt."); + + { + var install = new Command("install", "Installs the service.") + { + Handler = CommandHandler.Create(Install), + }; + + install.Add(config); + install.Add(noElevate); + install.Add(new Option(new[] { "--username", "--user" }, "Specifies the user name of the service account.")); + install.Add(new Option(new[] { "--password", "--pass" }, "Specifies the password of the service account.")); + + root.Add(install); + } + + { + var uninstall = new Command("uninstall", "Uninstalls the service.") + { + Handler = CommandHandler.Create(Uninstall), + }; + + uninstall.Add(config); + uninstall.Add(noElevate); + + root.Add(uninstall); + } + + { + var start = new Command("start", "Starts the service.") + { + Handler = CommandHandler.Create(Start), + }; + + start.Add(config); + start.Add(noElevate); + + root.Add(start); + } + + { + var stop = new Command("stop", "Stops the service.") + { + Handler = CommandHandler.Create(Stop), + }; + + stop.Add(config); + stop.Add(noElevate); + stop.Add(new Option("--no-wait", "Doesn't wait for the service to actually stop.")); + stop.Add(new Option("--force", "Stops the service even if it has started dependent services.")); + + root.Add(stop); + } + + { + var restart = new Command("restart", "Stops and then starts the service.") + { + Handler = CommandHandler.Create(Restart), + }; + + restart.Add(config); + restart.Add(noElevate); + restart.Add(new Option("--force", "Restarts the service even if it has started dependent services.")); + + root.Add(restart); + } + + { + var restartSelf = new Command("restart!", "self-restart (can be called from child processes)") + { + Handler = CommandHandler.Create(RestartSelf), + }; + + restartSelf.Add(config); + + root.Add(restartSelf); + } + + { + var status = new Command("status", "Checks the status of the service.") + { + Handler = CommandHandler.Create(Status), + }; + + status.Add(config); + + root.Add(status); + } + + { + var test = new Command("test", "Checks if the service can be started and then stopped without installation.") + { + Handler = CommandHandler.Create(Test), + }; + + test.Add(config); + test.Add(noElevate); + + const int minTimeout = -1; + const int maxTimeout = int.MaxValue / 1000; + + var timeout = new Option("--timeout", "Specifies the number of seconds to wait before the service is stopped."); + timeout.Argument.AddValidator(argument => + { + string token = argument.Tokens.Single().Value; + return !int.TryParse(token, out int value) ? null : + value < minTimeout ? $"Argument '{token}' must be greater than or equal to {minTimeout}." : + value > maxTimeout ? $"Argument '{token}' must be less than or equal to {maxTimeout}." : + null; + }); + + test.Add(timeout); + test.Add(new Option("--no-break", "Ignores keystrokes.")); + + root.Add(test); + } + + { + var refresh = new Command("refresh", "Refreshes the service properties without reinstallation.") + { + Handler = CommandHandler.Create(Refresh), + }; + + refresh.Add(config); + refresh.Add(noElevate); + + root.Add(refresh); + } + + return new CommandLineBuilder(root) + + // see UseDefaults + .UseVersionOption() + .UseHelp() + /* .UseEnvironmentVariableDirective() */ + .UseParseDirective() + .UseDebugDirective() + .UseSuggestDirective() + .RegisterWithDotnetSuggest() + .UseTypoCorrections() + .UseParseErrorReporting() + .UseExceptionHandler(TestExceptionHandler ?? OnException) + .CancelOnProcessTermination() + .Build() + .Invoke(args); + + static void OnException(Exception exception, InvocationContext context) + { + Console.ForegroundColor = ConsoleColor.Red; + try + { + IStandardStreamWriter error = context.Console.Error; + + Debug.Assert(exception is TargetInvocationException); + Debug.Assert(exception.InnerException != null); + exception = exception.InnerException!; + switch (exception) + { + case InvalidDataException e: + { + string message = "The configuration file cound not be loaded. " + e.Message; + Log.Fatal(message, e); + error.WriteLine(message); + context.ResultCode = -1; + break; + } + + case CommandException e: + { + string message = e.Message; + Log.Fatal(message); + error.WriteLine(message); + context.ResultCode = e.InnerException is Win32Exception inner ? inner.NativeErrorCode : -1; + break; + } + + case InvalidOperationException e when e.InnerException is Win32Exception inner: + { + string message = e.Message; + Log.Fatal(message, e); + error.WriteLine(message); + context.ResultCode = inner.NativeErrorCode; + break; + } + + case Win32Exception e: + { + string message = e.Message; + Log.Fatal(message, e); + error.WriteLine(message); + context.ResultCode = e.NativeErrorCode; + break; + } + + default: + { + Log.Fatal("Unhandled exception", exception); + error.WriteLine(exception.ToString()); + context.ResultCode = -1; + break; + } + } + } + finally + { + Console.ResetColor(); + } + } + + void Install(string? pathToConfig, bool noElevate, string? username, string? password) + { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } @@ -200,36 +340,21 @@ namespace WinSW throw new CommandException("Installation failure: Service with id '" + descriptor.Id + "' already exists"); } - string? username = null; - string? password = null; - bool allowServiceLogonRight = false; - if (args.Count > 1 && args[1] == "/p") + if (descriptor.HasServiceAccount()) { - Console.Write("Username: "); - username = Console.ReadLine(); - Console.Write("Password: "); - password = ReadPassword(); - Console.WriteLine(); - Console.Write("Set Account rights to allow log on as a service (y/n)?: "); - var keypressed = Console.ReadKey(); - Console.WriteLine(); - if (keypressed.Key == ConsoleKey.Y) - { - allowServiceLogonRight = true; - } - } - else if (descriptor.HasServiceAccount()) - { - username = descriptor.ServiceAccountUserName; - password = descriptor.ServiceAccountPassword; - allowServiceLogonRight = descriptor.AllowServiceAcountLogonRight; + username = descriptor.ServiceAccountUserName ?? username; + password = descriptor.ServiceAccountPassword ?? password; if (username is null || password is null) { switch (descriptor.ServiceAccountPrompt) { case "dialog": - PropmtForCredentialsDialog(); + Credentials.PropmtForCredentialsDialog( + ref username, + ref password, + "Windows Service Wrapper", + "service account credentials"); // TODO break; case "console": @@ -239,16 +364,16 @@ namespace WinSW } } - if (allowServiceLogonRight) + if (username != null) { - Security.AddServiceLogonRight(descriptor.ServiceAccountUserName!); + Security.AddServiceLogonRight(username); } using Service sc = scm.CreateService( descriptor.Id, descriptor.Caption, descriptor.StartMode, - "\"" + descriptor.ExecutablePath + "\"", + "\"" + descriptor.ExecutablePath + "\"" + (pathToConfig != null ? " \"" + Path.GetFullPath(pathToConfig) + "\"" : null), descriptor.ServiceDependencies, username, password); @@ -280,99 +405,6 @@ namespace WinSW EventLog.CreateEventSource(eventLogSource, "Application"); } - void PropmtForCredentialsDialog() - { - username ??= string.Empty; - password ??= string.Empty; - - int inBufferSize = 0; - _ = CredentialApis.CredPackAuthenticationBuffer( - 0, - username, - password, - IntPtr.Zero, - ref inBufferSize); - - IntPtr inBuffer = Marshal.AllocCoTaskMem(inBufferSize); - try - { - if (!CredentialApis.CredPackAuthenticationBuffer( - 0, - username, - password, - inBuffer, - ref inBufferSize)) - { - Throw.Command.Win32Exception("Failed to pack auth buffer."); - } - - CredentialApis.CREDUI_INFO info = new CredentialApis.CREDUI_INFO - { - Size = Marshal.SizeOf(typeof(CredentialApis.CREDUI_INFO)), - CaptionText = "Windows Service Wrapper", // TODO - MessageText = "service account credentials", // TODO - }; - uint authPackage = 0; - bool save = false; - int error = CredentialApis.CredUIPromptForWindowsCredentials( - info, - 0, - ref authPackage, - inBuffer, - inBufferSize, - out IntPtr outBuffer, - out uint outBufferSize, - ref save, - CredentialApis.CREDUIWIN_GENERIC); - - if (error != Errors.ERROR_SUCCESS) - { - throw new Win32Exception(error); - } - - try - { - int userNameLength = 0; - int passwordLength = 0; - _ = CredentialApis.CredUnPackAuthenticationBuffer( - 0, - outBuffer, - outBufferSize, - null, - ref userNameLength, - default, - default, - null, - ref passwordLength); - - username = userNameLength == 0 ? null : new string('\0', userNameLength - 1); - password = passwordLength == 0 ? null : new string('\0', passwordLength - 1); - - if (!CredentialApis.CredUnPackAuthenticationBuffer( - 0, - outBuffer, - outBufferSize, - username, - ref userNameLength, - default, - default, - password, - ref passwordLength)) - { - Throw.Command.Win32Exception("Failed to unpack auth buffer."); - } - } - finally - { - Marshal.FreeCoTaskMem(outBuffer); - } - } - finally - { - Marshal.FreeCoTaskMem(inBuffer); - } - } - void PromptForCredentialsConsole() { if (username is null) @@ -391,11 +423,14 @@ namespace WinSW } } - void Uninstall() + void Uninstall(string? pathToConfig, bool noElevate) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } @@ -437,11 +472,14 @@ namespace WinSW } } - void Start() + void Start(string? pathToConfig, bool noElevate) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } @@ -453,29 +491,26 @@ namespace WinSW { svc.Start(); } - catch (InvalidOperationException e) when (e.InnerException is Win32Exception inner) + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) { - switch (inner.NativeErrorCode) - { - case Errors.ERROR_SERVICE_DOES_NOT_EXIST: - ThrowNoSuchService(inner); - break; - - case Errors.ERROR_SERVICE_ALREADY_RUNNING: - Log.Info($"The service with ID '{descriptor.Id}' has already been started"); - break; - - default: - throw; - } + ThrowNoSuchService(inner); + } + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_ALREADY_RUNNING) + { + Log.Info($"The service with ID '{descriptor.Id}' has already been started"); } } - void Stop() + void Stop(string? pathToConfig, bool noElevate, bool noWait, bool force) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } @@ -485,72 +520,50 @@ namespace WinSW try { - svc.Stop(); - } - catch (InvalidOperationException e) when (e.InnerException is Win32Exception inner) - { - switch (inner.NativeErrorCode) + if (!force) { - case Errors.ERROR_SERVICE_DOES_NOT_EXIST: - ThrowNoSuchService(inner); - break; - - case Errors.ERROR_SERVICE_NOT_ACTIVE: - Log.Info($"The service with ID '{descriptor.Id}' is not running"); - break; - - default: - throw; - + if (svc.HasAnyStartedDependentService()) + { + throw new CommandException("Failed to stop the service because it has started dependent services. Specify '--force' to proceed."); + } } - } - } - void StopWait() - { - if (!elevated) - { - Elevate(); - return; - } - - Log.Info("Stopping the service with id '" + descriptor.Id + "'"); - - using var svc = new ServiceController(descriptor.Id); - - try - { svc.Stop(); - while (!ServiceControllerExtension.TryWaitForStatus(svc, ServiceControllerStatus.Stopped, TimeSpan.FromSeconds(1))) + if (!noWait) { - Log.Info("Waiting the service to stop..."); + Log.Info("Waiting for the service to stop..."); + try + { + svc.WaitForStatus(ServiceControllerStatus.Stopped, ServiceControllerStatus.StopPending); + } + catch (TimeoutException e) + { + throw new CommandException("Failed to stop the service.", e); + } } } - catch (InvalidOperationException e) when (e.InnerException is Win32Exception inner) + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) + { + ThrowNoSuchService(inner); + } + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_NOT_ACTIVE) { - switch (inner.NativeErrorCode) - { - case Errors.ERROR_SERVICE_DOES_NOT_EXIST: - ThrowNoSuchService(inner); - break; - - case Errors.ERROR_SERVICE_NOT_ACTIVE: - break; - - default: - throw; - } } Log.Info("The service stopped."); } - void Restart() + void Restart(string? pathToConfig, bool noElevate, bool force) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } @@ -558,39 +571,64 @@ namespace WinSW using var svc = new ServiceController(descriptor.Id); + List? startedDependentServices = null; + try { + if (svc.HasAnyStartedDependentService()) + { + if (!force) + { + throw new CommandException("Failed to restart the service because it has started dependent services. Specify '--force' to proceed."); + } + + startedDependentServices = svc.DependentServices.Where(service => service.Status != ServiceControllerStatus.Stopped).ToList(); + } + svc.Stop(); - while (!ServiceControllerExtension.TryWaitForStatus(svc, ServiceControllerStatus.Stopped, TimeSpan.FromSeconds(1))) + Log.Info("Waiting for the service to stop..."); + try { + svc.WaitForStatus(ServiceControllerStatus.Stopped, ServiceControllerStatus.StopPending); + } + catch (TimeoutException e) + { + throw new CommandException("Failed to stop the service.", e); } } - catch (InvalidOperationException e) when (e.InnerException is Win32Exception inner) + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) + { + ThrowNoSuchService(inner); + } + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_NOT_ACTIVE) { - switch (inner.NativeErrorCode) - { - case Errors.ERROR_SERVICE_DOES_NOT_EXIST: - ThrowNoSuchService(inner); - break; - - case Errors.ERROR_SERVICE_NOT_ACTIVE: - break; - - default: - throw; - - } } svc.Start(); + + if (startedDependentServices != null) + { + foreach (ServiceController service in startedDependentServices) + { + if (service.Status == ServiceControllerStatus.Stopped) + { + service.Start(); + } + } + } } - void RestartSelf() + void RestartSelf(string? pathToConfig) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - throw new UnauthorizedAccessException("Access is denied."); + throw new CommandException(new Win32Exception(Errors.ERROR_ACCESS_DENIED)); } Log.Info("Restarting the service with id '" + descriptor.Id + "'"); @@ -603,54 +641,77 @@ namespace WinSW } } - void Status() + static void Status(string? pathToConfig) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + Log.Debug("User requested the status of the process with id '" + descriptor.Id + "'"); using var svc = new ServiceController(descriptor.Id); try { Console.WriteLine(svc.Status == ServiceControllerStatus.Running ? "Started" : "Stopped"); } - catch (InvalidOperationException e) when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) + catch (InvalidOperationException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) { Console.WriteLine("NonExistent"); } } - void Test() + void Test(string? pathToConfig, bool noElevate, int? timeout, bool noBreak) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } using WrapperService wsvc = new WrapperService(descriptor); - wsvc.RaiseOnStart(args.ToArray()); - Thread.Sleep(1000); - wsvc.RaiseOnStop(); - } - - void TestWait() - { - if (!elevated) + wsvc.RaiseOnStart(args); + try { - Elevate(); - return; + // validated [-1, int.MaxValue / 1000] + int millisecondsTimeout = timeout is int secondsTimeout && secondsTimeout >= 0 ? secondsTimeout * 1000 : -1; + + if (!noBreak) + { + Console.WriteLine("Press any key to stop the service..."); + _ = Task.Run(() => _ = Console.ReadKey()).Wait(millisecondsTimeout); + } + else + { + using ManualResetEventSlim evt = new ManualResetEventSlim(); + + Console.WriteLine("Press Ctrl+C to stop the service..."); + Console.CancelKeyPress += CancelKeyPress; + + _ = evt.Wait(millisecondsTimeout); + Console.CancelKeyPress -= CancelKeyPress; + + void CancelKeyPress(object sender, ConsoleCancelEventArgs e) + { + evt.Set(); + } + } + } + finally + { + wsvc.RaiseOnStop(); } - - using WrapperService wsvc = new WrapperService(descriptor); - wsvc.RaiseOnStart(args.ToArray()); - Console.WriteLine("Press any key to stop the service..."); - _ = Console.Read(); - wsvc.RaiseOnStop(); } - void Refresh() + void Refresh(string? pathToConfig, bool noElevate) { + ServiceDescriptor descriptor = ServiceDescriptor.Create(pathToConfig); + InitLoggers(descriptor, enableConsoleLogging: true); + if (!elevated) { - Elevate(); + Elevate(noElevate); return; } @@ -659,7 +720,7 @@ namespace WinSW { using Service sc = scm.OpenService(descriptor.Id); - sc.ChangeConfig(descriptor.Caption, descriptor.StartMode); + sc.ChangeConfig(descriptor.Caption, descriptor.StartMode, descriptor.ServiceDependencies); sc.SetDescription(descriptor.Description); @@ -682,27 +743,33 @@ namespace WinSW sc.SetSecurityDescriptor(new RawSecurityDescriptor(securityDescriptor)); } } - catch (CommandException e) when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) + catch (CommandException e) + when (e.InnerException is Win32Exception inner && inner.NativeErrorCode == Errors.ERROR_SERVICE_DOES_NOT_EXIST) { ThrowNoSuchService(inner); } } // [DoesNotReturn] - void Elevate() + static void Elevate(bool noElevate) { + if (noElevate) + { + throw new CommandException(new Win32Exception(Errors.ERROR_ACCESS_DENIED)); + } + using Process current = Process.GetCurrentProcess(); + string exe = Environment.GetCommandLineArgs()[0]; + string commandLine = Environment.CommandLine; + string arguments = "--elevated" + commandLine.Remove(commandLine.IndexOf(exe), exe.Length).TrimStart('"'); + ProcessStartInfo startInfo = new ProcessStartInfo { UseShellExecute = true, Verb = "runas", FileName = current.MainModule.FileName, -#if NETCOREAPP - Arguments = "/elevated " + string.Join(' ', args), -#else - Arguments = "/elevated " + string.Join(" ", args), -#endif + Arguments = arguments, WindowStyle = ProcessWindowStyle.Hidden, }; @@ -728,6 +795,11 @@ namespace WinSW private static void InitLoggers(ServiceDescriptor descriptor, bool enableConsoleLogging) { + if (ServiceDescriptor.TestDescriptor != null) + { + return; + } + // TODO: Make logging levels configurable Level fileLogLevel = Level.Debug; // TODO: Debug should not be printed to console by default. Otherwise commands like 'status' will be pollutted @@ -840,43 +912,5 @@ namespace WinSW } } } - - private static void PrintHelp() - { - Console.WriteLine("A wrapper binary that can be used to host executables as Windows services"); - Console.WriteLine(); - Console.WriteLine("Usage: winsw []"); - Console.WriteLine(" Missing arguments triggers the service mode"); - Console.WriteLine(); - PrintAvailableCommands(); - Console.WriteLine(); - PrintVersion(); - Console.WriteLine("More info: https://github.com/winsw/winsw"); - Console.WriteLine("Bug tracker: https://github.com/winsw/winsw/issues"); - } - - // TODO: Rework to enum in winsw-2.0 - private static void PrintAvailableCommands() - { - Console.WriteLine( -@"Available commands: - install install the service to Windows Service Controller - uninstall uninstall the service - start start the service (must be installed before) - stop stop the service - stopwait stop the service and wait until it's actually stopped - restart restart the service - restart! self-restart (can be called from child processes) - status check the current status of the service - test check if the service can be started and then stopped - testwait starts the service and waits until a key is pressed then stops the service - version print the version info - help print the help info (aliases: -h,--help,-?,/?)"); - } - - private static void PrintVersion() - { - Console.WriteLine("WinSW " + WrapperService.Version); - } } } diff --git a/src/WinSW/ServiceControllerExtension.cs b/src/WinSW/ServiceControllerExtension.cs index e5af01d..499424a 100644 --- a/src/WinSW/ServiceControllerExtension.cs +++ b/src/WinSW/ServiceControllerExtension.cs @@ -6,17 +6,26 @@ namespace WinSW { internal static class ServiceControllerExtension { - internal static bool TryWaitForStatus(/*this*/ ServiceController serviceController, ServiceControllerStatus desiredStatus, TimeSpan timeout) + /// + internal static void WaitForStatus(this ServiceController serviceController, ServiceControllerStatus desiredStatus, ServiceControllerStatus pendingStatus) { - try + TimeSpan timeout = TimeSpan.FromSeconds(1); + for (; ; ) { - serviceController.WaitForStatus(desiredStatus, timeout); - return true; - } - catch (TimeoutException) - { - return false; + try + { + serviceController.WaitForStatus(desiredStatus, timeout); + break; + } + catch (TimeoutException) when (serviceController.Status == desiredStatus || serviceController.Status == pendingStatus) + { + } } } + + internal static bool HasAnyStartedDependentService(this ServiceController serviceController) + { + return Array.Exists(serviceController.DependentServices, service => service.Status != ServiceControllerStatus.Stopped); + } } } diff --git a/src/WinSW/UserException.cs b/src/WinSW/UserException.cs deleted file mode 100644 index 0a732f0..0000000 --- a/src/WinSW/UserException.cs +++ /dev/null @@ -1,17 +0,0 @@ -using System; - -namespace WinSW -{ - internal sealed class UserException : Exception - { - internal UserException(string message) - : base(message) - { - } - - internal UserException(string? message, Exception inner) - : base(message, inner) - { - } - } -} diff --git a/src/WinSW/WinSW.csproj b/src/WinSW/WinSW.csproj index 96391ad..112dc4e 100644 --- a/src/WinSW/WinSW.csproj +++ b/src/WinSW/WinSW.csproj @@ -19,6 +19,10 @@ 3.0.40 + + + + @@ -59,6 +63,7 @@ $(InputAssemblies) "$(OutDir)WinSW.Core.dll" $(InputAssemblies) "$(OutDir)WinSW.Plugins.dll" $(InputAssemblies) "$(OutDir)log4net.dll" + $(InputAssemblies) "$(OutDir)System.CommandLine.dll" "$(ArtifactsDir)WinSW.$(TargetFrameworkSuffix).exe"