diff --git a/Credentials/Credential.cs b/Credentials/Credential.cs index f76cbf9..2edb597 100644 --- a/Credentials/Credential.cs +++ b/Credentials/Credential.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.DirectoryServices.AccountManagement; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.ConstrainedExecution; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; @@ -116,27 +117,51 @@ namespace NSspi } } - public string Name + public string Name { get { QueryNameAttribCarrier carrier = new QueryNameAttribCarrier(); - SecurityStatus status; + SecurityStatus status = SecurityStatus.InternalError; string name = null; + bool gotRef = false; - status = CredentialNativeMethods.QueryCredentialsAttribute_Name( - ref this.safeCredHandle.rawHandle, - CredentialQueryAttrib.Names, - ref carrier - ); - - if ( status == SecurityStatus.OK ) + RuntimeHelpers.PrepareConstrainedRegions(); + try { - name = Marshal.PtrToStringUni( carrier.Name ); - NativeMethods.FreeContextBuffer( carrier.Name ); + this.safeCredHandle.DangerousAddRef( ref gotRef ); } - else + catch( Exception ) + { + if( gotRef == true ) + { + this.safeCredHandle.DangerousRelease(); + gotRef = false; + } + throw; + } + finally + { + if( gotRef ) + { + status = CredentialNativeMethods.QueryCredentialsAttribute_Name( + ref this.safeCredHandle.rawHandle, + CredentialQueryAttrib.Names, + ref carrier + ); + + this.safeCredHandle.DangerousRelease(); + + if( status == SecurityStatus.OK && carrier.Name != IntPtr.Zero ) + { + name = Marshal.PtrToStringUni( carrier.Name ); + NativeMethods.FreeContextBuffer( carrier.Name ); + } + } + } + + if( status.IsError() ) { throw new SSPIException( "Failed to query credential name", status ); } diff --git a/Credentials/CredentialNativeMethods.cs b/Credentials/CredentialNativeMethods.cs index 8d848c2..e8ae888 100644 --- a/Credentials/CredentialNativeMethods.cs +++ b/Credentials/CredentialNativeMethods.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.ConstrainedExecution; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; @@ -63,6 +64,7 @@ namespace NSspi /// /// /// + [ReliabilityContract( Consistency.WillNotCorruptState, Cer.Success )] [DllImport( "Secur32.dll", EntryPoint = "QueryCredentialsAttributes", CharSet = CharSet.Unicode )] public static extern SecurityStatus QueryCredentialsAttribute_Name( ref RawSspiHandle credentialHandle, diff --git a/NativeMethods.cs b/NativeMethods.cs index db87fc7..091691c 100644 --- a/NativeMethods.cs +++ b/NativeMethods.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; +using System.Runtime.ConstrainedExecution; namespace NSspi { @@ -23,13 +24,9 @@ namespace NSspi _In_ PVOID pvContextBuffer ); */ - [DllImport( - "Secur32.dll", - EntryPoint = "FreeContextBuffer", - CallingConvention = CallingConvention.Winapi, - CharSet = CharSet.Unicode, - SetLastError = true - )] + + [ReliabilityContract( Consistency.WillNotCorruptState, Cer.Success)] + [DllImport( "Secur32.dll", EntryPoint = "FreeContextBuffer", CharSet = CharSet.Unicode )] public static extern SecurityStatus FreeContextBuffer( IntPtr buffer ); }