diff --git a/src/Core/ServiceWrapper/WrapperService.cs b/src/Core/ServiceWrapper/WrapperService.cs index 6dc356a..041458c 100644 --- a/src/Core/ServiceWrapper/WrapperService.cs +++ b/src/Core/ServiceWrapper/WrapperService.cs @@ -334,7 +334,7 @@ namespace winsw try { Log.Debug("ProcessKill " + _process.Id); - ProcessHelper.StopProcessAndChildren(_process.Id, _descriptor.StopTimeout, _descriptor.StopParentProcessFirst); + ProcessHelper.StopProcessAndChildren(_process, _descriptor.StopTimeout, _descriptor.StopParentProcessFirst); ExtensionManager.FireOnProcessTerminated(_process); } catch (InvalidOperationException) diff --git a/src/Core/WinSWCore/Native/Libraries.cs b/src/Core/WinSWCore/Native/Libraries.cs index 82d75f3..4d01c0a 100644 --- a/src/Core/WinSWCore/Native/Libraries.cs +++ b/src/Core/WinSWCore/Native/Libraries.cs @@ -4,5 +4,6 @@ { internal const string Advapi32 = "advapi32.dll"; internal const string Kernel32 = "kernel32.dll"; + internal const string NtDll = "ntdll.dll"; } } diff --git a/src/Core/WinSWCore/Native/ProcessApis.cs b/src/Core/WinSWCore/Native/ProcessApis.cs index 1c6292f..36241bd 100644 --- a/src/Core/WinSWCore/Native/ProcessApis.cs +++ b/src/Core/WinSWCore/Native/ProcessApis.cs @@ -24,12 +24,36 @@ namespace winsw.Native [DllImport(Libraries.Kernel32)] internal static extern IntPtr GetCurrentProcess(); + [DllImport(Libraries.NtDll)] + internal static extern int NtQueryInformationProcess( + IntPtr processHandle, + PROCESSINFOCLASS processInformationClass, + out PROCESS_BASIC_INFORMATION processInformation, + int processInformationLength, + IntPtr returnLength = default); + [DllImport(Libraries.Advapi32, SetLastError = true)] internal static extern bool OpenProcessToken( IntPtr processHandle, TokenAccessLevels desiredAccess, out IntPtr tokenHandle); + internal enum PROCESSINFOCLASS + { + ProcessBasicInformation = 0, + } + + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct PROCESS_BASIC_INFORMATION + { + private readonly IntPtr Reserved1; + private readonly IntPtr PebBaseAddress; + private readonly IntPtr Reserved2_1; + private readonly IntPtr Reserved2_2; + internal readonly IntPtr UniqueProcessId; + internal readonly IntPtr InheritedFromUniqueProcessId; + } + internal struct PROCESS_INFORMATION { public IntPtr ProcessHandle; diff --git a/src/Core/WinSWCore/Util/ProcessHelper.cs b/src/Core/WinSWCore/Util/ProcessHelper.cs index 0c22adf..284aba1 100644 --- a/src/Core/WinSWCore/Util/ProcessHelper.cs +++ b/src/Core/WinSWCore/Util/ProcessHelper.cs @@ -1,9 +1,10 @@ using System; using System.Collections.Generic; +using System.ComponentModel; using System.Diagnostics; -using System.Management; using System.Threading; using log4net; +using static winsw.Native.ProcessApis; namespace winsw.Util { @@ -18,30 +19,48 @@ namespace winsw.Util /// /// Gets all children of the specified process. /// - /// Process PID + /// Process PID /// List of child process PIDs - public static List GetChildPids(int pid) + private static unsafe List GetChildProcesses(int processId) { - var childPids = new List(); + var children = new List(); - try + foreach (Process process in Process.GetProcesses()) { - string query = "SELECT * FROM Win32_Process WHERE ParentProcessID = " + pid; - using ManagementObjectSearcher searcher = new ManagementObjectSearcher(query); - using ManagementObjectCollection results = searcher.Get(); - foreach (ManagementBaseObject wmiObject in results) + IntPtr handle; + try { - var childProcessId = wmiObject["ProcessID"]; - Logger.Info("Found child process: " + childProcessId + " Name: " + wmiObject["Name"]); - childPids.Add(Convert.ToInt32(childProcessId)); + handle = process.Handle; + } + catch (Win32Exception) + { + process.Dispose(); + continue; + } + + if (NtQueryInformationProcess( + handle, + PROCESSINFOCLASS.ProcessBasicInformation, + out PROCESS_BASIC_INFORMATION information, + sizeof(PROCESS_BASIC_INFORMATION)) != 0) + { + Logger.Warn("Failed to locate children of the process with PID=" + processId + ". Child processes won't be terminated"); + process.Dispose(); + continue; + } + + if ((int)information.InheritedFromUniqueProcessId == processId) + { + Logger.Info("Found child process: " + process.Id + " Name: " + process.ProcessName); + children.Add(process); + } + else + { + process.Dispose(); } } - catch (Exception ex) - { - Logger.Warn("Failed to locate children of the process with PID=" + pid + ". Child processes won't be terminated", ex); - } - return childPids; + return children; } /// @@ -50,22 +69,18 @@ namespace winsw.Util /// /// PID of the process /// Stop timeout - public static void StopProcess(int pid, TimeSpan stopTimeout) + public static void StopProcess(Process process, TimeSpan stopTimeout) { - Logger.Info("Stopping process " + pid); - Process proc; - try + Logger.Info("Stopping process " + process.Id); + + if (process.HasExited) { - proc = Process.GetProcessById(pid); - } - catch (ArgumentException ex) - { - Logger.Info("Process " + pid + " is already stopped", ex); + Logger.Info("Process " + process.Id + " is already stopped"); return; } // (bool sent, bool exited) - KeyValuePair result = SignalHelper.SendCtrlCToProcess(proc, stopTimeout); + KeyValuePair result = SignalHelper.SendCtrlCToProcess(process, stopTimeout); bool exited = result.Value; if (!exited) { @@ -74,13 +89,13 @@ namespace winsw.Util bool sent = result.Key; if (sent) { - Logger.Warn("Process " + pid + " did not respond to Ctrl+C signal - Killing as fallback"); + Logger.Warn("Process " + process.Id + " did not respond to Ctrl+C signal - Killing as fallback"); } - proc.Kill(); + process.Kill(); } catch (Exception ex) { - if (!proc.HasExited) + if (!process.HasExited) { throw; } @@ -100,23 +115,23 @@ namespace winsw.Util /// Process PID /// Stop timeout (for each process) /// If enabled, the perent process will be terminated before its children on all levels - public static void StopProcessAndChildren(int pid, TimeSpan stopTimeout, bool stopParentProcessFirst) + public static void StopProcessAndChildren(Process process, TimeSpan stopTimeout, bool stopParentProcessFirst) { if (!stopParentProcessFirst) { - foreach (var childPid in GetChildPids(pid)) + foreach (Process child in GetChildProcesses(process.Id)) { - StopProcessAndChildren(childPid, stopTimeout, stopParentProcessFirst); + StopProcessAndChildren(child, stopTimeout, stopParentProcessFirst); } } - StopProcess(pid, stopTimeout); + StopProcess(process, stopTimeout); if (stopParentProcessFirst) { - foreach (var childPid in GetChildPids(pid)) + foreach (Process child in GetChildProcesses(process.Id)) { - StopProcessAndChildren(childPid, stopTimeout, stopParentProcessFirst); + StopProcessAndChildren(child, stopTimeout, stopParentProcessFirst); } } } diff --git a/src/Core/WinSWCore/WinSWCore.csproj b/src/Core/WinSWCore/WinSWCore.csproj index 7b655ad..284f8ea 100644 --- a/src/Core/WinSWCore/WinSWCore.csproj +++ b/src/Core/WinSWCore/WinSWCore.csproj @@ -15,7 +15,6 @@ - @@ -32,7 +31,6 @@ - diff --git a/src/Plugins/RunawayProcessKiller/RunawayProcessKillerExtension.cs b/src/Plugins/RunawayProcessKiller/RunawayProcessKillerExtension.cs index 1623004..e06ce3f 100644 --- a/src/Plugins/RunawayProcessKiller/RunawayProcessKillerExtension.cs +++ b/src/Plugins/RunawayProcessKiller/RunawayProcessKillerExtension.cs @@ -276,7 +276,7 @@ namespace winsw.Plugins.RunawayProcessKiller bldr.Append(proc); Logger.Warn(bldr.ToString()); - ProcessHelper.StopProcessAndChildren(pid, this.StopTimeout, this.StopParentProcessFirst); + ProcessHelper.StopProcessAndChildren(proc, this.StopTimeout, this.StopParentProcessFirst); } /// diff --git a/src/Test/winswTests/Extensions/RunawayProcessKillerTest.cs b/src/Test/winswTests/Extensions/RunawayProcessKillerTest.cs index 007cf67..4db32ff 100644 --- a/src/Test/winswTests/Extensions/RunawayProcessKillerTest.cs +++ b/src/Test/winswTests/Extensions/RunawayProcessKillerTest.cs @@ -110,7 +110,7 @@ $@" if (!proc.HasExited) { Console.Error.WriteLine("Test: Killing runaway process with ID=" + proc.Id); - ProcessHelper.StopProcessAndChildren(proc.Id, TimeSpan.FromMilliseconds(100), false); + ProcessHelper.StopProcessAndChildren(proc, TimeSpan.FromMilliseconds(100), false); if (!proc.HasExited) { // The test is failed here anyway, but we add additional diagnostics info