mirror of https://github.com/InternLM/InternLM
[Daily Pull] Merge Main to Develop 20230901 (#260)
* Standard and experiment docker (#220) * feat:standard docker image * feat:standard docker image * feat: standard dockerfile * feat: standard dockerfile * feat: standard dockerfile * feat: standard dockerfile * feat: standard dockerfile * feat: standard dockerfile * feat: standard dockerfile * experiment and standard docker * experiment and standard docker * fix(core/trainer.py): fix streaming train state load error (#247) * Fix requirement (#243) * feat:standard docker image * feat:standard docker image * fix: a little problem * fix: a little problem * fix(eval): StreamingDataset does not have an __len__ method. (#251) * fix(metric): argument missing in getting loss metrics. (#256) * feat(model): implement uniform_init for tensor. (#252) * Implement uniform_init for tensor. * Fix functinal calling bugs: normal->uniform. * Format editting: remove unused torch importing. --------- Co-authored-by: li126com <43110891+li126com@users.noreply.github.com> Co-authored-by: huangting4201 <1538303371@qq.com> Co-authored-by: Shuo Zhang <zhangshuolove@live.com> Co-authored-by: Ryan (张磊) <MagicDevil.Zhang@qq.com> Co-authored-by: Pryest <54388244+Pryest@users.noreply.github.com>pull/262/head
parent
992499d00d
commit
fca1df20ae
|
@ -59,12 +59,28 @@ cd ../../
|
|||
```
|
||||
|
||||
### Environment Image
|
||||
Users can obtain an image with the InternLM runtime environment installed from https://hub.docker.com/r/sunpengsdu/internlm. The commands for pulling the image and starting the container are as follows:
|
||||
Users can use the provided dockerfile combined with docker.Makefile to build their own images, or obtain images with InternLM runtime environment installed from https://hub.docker.com/r/internlm/internlm.
|
||||
|
||||
#### Image Configuration and Build
|
||||
The configuration and build of the Dockerfile are implemented through the docker.Makefile. To build the image, execute the following command in the root directory of InternLM:
|
||||
``` bash
|
||||
make -f docker.Makefile BASE_OS=centos7
|
||||
```
|
||||
In docker.Makefile, you can customize the basic image, environment version, etc., and the corresponding parameters can be passed directly through the command line. For BASE_OS, ubuntu20.04 and centos7 are respectively supported.
|
||||
|
||||
#### Pull Standard Image
|
||||
The standard image based on ubuntu and centos has been built and can be directly pulled:
|
||||
|
||||
```bash
|
||||
# pull image
|
||||
docker pull sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
|
||||
# start container
|
||||
docker run --gpus all -d -it --shm-size=2gb --name myinternlm sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
|
||||
docker exec -it myinternlm bash
|
||||
# ubuntu20.04
|
||||
docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-ubuntu20.04
|
||||
# centos7
|
||||
docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7
|
||||
```
|
||||
|
||||
#### Run Container
|
||||
For the local standard image built with dockerfile or pulled, use the following command to run and enter the container:
|
||||
```bash
|
||||
docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7 bash
|
||||
```
|
||||
The default directory in the container is `/InternLM`, please start training according to the [Usage](./usage.md).
|
||||
|
|
|
@ -59,11 +59,28 @@ cd ../../
|
|||
```
|
||||
|
||||
### 环境镜像
|
||||
用户可以从 https://hub.docker.com/r/sunpengsdu/internlm 获取安装了 InternLM 运行环境的镜像,拉取镜像及启动容器的命令如下:
|
||||
用户可以使用提供的 dockerfile 结合 docker.Makefile 来构建自己的镜像,或者也可以从 https://hub.docker.com/r/internlm/internlm 获取安装了 InternLM 运行环境的镜像。
|
||||
|
||||
#### 镜像配置及构造
|
||||
dockerfile 的配置以及构造均通过 docker.Makefile 文件实现,在 InternLM 根目录下执行如下命令即可 build 镜像:
|
||||
``` bash
|
||||
make -f docker.Makefile BASE_OS=centos7
|
||||
```
|
||||
在 docker.Makefile 中可自定义基础镜像,环境版本等内容,对应参数可直接通过命令行传递。对于 BASE_OS 分别支持 ubuntu20.04 和 centos7。
|
||||
|
||||
#### 镜像拉取
|
||||
基于 ubuntu 和 centos 的标准镜像已经 build 完成也可直接拉取使用:
|
||||
|
||||
```bash
|
||||
# 拉取镜像
|
||||
docker pull sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
|
||||
# 启动容器
|
||||
docker run --gpus all -d -it --shm-size=2gb --name myinternlm sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
|
||||
docker exec -it myinternlm bash
|
||||
# ubuntu20.04
|
||||
docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-ubuntu20.04
|
||||
# centos7
|
||||
docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7
|
||||
```
|
||||
|
||||
#### 容器启动
|
||||
对于使用 dockerfile 构建或拉取的本地标准镜像,使用如下命令启动并进入容器:
|
||||
```bash
|
||||
docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7 bash
|
||||
```
|
||||
容器内默认目录即 `/InternLM`,根据[使用文档](./usage.md)即可启动训练。
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
DOCKER_REGISTRY ?= docker.io
|
||||
DOCKER_ORG ?= my
|
||||
DOCKER_IMAGE ?= internlm
|
||||
DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE)
|
||||
|
||||
CUDA_VERSION = 11.7.1
|
||||
GCC_VERSION = 10.2.0
|
||||
|
||||
CUDNN_VERSION = 8
|
||||
BASE_RUNTIME =
|
||||
# ubuntu20.04 centos7
|
||||
BASE_OS = centos7
|
||||
BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-devel-${BASE_OS}
|
||||
# The conda channel to use to install cudatoolkit
|
||||
CUDA_CHANNEL = nvidia
|
||||
# The conda channel to use to install pytorch / torchvision
|
||||
INSTALL_CHANNEL ?= pytorch
|
||||
|
||||
PYTHON_VERSION ?= 3.10
|
||||
PYTORCH_VERSION ?= 1.13.1
|
||||
TORCHVISION_VERSION ?= 0.14.1
|
||||
TORCHAUDIO_VERSION ?= 0.13.1
|
||||
BUILD_PROGRESS ?= auto
|
||||
TRITON_VERSION ?=
|
||||
GMP_VERSION ?= 6.2.1
|
||||
MPFR_VERSION ?= 4.1.0
|
||||
MPC_VERSION ?= 1.2.1
|
||||
GCC_VERSION ?= 10.2.0
|
||||
HTTPS_PROXY_I ?=
|
||||
HTTP_PROXY_I ?=
|
||||
FLASH_ATTEN_VERSION ?= 1.0.5
|
||||
FLASH_ATTEN_TAG ?= v${FLASH_ATTEN_VERSION}
|
||||
|
||||
BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \
|
||||
--build-arg PYTHON_VERSION=$(PYTHON_VERSION) \
|
||||
--build-arg CUDA_VERSION=$(CUDA_VERSION) \
|
||||
--build-arg CUDA_CHANNEL=$(CUDA_CHANNEL) \
|
||||
--build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \
|
||||
--build-arg TORCHVISION_VERSION=$(TORCHVISION_VERSION) \
|
||||
--build-arg TORCHAUDIO_VERSION=$(TORCHAUDIO_VERSION) \
|
||||
--build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) \
|
||||
--build-arg TRITON_VERSION=$(TRITON_VERSION) \
|
||||
--build-arg GMP_VERSION=$(GMP_VERSION) \
|
||||
--build-arg MPFR_VERSION=$(MPFR_VERSION) \
|
||||
--build-arg MPC_VERSION=$(MPC_VERSION) \
|
||||
--build-arg GCC_VERSION=$(GCC_VERSION) \
|
||||
--build-arg https_proxy=$(HTTPS_PROXY_I) \
|
||||
--build-arg http_proxy=$(HTTP_PROXY_I) \
|
||||
--build-arg FLASH_ATTEN_TAG=$(FLASH_ATTEN_TAG)
|
||||
|
||||
EXTRA_DOCKER_BUILD_FLAGS ?=
|
||||
|
||||
BUILD ?= build
|
||||
# Intentionally left blank
|
||||
PLATFORMS_FLAG ?=
|
||||
PUSH_FLAG ?=
|
||||
USE_BUILDX ?=1
|
||||
BUILD_PLATFORMS ?=
|
||||
WITH_PUSH ?= false
|
||||
BUILD_TYPE ?= intrenlm-dev
|
||||
|
||||
# Setup buildx flags
|
||||
ifneq ("$(USE_BUILDX)","")
|
||||
BUILD = buildx build
|
||||
ifneq ("$(BUILD_PLATFORMS)","")
|
||||
PLATFORMS_FLAG = --platform="$(BUILD_PLATFORMS)"
|
||||
endif
|
||||
endif
|
||||
# endif
|
||||
|
||||
# # Only set platforms flags if using buildx
|
||||
# ifeq ("$(WITH_PUSH)","true")
|
||||
# PUSH_FLAG = --push
|
||||
# endif
|
||||
# endif
|
||||
|
||||
ifeq ($(findstring centos,$(BASE_OS)),centos)
|
||||
DOCKERFILE_PATH ?= ./docker/Dockerfile-centos
|
||||
else
|
||||
DOCKERFILE_PATH ?= ./docker/Dockerfile-ubuntu
|
||||
endif
|
||||
|
||||
#use -f to specify dockerfile
|
||||
DOCKER_BUILD = DOCKER_BUILDKIT=1 \
|
||||
docker $(BUILD) \
|
||||
--progress=$(BUILD_PROGRESS) \
|
||||
$(EXTRA_DOCKER_BUILD_FLAGS) \
|
||||
$(PLATFORMS_FLAG) \
|
||||
$(PUSH_FLAG) \
|
||||
-f $(DOCKERFILE_PATH) \
|
||||
-t $(DOCKER_FULL_NAME):$(DOCKER_TAG) \
|
||||
$(BUILD_ARGS) .
|
||||
|
||||
# --target $(BUILD_TYPE)
|
||||
|
||||
.PHONY: all
|
||||
all: devel-image
|
||||
|
||||
.PHONY: devel-image
|
||||
devel-image: BASE_IMAGE := $(BASE_DEVEL)
|
||||
devel-image: DOCKER_TAG := torch${PYTORCH_VERSION}-cuda${CUDA_VERSION}-flashatten${FLASH_ATTEN_VERSION}-${BASE_OS}
|
||||
devel-image:
|
||||
$(DOCKER_BUILD)
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
-docker rmi -f $(shell docker images -q $(DOCKER_FULL_NAME))
|
|
@ -0,0 +1,131 @@
|
|||
ARG BASE_IMAGE
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
|
||||
##############################################################################
|
||||
# Install the basic environment on centos
|
||||
##############################################################################
|
||||
FROM ${BASE_IMAGE} as base
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN yum install deltarpm -y && yum update -y \
|
||||
&& yum install -y \
|
||||
ca-certificates \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
wget \
|
||||
tar \
|
||||
m4 \
|
||||
bzip2 \
|
||||
gcc \
|
||||
gcc-c++ \
|
||||
file \
|
||||
texinfo \
|
||||
which
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install the conda environment
|
||||
##############################################################################
|
||||
FROM base as conda
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG TARGETPLATFORM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MINICONDA_ARCH=aarch64 ;; \
|
||||
*) MINICONDA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh"
|
||||
|
||||
RUN chmod +x ~/miniconda.sh && \
|
||||
bash ~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install environment dependencies
|
||||
##############################################################################
|
||||
FROM conda as dep
|
||||
WORKDIR /dep
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG GMP_VERSION
|
||||
ARG MPFR_VERSION
|
||||
ARG MPC_VERSION
|
||||
RUN wget https://ftp.gnu.org/gnu/gmp/gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& tar -vxf gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& cd gmp-${GMP_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget https://ftp.gnu.org/gnu/mpfr/mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& tar -vxf mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& cd mpfr-${MPFR_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpfr-${MPFR_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget http://www.multiprecision.org/downloads/mpc-${MPC_VERSION}.tar.gz \
|
||||
&& tar -vxf mpc-${MPC_VERSION}.tar.gz \
|
||||
&& cd mpc-${MPC_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpc-${MPC_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& git clone https://github.com/ninja-build/ninja.git \
|
||||
&& cd ninja \
|
||||
&& git checkout release \
|
||||
&& ./configure.py --bootstrap \
|
||||
&& mv ./ninja /usr/bin \
|
||||
&& cd ..
|
||||
|
||||
ENV MPFR_HOME=/usr/local/mpfr-${MPFR_VERSION}
|
||||
ENV LD_LIBRARY_PATH=${MPFR_HOME}/lib:$LD_LIBRARY_PATH
|
||||
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG GCC_VERSION
|
||||
ARG GMP_VERSION
|
||||
ARG MPFR_VERSION
|
||||
ARG MPC_VERSION
|
||||
RUN wget https://ftp.gnu.org/gnu/gcc/gcc-${GCC_VERSION}/gcc-${GCC_VERSION}.tar.xz \
|
||||
&& tar -vxf gcc-${GCC_VERSION}.tar.xz \
|
||||
&& mkdir build \
|
||||
&& cd build/ \
|
||||
&& ../gcc-${GCC_VERSION}/configure --prefix=/usr/local/gcc-${GCC_VERSION}/ --enable-threads=posix --disable-checking --enable-languages=c,c++ --disable-multilib \
|
||||
--with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} --with-mpc=/usr/local/mpc-${MPC_VERSION} \
|
||||
&& make -j64 && make install
|
||||
|
||||
ENV GCC_HOME=/usr/local/gcc-${GCC_VERSION}
|
||||
ENV LD_LIBRARY_PATH=${GCC_HOME}/lib64:${CUDA_PATH}/lib64:$LD_LIBRARY_PATH
|
||||
ENV PATH=${GCC_HOME}/bin:${CUDA_PATH}/bin:$PATH
|
||||
ENV CC=${GCC_HOME}/bin/gcc
|
||||
ENV CXX=${GCC_HOME}/bin/c++
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install InternLM development environment, including flash-attention and apex
|
||||
##############################################################################
|
||||
FROM dep as intrenlm-dev
|
||||
COPY . /InternLM
|
||||
WORKDIR /InternLM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
|
||||
RUN git submodule update --init --recursive \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -r requirements/torch.txt \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -r requirements/runtime.txt \
|
||||
&& cd /InternLM/third_party/flash-attention \
|
||||
&& /opt/conda/bin/python setup.py install \
|
||||
&& cd ./csrc \
|
||||
&& cd fused_dense_lib && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../xentropy && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../rotary && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../layer_norm && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../../../../ \
|
||||
&& cd ./third_party/apex \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \
|
||||
&& /opt/conda/bin/pip cache purge \
|
||||
&& rm -rf ~/.cache/pip
|
|
@ -0,0 +1,112 @@
|
|||
ARG BASE_IMAGE
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
|
||||
##############################################################################
|
||||
# Install the basic environment on ubuntu
|
||||
##############################################################################
|
||||
FROM ${BASE_IMAGE} as base
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
wget \
|
||||
tar \
|
||||
m4 \
|
||||
ninja-build
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install the conda environment
|
||||
##############################################################################
|
||||
FROM base as conda
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG TARGETPLATFORM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MINICONDA_ARCH=aarch64 ;; \
|
||||
*) MINICONDA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh"
|
||||
|
||||
RUN chmod +x ~/miniconda.sh && \
|
||||
bash ~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install environment dependencies
|
||||
##############################################################################
|
||||
FROM conda as dep
|
||||
WORKDIR /dep
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG GCC_VERSION
|
||||
ARG GMP_VERSION
|
||||
ARG MPFR_VERSION
|
||||
ARG MPC_VERSION
|
||||
RUN wget https://ftp.gnu.org/gnu/gmp/gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& tar -vxf gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& cd gmp-${GMP_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget https://ftp.gnu.org/gnu/mpfr/mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& tar -vxf mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& cd mpfr-${MPFR_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpfr-${MPFR_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget http://www.multiprecision.org/downloads/mpc-${MPC_VERSION}.tar.gz \
|
||||
&& tar -vxf mpc-${MPC_VERSION}.tar.gz \
|
||||
&& cd mpc-${MPC_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpc-${MPC_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget https://ftp.gnu.org/gnu/gcc/gcc-${GCC_VERSION}/gcc-${GCC_VERSION}.tar.xz \
|
||||
&& tar -vxJf gcc-${GCC_VERSION}.tar.xz \
|
||||
&& mkdir build \
|
||||
&& cd build/ \
|
||||
&& ../gcc-${GCC_VERSION}/configure --prefix=/usr/local/gcc-${GCC_VERSION}/ --enable-checking=release --enable-languages=c,c++ --disable-multilib \
|
||||
--with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} --with-mpc=/usr/local/mpc-${MPC_VERSION} \
|
||||
&& make -j64 && make install
|
||||
|
||||
ENV GCC_HOME=/usr/local/gcc-${GCC_VERSION}
|
||||
ENV MPFR_HOME=/usr/local/mpfr-${MPFR_VERSION}
|
||||
ENV LD_LIBRARY_PATH=${GCC_HOME}/lib64:${MPFR_HOME}/lib:${CUDA_PATH}/lib64:$LD_LIBRARY_PATH
|
||||
ENV PATH=${GCC_HOME}/bin:${CUDA_PATH}/bin:$PATH
|
||||
ENV CC=${GCC_HOME}/bin/gcc
|
||||
ENV CXX=${GCC_HOME}/bin/c++
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install InternLM development environment, including flash-attention and apex
|
||||
##############################################################################
|
||||
FROM dep as intrenlm-dev
|
||||
COPY . /InternLM
|
||||
WORKDIR /InternLM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
|
||||
RUN git submodule update --init --recursive \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -r requirements/torch.txt \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -r requirements/runtime.txt \
|
||||
&& cd /InternLM/third_party/flash-attention \
|
||||
&& /opt/conda/bin/python setup.py install \
|
||||
&& cd ./csrc \
|
||||
&& cd fused_dense_lib && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../xentropy && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../rotary && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../layer_norm && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../../../../ \
|
||||
&& cd ./third_party/apex \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \
|
||||
&& /opt/conda/bin/pip cache purge \
|
||||
&& rm -rf ~/.cache/pip
|
|
@ -0,0 +1,161 @@
|
|||
ARG BASE_IMAGE
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
|
||||
##############################################################################
|
||||
# Install the basic environment on centos
|
||||
##############################################################################
|
||||
FROM ${BASE_IMAGE} as base
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN yum install deltarpm -y && yum update -y \
|
||||
&& yum install -y \
|
||||
ca-certificates \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
wget \
|
||||
tar \
|
||||
m4 \
|
||||
bzip2 \
|
||||
gcc \
|
||||
gcc-c++ \
|
||||
file \
|
||||
texinfo \
|
||||
which
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install the conda environment
|
||||
##############################################################################
|
||||
FROM base as conda
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG TARGETPLATFORM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MINICONDA_ARCH=aarch64 ;; \
|
||||
*) MINICONDA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh"
|
||||
|
||||
RUN chmod +x ~/miniconda.sh && \
|
||||
bash ~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install environment dependencies
|
||||
##############################################################################
|
||||
FROM conda as dep
|
||||
WORKDIR /dep
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG GMP_VERSION
|
||||
ARG MPFR_VERSION
|
||||
ARG MPC_VERSION
|
||||
RUN wget https://ftp.gnu.org/gnu/gmp/gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& tar -vxf gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& cd gmp-${GMP_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget https://ftp.gnu.org/gnu/mpfr/mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& tar -vxf mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& cd mpfr-${MPFR_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpfr-${MPFR_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget http://www.multiprecision.org/downloads/mpc-${MPC_VERSION}.tar.gz \
|
||||
&& tar -vxf mpc-${MPC_VERSION}.tar.gz \
|
||||
&& cd mpc-${MPC_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpc-${MPC_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& git clone https://github.com/ninja-build/ninja.git \
|
||||
&& cd ninja \
|
||||
&& git checkout release \
|
||||
&& ./configure.py --bootstrap \
|
||||
&& mv ./ninja /usr/bin \
|
||||
&& cd ..
|
||||
|
||||
ENV MPFR_HOME=/usr/local/mpfr-${MPFR_VERSION}
|
||||
ENV LD_LIBRARY_PATH=${MPFR_HOME}/lib:$LD_LIBRARY_PATH
|
||||
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG GCC_VERSION
|
||||
ARG GMP_VERSION
|
||||
ARG MPFR_VERSION
|
||||
ARG MPC_VERSION
|
||||
RUN wget https://ftp.gnu.org/gnu/gcc/gcc-${GCC_VERSION}/gcc-${GCC_VERSION}.tar.xz \
|
||||
&& tar -vxf gcc-${GCC_VERSION}.tar.xz \
|
||||
&& mkdir build \
|
||||
&& cd build/ \
|
||||
&& ../gcc-${GCC_VERSION}/configure --prefix=/usr/local/gcc-${GCC_VERSION}/ --enable-threads=posix --disable-checking --enable-languages=c,c++ --disable-multilib \
|
||||
--with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} --with-mpc=/usr/local/mpc-${MPC_VERSION} \
|
||||
&& make -j64 && make install
|
||||
|
||||
ENV GCC_HOME=/usr/local/gcc-${GCC_VERSION}
|
||||
ENV LD_LIBRARY_PATH=${GCC_HOME}/lib64:${CUDA_PATH}/lib64:$LD_LIBRARY_PATH
|
||||
ENV PATH=${GCC_HOME}/bin:${CUDA_PATH}/bin:$PATH
|
||||
ENV CC=${GCC_HOME}/bin/gcc
|
||||
ENV CXX=${GCC_HOME}/bin/c++
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install InternLM development environment, including flash-attention and apex
|
||||
##############################################################################
|
||||
FROM dep as intrenlm-dev
|
||||
COPY . /InternLM
|
||||
WORKDIR /InternLM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG PYTORCH_VERSION
|
||||
ARG TORCHVISION_VERSION
|
||||
ARG TORCHAUDIO_VERSION
|
||||
|
||||
RUN /opt/conda/bin/pip --no-cache-dir install \
|
||||
transformers==4.29.2 \
|
||||
sentencepiece \
|
||||
numpy \
|
||||
tqdm \
|
||||
psutil \
|
||||
packaging \
|
||||
pre-commit \
|
||||
ninja \
|
||||
gputil \
|
||||
pytest \
|
||||
packaging \
|
||||
boto3 \
|
||||
botocore \
|
||||
torch-scatter \
|
||||
pyecharts \
|
||||
-f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}+cu117.html \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 \
|
||||
torch==${PYTORCH_VERSION}+cu117 \
|
||||
torchvision==${TORCHVISION_VERSION}+cu117 \
|
||||
torchaudio==${TORCHAUDIO_VERSION}
|
||||
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
|
||||
ARG FLASH_ATTEN_TAG
|
||||
|
||||
RUN git submodule update --init --recursive \
|
||||
&& cd /InternLM/third_party/flash-attention \
|
||||
&& git checkout ${FLASH_ATTEN_TAG} \
|
||||
&& /opt/conda/bin/python setup.py install \
|
||||
&& cd ./csrc \
|
||||
&& cd fused_dense_lib && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../xentropy && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../rotary && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../layer_norm && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../../../../ \
|
||||
&& cd ./third_party/apex \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \
|
||||
&& /opt/conda/bin/pip cache purge \
|
||||
&& rm -rf ~/.cache/pip
|
|
@ -0,0 +1,142 @@
|
|||
ARG BASE_IMAGE
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
|
||||
##############################################################################
|
||||
# Install the basic environment on ubuntu
|
||||
##############################################################################
|
||||
FROM ${BASE_IMAGE} as base
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
wget \
|
||||
tar \
|
||||
m4 \
|
||||
ninja-build
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install the conda environment
|
||||
##############################################################################
|
||||
FROM base as conda
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG TARGETPLATFORM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MINICONDA_ARCH=aarch64 ;; \
|
||||
*) MINICONDA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh"
|
||||
|
||||
RUN chmod +x ~/miniconda.sh && \
|
||||
bash ~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install environment dependencies
|
||||
##############################################################################
|
||||
FROM conda as dep
|
||||
WORKDIR /dep
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG GCC_VERSION
|
||||
ARG GMP_VERSION
|
||||
ARG MPFR_VERSION
|
||||
ARG MPC_VERSION
|
||||
RUN wget https://ftp.gnu.org/gnu/gmp/gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& tar -vxf gmp-${GMP_VERSION}.tar.bz2 \
|
||||
&& cd gmp-${GMP_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget https://ftp.gnu.org/gnu/mpfr/mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& tar -vxf mpfr-${MPFR_VERSION}.tar.gz \
|
||||
&& cd mpfr-${MPFR_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpfr-${MPFR_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget http://www.multiprecision.org/downloads/mpc-${MPC_VERSION}.tar.gz \
|
||||
&& tar -vxf mpc-${MPC_VERSION}.tar.gz \
|
||||
&& cd mpc-${MPC_VERSION}/ \
|
||||
&& ./configure --prefix=/usr/local/mpc-${MPC_VERSION} --with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} \
|
||||
&& make -j64 && make install \
|
||||
&& cd .. \
|
||||
&& wget https://ftp.gnu.org/gnu/gcc/gcc-${GCC_VERSION}/gcc-${GCC_VERSION}.tar.xz \
|
||||
&& tar -vxJf gcc-${GCC_VERSION}.tar.xz \
|
||||
&& mkdir build \
|
||||
&& cd build/ \
|
||||
&& ../gcc-${GCC_VERSION}/configure --prefix=/usr/local/gcc-${GCC_VERSION}/ --enable-checking=release --enable-languages=c,c++ --disable-multilib \
|
||||
--with-gmp=/usr/local/gmp-${GMP_VERSION} --with-mpfr=/usr/local/mpfr-${MPFR_VERSION} --with-mpc=/usr/local/mpc-${MPC_VERSION} \
|
||||
&& make -j64 && make install
|
||||
|
||||
ENV GCC_HOME=/usr/local/gcc-${GCC_VERSION}
|
||||
ENV MPFR_HOME=/usr/local/mpfr-${MPFR_VERSION}
|
||||
ENV LD_LIBRARY_PATH=${GCC_HOME}/lib64:${MPFR_HOME}/lib:${CUDA_PATH}/lib64:$LD_LIBRARY_PATH
|
||||
ENV PATH=${GCC_HOME}/bin:${CUDA_PATH}/bin:$PATH
|
||||
ENV CC=${GCC_HOME}/bin/gcc
|
||||
ENV CXX=${GCC_HOME}/bin/c++
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Install InternLM development environment, including flash-attention and apex
|
||||
##############################################################################
|
||||
FROM dep as intrenlm-dev
|
||||
COPY . /InternLM
|
||||
WORKDIR /InternLM
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG PYTORCH_VERSION
|
||||
ARG TORCHVISION_VERSION
|
||||
ARG TORCHAUDIO_VERSION
|
||||
|
||||
RUN /opt/conda/bin/pip --no-cache-dir install \
|
||||
transformers==4.29.2 \
|
||||
sentencepiece \
|
||||
numpy \
|
||||
tqdm \
|
||||
psutil \
|
||||
packaging \
|
||||
pre-commit \
|
||||
ninja \
|
||||
gputil \
|
||||
pytest \
|
||||
packaging \
|
||||
boto3 \
|
||||
botocore \
|
||||
torch-scatter \
|
||||
pyecharts \
|
||||
-f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}+cu117.html \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 \
|
||||
torch==${PYTORCH_VERSION}+cu117 \
|
||||
torchvision==${TORCHVISION_VERSION}+cu117 \
|
||||
torchaudio==${TORCHAUDIO_VERSION}
|
||||
|
||||
ARG https_proxy
|
||||
ARG http_proxy
|
||||
ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
|
||||
ARG FLASH_ATTEN_TAG
|
||||
|
||||
RUN git submodule update --init --recursive \
|
||||
&& cd /InternLM/third_party/flash-attention \
|
||||
&& git checkout ${FLASH_ATTEN_TAG} \
|
||||
&& /opt/conda/bin/python setup.py install \
|
||||
&& cd ./csrc \
|
||||
&& cd fused_dense_lib && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../xentropy && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../rotary && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../layer_norm && /opt/conda/bin/pip install -v . \
|
||||
&& cd ../../../../ \
|
||||
&& cd ./third_party/apex \
|
||||
&& /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \
|
||||
&& /opt/conda/bin/pip cache purge \
|
||||
&& rm -rf ~/.cache/pip
|
|
@ -0,0 +1,25 @@
|
|||
## 实验性环境镜像
|
||||
本模块用于测试新版本环境,默认测试新环境 torch=2.0.1,flash-attention=2.1.0。新环境可能具有不稳定性,标准环境安装请参考:[安装文档](../doc/install.md)
|
||||
|
||||
### 镜像构建及拉取
|
||||
构建镜像时请于 InternLM 根目录下执行 docker.Makefile,该文件与标准环境镜像共用,所使用的 Dockerfile 位于 experiment 目录下。也可直接从 https://hub.docker.com/r/internlm/internlm 拉取镜像,命令如下:
|
||||
```bash
|
||||
# 构建镜像
|
||||
# ubuntu20.04
|
||||
make -f docker.Makefile BASE_OS=ubuntu20.04 DOCKERFILE_PATH=./experiment/Dockerfile-ubuntu PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0
|
||||
# centos7
|
||||
make -f docker.Makefile BASE_OS=centos7 DOCKERFILE_PATH=./experiment/Dockerfile-centos PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0
|
||||
|
||||
# 拉取镜像
|
||||
# ubuntu20.04
|
||||
docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-ubuntu20.04
|
||||
# centos7
|
||||
docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7
|
||||
```
|
||||
|
||||
### 容器启动
|
||||
对于使用 dockerfile 构建或拉取的本地标准镜像,使用如下命令启动并进入容器:
|
||||
```bash
|
||||
docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7 bash
|
||||
```
|
||||
容器内默认目录即 `/InternLM`,根据[使用文档](../doc/usage.md)即可启动训练。
|
|
@ -0,0 +1,25 @@
|
|||
## Environment Image for experiment
|
||||
This module is used to test the new version environment, the default test new environment is torch=2.0.1, flash-attention=2.1.0. The new environment may be unstable, for the standard environment installation please refer to: [installation guide](../doc/en/install.md)
|
||||
|
||||
### Build and Pull Image
|
||||
When building the image, please make docker.Makefile in the InternLM root directory. This Makefile is shared with the standard environment image, and the Dockerfile used is located in the experiment directory. You can also pull the image directly from https://hub.docker.com/r/internlm/internlm, the command is as follows:
|
||||
```bash
|
||||
# Build Image
|
||||
# ubuntu20.04
|
||||
make -f docker.Makefile BASE_OS=ubuntu20.04 DOCKERFILE_PATH=./experiment/Dockerfile-ubuntu PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0
|
||||
# centos7
|
||||
make -f docker.Makefile BASE_OS=centos7 DOCKERFILE_PATH=./experiment/Dockerfile-centos PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0
|
||||
|
||||
# Pull Image
|
||||
# ubuntu20.04
|
||||
docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-ubuntu20.04
|
||||
# centos7
|
||||
docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7
|
||||
```
|
||||
|
||||
### Run Container
|
||||
For the local standard image built with dockerfile or pulled, use the following command to run and enter the container:
|
||||
```bash
|
||||
docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7 bash
|
||||
```
|
||||
The default directory in the container is `/InternLM`, please start training according to the [Usage](../doc/en/usage.md).
|
|
@ -78,8 +78,9 @@ class TrainState:
|
|||
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
||||
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
if hasattr(self, "batch_sampler"):
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
# resume tensorboard from older tensorboard_folder
|
||||
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||
|
|
|
@ -3,16 +3,15 @@
|
|||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
def scaled_init_method_normal(sigma, num_layers):
|
||||
def scaled_init_method_normal(sigma: float = 1.0, num_layers: int = 1):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
return nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
@ -32,3 +31,33 @@ def normal_(mean: float = 0.0, std: float = 1.0):
|
|||
return nn.init.normal_(tensor, mean, std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def scaled_init_method_uniform(sigma: float = 1.0, num_layers: int = 1):
|
||||
"""Init method based on p(x)=Uniform(-a, a) where std(x)=sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
a = math.sqrt(3.0 * std)
|
||||
|
||||
def init_(tensor):
|
||||
return nn.init.uniform_(tensor, -a, a)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def uniform_(mean: float = 0.0, std: float = 1.0):
|
||||
r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution
|
||||
|
||||
.. math::
|
||||
\mathcal{U}(mean-a, mean+a), where a satisfies \mathcal{U}_{std}=std.
|
||||
|
||||
Args:
|
||||
mean (float): the mean of the uniform distribution. Defaults 0.0.
|
||||
std (float): the standard deviation of the uniform distribution. Defaults 1.0.
|
||||
"""
|
||||
|
||||
a = math.sqrt(3.0 * std)
|
||||
|
||||
def initializer(tensor: Tensor):
|
||||
return nn.init.uniform_(tensor, mean - a, mean + a)
|
||||
|
||||
return initializer
|
||||
|
|
|
@ -176,7 +176,7 @@ class AccPerplex:
|
|||
res.update(ds_acc)
|
||||
res.update(ds_tokens)
|
||||
|
||||
loss_res = self.loss_with_type_id.get_metric()
|
||||
loss_res = self.loss_with_type_id.get_metric(reset)
|
||||
res.update(loss_res)
|
||||
|
||||
return res
|
||||
|
|
|
@ -76,7 +76,7 @@ def evaluate_on_val_dls(
|
|||
data_cfg = gpc.config.data
|
||||
|
||||
for val_name, val_dl in val_dls.items():
|
||||
if len(val_dl) == 0 and verbose and not streaming:
|
||||
if not streaming and len(val_dl) == 0 and verbose:
|
||||
logger.info(f"Validation dataset: {val_name} is empty")
|
||||
continue
|
||||
|
||||
|
|
|
@ -13,4 +13,4 @@ boto3
|
|||
botocore
|
||||
torch-scatter
|
||||
pyecharts
|
||||
-f https://data.pyg.org/whl/torch-1.13.0+cu117.html
|
||||
-f https://data.pyg.org/whl/torch-1.13.1+cu117.html
|
Loading…
Reference in New Issue