From e8726d7c1b981ce967e6a6ad4f3e005b93e2433e Mon Sep 17 00:00:00 2001 From: NextTurn <45985406+NextTurn@users.noreply.github.com> Date: Tue, 1 Oct 2019 00:00:00 +0800 Subject: [PATCH] Rework Shared Directory Mapper --- .../SharedDirectoryMapper/NativeMethods.cs | 30 ++++ .../SharedDirectoryMapper.cs | 47 +++--- .../SharedDirectoryMapperHelper.cs | 71 --------- ....cs => SharedDirectoryMapperConfigTest.cs} | 2 +- .../Extensions/SharedDirectoryMapperTests.cs | 148 ++++++++++++++++++ 5 files changed, 204 insertions(+), 94 deletions(-) create mode 100644 src/Plugins/SharedDirectoryMapper/NativeMethods.cs delete mode 100644 src/Plugins/SharedDirectoryMapper/SharedDirectoryMapperHelper.cs rename src/Test/winswTests/Extensions/{SharedDirectoryMapperTest.cs => SharedDirectoryMapperConfigTest.cs} (96%) create mode 100644 src/Test/winswTests/Extensions/SharedDirectoryMapperTests.cs diff --git a/src/Plugins/SharedDirectoryMapper/NativeMethods.cs b/src/Plugins/SharedDirectoryMapper/NativeMethods.cs new file mode 100644 index 0000000..79b8603 --- /dev/null +++ b/src/Plugins/SharedDirectoryMapper/NativeMethods.cs @@ -0,0 +1,30 @@ +using System.Runtime.InteropServices; + +namespace winsw.Plugins.SharedDirectoryMapper +{ + internal static class NativeMethods + { + internal const uint RESOURCETYPE_DISK = 0x00000001; + + private const string MprLibraryName = "mpr.dll"; + + [DllImport(MprLibraryName, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "WNetAddConnection2W")] + internal static extern int WNetAddConnection2(in NETRESOURCE netResource, string? password = null, string? userName = null, uint flags = 0); + + [DllImport(MprLibraryName, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "WNetCancelConnection2W")] + internal static extern int WNetCancelConnection2(string name, uint flags = 0, bool force = false); + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + internal struct NETRESOURCE + { + public uint Scope; + public uint Type; + public uint DisplayType; + public uint Usage; + public string LocalName; + public string RemoteName; + public string Comment; + public string Provider; + } + } +} diff --git a/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapper.cs b/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapper.cs index 17b486b..a17e878 100644 --- a/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapper.cs +++ b/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapper.cs @@ -1,14 +1,15 @@ using System.Collections.Generic; +using System.ComponentModel; using System.Xml; using log4net; using winsw.Extensions; using winsw.Util; +using static winsw.Plugins.SharedDirectoryMapper.NativeMethods; namespace winsw.Plugins.SharedDirectoryMapper { public class SharedDirectoryMapper : AbstractWinSWExtension { - private readonly SharedDirectoryMappingHelper _mapper = new SharedDirectoryMappingHelper(); private readonly List _entries = new List(); public override string DisplayName => "Shared Directory Mapper"; @@ -22,7 +23,7 @@ namespace winsw.Plugins.SharedDirectoryMapper public SharedDirectoryMapper(bool enableMapping, string directoryUNC, string driveLabel) { SharedDirectoryMapperConfig config = new SharedDirectoryMapperConfig(enableMapping, driveLabel, directoryUNC); - _entries.Add(config); + this._entries.Add(config); } public override void Configure(ServiceDescriptor descriptor, XmlNode node) @@ -35,7 +36,7 @@ namespace winsw.Plugins.SharedDirectoryMapper if (mapNodes[i] is XmlElement mapElement) { var config = SharedDirectoryMapperConfig.FromXml(mapElement); - _entries.Add(config); + this._entries.Add(config); } } } @@ -43,50 +44,52 @@ namespace winsw.Plugins.SharedDirectoryMapper public override void OnWrapperStarted() { - foreach (SharedDirectoryMapperConfig config in _entries) + foreach (SharedDirectoryMapperConfig config in this._entries) { + string label = config.Label; + string uncPath = config.UNCPath; if (config.EnableMapping) { - Logger.Info(DisplayName + ": Mapping shared directory " + config.UNCPath + " to " + config.Label); - try + Logger.Info(this.DisplayName + ": Mapping shared directory " + uncPath + " to " + label); + + int error = WNetAddConnection2(new NETRESOURCE { - _mapper.MapDirectory(config.Label, config.UNCPath); - } - catch (MapperException ex) + Type = RESOURCETYPE_DISK, + LocalName = label, + RemoteName = uncPath, + }); + if (error != 0) { - HandleMappingError(config, ex); + this.ThrowExtensionException(error, $"Mapping of {label} failed."); } } else { - Logger.Warn(DisplayName + ": Mapping of " + config.Label + " is disabled"); + Logger.Warn(this.DisplayName + ": Mapping of " + label + " is disabled"); } } } public override void BeforeWrapperStopped() { - foreach (SharedDirectoryMapperConfig config in _entries) + foreach (SharedDirectoryMapperConfig config in this._entries) { + string label = config.Label; if (config.EnableMapping) { - try + int error = WNetCancelConnection2(label); + if (error != 0) { - _mapper.UnmapDirectory(config.Label); - } - catch (MapperException ex) - { - HandleMappingError(config, ex); + this.ThrowExtensionException(error, $"Unmapping of {label} failed."); } } } } - private void HandleMappingError(SharedDirectoryMapperConfig config, MapperException ex) + private void ThrowExtensionException(int error, string message) { - Logger.Error("Mapping of " + config.Label + " failed. STDOUT: " + ex.Process.StandardOutput.ReadToEnd() - + " \r\nSTDERR: " + ex.Process.StandardError.ReadToEnd(), ex); - throw new ExtensionException(Descriptor.Id, DisplayName + ": Mapping of " + config.Label + "failed", ex); + Win32Exception inner = new Win32Exception(error); + throw new ExtensionException(this.Descriptor.Id, $"{this.DisplayName}: {message} {inner.Message}", inner); } } } diff --git a/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapperHelper.cs b/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapperHelper.cs deleted file mode 100644 index 101f80d..0000000 --- a/src/Plugins/SharedDirectoryMapper/SharedDirectoryMapperHelper.cs +++ /dev/null @@ -1,71 +0,0 @@ -using System.Diagnostics; - -namespace winsw.Plugins.SharedDirectoryMapper -{ - class SharedDirectoryMappingHelper - { - /// - /// Invokes a system command - /// - /// - /// Command to be executed - /// Command arguments - /// Operation failure - private void InvokeCommand(string command, string args) - { - Process p = new Process - { - StartInfo = - { - UseShellExecute = false, - CreateNoWindow = true, - RedirectStandardError = true, - RedirectStandardOutput = true, - FileName = command, - Arguments = args - } - }; - - p.Start(); - p.WaitForExit(); - if (p.ExitCode != 0) - { - throw new MapperException(p, command, args); - } - } - - /// - /// Maps the remote directory - /// - /// Disk label - /// UNC path to the directory - /// Operation failure - public void MapDirectory(string label, string uncPath) - { - InvokeCommand("net.exe", " use " + label + " " + uncPath); - } - - /// - /// Unmaps the label - /// - /// Disk label - /// Operation failure - public void UnmapDirectory(string label) - { - InvokeCommand("net.exe", " use /DELETE /YES " + label); - } - } - - class MapperException : WinSWException - { - public string Call { get; private set; } - public Process Process { get; private set; } - - public MapperException(Process process, string command, string args) - : base("Command " + command + " " + args + " failed with code " + process.ExitCode) - { - Call = command + " " + args; - Process = process; - } - } -} diff --git a/src/Test/winswTests/Extensions/SharedDirectoryMapperTest.cs b/src/Test/winswTests/Extensions/SharedDirectoryMapperConfigTest.cs similarity index 96% rename from src/Test/winswTests/Extensions/SharedDirectoryMapperTest.cs rename to src/Test/winswTests/Extensions/SharedDirectoryMapperConfigTest.cs index 28b227e..7013d24 100644 --- a/src/Test/winswTests/Extensions/SharedDirectoryMapperTest.cs +++ b/src/Test/winswTests/Extensions/SharedDirectoryMapperConfigTest.cs @@ -6,7 +6,7 @@ using winsw.Plugins.SharedDirectoryMapper; namespace winswTests.Extensions { [TestFixture] - class SharedDirectoryMapperTest : ExtensionTestBase + class SharedDirectoryMapperConfigTest : ExtensionTestBase { ServiceDescriptor _testServiceDescriptor; diff --git a/src/Test/winswTests/Extensions/SharedDirectoryMapperTests.cs b/src/Test/winswTests/Extensions/SharedDirectoryMapperTests.cs new file mode 100644 index 0000000..e4f1e7e --- /dev/null +++ b/src/Test/winswTests/Extensions/SharedDirectoryMapperTests.cs @@ -0,0 +1,148 @@ +#if NETCOREAPP +using System; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using NUnit.Framework; +using winsw.Plugins.SharedDirectoryMapper; + +namespace winswTests.Extensions +{ + // TODO: Throws.TypeOf() + [TestFixture] + public class SharedDirectoryMapperTests + { + [Test] + public void TestMap() + { + using TestData data = TestData.Create(); + + const string label = "W:"; + SharedDirectoryMapper mapper = new SharedDirectoryMapper(true, $@"\\{Environment.MachineName}\{data.name}", label); + + mapper.OnWrapperStarted(); + Assert.That($@"{label}\", Does.Exist); + mapper.BeforeWrapperStopped(); + Assert.That($@"{label}\", Does.Not.Exist); + } + + [Test] + public void TestDisableMapping() + { + using TestData data = TestData.Create(); + + const string label = "W:"; + SharedDirectoryMapper mapper = new SharedDirectoryMapper(enableMapping: false, $@"\\{Environment.MachineName}\{data.name}", label); + + mapper.OnWrapperStarted(); + Assert.That($@"{label}\", Does.Not.Exist); + mapper.BeforeWrapperStopped(); + } + + [Test] + public void TestMap_PathEndsWithSlash_Throws() + { + using TestData data = TestData.Create(); + + const string label = "W:"; + SharedDirectoryMapper mapper = new SharedDirectoryMapper(true, $@"\\{Environment.MachineName}\{data.name}\", label); + + Assert.That(() => mapper.OnWrapperStarted(), Throws.Exception); + Assert.That($@"{label}\", Does.Not.Exist); + Assert.That(() => mapper.BeforeWrapperStopped(), Throws.Exception); + } + + [Test] + public void TestMap_LabelDoesNotEndWithColon_Throws() + { + using TestData data = TestData.Create(); + + const string label = "W"; + SharedDirectoryMapper mapper = new SharedDirectoryMapper(true, $@"\\{Environment.MachineName}\{data.name}", label); + + Assert.That(() => mapper.OnWrapperStarted(), Throws.Exception); + Assert.That($@"{label}\", Does.Not.Exist); + Assert.That(() => mapper.BeforeWrapperStopped(), Throws.Exception); + } + + private readonly ref struct TestData + { + internal readonly string name; + internal readonly string path; + + private TestData(string name, string path) + { + this.name = name; + this.path = path; + } + + internal static TestData Create([CallerMemberName] string name = null) + { + string path = Path.Combine(Path.GetTempPath(), name); + _ = Directory.CreateDirectory(path); + + try + { + NativeMethods.SHARE_INFO_2 shareInfo = new NativeMethods.SHARE_INFO_2 + { + netname = name, + type = NativeMethods.STYPE_DISKTREE | NativeMethods.STYPE_TEMPORARY, + max_uses = unchecked((uint)-1), + path = path, + }; + + uint error = NativeMethods.NetShareAdd(null, 2, shareInfo, out _); + Assert.That(error, Is.Zero); + + return new TestData(name, path); + } + catch + { + Directory.Delete(path); + throw; + } + } + + public void Dispose() + { + try + { + uint error = NativeMethods.NetShareDel(null, this.name); + Assert.That(error, Is.Zero); + } + finally + { + Directory.Delete(this.path); + } + } + } + + private static class NativeMethods + { + internal const uint STYPE_DISKTREE = 0; + internal const uint STYPE_TEMPORARY = 0x40000000; + + private const string Netapi32LibraryName = "netapi32.dll"; + + [DllImport(Netapi32LibraryName, CharSet = CharSet.Unicode)] + internal static extern uint NetShareAdd(string servername, uint level, in SHARE_INFO_2 buf, out uint parm_err); + + [DllImport(Netapi32LibraryName, CharSet = CharSet.Unicode)] + internal static extern uint NetShareDel(string servername, string netname, uint reserved = 0); + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + internal struct SHARE_INFO_2 + { + public string netname; + public uint type; + public string remark; + public uint permissions; + public uint max_uses; + public uint current_uses; + public string path; + public string passwd; + } + } + } +} +#endif