Handle special accounts

pull/620/head
NextTurn 2020-08-01 00:00:00 +08:00 committed by Next Turn
parent bd61d2986f
commit fcbc087b4b
3 changed files with 28 additions and 23 deletions

View File

@ -9,9 +9,9 @@ namespace WinSW.Native
{ {
/// <exception cref="CommandException" /> /// <exception cref="CommandException" />
/// <exception cref="Win32Exception" /> /// <exception cref="Win32Exception" />
internal static void AddServiceLogonRight(string userName) internal static void AddServiceLogonRight(ref string userName)
{ {
IntPtr sid = GetAccountSid(userName); IntPtr sid = GetAccountSid(ref userName);
try try
{ {
@ -19,13 +19,12 @@ namespace WinSW.Native
} }
finally finally
{ {
_ = FreeSid(sid);
Marshal.FreeHGlobal(sid); Marshal.FreeHGlobal(sid);
} }
} }
/// <exception cref="CommandException" /> /// <exception cref="CommandException" />
private static IntPtr GetAccountSid(string accountName) private static IntPtr GetAccountSid(ref string accountName)
{ {
int sidSize = 0; int sidSize = 0;
int domainNameLength = 0; int domainNameLength = 0;
@ -35,24 +34,23 @@ namespace WinSW.Native
accountName = Environment.MachineName + accountName.Substring(1); accountName = Environment.MachineName + accountName.Substring(1);
} }
_ = LookupAccountName(null, accountName, IntPtr.Zero, ref sidSize, IntPtr.Zero, ref domainNameLength, out _); _ = LookupAccountName(null, accountName, IntPtr.Zero, ref sidSize, null, ref domainNameLength, out _);
IntPtr sid = Marshal.AllocHGlobal(sidSize); IntPtr sid = Marshal.AllocHGlobal(sidSize);
IntPtr domainName = Marshal.AllocHGlobal(domainNameLength * sizeof(char)); string? domainName = domainNameLength == 0 ? null : new string('\0', domainNameLength - 1);
try if (!LookupAccountName(null, accountName, sid, ref sidSize, domainName, ref domainNameLength, out _))
{ {
if (!LookupAccountName(null, accountName, sid, ref sidSize, domainName, ref domainNameLength, out _)) Throw.Command.Win32Exception("Failed to find the account.");
{ }
Throw.Command.Win32Exception("Failed to find the account.");
}
return sid; // intentionally undocumented
} if (!accountName.Contains("\\") && !accountName.Contains("@"))
finally
{ {
Marshal.FreeHGlobal(domainName); accountName = domainName + '\\' + accountName;
} }
return sid;
} }
/// <exception cref="Win32Exception" /> /// <exception cref="Win32Exception" />

View File

@ -7,9 +7,6 @@ namespace WinSW.Native
{ {
internal static class SecurityApis internal static class SecurityApis
{ {
[DllImport(Libraries.Advapi32, SetLastError = false)]
internal static extern IntPtr FreeSid(IntPtr sid);
[DllImport(Libraries.Advapi32, SetLastError = true)] [DllImport(Libraries.Advapi32, SetLastError = true)]
internal static extern bool GetTokenInformation( internal static extern bool GetTokenInformation(
IntPtr tokenHandle, IntPtr tokenHandle,
@ -24,7 +21,7 @@ namespace WinSW.Native
string accountName, string accountName,
IntPtr sid, IntPtr sid,
ref int sidSize, ref int sidSize,
IntPtr referencedDomainName, string? referencedDomainName,
ref int referencedDomainNameLength, ref int referencedDomainNameLength,
out int use); out int use);

View File

@ -345,7 +345,7 @@ namespace WinSW
username = config.ServiceAccountUserName ?? username; username = config.ServiceAccountUserName ?? username;
password = config.ServiceAccountPassword ?? password; password = config.ServiceAccountPassword ?? password;
if (username is null || password is null) if (username is null || password is null && !IsSpecialAccount(username))
{ {
switch (config.ServiceAccountPrompt) switch (config.ServiceAccountPrompt)
{ {
@ -364,9 +364,9 @@ namespace WinSW
} }
} }
if (username != null) if (username != null && !IsSpecialAccount(username))
{ {
Security.AddServiceLogonRight(username); Security.AddServiceLogonRight(ref username);
} }
using Service sc = scm.CreateService( using Service sc = scm.CreateService(
@ -422,7 +422,7 @@ namespace WinSW
username = Console.ReadLine(); username = Console.ReadLine();
} }
if (password is null) if (password is null && !IsSpecialAccount(username))
{ {
Console.Write("Password: "); Console.Write("Password: ");
password = ReadPassword(); password = ReadPassword();
@ -430,6 +430,16 @@ namespace WinSW
Console.WriteLine(); Console.WriteLine();
} }
static bool IsSpecialAccount(string accountName) => accountName switch
{
@"LocalSystem" => true,
@".\LocalSystem" => true,
@"NT AUTHORITY\LocalService" => true,
@"NT AUTHORITY\NetworkService" => true,
string name when name == $@"{Environment.MachineName}\LocalSystem" => true,
_ => false
};
} }
void Uninstall(string? pathToConfig, bool noElevate) void Uninstall(string? pathToConfig, bool noElevate)