CSharp/AzureStorageCmdletBase.cs

namespace AzureStorageCmdlets
{
    using System;
    using System.Linq;
    using System.Collections.Generic;
    using System.Text;
    using System.Management.Automation;
    using System.Collections.Specialized;
    using System.Net;
    using System.Collections;
    using System.Web;
    using System.Threading;
 
    public class AzureStorageAccountData
    {
        public string AccountName
        {
            get;
            set;
        }
 
        public string Key
        {
            get;
            set;
        }
    }
    public abstract class AzureStorageCmdletBase: PSCmdlet
    {
        public abstract string Endpoint
        {
            get;
        }
 
        [Parameter()]
        public string StorageAccount
        {
            get; set;
        }
 
        [Parameter()]
        public string StorageKey
        {
            get; set;
        }
 
        protected virtual bool IsTableStorage
        {
            get
            {
                return false;
            }
        }
 
        #region REST HTTP Request Helper Methods
 
        // Construct and issue a REST request and return the response.
 
        public virtual HttpWebRequest CreateRESTRequest(string method, string resource, string requestBody, SortedList<string, string> headers, string ifMatch, string md5)
        {
            byte[] byteArray = null;
            DateTime now = DateTime.UtcNow;
            string uri = Endpoint + resource;
 
            HttpWebRequest request = HttpWebRequest.Create(uri) as HttpWebRequest;
            request.Method = method;
            request.ContentLength = 0;
            request.Headers.Add("x-ms-date", now.ToString("R", System.Globalization.CultureInfo.InvariantCulture));
            request.Headers.Add("x-ms-version", "2011-08-18");
 
            if (IsTableStorage)
            {
                request.ContentType = "application/atom+xml";
 
                request.Headers.Add("DataServiceVersion", "2.0;NetFx");
                request.Headers.Add("MaxDataServiceVersion", "2.0;NetFx");
            }
 
            if (headers != null)
            {
                foreach (KeyValuePair<string, string> header in headers)
                {
                    request.Headers.Add(header.Key, header.Value);
                }
            }
 
            if (!String.IsNullOrEmpty(requestBody))
            {
                request.Headers.Add("Accept-Charset", "UTF-8");
 
                byteArray = Encoding.UTF8.GetBytes(requestBody);
                request.ContentLength = byteArray.Length;
            }
 
            request.Headers.Add("Authorization", AuthorizationHeader(method, now, request, ifMatch, md5));
 
            if (!String.IsNullOrEmpty(requestBody))
            {
                request.GetRequestStream().Write(byteArray, 0, byteArray.Length);
            }
 
            return request;
        }
 
 
        // Generate an authorization header.
 
        public virtual string AuthorizationHeader(string method, DateTime now, HttpWebRequest request, string ifMatch, string md5)
        {
            string MessageSignature;
 
            if (IsTableStorage)
            {
                MessageSignature = String.Format("{0}\n\n{1}\n{2}\n{3}",
                    method,
                    "application/atom+xml",
                    now.ToString("R", System.Globalization.CultureInfo.InvariantCulture),
                    GetCanonicalizedResource(request.RequestUri, StorageAccount)
                    );
            }
            else
            {
                MessageSignature = String.Format("{0}\n\n\n{1}\n{5}\n\n\n\n{2}\n\n\n\n{3}{4}",
                    method,
                    (method == "GET" || method == "HEAD") ? String.Empty : request.ContentLength.ToString(),
                    ifMatch,
                    GetCanonicalizedHeaders(request),
                    GetCanonicalizedResource(request.RequestUri, StorageAccount),
                    md5
                    );
            }
            byte[] SignatureBytes = System.Text.Encoding.UTF8.GetBytes(MessageSignature);
            System.Security.Cryptography.HMACSHA256 SHA256 = new System.Security.Cryptography.HMACSHA256(Convert.FromBase64String(StorageKey));
            String AuthorizationHeader = "SharedKey " + StorageAccount + ":" + Convert.ToBase64String(SHA256.ComputeHash(SignatureBytes));
            return AuthorizationHeader;
        }
 
        // Get canonicalized headers.
 
        public static string GetCanonicalizedHeaders(HttpWebRequest request)
        {
            ArrayList headerNameList = new ArrayList();
            StringBuilder sb = new StringBuilder();
            foreach (string headerName in request.Headers.Keys)
            {
                if (headerName.ToLowerInvariant().StartsWith("x-ms-", StringComparison.Ordinal))
                {
                    headerNameList.Add(headerName.ToLowerInvariant());
                }
            }
            headerNameList.Sort();
            foreach (string headerName in headerNameList)
            {
                StringBuilder builder = new StringBuilder(headerName);
                string separator = ":";
                foreach (string headerValue in GetHeaderValues(request.Headers, headerName))
                {
                    string trimmedValue = headerValue.Replace("\r\n", String.Empty);
                    builder.Append(separator);
                    builder.Append(trimmedValue);
                    separator = ",";
                }
                sb.Append(builder.ToString());
                sb.Append("\n");
            }
            return sb.ToString();
        }
 
        // Get header values.
 
        public static ArrayList GetHeaderValues(NameValueCollection headers, string headerName)
        {
            ArrayList list = new ArrayList();
            string[] values = headers.GetValues(headerName);
            if (values != null)
            {
                foreach (string str in values)
                {
                    list.Add(str.TrimStart(null));
                }
            }
            return list;
        }
 
        // Get canonicalized resource.
 
        public virtual string GetCanonicalizedResource(Uri address, string accountName)
        {
            StringBuilder str = new StringBuilder();
            StringBuilder builder = new StringBuilder("/");
            builder.Append(accountName);
            builder.Append(address.AbsolutePath);
            str.Append(builder.ToString());
            NameValueCollection values2 = new NameValueCollection();
            if (!IsTableStorage)
            {
                NameValueCollection values = HttpUtility.ParseQueryString(address.Query);
                foreach (string str2 in values.Keys)
                {
                    ArrayList list = new ArrayList(values.GetValues(str2));
                    list.Sort();
                    StringBuilder builder2 = new StringBuilder();
                    foreach (object obj2 in list)
                    {
                        if (builder2.Length > 0)
                        {
                            builder2.Append(",");
                        }
                        builder2.Append(obj2.ToString());
                    }
                    values2.Add((str2 == null) ? str2 : str2.ToLowerInvariant(), builder2.ToString());
                }
            }
            ArrayList list2 = new ArrayList(values2.AllKeys);
            list2.Sort();
            foreach (string str3 in list2)
            {
                StringBuilder builder3 = new StringBuilder(string.Empty);
                builder3.Append(str3);
                builder3.Append(":");
                builder3.Append(values2[str3]);
                str.Append("\n");
                str.Append(builder3.ToString());
            }
            return str.ToString();
        }
 
        #endregion
 
        #region Retry Delegate
 
        public delegate T RetryDelegate<T>();
        public delegate void RetryDelegate();
 
        const int retryCount = 3;
        const int retryIntervalMS = 200;
 
        // Retry delegate with default retry settings.
 
        public static T Retry<T>(RetryDelegate<T> del)
        {
            return Retry<T>(del, retryCount, retryIntervalMS);
        }
 
        // Retry delegate.
 
        public static T Retry<T>(RetryDelegate<T> del, int numberOfRetries, int msPause)
        {
            int counter = 0;
        RetryLabel:
 
            try
            {
                counter++;
                return del.Invoke();
            }
            catch (Exception ex)
            {
                if (counter > numberOfRetries)
                {
                    throw ex;
                }
                else
                {
                    if (msPause > 0)
                    {
                        Thread.Sleep(msPause);
                    }
                    goto RetryLabel;
                }
            }
        }
 
 
        // Retry delegate with default retry settings.
 
        public static bool Retry(RetryDelegate del)
        {
            return Retry(del, retryCount, retryIntervalMS);
        }
 
 
        public static bool Retry(RetryDelegate del, int numberOfRetries, int msPause)
        {
            int counter = 0;
 
        RetryLabel:
            try
            {
                counter++;
                del.Invoke();
                return true;
            }
            catch (Exception ex)
            {
                if (counter > numberOfRetries)
                {
                    throw ex;
                }
                else
                {
                    if (msPause > 0)
                    {
                        Thread.Sleep(msPause);
                    }
                    goto RetryLabel;
                }
            }
        }
 
        protected void WriteWebError(WebException ex, string extraString)
        {
            WriteError(
            new ErrorRecord(
                new InvalidOperationException(
                    ((ex.Response as HttpWebResponse).StatusCode.ToString()) + ' ' + extraString),
                "SetAzureTableCommand.WebError." + ((int)(ex.Response as HttpWebResponse).StatusCode).ToString(),
                ErrorCategory.InvalidOperation,
                this)
                );
        }
 
        protected override void ProcessRecord()
        {
            if (! (this.MyInvocation.MyCommand.Module.PrivateData is AzureStorageAccountData)) {
                this.MyInvocation.MyCommand.Module.PrivateData = new AzureStorageAccountData();
                AzureStorageAccountData accountData = new AzureStorageAccountData();
                accountData.AccountName = StorageAccount;
                accountData.Key = StorageKey;
            }
 
            if (!this.MyInvocation.BoundParameters.ContainsKey("StorageKey"))
            {
                StorageKey = (this.MyInvocation.MyCommand.Module.PrivateData as AzureStorageAccountData).Key;
            }
            else
            {
                (this.MyInvocation.MyCommand.Module.PrivateData as AzureStorageAccountData).Key = StorageKey;
            }
 
            if (!this.MyInvocation.BoundParameters.ContainsKey("StorageAccount"))
            {
                StorageAccount = (this.MyInvocation.MyCommand.Module.PrivateData as AzureStorageAccountData).AccountName;
            }
            else
            {
                (this.MyInvocation.MyCommand.Module.PrivateData as AzureStorageAccountData).AccountName = StorageAccount;
            }
 
 
            if (String.IsNullOrEmpty(StorageAccount) || String.IsNullOrEmpty(StorageKey))
            {
                WriteError(
                    new ErrorRecord(
                        new Exception("Must provide a StorageAccount and StorageKey"),
                            "MissingAzureStorageAccountOrKey",
                            ErrorCategory.InvalidOperation,
                            null)
                        );
            }
        }
         
 
        #endregion
    }
}