openvpn-gui/plap/test_plap.cpp

178 lines
5.8 KiB
C++
Raw Normal View History

/*
* OpenVPN-PLAP-Provider
*
* Copyright (C) 2019-2022 Selva Nair <selva.nair@gmail.com>
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program (see the file COPYING included with this
* distribution); if not, write to the Free Software Foundation, Inc.,
* 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
#undef WIN32_LEAN_AND_MEAN
#include <initguid.h>
#include "plap_common.h"
#include <windows.h>
#include <credentialprovider.h>
#include <fcntl.h>
#include <io.h>
#include <stdio.h>
#include <assert.h>
#include <time.h>
#define error_return(fname, hr) do { fwprintf(stdout, L"Error in %ls: status = 0x%08x", fname, hr);\
return 1;} while (0)
typedef HRESULT (WINAPI * f_func)(REFCLSID rclsid, REFIID riid, LPVOID* ppv);
class MyQueryContinue : public IQueryContinueWithStatus
{
public:
STDMETHODIMP_(ULONG) AddRef() {
return InterlockedIncrement(&ref_count);
}
STDMETHODIMP_(ULONG) Release() {
int count = InterlockedDecrement(&ref_count);
if (ref_count == 0) delete this;
return count;
}
STDMETHODIMP QueryInterface(REFIID riid, void **ppv) { return E_FAIL; }
STDMETHODIMP QueryContinue() {return time(NULL) > timeout ? S_FALSE : S_OK;}
STDMETHODIMP SetStatusMessage(const wchar_t *ws) { wprintf(L"%ls\r", ws); return S_OK; }
MyQueryContinue() : ref_count(1) {};
time_t timeout;
private:
~MyQueryContinue()=default;
LONG ref_count;
};
static int test_provider(IClassFactory *cf)
{
assert(cf != NULL);
ICredentialProvider *o = NULL;
HRESULT hr;
hr = cf->CreateInstance(NULL, IID_ICredentialProvider, (void**) &o);
if (!SUCCEEDED(hr)) error_return(L"IID_ICredentialProvider", hr);
hr = o->SetUsageScenario(CPUS_PLAP, 0);
if (!SUCCEEDED(hr)) error_return(L"SetUsageScenario", hr);
DWORD count, def;
BOOL auto_def;
hr = o->GetCredentialCount(&count, &def, &auto_def);
if (!SUCCEEDED(hr)) error_return(L"GetCredentialCount", hr);
fwprintf(stdout, L"credential count = %lu, default = %d, autologon = %d\n", count, (int) def, auto_def);
if (count < 1) fwprintf(stdout, L"No persistent configs found!\n");
ICredentialProviderCredential *c = NULL;
MyQueryContinue *qc = new MyQueryContinue();
for (DWORD i = 0; i < count; i++)
{
hr = o->GetCredentialAt(i, &c);
if (!SUCCEEDED(hr)) error_return(L"GetCredentialAt", hr);
fwprintf(stdout, L"credential # = %lu: ", i);
wchar_t *ws;
for (DWORD j = 0; j < 4; j++)
{
hr = c->GetStringValue(j, &ws);
if (!SUCCEEDED(hr)) error_return(L"GetStringValue", hr);
CoTaskMemFree(ws);
}
/* test getbitmap */
HBITMAP bmp;
hr = c->GetBitmapValue(0, &bmp);
if (!SUCCEEDED(hr))
fwprintf(stdout, L"Warning: could not get bitmap"); /* not fatal */
else
DeleteObject(bmp);
/* set a time out so that we can move to next config in case */
qc->timeout = time(NULL) + 20;
/* get a connection instance and call connect on it */
IConnectableCredentialProviderCredential *c1 = NULL;
hr = c->QueryInterface(IID_IConnectableCredentialProviderCredential, (void**)&c1);
fwprintf(stdout, L"\nConnecting connection # <%lu>\n", i);
c1->Connect(qc); /* this will return when connected/failed or qc timesout */
fwprintf(stdout, L"\nsleep for 2 sec\n");
Sleep(2000);
c1->Release();
}
assert(o->Release() == 0); /* check refcount */
assert(qc->Release() == 0);
return 0;
}
int wmain()
{
HRESULT hr;
_setmode(_fileno(stdout), _O_U16TEXT);
_setmode(_fileno(stderr), _O_U16TEXT);
IClassFactory *cf = NULL;
DWORD ctx = CLSCTX_INPROC_SERVER;
hr = CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
if (!SUCCEEDED(hr)) error_return(L"CoIntialize", hr);
/* Test by loading the dll */
fwprintf(stdout, L"Test plap dll direct loading\n");
HMODULE lib = LoadLibraryW(L"C:\\Program Files\\OpenVPN\\bin\\libopenvpn_plap.dll");
f_func func = NULL;
if (lib == NULL)
{
fwprintf(stderr, L"Failed to load the dll: error = 0x%08x\n", GetLastError());
}
else {
func = (f_func) GetProcAddress(lib, "DllGetClassObject");
if (!func)
fwprintf(stderr, L"Failed to find DllGetClassObject in dll: error = 0x%08x\n", GetLastError());
}
if (func) {
hr = func(CLSID_OpenVPNProvider, IID_IClassFactory, (void **)&cf);
if (!SUCCEEDED(hr)) fwprintf(stdout, L"Error in DllGetClassObject: status = 0x%08x\n", hr);
else {
fwprintf(stdout, L"Success: found ovpn provider class factory by direct access\n");
cf->Release();
}
}
/* Test by finding the class through COM's registration mechanism */
fwprintf(stdout, L"Testing plap using CoGetclassobject -- requires proper dll registration\n");
hr = CoGetClassObject(CLSID_OpenVPNProvider, ctx, NULL, IID_IClassFactory, (void **)&cf);
if (SUCCEEDED(hr)) {
test_provider(cf);
cf->Release();
}
else {
fwprintf(stdout, L"CoGetClassObject (class not registered?): error = 0x%08x\n", hr);
}
CoUninitialize();
if (lib) {
FreeLibrary(lib);
}
return 0;
}