spark-deep-learning:基于 Apache Spark 的分布式深度学习训练项目

Deep Learning Pipelines for Apache Spark

分支11Tags9
文件最后提交记录最后更新时间
4 年前
5 年前
4 年前
5 年前
5 年前
5 年前
4 年前
5 年前

Apache Spark 深度学习管道

构建状态 覆盖率

此仓库仅包含用于本地持续集成和API文档的HorovodRunner代码。若要使用HorovodRunner进行分布式训练,请使用Databricks机器学习运行时。详情请查阅Databricks文档中的HorovodRunner:使用Horovod的分布式深度学习训练

如需使用包含Spark深度学习管道API的早期版本,请访问Spark Packages页面

API 文档

类 sparkdl.HorovodRunner(*, np, driver_log_verbosity='all')

基类: object

HorovodRunner使用Horovod执行分布式深度学习训练任务。

在Databricks Runtime 5.0 ML及以上版本中,它作为分布式Spark作业启动Horovod任务。它通过管理集群设置和与Spark的集成,使得在Databricks上运行Horovod变得简单。查看Databricks文档以获取端到端示例和性能调优提示。

开源版仅在同一Python进程中本地运行该作业,适用于本地开发。

注意:Horovod是由Uber开发的分布式训练框架。

  • 参数:

    • np - 用于Horovod作业的并行进程数。 此参数仅在Databricks Runtime 5.0 ML及更高版本生效。 在开源版本中被忽略。 在Databricks上,每个进程会占用一个可用的任务槽, 对于GPU集群来说对应一个GPU,对于CPU集群则对应一个CPU核心。 接受以下值:

      • <0,将在驱动程序节点上启动-np 个子进程来运行Horovod(本地模式)。 训练标准输出和错误信息会显示在笔记本单元格输出中,如果单元格输出被截断,也可以在驱动程序日志中找到。这对于调试很有用,建议首先在这种模式下测试你的代码。但是请注意,在共享的Databricks集群上过度使用Spark驱动程序可能会造成问题。 注意,np < -1 只在Databricks Runtime 5.5 ML及更高版本支持。
      • >0,将启动一个Spark作业,一次性启动np 个任务来运行Horovod作业。 它会等待直到有np 个任务槽可用才启动作业。 如果np大于集群上的总任务槽数量,作业将会失败。 截止Databricks Runtime 5.4 ML,训练标准输出和错误信息会显示在笔记本单元格输出中。 如遇到单元格输出被截断的情况,完整日志可在由HorovodRunner启动的第二个Spark作业的第0个任务的标准错误流中找到,可以在Spark UI中找到。
      • 0,将使用集群上所有的任务槽来启动作业。 警告:将np设为0已被弃用,并将在下一个主要的Databricks Runtime版本中移除。 因为动态执行器注册,运行时根据总任务槽选择np是不可靠的。 请明确设定所需的并行进程数。
    • driver_log_verbosity:此参数仅在Databricks Runtime上可用。

方法 run(main, **kwargs)

运行Horovod训练任务,调用main(**kwargs)。

开源版本只在同一Python进程中调用main(**kwargs)。 在Databricks Runtime 5.0 ML及更高版本上,它将按照np的文档化行为启动Horovod作业。 主函数和关键字参数都使用cloudpickle序列化并分发给集群工作者。

  • 参数

    • main – 包含Horovod训练代码的Python函数。 预期签名是def main(\*\*kwargs) 或兼容形式。 由于函数会被pickled并分发给工作节点,因此请在函数内部更改全局状态,例如设置日志级别,并注意pickling的限制。 避免在函数中引用大型对象,这可能导致较大的序列化数据,从而使作业启动变慢。

    • kwargs – 调用主函数时传入的关键字参数。

  • 返回

    主函数的返回值。 当np >= 0时,从rank 0 进程返回。请注意,返回值应能被cloudpickle序列化。

发布版本

请访问Github发布页面查看发布说明。

许可证

  • 源代码根据Apache许可证2.0(见LICENSE文件)发布。

项目介绍

Deep Learning Pipelines for Apache Spark

定制我的领域
1521.99 K488访问 GitHub