diff --git a/dlls/urlmon/mk.c b/dlls/urlmon/mk.c index a11e40fe16c..3ee43292698 100644 --- a/dlls/urlmon/mk.c +++ b/dlls/urlmon/mk.c @@ -28,12 +28,16 @@ #include "urlmon_main.h" #include "wine/debug.h" +#include "wine/unicode.h" WINE_DEFAULT_DEBUG_CHANNEL(urlmon); typedef struct { const IInternetProtocolVtbl *lpInternetProtocolVtbl; + LONG ref; + + IStream *stream; } MkProtocol; #define PROTOCOL_THIS(iface) DEFINE_THIS(MkProtocol, InternetProtocol, iface) @@ -81,6 +85,9 @@ static ULONG WINAPI MkProtocol_Release(IInternetProtocol *iface) TRACE("(%p) ref=%d\n", This, ref); if(!ref) { + if(This->stream) + IStream_Release(This->stream); + HeapFree(GetProcessHeap(), 0, This); URLMON_UnlockModule(); @@ -89,14 +96,112 @@ static ULONG WINAPI MkProtocol_Release(IInternetProtocol *iface) return ref; } +static HRESULT report_result(IInternetProtocolSink *sink, HRESULT hres, DWORD dwError) +{ + IInternetProtocolSink_ReportResult(sink, hres, dwError, NULL); + return hres; +} + static HRESULT WINAPI MkProtocol_Start(IInternetProtocol *iface, LPCWSTR szUrl, IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo, DWORD grfPI, DWORD dwReserved) { MkProtocol *This = PROTOCOL_THIS(iface); - FIXME("(%p)->(%s %p %p %08x %d)\n", This, debugstr_w(szUrl), pOIProtSink, + IParseDisplayName *pdn; + IMoniker *mon; + LPWSTR mime, progid, display_name; + LPCWSTR ptr, ptr2; + BINDINFO bindinfo; + STATSTG statstg; + DWORD bindf=0, eaten=0, len; + CLSID clsid; + HRESULT hres; + + static const WCHAR wszMK[] = {'m','k',':'}; + + TRACE("(%p)->(%s %p %p %08x %d)\n", This, debugstr_w(szUrl), pOIProtSink, pOIBindInfo, grfPI, dwReserved); - return E_NOTIMPL; + + memset(&bindinfo, 0, sizeof(bindinfo)); + bindinfo.cbSize = sizeof(BINDINFO); + hres = IInternetBindInfo_GetBindInfo(pOIBindInfo, &bindf, &bindinfo); + if(FAILED(hres)) { + WARN("GetBindInfo failed: %08x\n", hres); + return hres; + } + + ReleaseBindInfo(&bindinfo); + + if(strncmpiW(szUrl, wszMK, sizeof(wszMK)/sizeof(WCHAR))) + return MK_E_SYNTAX; + + IInternetProtocolSink_ReportProgress(pOIProtSink, BINDSTATUS_DIRECTBIND, NULL); + IInternetProtocolSink_ReportProgress(pOIProtSink, BINDSTATUS_SENDINGREQUEST, NULL); + + hres = FindMimeFromData(NULL, szUrl, NULL, 0, NULL, 0, &mime, 0); + if(SUCCEEDED(hres)) { + IInternetProtocolSink_ReportProgress(pOIProtSink, BINDSTATUS_MIMETYPEAVAILABLE, mime); + CoTaskMemFree(mime); + } + + ptr2 = szUrl + sizeof(wszMK)/sizeof(WCHAR); + if(*ptr2 != '@') + return report_result(pOIProtSink, INET_E_RESOURCE_NOT_FOUND, ERROR_INVALID_PARAMETER); + ptr2++; + + ptr = strchrW(ptr2, ':'); + if(!ptr) + return report_result(pOIProtSink, INET_E_RESOURCE_NOT_FOUND, ERROR_INVALID_PARAMETER); + + progid = HeapAlloc(GetProcessHeap(), 0, (ptr-ptr2+1)*sizeof(WCHAR)); + memcpy(progid, ptr2, (ptr-ptr2)*sizeof(WCHAR)); + progid[ptr-ptr2] = 0; + hres = CLSIDFromProgID(progid, &clsid); + HeapFree(GetProcessHeap(), 0, progid); + if(FAILED(hres)) + return report_result(pOIProtSink, INET_E_RESOURCE_NOT_FOUND, ERROR_INVALID_PARAMETER); + + hres = CoCreateInstance(&clsid, NULL, CLSCTX_INPROC_SERVER|CLSCTX_INPROC_HANDLER, + &IID_IParseDisplayName, (void**)&pdn); + if(FAILED(hres)) { + WARN("Could not create object %s\n", debugstr_guid(&clsid)); + return report_result(pOIProtSink, hres, ERROR_INVALID_PARAMETER); + } + + len = strlenW(--ptr2); + display_name = HeapAlloc(GetProcessHeap(), 0, (len+1)*sizeof(WCHAR)); + memcpy(display_name, ptr2, (len+1)*sizeof(WCHAR)); + hres = IParseDisplayName_ParseDisplayName(pdn, NULL /* FIXME */, display_name, &eaten, &mon); + HeapFree(GetProcessHeap(), 0, display_name); + IParseDisplayName_Release(pdn); + if(FAILED(hres)) { + WARN("ParseDisplayName failed: %08x\n", hres); + return report_result(pOIProtSink, hres, ERROR_INVALID_PARAMETER); + } + + if(This->stream) { + IStream_Release(This->stream); + This->stream = NULL; + } + + hres = IMoniker_BindToStorage(mon, NULL /* FIXME */, NULL, &IID_IStream, (void**)&This->stream); + IMoniker_Release(mon); + if(FAILED(hres)) { + WARN("BindToStorage failed: %08x\n", hres); + return report_result(pOIProtSink, hres, ERROR_INVALID_PARAMETER); + } + + hres = IStream_Stat(This->stream, &statstg, STATFLAG_NONAME); + if(FAILED(hres)) { + WARN("Stat failed: %08x\n", hres); + return report_result(pOIProtSink, hres, ERROR_INVALID_PARAMETER); + } + + IInternetProtocolSink_ReportData(pOIProtSink, + BSCF_FIRSTDATANOTIFICATION | BSCF_LASTDATANOTIFICATION, + statstg.cbSize.u.LowPart, statstg.cbSize.u.LowPart); + + return report_result(pOIProtSink, S_OK, ERROR_SUCCESS); } static HRESULT WINAPI MkProtocol_Continue(IInternetProtocol *iface, PROTOCOLDATA *pProtocolData) @@ -117,8 +222,10 @@ static HRESULT WINAPI MkProtocol_Abort(IInternetProtocol *iface, HRESULT hrReaso static HRESULT WINAPI MkProtocol_Terminate(IInternetProtocol *iface, DWORD dwOptions) { MkProtocol *This = PROTOCOL_THIS(iface); - FIXME("(%p)->(%08x)\n", This, dwOptions); - return E_NOTIMPL; + + TRACE("(%p)->(%08x)\n", This, dwOptions); + + return S_OK; } static HRESULT WINAPI MkProtocol_Suspend(IInternetProtocol *iface) @@ -139,8 +246,13 @@ static HRESULT WINAPI MkProtocol_Read(IInternetProtocol *iface, void *pv, ULONG cb, ULONG *pcbRead) { MkProtocol *This = PROTOCOL_THIS(iface); - FIXME("(%p)->(%p %u %p)\n", This, pv, cb, pcbRead); - return E_NOTIMPL; + + TRACE("(%p)->(%p %u %p)\n", This, pv, cb, pcbRead); + + if(!This->stream) + return E_FAIL; + + return IStream_Read(This->stream, pv, cb, pcbRead); } static HRESULT WINAPI MkProtocol_Seek(IInternetProtocol *iface, LARGE_INTEGER dlibMove, @@ -154,15 +266,19 @@ static HRESULT WINAPI MkProtocol_Seek(IInternetProtocol *iface, LARGE_INTEGER dl static HRESULT WINAPI MkProtocol_LockRequest(IInternetProtocol *iface, DWORD dwOptions) { MkProtocol *This = PROTOCOL_THIS(iface); - FIXME("(%p)->(%08x)\n", This, dwOptions); - return E_NOTIMPL; + + TRACE("(%p)->(%08x)\n", This, dwOptions); + + return S_OK; } static HRESULT WINAPI MkProtocol_UnlockRequest(IInternetProtocol *iface) { MkProtocol *This = PROTOCOL_THIS(iface); - FIXME("(%p)\n", This); - return E_NOTIMPL; + + TRACE("(%p)\n", This); + + return S_OK; } #undef PROTOCOL_THIS @@ -195,6 +311,7 @@ HRESULT MkProtocol_Construct(IUnknown *pUnkOuter, LPVOID *ppobj) ret->lpInternetProtocolVtbl = &MkProtocolVtbl; ret->ref = 1; + ret->stream = NULL; /* NOTE: * Native returns NULL ppobj and S_OK in CreateInstance if called with IID_IUnknown riid. diff --git a/dlls/urlmon/tests/protocol.c b/dlls/urlmon/tests/protocol.c index 221df9c41f7..4551a89c59e 100644 --- a/dlls/urlmon/tests/protocol.c +++ b/dlls/urlmon/tests/protocol.c @@ -86,7 +86,8 @@ static DWORD bindf = 0; static enum { FILE_TEST, - HTTP_TEST + HTTP_TEST, + MK_TEST } tested_protocol; static HRESULT WINAPI HttpNegotiate_QueryInterface(IHttpNegotiate2 *iface, REFIID riid, void **ppv) @@ -926,6 +927,12 @@ static void test_mk_protocol(void) IUnknown *unk; HRESULT hres; + static const WCHAR wrong_url1[] = {'t','e','s','t',':','@','M','S','I','T','S','t','o','r','e', + ':',':','/','t','e','s','t','.','h','t','m','l',0}; + static const WCHAR wrong_url2[] = {'m','k',':','/','t','e','s','t','.','h','t','m','l',0}; + + tested_protocol = MK_TEST; + hres = CoGetClassObject(&CLSID_MkProtocol, CLSCTX_INPROC_SERVER, NULL, &IID_IUnknown, (void**)&unk); ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres); @@ -946,6 +953,27 @@ static void test_mk_protocol(void) IClassFactory_Release(factory); ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres); + SET_EXPECT(GetBindInfo); + hres = IInternetProtocol_Start(protocol, wrong_url1, &protocol_sink, &bind_info, 0, 0); + ok(hres == MK_E_SYNTAX, "Start failed: %08x, expected MK_E_SYNTAX\n", hres); + CHECK_CALLED(GetBindInfo); + + SET_EXPECT(GetBindInfo); + SET_EXPECT(ReportProgress_DIRECTBIND); + SET_EXPECT(ReportProgress_SENDINGREQUEST); + SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE); + SET_EXPECT(ReportResult); + expect_hrResult = INET_E_RESOURCE_NOT_FOUND; + + hres = IInternetProtocol_Start(protocol, wrong_url2, &protocol_sink, &bind_info, 0, 0); + ok(hres == INET_E_RESOURCE_NOT_FOUND, "Start failed: %08x, expected INET_E_RESOURCE_NOT_FOUND\n", hres); + + CHECK_CALLED(GetBindInfo); + CHECK_CALLED(ReportProgress_DIRECTBIND); + CHECK_CALLED(ReportProgress_SENDINGREQUEST); + CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE); + CHECK_CALLED(ReportResult); + IInternetProtocol_Release(protocol); }