diff --git a/as.c b/as.c index 712b4da..cfa4d4d 100644 --- a/as.c +++ b/as.c @@ -22,6 +22,7 @@ #include #include #include +#include #include "config.h" #include "localization.h" @@ -38,6 +39,21 @@ #define PROFILE_NAME_TOKEN L"# OVPN_ACCESS_SERVER_PROFILE=" #define FRIENDLY_NAME_TOKEN L"# OVPN_ACCESS_SERVER_FRIENDLY_NAME=" +/** Replace characters not allowed in Windows filenames with '_' */ +void +SanitizeFilename(wchar_t *fname) +{ + const wchar_t *reserved = L"<>:\"/\\|?*;"; /* remap these and ascii 1 to 31 */ + while (*fname) { + wchar_t c = *fname; + if (c < 32 || wcschr(reserved, c)) + { + *fname = L'_'; + } + ++fname; + } +} + /** * Extract profile name from profile content. * @@ -87,14 +103,7 @@ ExtractProfileName(const WCHAR *profile, const WCHAR *default_name, WCHAR *out_n out_name[out_name_length - 1] = L'\0'; - /* sanitize profile name */ - const WCHAR *reserved = L"<>:\"/\\|?*;"; /* remap these and ascii 1 to 31 */ - while (*out_name) { - wchar_t c = *out_name; - if (c < 32 || wcschr(reserved, c)) - *out_name = L'_'; - ++out_name; - } + SanitizeFilename(out_name); free(buf); } @@ -295,6 +304,77 @@ GetASUrl(const WCHAR *host, bool autologin, struct UrlComponents *comps) comps->path[URL_LEN - 1] = L'\0'; } +/** + * Read content-disposition header and extract file name if any. + * Returns true on success, false otherwise. + */ +bool +ExtractFilenameFromHeader(HINTERNET hRequest, wchar_t *name, size_t len) +{ + DWORD index = 0; + char *buf = NULL; + DWORD buflen = 256; + bool res = false; + UINT codepage = 28591; /* ISO 8859_1 -- the default char set for http header */ + + buf = malloc(buflen); + if (!buf + || (!HttpQueryInfoA(hRequest, HTTP_QUERY_CONTENT_DISPOSITION, buf, &buflen, &index) + && GetLastError() != ERROR_INSUFFICIENT_BUFFER)) + { + goto done; + } + + if (GetLastError() == ERROR_INSUFFICIENT_BUFFER) + { + /* try again with more space */ + free(buf); + buf = malloc(buflen); + if (!buf + || !HttpQueryInfoA(hRequest, HTTP_QUERY_CONTENT_DISPOSITION, buf, &buflen, &index)) + { + goto done; + } + } + + /* look for filename= */ + char *p = strtok(buf, ";"); + char *fn = NULL; + for ( ; p; p = strtok(NULL, ";")) + { + if ((fn = strstr(p, "filename=")) != NULL) + { + fn += 9; + continue; + } + else if ((fn = strstr(p, "filename*=utf-8''")) != NULL) + { + fn += 17; + UrlUnescapeA(fn, NULL, NULL, URL_UNESCAPE_INPLACE); + codepage = CP_UTF8; + break; /* we prefer filename*= value */ + } + } + + if (fn && strlen(fn)) + { + StrTrimA(fn, " \""); /* strip leading and trailing spaces and quotes */ + wchar_t *wfn = WidenEx(codepage, fn); + if (wfn) + { + wcsncpy_s(name, len, wfn, _TRUNCATE); + res = true; + free(wfn); + } + } + + SanitizeFilename(name); + +done: + free(buf); + return res; +} + /** * Download profile from a generic URL and save it to a temp file * @@ -467,13 +547,18 @@ again: } WCHAR name[MAX_PATH] = {0}; - WCHAR* wbuf = Widen(buf); - if (!wbuf) { - MessageBoxW(hWnd, L"Failed to convert profile content to wchar", _T(PACKAGE_NAME), MB_OK); - goto done; + /* read filename from header or from the profile metadata */ + if (strlen(comps->content_type) == 0 /* AS profile */ + || !ExtractFilenameFromHeader(hRequest, name, MAX_PATH)) + { + WCHAR* wbuf = Widen(buf); + if (!wbuf) { + MessageBoxW(hWnd, L"Failed to convert profile content to wchar", _T(PACKAGE_NAME), MB_OK); + goto done; + } + ExtractProfileName(wbuf, comps->host, name, MAX_PATH); + free(wbuf); } - ExtractProfileName(wbuf, comps->host, name, MAX_PATH); - free(wbuf); /* save profile content into tmp file */ DWORD res = GetTempPathW((DWORD)out_path_size, out_path); diff --git a/misc.c b/misc.c index 5c853e9..7b94a7d 100644 --- a/misc.c +++ b/misc.c @@ -466,22 +466,23 @@ CheckFileAccess (const TCHAR *path, int access) return ret; } -/* - * Convert a NUL terminated utf8 string to widechar. The caller must free +/** + * Convert a NUL terminated narrow string to wide string using + * specified codepage. The caller must free * the returned pointer. Return NULL on error. */ WCHAR * -Widen(const char *utf8) +WidenEx(UINT codepage, const char *str) { WCHAR *wstr = NULL; - if (!utf8) + if (!str) return wstr; - int nch = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0); + int nch = MultiByteToWideChar(codepage, 0, str, -1, NULL, 0); if (nch > 0) wstr = malloc(sizeof(WCHAR) * nch); if (wstr) - nch = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wstr, nch); + nch = MultiByteToWideChar(codepage, 0, str, -1, wstr, nch); if (nch == 0 && wstr) { @@ -492,6 +493,15 @@ Widen(const char *utf8) return wstr; } +/* + * Same as WidenEx with codepage = UTF8 + */ +WCHAR * +Widen(const char *utf8) +{ + return WidenEx(CP_UTF8, utf8); +} + /* Return false if input contains any characters in exclude */ BOOL validate_input(const WCHAR *input, const WCHAR *exclude) diff --git a/misc.h b/misc.h index 93ac370..625a9b1 100644 --- a/misc.h +++ b/misc.h @@ -46,6 +46,7 @@ BOOL CheckFileAccess (const TCHAR *path, int access); BOOL Base64Encode(const char *input, int input_len, char **output); int Base64Decode(const char *input, char **output); WCHAR *Widen(const char *utf8); +WCHAR *WidenEx(UINT codepage, const char *utf8); BOOL validate_input(const WCHAR *input, const WCHAR *exclude); /* Concatenate two wide strings with a separator */ void wcs_concat2(WCHAR *dest, int len, const WCHAR *src1, const WCHAR *src2, const WCHAR *sep);