[源码解析] TensorFlow 分布式环境(3)— Worker 静态逻辑

[源码解析] TensorFlow 分布式环境(3)— Worker 静态逻辑

文章目录

在具体介绍 TensorFlow 分布式的各种 Strategy 之前,我们首先需要看看分布式的基础:分布式环境。只有把基础打扎实了,才能在以后的分析工作之中最大程度的扫清障碍,事半功倍。本篇介绍 Worker(一系列相关概念) 的静态架构。

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 “TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Syst

[翻译] TensorFlow 分布式之论文篇 “Implementation of Control Flow in TensorFlow”

[源码解析] TensorFlow 分布式环境(1) — 总体架构

[源码解析] TensorFlow 分布式环境(2)—Master 静态逻辑

1. 继承关系

1.1 角色概念

TensorFlow Worker 类是执行计算的实体,其主要功能是:

  • 接收 Master的请求。
  • 管理 WorkerSession。
  • 处理已注册的子图,例如根据自己节点上的设备将子图拆分两次。
    [En]

    deal with the registered subgraph, such as splitting the subgraph twice according to the equipment on your own node.*

  • 在每个设备上运行注册的子图。
    [En]

    run the registered subgraph on each device.*

  • 支持 worker-to-worker 的张量传输等等。具体如何处理依据 worker 和 worker 的位置关系来决定,比如 CPU 和 GPU 之间使用 cudaMemcpyAsync,本地 GPU 之间通过 DMA,远端 worker 通过 gRPC 或者 RDMA。
  • 执行完毕之后,从计算图的终止节点 sink 中取出结果。

可以参见 protobuf/worker_service.proto 以了解关于每个方法的更多细节。

1.2 接口

对于 WorkerService 的访问是通过 WorkerInterface 完成的。WorkerInterface 是 worker 的接口类,其是与 TensorFlow Worker service 交互的接口,主要是:

  • 定义了一些异步虚函数,比如 CreateWorkerSessionAsync,派生类将实现它们,这些虚函数和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一对应,也和 Protobuf 的配置一一对应。
  • 定义了一些同步函数,比如 CreateWorkerSession,其会通过类似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 来调用到具体异步虚函数。

1.3 WorkerInterface 派生类

如下图所示,WorkerInterface 有三种实现。

  • Worker : 这个类可以被子类化,以便为不同的传输机制提供特定方法的专门实现。例如,GrpcWorker 专门实现了 RecvTensorAsync() 方法,以支持更有效的 gRPC 数据结构来处理大型二进制数据。
  • GrpcWorker : 从 Worker 再次派生,是本地模式下的 Worker 角色。如果 Master/Worker 都是在本地,则可以直接调用,不需要 RPC 的网络传输。
  • GrpcRemoteWorker :分布式模式下,Worker 位于远端,本地需要使用 GrpcRemoteWorker 来访问远端 Worker。
  • GrpcRemoteWorker 是 gRPC 客户端,其通过 stub 来访问远端 Worker 之上的 GrpcWorkerService 服务。
  • GrpcWorkerService 实现了 WorkerService 定义的所有接口,但是实际业务是转发给本地 GrpcWorker 完成。

具体示例如下:

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 1 Worker 逻辑关系

; 2. GrpcRemoteWorker

GrpcRemoteWorker 相当于是远端 Worker 的一个本地代理。

  • 本地 Master 将计算图进行分区,然后依据分区是不在本地还是远端,分别调用本地 Worker 或者 GrpcRemoteWorker 来执行分区的子计算图。
  • 本地 GrpcRemoteWorker 生成是在 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc 的GetOrCreateWorker 之中。
  • GrpcRemoteWorker 会通过 IssueRequest 向远端发送 grpc 请求。
  • 远程 GrpcWorkerService 守护进程收到请求后,调用本地 Worker 处理请求,完成后返回结果。

2.1 定义

具体 GrpcRemoteWorker 代码如下,我们省略了部分代码,比如 DeleteWorkerSessionAsync 方法的实现等。

class GrpcRemoteWorker : public WorkerInterface {
 public:
  explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                            ::grpc::CompletionQueue* completion_queue,
                            thread::ThreadPool* callback_threadpool,
                            WorkerCacheLogger* logger, const string& target)
      : channel_(std::move(channel)),
        stub_(channel_),
        cq_(completion_queue),
        callback_threadpool_(callback_threadpool),
        getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
        createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
        deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
        registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
        deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
        rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
        cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
        cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
        recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
        recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
        logging_(Method(GrpcWorkerMethod::kLogging)),
        tracing_(Method(GrpcWorkerMethod::kTracing)),
        completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
        instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
        getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
        markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
        logger_(logger),
        target_(target) {}

  ~GrpcRemoteWorker() override {}

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

  void RegisterGraphAsync(const RegisterGraphRequest* request,
                          RegisterGraphResponse* response,
                          StatusCallback done) override {
    IssueRequest(request, response, registergraph_, std::move(done));
  }

  void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                     RunGraphResponse* response, StatusCallback done) override {
    IssueRequest(request, response, rungraph_, std::move(done), call_opts);
  }
  void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
                     MutableRunGraphResponseWrapper* response,
                     StatusCallback done) override {
    IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
                 rungraph_, std::move(done), call_opts);
  }

 private:

  void IssueRequest(const protobuf::Message* request,
                    protobuf::Message* response, const ::grpc::string& method,
                    StatusCallback done, CallOptions* call_opts = nullptr,
                    bool fail_fast = true) {
    new RPCState<protobuf::Message>(
        &stub_, cq_, method, *request, response, std::move(done), call_opts,
        callback_threadpool_, MaxRetries(), fail_fast, &target_);
  }

  void IssueRequest(const protobuf::Message* request, TensorResponse* response,
                    const ::grpc::string& method, StatusCallback done,
                    CallOptions* call_opts = nullptr) {
    new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
                                 std::move(done), call_opts,
                                 callback_threadpool_, MaxRetries(),
                                 true, &target_);
  }

  const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }

  const int64_t MaxRetries() {
    int64_t max_retries = -1;
    TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
    return max_retries;
  }

  SharedGrpcChannelPtr channel_;
  ::grpc::GenericStub stub_;
  ::grpc::CompletionQueue* cq_;
  thread::ThreadPool* callback_threadpool_;

  const ::grpc::string getstatus_;
  const ::grpc::string createworkersession_;
  const ::grpc::string deleteworkersession_;
  const ::grpc::string registergraph_;
  const ::grpc::string deregistergraph_;
  const ::grpc::string rungraph_;
  const ::grpc::string cleanupgraph_;
  const ::grpc::string cleanupall_;
  const ::grpc::string recvtensor_;
  const ::grpc::string recvbuf_;
  const ::grpc::string logging_;
  const ::grpc::string tracing_;
  const ::grpc::string completegroup_;
  const ::grpc::string instancesource_;
  const ::grpc::string getstepsequence_;
  const ::grpc::string markrecvfinished_;

  WorkerCacheLogger* logger_;
  const string target_;

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};

2.2 生成

生成代码如下:

WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                     ::grpc::CompletionQueue* completion_queue,
                                     thread::ThreadPool* callback_threadpool,
                                     WorkerCacheLogger* logger,
                                     const string& target) {
  return new GrpcRemoteWorker(std::move(channel), completion_queue,
                              callback_threadpool, logger, target);
}

具体调用是在缓存之中,代码位于:tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc,其会依据参数决定生成何种 Worker。

WorkerInterface* GetOrCreateWorker(const string& target) override {
  if (target == local_target_) {
    return local_worker_;
  } else {
    SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
    if (!channel) {
      return nullptr;
    }
    size_t index = AssignWorkerToThread(target);
    return NewGrpcRemoteWorker(
        channel, worker_env_->GetCompletionQueue(index),
        worker_env_->GetThreadPool(), &logger_, target);
  }
}

2.3 发送请求

我们接下看看如何发送请求。CreateWorkerSessionAsync 实际发送的就是 createworkersession_ 这个字符串对应的请求。

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

IssueRequest 在上面定义之中有, 重新列出如下,可以看到调用的是 method 这个远端方法,对于我们这里就是 createworkersession_。

void IssueRequest(const protobuf::Message* request,
                  protobuf::Message* response, const ::grpc::string& method,
                  StatusCallback done, CallOptions* call_opts = nullptr,
                  bool fail_fast = true) {
  new RPCState<protobuf::Message>(
      &stub_, cq_, method, *request, response, std::move(done), call_opts,
      callback_threadpool_, MaxRetries(), fail_fast, &target_);
}

createworkersession_ 是在构建函数之中配置。

explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                          ::grpc::CompletionQueue* completion_queue,
                          thread::ThreadPool* callback_threadpool,
                          WorkerCacheLogger* logger, const string& target)
    : channel_(std::move(channel)),
      createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),

GrpcWorkerMethodName 定义在 tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc 之中,这里是具体的字符串,也就是远端 GrpcWorker 的方法名字,可以看到,CreateWorkerSessionAsync 实际上调用的是 “/tensorflow.WorkerService/CreateWorkerSession”。


enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }

  LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
  return "invalid id";
}

3. Worker Service

WorkerService是一个 gRPC 服务,其定义了一个 TensorFlow 服务。WorkerService 代表MasterService在一组本地设备上执行数据流图。 一个 WorkerService 会跟踪多个 “注册的计算图”。每个注册图是客户计算图的一个子图,只对应那些应该在这个工作者上执行的节点(以及使用 RecvTensor 方法进行进程间通信之中所需的任何额外节点)。

Master 会依据 ClusterSpec 内容在集群之中寻找其他的 Server 实例,找到之后把这些 Server 实例作为 Worker 角色。Master 接着把子图分发给这些 Worker 节点,然后安排这些 Worker 完成具体子图的计算过程。Worker 之间如果存在数据依赖,则通过进程间通信进行交互。无论是 Master 调用 Worker,还是 Worker 之间互相访问,都要遵循 WorkerService 定义的接口规范。WorkerService 的所有接口定义在 worker_service.proto 文件中。

service WorkerService {

  rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);

  rpc CreateWorkerSession(CreateWorkerSessionRequest)
      returns (CreateWorkerSessionResponse);

  rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
      returns (DeleteWorkerSessionResponse);

  rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);

  rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);

  rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);

  rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);

  rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);

  rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {

  }

  rpc Logging(LoggingRequest) returns (LoggingResponse);

  rpc Tracing(TracingRequest) returns (TracingResponse);

  rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {}

  rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);

  rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);

  rpc CompleteInstance(CompleteInstanceRequest)
      returns (CompleteInstanceResponse);
}

3.3.1 WorkerInterface

与 MasterService 类似,对于 WorkerService 的访问是通过 WorkerInterface 完成的。WorkerInterface 是 worker 的接口类,其是与 TensorFlow Worker service 交互的接口,主要是:

  • 定义了一些异步虚函数,比如 CreateWorkerSessionAsync,派生类将实现它们,这些虚函数和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一对应,也和 Protobuf 的配置一一对应。
  • 定义了一些同步函数,比如 CreateWorkerSession,其会通过类似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 的方法来调用到具体异步虚函数。

我们首先列出其异步接口,如下所示。

[En]

We first list its asynchronous interface as follows.


class WorkerInterface {
 public:
  virtual void GetStatusAsync(CallOptions* opts,
                              const GetStatusRequest* request,
                              GetStatusResponse* response, bool fail_fast,
                              StatusCallback done) = 0;

  virtual void CreateWorkerSessionAsync(
      const CreateWorkerSessionRequest* request,
      CreateWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void DeleteWorkerSessionAsync(
      CallOptions* opts, const DeleteWorkerSessionRequest* request,
      DeleteWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
                                  RegisterGraphResponse* response,
                                  StatusCallback done) = 0;

  virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
                                    DeregisterGraphResponse* response,
                                    StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
                             MutableRunGraphResponseWrapper* response,
                             StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
                             RunGraphResponse* response, StatusCallback done) {
    RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
    MutableRunGraphResponseWrapper* wrapped_response =
        new NonOwnedProtoRunGraphResponse(response);
    RunGraphAsync(opts, wrapped_request, wrapped_response,
                  [wrapped_request, wrapped_response,
                   done = std::move(done)](const Status& s) {
                    done(s);
                    delete wrapped_request;
                    delete wrapped_response;
                  });
  }

  virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
                                 CleanupGraphResponse* response,
                                 StatusCallback done) = 0;

  virtual void CleanupAllAsync(const CleanupAllRequest* request,
                               CleanupAllResponse* response,
                               StatusCallback done) = 0;

  virtual void RecvTensorAsync(CallOptions* opts,
                               const RecvTensorRequest* request,
                               TensorResponse* response,
                               StatusCallback done) = 0;

  virtual void LoggingAsync(const LoggingRequest* request,
                            LoggingResponse* response, StatusCallback done) = 0;

  virtual void TracingAsync(const TracingRequest* request,
                            TracingResponse* response, StatusCallback done) = 0;

  virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                            RecvBufResponse* response, StatusCallback done) = 0;

  virtual void CompleteGroupAsync(CallOptions* opts,
                                  const CompleteGroupRequest* request,
                                  CompleteGroupResponse* response,
                                  StatusCallback done) = 0;

  virtual void CompleteInstanceAsync(CallOptions* ops,
                                     const CompleteInstanceRequest* request,
                                     CompleteInstanceResponse* response,
                                     StatusCallback done) = 0;

  virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
                                    GetStepSequenceResponse* response,
                                    StatusCallback done) = 0;
}

WorkerInterface 也提供给了同步接口,这样 Master 或者 Worker 就可以像调用本地函数一样调用远端 WorkerService 的方法。同步接口是在异步接口之上实现的,通过使用 CallAndWait 适配器来完成对异步的封装。 另外,为了避免外部代码非法删除 WorkerInterface 实例,也做了一些限制,比如其析构函数是 protected,让 WorkerCacheInterface 成为友元,并且由 WorkerCacheInterface::ReleaseWorker 负责删除 WorkerInterface 实例。下面是同步接口和一些基础函数,成员变量。


class WorkerInterface {
 public:

  virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
    return new MutableProtoRunGraphRequest;
  }

  virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
    return new OwnedProtoRunGraphResponse;
  }

  Status GetStatus(const GetStatusRequest* request,
                   GetStatusResponse* response) {
    Status ret;
    Notification n;
    GetStatusAsync(nullptr, request, response, true,
                   [&ret, &n](const Status& s) {
                     ret = s;
                     n.Notify();
                   });
    n.WaitForNotification();
    return ret;
  }

  Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
                             CreateWorkerSessionResponse* response) {
    return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
  }

  Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
                             DeleteWorkerSessionResponse* response) {
    return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
                                  response);
  }

  Status RegisterGraph(const RegisterGraphRequest* request,
                       RegisterGraphResponse* response) {
    return CallAndWait(&ME::RegisterGraphAsync, request, response);
  }

  Status DeregisterGraph(const DeregisterGraphRequest* request,
                         DeregisterGraphResponse* response) {
    return CallAndWait(&ME::DeregisterGraphAsync, request, response);
  }

  Status CleanupGraph(const CleanupGraphRequest* request,
                      CleanupGraphResponse* response) {
    return CallAndWait(&ME::CleanupGraphAsync, request, response);
  }

  Status CleanupAll(const CleanupAllRequest* request,
                    CleanupAllResponse* response) {
    return CallAndWait(&ME::CleanupAllAsync, request, response);
  }

  Status Logging(const LoggingRequest* request, LoggingResponse* response) {
    return CallAndWait(&ME::LoggingAsync, request, response);
  }

  Status Tracing(const TracingRequest* request, TracingResponse* response) {
    return CallAndWait(&ME::TracingAsync, request, response);
  }

  Status GetStepSequence(const GetStepSequenceRequest* request,
                         GetStepSequenceResponse* response) {
    return CallAndWait(&ME::GetStepSequenceAsync, request, response);
  }

 protected:

  virtual ~WorkerInterface() {}
  friend class WorkerCacheInterface;

  RunGraphResponse* get_proto_from_wrapper(
      MutableRunGraphResponseWrapper* wrapper) {
    return wrapper->get_proto();
  }

 private:
  typedef WorkerInterface ME;

  template <typename Method, typename Req, typename Resp>
  Status CallAndWait(Method func, const Req* req, Resp* resp) {
    Status ret;
    Notification n;
    (this->*func)(req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify();
    });
    n.WaitForNotification();
    return ret;
  }

  template <typename Method, typename Req, typename Resp>
  Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
    CallOptions call_opts;
    Status ret;
    Notification n;
    (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify();
    });
    n.WaitForNotification();
    return ret;
  }
};

3.3.2 概念梳理

WorkerService 接口之中牵扯到众多概念,我们需要仔细梳理一下。

前面提到了,Client 和 Master 之间是通过 session_handle / MasterSession 对 来进行合作,Master 和 Worker 之间就是通过 MasterSession 和 WorkerSession 来完成合作的,MasterSession 会统一管理多个隶属的 WorkerSession。这里需要理清楚几个概念之间的关系:

  • session_handle :目的是为了让 MasterSession 统一管理其下面的多个 WorkerSession。与 MasterSession 一一对应,在创建 MasterSession 时候生成。通过 CreateSessionResponse 返回给 Client,通过 CreateWorkerSessionRequest 发送给 Worker,这样从 Client 到 Master,再到 Worker 这一条链路就是由 session_handle 唯一标示。
  • graph_handle :注册子图时候,由 GraphMgr::Register 生成,通过 RegisterGraphResponse 返回给 Master。子图就被该 graph_handle 所标识。在集群内部则是 (session_handle, graph_handle) 二元组来唯一标识某一个子图。
  • step_id :因为 Master 会让多个 Worker 并发执行计算,所以会广播通知大家执行 RunGraph,为了区别不同的 Step,Master 为每次 RunStep 生成全局唯一的标识 step_id,通过 RunGraphRequest 消息把 step_id 携带给 Worker。

我们梳理一下 graph_handle。GraphMgr::Register 之中会生成 graph_handle。

Status GraphMgr::Register(
    const string& handle, const GraphDef& gdef, WorkerSession* session,
    const GraphOptions& graph_options, const DebugOptions& debug_options,
    const ConfigProto& config_proto, int64_t collective_graph_key,
    DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
  Item* item = new Item;
  Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                      config_proto, collective_graph_key, cluster_flr, item);

  {
    mutex_lock l(mu_);
    *graph_handle =
        strings::Printf("%016llx", static_cast<long long>(++next_id_));
    item->handle = *graph_handle;
    CHECK(table_.insert({*graph_handle, item}).second);
  }
  return Status::OK();
}

RegisterGraphResponse 之中会返回 graph_handle 给 Master。

message RegisterGraphResponse {

  string graph_handle = 1;
}

分割的子图里有 graph_handle。


struct Part {

  string name;

  std::unordered_map<string, string> feed_key;

  std::unordered_map<string, string> key_fetch;

  WorkerInterface* worker = nullptr;

  string graph_handle;

  Part() : feed_key(3), key_fetch(3) {}
};

注册返回时候会给子图设定 graph_handle。

Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
    const PartitionOptions& popts,
    std::unordered_map<string, GraphDef> graph_partitions) {
  partitions_.reserve(graph_partitions.size());
  Status s;
  for (auto& name_def : graph_partitions) {
    partitions_.emplace_back();
    Part* part = &partitions_.back();
    part->name = name_def.first;
    TrackFeedsAndFetches(part, name_def.second, popts);
    part->worker = worker_cache_->GetOrCreateWorker(part->name);
    if (part->worker == nullptr) {
      s = errors::NotFound("worker ", part->name);
      break;
    }
  }
  if (!s.ok()) {
    for (Part& part : partitions_) {
      worker_cache_->ReleaseWorker(part.name, part.worker);
      part.worker = nullptr;
    }
    return s;
  }
  struct Call {
    RegisterGraphRequest req;
    RegisterGraphResponse resp;
    Status status;
  };
  const int num = partitions_.size();
  gtl::InlinedVector<Call, 4> calls(num);
  BlockingCounter done(num);
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    Call* c = &calls[i];
    c->req.set_session_handle(session_handle_);
    c->req.set_create_worker_session_called(!should_deregister_);
    c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
    StripDefaultAttributes(*OpRegistry::Global(),
                           c->req.mutable_graph_def()->mutable_node());
    *c->req.mutable_config_proto() = session_opts_.config;
    *c->req.mutable_graph_options() = session_opts_.config.graph_options();
    *c->req.mutable_debug_options() =
        callable_opts_.run_options().debug_options();
    c->req.set_collective_graph_key(collective_graph_key_);

    auto cb = [c, &done](const Status& s) {
      c->status = s;
      done.DecrementCount();
    };
    part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
  }
  done.Wait();
  for (int i = 0; i < num; ++i) {
    Call* c = &calls[i];
    s.Update(c->status);
    partitions_[i].graph_handle = c->resp.graph_handle();
  }
  return s;
}

使用时候会用 graph_handle 来唯一确定一个子图。


void MasterSession::ReffedClientGraph::DeregisterPartitions() {
  struct Call {
    DeregisterGraphRequest req;
    DeregisterGraphResponse resp;
  };
  for (Part& part : partitions_) {

    if (!part.graph_handle.empty()) {
      Call* c = new Call;
      c->req.set_session_handle(session_handle_);
      c->req.set_create_worker_session_called(!should_deregister_);
      c->req.set_graph_handle(part.graph_handle);

      WorkerCacheInterface* worker_cache = worker_cache_;
      const string name = part.name;
      WorkerInterface* w = part.worker;
      CHECK_NOTNULL(w);
      auto cb = [worker_cache, c, name, w](const Status& s) {
         delete c;
        worker_cache->ReleaseWorker(name, w);
      };
      w->DeregisterGraphAsync(&c->req, &c->resp, cb);
    }
  }
}

3.3.4 WorkerInterface 派生类

如下图所示,WorkerInterface 有两种实现。

  • GrpcWorker : 本地模式下的Worker 角色,如果 Master/Worker都是在本地,则可以直接调用,不需要 RPC 的网络传输。
  • GrpcRemoteWorker :分布式模式下,Worker 位于远端,本地需要使用 GrpcRemoteWorker 来访问远端 Worker。
  • GrpcRemoteWorker 是 gRPC 客户端,其通过 stub 来访问远端 Worker 之上的 GrpcWorkerService 服务。
  • GrpcWorkerService 实现了 WorkerService 定义的所有接口,但是实际业务是转发给本地 GrpcWorker 完成。

具体示例如下:

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 1 WorkerInterface 派生类

; 3.3.5 使用

Server 初始化时候,用如下代码建立Worker Service。


  worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                  : NewGrpcWorker(&worker_env_, config);
  worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
                                         opts.worker_service_options)

具体就是返回 GrpcWorkerService。


std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
    GrpcWorker* worker, ::grpc::ServerBuilder* builder,
    GrpcWorkerServiceOptions options) {
  return std::unique_ptr<AsyncServiceInterface>(
      new GrpcWorkerService(worker, builder, options));
}

GrpcServer 之中,使用 worker_thread_ 线程来执行 GrpcWorkerService 的 HandleRPCsLoop 方法。

worker_thread_.reset(
    env_->StartThread(ThreadOptions(), "TF_worker_service",
                      [this] { worker_service_->HandleRPCsLoop(); }));

3.3.6 定义

GrpcWorkerService 定义如下,因为其需要作为守护进程处理传入的 gRPC 请求,所以在构造函数之中会建立若干线程,用来响应请求,然后在 HandleRPCsLoop 之中会启动这些线程,然后做 Join。

class GrpcWorkerService : public AsyncServiceInterface {
 public:
  GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
                    GrpcWorkerServiceOptions options)
      : is_shutdown_(false) {
    builder->RegisterService(&worker_service_);

    for (int i = 0; i < options.num_serving_threads; i++) {
      threads_.emplace_back(
          new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
                                      cache_.get(), &worker_service_));
    }
  }

  void HandleRPCsLoop() override {
    for (auto& worker_thread : threads_) {
      worker_thread->Start();
    }
    for (auto& worker_thread : threads_) {
      worker_thread->Join();
    }
  }

 private:
  grpc::WorkerService::AsyncService worker_service_;
  std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;

  std::unique_ptr<GrpcResponseCache> cache_;
  mutex service_shutdown_mu_;
  bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};

3.3.7 线程

具体循环和响应请求其实是在线程之中完成的,cq_ 则是 grpc 的完成队列。


class GrpcWorkerServiceThread {
 public:
  explicit GrpcWorkerServiceThread(
      GrpcWorker* worker, ::grpc::ServerBuilder* builder,
      std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
      grpc::WorkerService::AsyncService* worker_service)
      : worker_(worker),
        queue_depth_(queue_depth),
        cache_(cache),
        worker_service_(worker_service),
        is_shutdown_(false) {
    cq_ = builder->AddCompletionQueue();
  }

  void Start() {
    thread_.reset(
        worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
                                         [this]() { HandleRPCsLoop(); }));
  }
}

主循环

GrpcWorkerServiceThread::HandleRPCsLoop 是线程主循环,和 master service 类似。这里先准备好一些 gRPC 调用的等待队列,这些调用请求与后面的 GrpcWorkerMethod 一一对应,每个方法对应的处理过程的代码会在后面提到。


void GrpcWorkerServiceThread::HandleRPCsLoop() {

  SETUP_FOR_REQUEST(GetStatus, 1, false);
  SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
  SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
  SETUP_FOR_REQUEST(CleanupAll, 1, false);
  SETUP_FOR_REQUEST(RegisterGraph, 1, false);
  SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
  SETUP_FOR_REQUEST(Logging, 1, false);
  SETUP_FOR_REQUEST(Tracing, 1, false);
  SETUP_FOR_REQUEST(CompleteGroup, 10, true);
  SETUP_FOR_REQUEST(CompleteInstance, 10, true);
  SETUP_FOR_REQUEST(GetStepSequence, 10, true);
  SETUP_FOR_REQUEST(RecvBuf, 500, true);
  SETUP_FOR_REQUEST(RunGraph, 100, true);
  SETUP_FOR_REQUEST(CleanupGraph, 100, false);
  SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);

  for (int i = 0;
       i < gtl::FindWithDefault(
               queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
               1000);
       ++i) {
    EnqueueRecvTensorRequestRaw();
  }

  void* tag;
  bool ok;

  while (cq_->Next(&tag, &ok)) {
    UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
        static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
    CHECK(callback_tag);
    callback_tag->OnCompleted(this, ok);
  }
}

grpc request

对于 request 的处理与 master 类似。每个 request 会调用到一个业务 handler,如下面宏定义的 GrpcWorkerServiceThread::method##Handler。

#define ENQUEUE_REQUEST(method, supports_cancel)                             \
  do {                                                                       \
    mutex_lock l(shutdown_mu_);                                              \
    if (!is_shutdown_) {                                                     \
      Call::                              \
          EnqueueRequestForMethod(                                           \
              worker_service_, cq_.get(),                                    \
              static_cast(GrpcWorkerMethod::k##method),                 \
              &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
    }                                                                        \
  } while (0)

#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
  for (int i = 0;                                                              \
       i < gtl::FindWithDefault(queue_depth_,                                  \
                                static_cast(GrpcWorkerMethod::k##method), \
                                default_depth);                                \
       ++i) {                                                                  \
    ENQUEUE_REQUEST(method, supports_cancel);                                  \
  }

这里需要把每个 RPC 服务注册为异步服务,这使用 gRPC 自带的 AddMethod 接口和 MarkMethodAsync 接口来完成。

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i);
  }
}

Handler & 线程池

具体 Handler 是通过宏来配置的,具体如下,这里调用了 Call,其会依据配置来决定是否使用线程池 compute_pool->Schedule 来进行计算。这里就用到了 worker env 里面集成的模块。


#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

#undef HANDLE_CALL

消息&方法

GrpcWorkerMethod 定义了 worker 具体有哪些方法。


enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

具体这些消息名字对应哪些方法,就是由 GrpcWorkerMethodName 完成。

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }

  return "invalid id";
}

在 AsyncService 之中会调用 GrpcWorkerMethodName 完成给 grpc 注册。

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i);
  }
}

业务处理

具体业务处理则是调用了 Worker 完成的。

void GetStepSequenceHandler(
    WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
  Schedule([this, call]() {
    worker_->GetStepSequenceAsync(
        &call->request, &call->response, [call](const Status& s) {
          call->SendResponse(ToGrpcStatus(s));
        });
  });
  ENQUEUE_REQUEST(GetStepSequence, true);
}

目前从线程角度看,逻辑如下,这里假定有三个线程。Server 的 worker_thread_ 启动了 GrpcWorkerService::HandleRPCsLoop(),其作用就是启动两个 GrpcWorkerServiceThread,每个 GrpcWorkerServiceThread 在 GrpcWorkerServiceThread::HandleRPCsLoop 之中会响应 gRPC 请求,进行业务处理。这里需要注意,GrpcWorkerService 和 GrpcWorkerServiceThread 都有 HandleRPCsLoop 这个方法。

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 2 线程角度

3.3.8 业务逻辑

CreateWorkerSession

CreateWorkerSessionRequest 消息之中会传递 MasterSession对应的 session_handle,Worker 接收消息之后,生成一个 WorkerSession。在一个集群之内,当 MasterSession 建立 WorkerSession 时候,都会把自己对应的 session_handle 传过去,这样,WorkerSession 就可以通过 session_handle 知道自己属于哪个 MasterSession。MasterSession 实例也可以统一管理隶属于它的所有 WorkerSession。

GrpcWorker 通过 SessionMgr 来具体完成对 WorkerSession 的管理,既可以通过 master task name 来确定 WorkerSession,也可以通过 session_handle 来确定。

class SessionMgr {

  WorkerEnv* const worker_env_;
  std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
  std::shared_ptr<WorkerSession> legacy_session_;
  const WorkerCacheFactory worker_cache_factory_;

  std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);

  struct MasterAssociatedSession {
    const int64_t master_incarnation;
    const string session_handle;
  };

  std::unordered_multimap<string, MasterAssociatedSession>
      master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};

具体消息如下,注意,CreateWorkerSessionResponse 没有返回任何东西:

message CreateWorkerSessionRequest {

  string session_handle = 1;

  ServerDef server_def = 2;

  bool isolate_session_state = 3;

  repeated DeviceAttributes cluster_device_attributes = 4;

  string master_task = 5;

  int64 master_incarnation = 6;
}

message CreateWorkerSessionResponse {}

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 3 CreateWorkerSession

如前所述,GrpcWorker 这些消息都是用宏来生成的。

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

RegisterGraph

RegisterGraphRequest 消息会发送 MasterSession 对应的 session_handle,子图 graph_def。当 Worker 接收消息,完成子图注册/初始化后,会返回该子图的 graph_handle 给 Master。

对于每个会话,在 master 将每个节点放在一个设备上之后,它将整个图分割成许多子图。一个子图中的所有节点都在同一个 worker 中,但可能在该 worker 拥有的许多设备上(例如cpu0,加上gpu0、gpu1、…、gpu7)。在运行任何step之前,master 为 worker 注册了子图。成功的注册会返回一个图的句柄,以便在以后的 RunGraph请求中使用。


message RegisterGraphRequest {

  string session_handle = 1;

  bool create_worker_session_called = 6;

  GraphDef graph_def = 2;

  bool has_control_flow = 3 [deprecated = true];

  GraphOptions graph_options = 4;

  DebugOptions debug_options = 5;

  int64 collective_graph_key = 7;

  ConfigProto config_proto = 8;
}

message RegisterGraphResponse {

  string graph_handle = 1;
}

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 4 RegisterGraph

DeregisterGraph

当不再需要计算图时(例如,整个计算图图被重新调度,图内节点被重新编排),Master 会利用该图对应的 graph_handle来取消注册。在 Master 重启情况下,Worker 根据以 TTL 为基础的策略自动取消对应 graph_handle 的注册。


message DeregisterGraphRequest {

  string session_handle = 2;

  bool create_worker_session_called = 3;

  string graph_handle = 1;
}

message DeregisterGraphResponse {

}

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 5 DeregisterGraph

RunGraph

Master 用 RunGraphRequest 来执行在 graph_handle下注册的所有子图。

Master 会生成一个全局唯一的 step_id 来区分图计算的不同运行 step。子图之间可以使用 step_id 进行彼此通信(例如,发送/转发操作),以区分不同运行产生的张量。

RunGraphRequest 消息的 send 表示子图输入的张量,recv_key 指明子图输出的张量。RunGraphResponse 会返回 recv_key 对应的 Tensor 列表。

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 6 RunGraph


message ExecutorOpts {
  bool record_costs = 1;
  bool record_timeline = 3;
  bool record_partition_graphs = 4;
  bool report_tensor_allocations_upon_oom = 5;
}

message RunGraphRequest {

  string session_handle = 8;

  bool create_worker_session_called = 10;

  string graph_handle = 1;

  int64 step_id = 2;

  ExecutorOpts exec_opts = 5;

  repeated NamedTensorProto send = 3;
  repeated string recv_key = 4;

  bool is_partial = 6;

  bool is_last_partial_run = 7;

  bool store_errors_in_response_body = 9;

  int64 request_id = 11;

}

message RunGraphResponse {

  repeated NamedTensorProto recv = 1;

  StepStats step_stats = 2;
  CostGraphDef cost_graph = 3;
  repeated GraphDef partition_graph = 4;

  error.Code status_code = 5;
  string status_error_message = 6;
}

RecvTensor

在具体运行之中,两个 Worker 之间可能会交换数据,此时生产者只是把准备好的张量放入 rendezvous,消费者会主动发起 RecvTensorRequest 请求,RecvTensorRequest 里面 step_id 标识是哪次 step,rendezvous_key 标识要接收张量的通道(channel)。

一个 RecvTensor 请求从通道中获取一个张量,也可以通过多个 RecvTensor 请求在同一个通道中发送和接收多个张量。最终生产者的张量会通过 RecvTensorResponse 返回给消费者。

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

图 7 RecvTensor


message RecvTensorRequest {

  int64 step_id = 1;

  string rendezvous_key = 2;

  bool dma_ok = 3;

  DeviceLocality client_locality = 4;

  DeviceLocality server_locality = 5;

  google.protobuf.Any transport_options = 6;

  int64 request_id = 7;
}

message RecvTensorResponse {

  TensorProto tensor = 1;

  bool is_dead = 2;

  int64 send_start_micros = 3;

  google.protobuf.Any transport_options = 4;

  bool require_ack = 5;
}

4. Worker

Worker 类主要是提供了 WorkerEnv 和 PartialRunMgr,其可以被子类化,以便为不同的传输机制提供特定方法的专门实现。例如,GrpcWorker 专门实现了 RecvTensorAsync 方法,以支持更有效的 gRPC 数据结构来处理大型二进制数据。

class Worker : public WorkerInterface {
 protected:
  WorkerEnv* const env_;
  RecentRequestIds recent_request_ids_;

 private:
  PartialRunMgr partial_run_mgr_;

  CancellationManager cancellation_manager_;

  TF_DISALLOW_COPY_AND_ASSIGN(Worker);
};

让我们举一个方法来看看,具体的其他方法我们稍后遇到时会说。

[En]

Let’s cite a method to see, specific other methods we will say when we encounter later.

void Worker::CleanupAllAsync(const CleanupAllRequest* request,
                             CleanupAllResponse* response,
                             StatusCallback done) {
  std::vector<string> containers;
  for (const auto& c : request->container()) containers.push_back(c);
  env_->device_mgr->ClearContainers(containers);
  done(Status::OK());
}

5. GrpcWorker

GrpcWorker 是 GrpcRemoteWorker 对应的远端 Worker。也是 GrpcWorkerService 调用的对象,其实现了业务逻辑。其定义如下,我们可以看到其实现了几个方法。

class GrpcWorker : public Worker {
 public:
  GrpcWorker(WorkerEnv* env, const ConfigProto& config);

  virtual void GrpcRecvTensorAsync(CallOptions* opts,
                                   const RecvTensorRequest* request,
                                   ::grpc::ByteBuffer* response,
                                   StatusCallback done);

  void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
                    StatusCallback done) override;

  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                    RecvBufResponse* response, StatusCallback done) override;

  void CleanupGraphAsync(const CleanupGraphRequest* request,
                         CleanupGraphResponse* response,
                         StatusCallback done) override;

  WorkerEnv* env();

  void EnableResponseCache();

  void RemoveCacheEntryForId(int64 request_id);

 private:
  std::unique_ptr<GrpcResponseCache> response_cache_;
  const int32 recv_buf_max_chunk_;
};

至此,Worker 的静态结构我们已经介绍完毕,具体 Worker 功能我们将在后文 Session 部分进行具体介绍。

0xEE 个人信息

★的S关于生活和技术的★思想

[En]

★ ‘s thoughts on Life and Technology ★

微信公众账号: 罗西的思考

如果你想及时获取撰写文章的信息,或者想看看个人推荐的技术资料,请关注我们。

[En]

If you want to get the message of writing articles in time, or if you want to see the technical materials recommended by individuals, please follow us.

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

; 0xFF 参考

TensorFlow Internals

TensorFlow架构与设计:概述

TensorFlow内核剖析

TensorFlow架构与设计:OP本质论

[译] TensorFlow 白皮书

2017TensorFlow开发者峰会

https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

『深度长文』Tensorflow代码解析(五)

什么是in-graph replication和between-graph replication?

[腾讯机智] TensorFlow源码解析(1): 创建会话

05tensorflow分布式会话

第八节,配置分布式TensorFlow

TensorFlow 分布式(Distributed TensorFlow)

tensorflow源码解析之distributed_runtime

Distributed TensorFlow: A Gentle Introduction

一文说清楚Tensorflow分布式训练必备知识

TensorFlow中的Placement启发式算法模块——Placer

TensorFlow的图切割模块——Graph Partitioner

TensorFlow中的通信机制——Rendezvous(一)本地传输

TensorFlow分布式采坑记

TensorFlow技术内幕(九):模型优化之分布式执行

[Tensorflow架构流程]](https://www.cnblogs.com/shouhuxianjian/p/9107539.html)

Original: https://blog.csdn.net/weixin_47364682/article/details/123643814
Author: 罗西的思考
Title: [源码解析] TensorFlow 分布式环境(3)— Worker 静态逻辑

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/511138/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球