SignUsbToken.cs
//
// This code is part of Document Solutions for PDF demos.
// Copyright (c) MESCIUS inc. All rights reserved.
//
using System;
using System.IO;
using System.Drawing;
using System.Text;
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;

using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Asn1;
using Org.BouncyCastle.Asn1.X509;

using Net.Pkcs11Interop.Common;
using Net.Pkcs11Interop.HighLevelAPI;

using GrapeCity.Documents.Pdf;
using GrapeCity.Documents.Pdf.Security;
using GrapeCity.Documents.Pdf.AcroForms;
using GrapeCity.Documents.Text;


namespace DsPdfWeb.Demos
{
    // This sample shows how to sign an existing PDF file that contains
    // an empty signature field with a certificate that is stored
    // on a USB Token for DSC (Digital Signature Certificate).
    //
    // The sample includes a ready to use utility class Pkcs11SignatureGenerator
    // that implements the GrapeCity.Documents.Pdf.IPkcs7SignatureGenerator interface,
    // and can be used to sign PDFs with certificates stored on a USB Token for DSC.
    //  
    // Please note that when run directly off the DsPdf demo site,
    // this sample will NOT sign the PDF, as it passes dummy library name/parameters.
    // to the Pkcs11SignatureGenerator's ctor. You will need to download the sample
    // and provide your own library and parameters for the sample code to actually sign a PDF.
    //
    public class SignUsbToken
    {
        public int CreatePDF(Stream stream)
        {
            var doc = new GcPdfDocument();
            using var s = File.OpenRead(Path.Combine("Resources", "PDFs", "SignUsbToken.pdf"));
            doc.Load(s);

            try
            {
                // This WILL NOT WORK due to dummy USB Token for DSC library name/parameters.
                // Supply valid library name and parameters to actually sign the PDF.
                using var sg = new Pkcs11SignatureGenerator(
                    "path-to-dummy-PKCS11.dll",
                    null,
                    null,
                    Encoding.ASCII.GetBytes("12345"),
                    null,
                    null,
                    OID.HashAlgorithms.SHA512);

                var sp = new SignatureProperties()
                {
                    SignatureBuilder = new Pkcs7SignatureBuilder()
                    {
                        SignatureGenerator = sg,
                        CertificateChain = new X509Certificate2[] { sg.Certificate },
                    },
                    SignatureField = doc.AcroForm.Fields[0]
                };
                doc.Sign(sp, stream);
            }
            catch (Exception)
            {
                var page = doc.Pages[0];
                var r = doc.AcroForm.Fields[0].Widgets[0].Rect;
                Common.Util.AddNote(
                    "Signing failed because a dummy USB Token for DSC library name and dummy parameters were used.\n" +
                    "Provide a valid USB Token library and correct parameters to sign the PDF.",
                    page,
                    new RectangleF(r.Left, r.Bottom + 24, page.Size.Width - r.Left * 2, 0));
                doc.Save(stream);
            }

            // Done.
            return doc.Pages.Count;
        }
    }

    /// <summary>
    /// Implements <see cref="IPkcs7SignatureGenerator"/> 
    /// and allows generating a digital signature using a certificate
    /// stored on a USB Token for DSC (Digital Signature Certificate).
    /// 
    /// The <b>Pkcs11Interop</b> NuGet package is used to manage the token.
    /// </summary>
    public class Pkcs11SignatureGenerator : IPkcs7SignatureGenerator, IDisposable
    {
        public static readonly Pkcs11InteropFactories Factories = new Pkcs11InteropFactories();

        private IPkcs11Library _pkcs11Library;
        private ISlot _slot;
        private ISession _session;
        private IObjectHandle _privateKeyHandle;
        private string _ckaLabel;
        private byte[] _ckaId;
        private X509Certificate2 _certificate;
        private OID _hashAlgorithm;
        private IDigest _hashDigest;

        /// <summary>
        /// Initializes a new instance of the <see cref="Pkcs11SignatureGenerator"/> class.
        /// The <paramref name="tokenSerial"/> and <paramref name="tokenLabel"/> parameters are used
        /// to select the token to use if several tokens are connected. 
        /// If only one token is connected then both these parameters can be <see langword="null"/>.
        /// The <paramref name="ckaLabel"/> and <paramref name="ckaId"/> parameters are used
        /// to select the private key to use if the token contains multiple keys.
        /// If the token contains a single private key then both these parameters can be <see langword="null"/>.
        /// </summary>
        /// <param name="libraryPath">Path to the unmanaged PCKS#11 library to use.</param>
        /// <param name="tokenSerial">Serial number of the token (smartcard) that contains the signing key.</param>
        /// <param name="tokenLabel">Label of the token (smartcard) that contains the signing key.</param>
        /// <param name="pin">PIN for the token (smartcard).</param>
        /// <param name="ckaLabel">Label (value of CKA_LABEL attribute) of the private key used for signing.</param>
        /// <param name="ckaId">Hex encoded string with identifier (value of CKA_ID attribute) of the private key used for signing.</param>
        /// <param name="hashAlgorihtm">The hash algorithm to use when creating the signature.</param>
        public Pkcs11SignatureGenerator(string libraryPath, string tokenSerial, string tokenLabel, byte[] pin, string ckaLabel, byte[] ckaId, OID hashAlgorihtm)
        {
            Init(libraryPath, tokenSerial, tokenLabel, pin, ckaLabel, ckaId, hashAlgorihtm);
        }

        ~Pkcs11SignatureGenerator()
        {
            Dispose(false);
        }

        /// <summary>
        /// Releases resources used by this object.
        /// </summary>
        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        protected void Dispose(bool disposing)
        {
            if (disposing)
            {
                if (_certificate != null)
                {
                    _certificate.Dispose();
                    _certificate = null;
                }
                if (_session != null)
                {
                    _session.Dispose();
                    _session = null;
                }
                if (_pkcs11Library != null)
                {
                    _pkcs11Library.Dispose();
                    _pkcs11Library = null;
                }
            }
        }

        private ISlot FindSlot(string tokenSerial, string tokenLabel)
        {
            if (string.IsNullOrEmpty(tokenSerial) && string.IsNullOrEmpty(tokenLabel))
                throw new ArgumentException("Token serial and/or label has to be specified");

            List<ISlot> slots = _pkcs11Library.GetSlotList(SlotsType.WithTokenPresent);
            foreach (ISlot slot in slots)
            {
                ITokenInfo tokenInfo = null;

                try
                {
                    tokenInfo = slot.GetTokenInfo();
                }
                catch (Pkcs11Exception ex)
                {
                    if (ex.RV != CKR.CKR_TOKEN_NOT_RECOGNIZED && ex.RV != CKR.CKR_TOKEN_NOT_PRESENT)
                        throw;
                }

                if (tokenInfo == null)
                    continue;

                if (!string.IsNullOrEmpty(tokenSerial))
                    if (String.Compare(tokenSerial, tokenInfo.SerialNumber, StringComparison.InvariantCultureIgnoreCase) != 0)
                        continue;

                if (!string.IsNullOrEmpty(tokenLabel))
                    if (String.Compare(tokenLabel, tokenInfo.Label, StringComparison.InvariantCultureIgnoreCase) != 0)
                        continue;

                return slot;
            }
            return null;
        }

        protected void Init(string libraryPath, string tokenSerial, string tokenLabel, byte[] pin, string ckaLabel, byte[] ckaId, OID hashAlgorihtm)
        {
            if (string.IsNullOrEmpty(libraryPath))
                throw new ArgumentNullException($"Invalid library path \"{libraryPath}\".");

            try
            {
                _pkcs11Library = Factories.Pkcs11LibraryFactory.LoadPkcs11Library(Factories, libraryPath, AppType.SingleThreaded);

                _slot = FindSlot(tokenSerial, tokenLabel);
                if (_slot == null)
                    throw new Exception(string.Format("Token with serial \"{0}\" and label \"{1}\" was not found", tokenSerial, tokenLabel));

                _session = _slot.OpenSession(SessionType.ReadOnly);
                _session.Login(CKU.CKU_USER, pin);

                // initialize _privateKeyHandle and _certificate
                using (ISession session = _slot.OpenSession(SessionType.ReadOnly))
                {
                    // private key
                    List<IObjectAttribute> searchTemplate = new List<IObjectAttribute>();
                    searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_CLASS, CKO.CKO_PRIVATE_KEY));
                    searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_KEY_TYPE, CKK.CKK_RSA));
                    if (!string.IsNullOrEmpty(ckaLabel))
                        searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_LABEL, ckaLabel));
                    if (ckaId != null)
                        searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_ID, ckaId));

                    List<IObjectHandle> foundObjects = session.FindAllObjects(searchTemplate);
                    if (foundObjects.Count < 1)
                        throw new Exception(string.Format("Private key with label \"{0}\" and id \"{1}\" was not found.", ckaLabel, (ckaId == null) ? null : ConvertUtils.BytesToHexString(ckaId)));
                    else if (foundObjects.Count > 1)
                        throw new Exception(string.Format("More than one private key with label \"{0}\" and id \"{1}\" was found.", ckaLabel, (ckaId == null) ? null : ConvertUtils.BytesToHexString(ckaId)));
                    _privateKeyHandle = foundObjects[0];

                    // certificate
                    searchTemplate.Clear();
                    searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_CLASS, CKO.CKO_CERTIFICATE));
                    if (!string.IsNullOrEmpty(ckaLabel))
                        searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_LABEL, ckaLabel));
                    if (ckaId != null)
                        searchTemplate.Add(Factories.ObjectAttributeFactory.Create(CKA.CKA_ID, ckaId));

                    foundObjects = session.FindAllObjects(searchTemplate);
                    if (foundObjects.Count == 1)
                    {
                        List<CKA> attributes = new List<CKA>();
                        attributes.Add(CKA.CKA_VALUE);

                        List<IObjectAttribute> certificateAttributes = session.GetAttributeValue(foundObjects[0], attributes);
                        byte[] certificateData = certificateAttributes[0].GetValueAsByteArray();
                        _certificate = new X509Certificate2(certificateData);
                    }
                }

                _ckaLabel = ckaLabel;
                _ckaId = ckaId;
                if (hashAlgorihtm == OID.HashAlgorithms.SHA1)
                    _hashDigest = new Sha1Digest();
                else if (hashAlgorihtm == OID.HashAlgorithms.SHA256)
                    _hashDigest = new Sha256Digest();
                else if (hashAlgorihtm == OID.HashAlgorithms.SHA384)
                    _hashDigest = new Sha384Digest();
                else if (hashAlgorihtm == OID.HashAlgorithms.SHA512)
                    _hashDigest = new Sha512Digest();
                else
                    throw new Exception($"Unsupported HASH algorithm {hashAlgorihtm}.");
                _hashAlgorithm = hashAlgorihtm;
            }
            catch
            {
                if (_session != null)
                {
                    _session.Dispose();
                    _session = null;
                }
                if (_pkcs11Library != null)
                {
                    _pkcs11Library.Dispose();
                    _pkcs11Library = null;
                }

                throw;
            }
        }

        /// <summary>
        /// Gets the <see cref="Sys.X509Certificate2"/> object found on the token
        /// with same <b>ckaLabel</b> and <b>ckaId</b> as a private key.
        /// </summary>
        public X509Certificate2 Certificate
        {
            get { return _certificate; }
        }

        /// <summary>
        /// Gets the ID of the hash algorithm.
        /// </summary>
        public OID HashAlgorithm => _hashAlgorithm;

        /// <summary>
        /// Gets the ID of the encryption algorithm.
        /// </summary>
        public OID DigestEncryptionAlgorithm => OID.EncryptionAlgorithms.RSA;

        /// <summary>
        /// Signs data.
        /// </summary>
        /// <param name="input">The input data to sign.</param>
        /// <returns>The signed data.</returns>
        public byte[] SignData(byte[] input)
        {
            using (ISession session = _slot.OpenSession(SessionType.ReadOnly))
            using (IMechanism mechanism = Factories.MechanismFactory.Create(CKM.CKM_RSA_PKCS))
            {
                byte[] hash = new byte[_hashDigest.GetDigestSize()];
                _hashDigest.Reset();
                _hashDigest.BlockUpdate(input, 0, input.Length);
                _hashDigest.DoFinal(hash, 0);

                var derObjectIdentifier = new DerObjectIdentifier(_hashAlgorithm.ID);
                var algorithmIdentifier = new AlgorithmIdentifier(derObjectIdentifier, null);
                var digestInfo = new DigestInfo(algorithmIdentifier, hash);
                byte[] digestInfoBytes = digestInfo.GetDerEncoded();

                return session.Sign(mechanism, _privateKeyHandle, digestInfoBytes);
            }
        }
    }
}