所周知,async/await在C#中是一个非常好用的特性,我们可以利用它来做到异步逻辑同步化,假设我们有这样一个耗时函数:

public Task<string> GetString(string url)
{
    var httpClient = new HttpClient();
    return httpClient.GetStringAsync(url);
}

现在我们需要对一个List<string>中所有的url都执行此操作,并且将结果汇总为一个string[],该怎么做?
一般来说可能会有人提出这种方案:

var results = new List<string>();
foreach (var url in list)
{
    results.Add(await GetString(url));
}

但是实际上你会发现这种做法是低效的,因为每个单独的GetString操作都与其他的无关,也就是说这段代码应该是可以并行化的,但是在上面的例子中由于await的存在,使得整个操作都串行化了,不信的话看如下代码的执行结果:

public static async Task<string> SomeLongRunningTask()
{
    await Task.Delay(1000);
    return null;
}

public static async Task Main(string[] args)
{
    var stopWatch = new Stopwatch();
    stopWatch.Start();
    foreach (var task in Enumerable.Range(1, 10).Select(_ => SomeLongRunningTask()))
    {
        await task;
    }

    Console.WriteLine(stopWatch.Elapsed.Seconds);
}

没有其他问题的话,程序输出的结果应该是10,也就是这10个task完全串行执行,我们该如何并行化他们呢?
C#的异步库为我们专门准备了一个方法,叫做Task.WhenAll,这个方法可以并行的等待一系列Task,并返回代表整个等待任务的Task,函数签名是public static Task<TResult[]> WhenAll<TResult>(IEnumerable<Task<TResult>> tasks),msdn上面这么介绍这个函数:

Creates a task that will complete when all of the supplied tasks have completed.

用法也很简单,稍微改一下上面的例子:

var stopWatch = new Stopwatch();
stopWatch.Start();
await Task.WhenAll(Enumerable.Range(1, 10).Select(_ => SomeLongRunningTask()));
Console.WriteLine(stopWatch.Elapsed.Seconds);

不出意外的话,这次的控制台上应该会显示1
你可能会奇怪这是怎么做到的,C#的内部实现使用了InternalWait这种比较暴力的手段,当然,我们今天使用另外一种方法来实现,那就是使用TaskCompletionSource<T>
什么是TaskCompletionSource<T>呢?MSDN这样描述它:

Represents the producer side of a Task<TResult> unbound to a delegate, providing access to the consumer side through the Task property.

看上去是不是非常深奥?其实就是提供一个可以自由控制的Task,你可以自己控制在什么时间设置这个Task的结果(可以是正常结束也可以是异常,还可以是取消),知道这点以后,我们就可以通过"控制一个TaskCompletionSource<T>使其在所有Task<T>完成的时候完成"来实现我们自己的WhenAll方法:

public static Task<T[]> WhenAll<T>(this Task<T>[] tasks, CancellationToken cancellationToken)
{
    // 创建一个TaskCompletionSource对象,用于在所有的tasks都完成后完成
    var taskCompletionSource = new TaskCompletionSource<T[]>();
    // 存储结果的数组
    var results = new T[tasks.Length];
    // 剩余未完成的Task
    var left = results.Length;
    // 在CancellationToken被取消时将taskCompletionSource的内部Task的状态同样设置为Canceled以广播取消事件
    cancellationToken.Register(() => taskCompletionSource.TrySetCanceled());
    for (var i = 0; i < tasks.Length; i++)
    {
        var j = i;
        // 给每个Task注册回调
        tasks[i].ContinueWith(t =>
        {
            if (t.IsFaulted)
            {
                // 一旦任意一个Task失败,那么WhenAll将会立即处于失败状态(广播Faulted事件)
                taskCompletionSource.TrySetException(t.Exception!);
            }
            else if (t.IsCanceled)
            {
                // 一旦任意一个Task被取消,那么WhenAll将会立即处于取消状态(广播Canceled事件)
                taskCompletionSource.TrySetCanceled();
            }
            else // Task成功完成
            {
                // 否则给结果数组中对应的位置赋值
                results[j] = t.Result;
                // 使用CAS操作检测剩余的未完成Task是否已为0,如果是那么说明所有Task都已经执行完毕
                if (0 == Interlocked.Decrement(ref left))
                {
                    // 说明等待任务成功完成,这时设置taskCompletionSource的结果,结束状态流转
                    taskCompletionSource.SetResult(results);
                }
            }
        }, cancellationToken, TaskContinuationOptions.ExecuteSynchronously /* 保证该Continuation在执行Task的相同线程上执行以避免奇怪的并发问题 */ , TaskScheduler.Default /* 参数列表要求,使用默认的任务调度器就好了 */);
    }
    return taskCompletionSource.Task;
}

现在再运行一次

var stopwatch = new Stopwatch();
stopwatch.Start();
await WhenAll(Enumerable.Range(1, 10).Select(async _ =>
{
    await Task.Delay(1000);
    return "dummy";
}).ToArray());
Console.WriteLine(stopwatch.Elapsed.Seconds);

没有问题的话,控制台上应该已经显示1

Jusqu'à ce que le mort nous sépare.
最后更新于 2020-10-17