众所周知,async/await
在C#中是一个非常好用的特性,我们可以利用它来做到异步逻辑同步化,假设我们有这样一个耗时函数:
public Task<string> GetString(string url)
{
var httpClient = new HttpClient();
return httpClient.GetStringAsync(url);
}
现在我们需要对一个List<string>
中所有的url都执行此操作,并且将结果汇总为一个string[]
,该怎么做?
一般来说可能会有人提出这种方案:
var results = new List<string>();
foreach (var url in list)
{
results.Add(await GetString(url));
}
但是实际上你会发现这种做法是低效的,因为每个单独的GetString
操作都与其他的无关,也就是说这段代码应该是可以并行化的,但是在上面的例子中由于await
的存在,使得整个操作都串行化了,不信的话看如下代码的执行结果:
public static async Task<string> SomeLongRunningTask()
{
await Task.Delay(1000);
return null;
}
public static async Task Main(string[] args)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
foreach (var task in Enumerable.Range(1, 10).Select(_ => SomeLongRunningTask()))
{
await task;
}
Console.WriteLine(stopWatch.Elapsed.Seconds);
}
没有其他问题的话,程序输出的结果应该是10
,也就是这10个task
完全串行执行,我们该如何并行化他们呢?
C#的异步库为我们专门准备了一个方法,叫做Task.WhenAll
,这个方法可以并行的等待一系列Task
,并返回代表整个等待任务的Task
,函数签名是public static Task<TResult[]> WhenAll<TResult>(IEnumerable<Task<TResult>> tasks)
,msdn上面这么介绍这个函数:
Creates a task that will complete when all of the supplied tasks have completed.
用法也很简单,稍微改一下上面的例子:
var stopWatch = new Stopwatch();
stopWatch.Start();
await Task.WhenAll(Enumerable.Range(1, 10).Select(_ => SomeLongRunningTask()));
Console.WriteLine(stopWatch.Elapsed.Seconds);
不出意外的话,这次的控制台上应该会显示1
你可能会奇怪这是怎么做到的,C#的内部实现使用了InternalWait
这种比较暴力的手段,当然,我们今天使用另外一种方法来实现,那就是使用TaskCompletionSource<T>
。
什么是TaskCompletionSource<T>
呢?MSDN这样描述它:
Represents the producer side of a
Task<TResult>
unbound to a delegate, providing access to the consumer side through the Task property.
看上去是不是非常深奥?其实就是提供一个可以自由控制的Task
,你可以自己控制在什么时间设置这个Task
的结果(可以是正常结束也可以是异常,还可以是取消),知道这点以后,我们就可以通过"控制一个TaskCompletionSource<T>
使其在所有Task<T>
完成的时候完成"来实现我们自己的WhenAll
方法:
public static Task<T[]> WhenAll<T>(this Task<T>[] tasks, CancellationToken cancellationToken)
{
// 创建一个TaskCompletionSource对象,用于在所有的tasks都完成后完成
var taskCompletionSource = new TaskCompletionSource<T[]>();
// 存储结果的数组
var results = new T[tasks.Length];
// 剩余未完成的Task
var left = results.Length;
// 在CancellationToken被取消时将taskCompletionSource的内部Task的状态同样设置为Canceled以广播取消事件
cancellationToken.Register(() => taskCompletionSource.TrySetCanceled());
for (var i = 0; i < tasks.Length; i++)
{
var j = i;
// 给每个Task注册回调
tasks[i].ContinueWith(t =>
{
if (t.IsFaulted)
{
// 一旦任意一个Task失败,那么WhenAll将会立即处于失败状态(广播Faulted事件)
taskCompletionSource.TrySetException(t.Exception!);
}
else if (t.IsCanceled)
{
// 一旦任意一个Task被取消,那么WhenAll将会立即处于取消状态(广播Canceled事件)
taskCompletionSource.TrySetCanceled();
}
else // Task成功完成
{
// 否则给结果数组中对应的位置赋值
results[j] = t.Result;
// 使用CAS操作检测剩余的未完成Task是否已为0,如果是那么说明所有Task都已经执行完毕
if (0 == Interlocked.Decrement(ref left))
{
// 说明等待任务成功完成,这时设置taskCompletionSource的结果,结束状态流转
taskCompletionSource.SetResult(results);
}
}
}, cancellationToken, TaskContinuationOptions.ExecuteSynchronously /* 保证该Continuation在执行Task的相同线程上执行以避免奇怪的并发问题 */ , TaskScheduler.Default /* 参数列表要求,使用默认的任务调度器就好了 */);
}
return taskCompletionSource.Task;
}
现在再运行一次
var stopwatch = new Stopwatch();
stopwatch.Start();
await WhenAll(Enumerable.Range(1, 10).Select(async _ =>
{
await Task.Delay(1000);
return "dummy";
}).ToArray());
Console.WriteLine(stopwatch.Elapsed.Seconds);
没有问题的话,控制台上应该已经显示1
了
Comments NOTHING