Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Commit

Permalink
Don't rely on the built-in marshaller during activation. (#28073)
Browse files Browse the repository at this point in the history
Relying on the built-in marshaller leverages the Class interface approach
which doesn't work for some interface types (e.g. interfaces inheriting
from IDispatch).

This approach is wrong regardless of why given that COM dictates the
returned value must be properly cast the specific interface vtable.

Updated tests so they would have found this issue.
  • Loading branch information
AaronRobinsonMSFT authored Aug 11, 2020
1 parent 4cf9136 commit d7c967a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public interface IClassFactory
void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
ref Guid riid,
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject);
out IntPtr ppvObject);

void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
}
Expand All @@ -51,7 +51,7 @@ internal interface IClassFactory2 : IClassFactory
new void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
ref Guid riid,
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject);
out IntPtr ppvObject);

new void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);

Expand All @@ -66,7 +66,7 @@ void CreateInstanceLic(
[MarshalAs(UnmanagedType.Interface)] object? pUnkReserved,
ref Guid riid,
[MarshalAs(UnmanagedType.BStr)] string bstrKey,
[MarshalAs(UnmanagedType.Interface)] out object ppvObject);
out IntPtr ppvObject);
}

[StructLayout(LayoutKind.Sequential)]
Expand Down Expand Up @@ -424,27 +424,31 @@ public static Type GetValidatedInterfaceType(Type classType, ref Guid riid, obje
throw new InvalidCastException();
}

public static void ValidateObjectIsMarshallableAsInterface(object obj, Type interfaceType)
public static IntPtr GetObjectAsInterface(object obj, Type interfaceType)
{
// If the requested "interface type" is type object then return
// because type object is always marshallable.
// If the requested "interface type" is type object then return as IUnknown
if (interfaceType == typeof(object))
{
return;
return Marshal.GetIUnknownForObject(obj);
}

Debug.Assert(interfaceType.IsInterface);

// The intent of this call is to validate the interface can be
// The intent of this call is to get AND validate the interface can be
// marshalled to native code. An exception will be thrown if the
// type is unable to be marshalled to native code.
// Scenarios where this is relevant:
// - Interfaces that use Generics
// - Interfaces that define implementation
IntPtr ptr = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);
IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);

// Decrement the above 'Marshal.GetComInterfaceForObject()'
Marshal.Release(ptr);
if (interfaceMaybe == IntPtr.Zero)
{
// E_NOINTERFACE
throw new InvalidCastException();
}

return interfaceMaybe;
}

public static object CreateAggregatedObject(object pUnkOuter, object comObject)
Expand All @@ -467,17 +471,17 @@ public static object CreateAggregatedObject(object pUnkOuter, object comObject)
public void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
ref Guid riid,
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject)
out IntPtr ppvObject)
{
Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter);

ppvObject = Activator.CreateInstance(_classType)!;
object obj = Activator.CreateInstance(_classType)!;
if (pUnkOuter != null)
{
ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject);
obj = BasicClassFactory.CreateAggregatedObject(pUnkOuter, obj);
}

BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType);
ppvObject = BasicClassFactory.GetObjectAsInterface(obj, interfaceType);
}

public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock)
Expand All @@ -502,7 +506,7 @@ public LicenseClassFactory(Guid clsid, Type classType)
public void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
ref Guid riid,
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject)
out IntPtr ppvObject)
{
CreateInstanceInner(pUnkOuter, ref riid, key: null, isDesignTime: true, out ppvObject);
}
Expand Down Expand Up @@ -535,7 +539,7 @@ public void CreateInstanceLic(
[MarshalAs(UnmanagedType.Interface)] object? pUnkReserved,
ref Guid riid,
[MarshalAs(UnmanagedType.BStr)] string bstrKey,
[MarshalAs(UnmanagedType.Interface)] out object ppvObject)
out IntPtr ppvObject)
{
Debug.Assert(pUnkReserved == null);
CreateInstanceInner(pUnkOuter, ref riid, bstrKey, isDesignTime: false, out ppvObject);
Expand All @@ -546,17 +550,17 @@ private void CreateInstanceInner(
ref Guid riid,
string? key,
bool isDesignTime,
out object ppvObject)
out IntPtr ppvObject)
{
Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter);

ppvObject = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime);
object obj = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime);
if (pUnkOuter != null)
{
ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject);
obj = BasicClassFactory.CreateAggregatedObject(pUnkOuter, obj);
}

BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType);
ppvObject = BasicClassFactory.GetObjectAsInterface(obj, interfaceType);
}
}
}
Expand Down
28 changes: 18 additions & 10 deletions tests/src/Interop/COM/Activator/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ static void ValidateAssemblyIsolation()

var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);

object svr;
factory.CreateInstance(null, ref iid, out svr);
typeCFromAssemblyA = (Type)((IGetTypeFromC)svr).GetTypeFromC();
IntPtr svrRaw;
factory.CreateInstance(null, ref iid, out svrRaw);
var svr = (IGetTypeFromC)Marshal.GetObjectForIUnknown(svrRaw);
Marshal.Release(svrRaw);
typeCFromAssemblyA = (Type)svr.GetTypeFromC();
}

using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(
Expand All @@ -128,9 +130,11 @@ static void ValidateAssemblyIsolation()

var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);

object svr;
factory.CreateInstance(null, ref iid, out svr);
typeCFromAssemblyB = (Type)((IGetTypeFromC)svr).GetTypeFromC();
IntPtr svrRaw;
factory.CreateInstance(null, ref iid, out svrRaw);
var svr = (IGetTypeFromC)Marshal.GetObjectForIUnknown(svrRaw);
Marshal.Release(svrRaw);
typeCFromAssemblyB = (Type)svr.GetTypeFromC();
}

Assert.AreNotEqual(typeCFromAssemblyA, typeCFromAssemblyB, "Types should be from different AssemblyLoadContexts");
Expand Down Expand Up @@ -172,8 +176,10 @@ static void ValidateUserDefinedRegistrationCallbacks()

var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);

object svr;
factory.CreateInstance(null, ref iid, out svr);
IntPtr svrRaw;
factory.CreateInstance(null, ref iid, out svrRaw);
var svr = Marshal.GetObjectForIUnknown(svrRaw);
Marshal.Release(svrRaw);

var inst = (IValidateRegistrationCallbacks)svr;
Assert.IsFalse(inst.DidRegister());
Expand Down Expand Up @@ -209,8 +215,10 @@ static void ValidateUserDefinedRegistrationCallbacks()

var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);

object svr;
factory.CreateInstance(null, ref iid, out svr);
IntPtr svrRaw;
factory.CreateInstance(null, ref iid, out svrRaw);
var svr = Marshal.GetObjectForIUnknown(svrRaw);
Marshal.Release(svrRaw);

var inst = (IValidateRegistrationCallbacks)svr;
cxt.InterfaceId = Guid.Empty;
Expand Down

0 comments on commit d7c967a

Please sign in to comment.