mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			352 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			352 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import hashlib
 | |
| import io
 | |
| import os
 | |
| import re
 | |
| import socket
 | |
| from enum import Enum
 | |
| from typing import Any, Dict, List, Union
 | |
| 
 | |
| import boto3
 | |
| import botocore
 | |
| import torch
 | |
| 
 | |
| from internlm.utils.common import SingletonMeta
 | |
| from internlm.utils.logger import get_logger
 | |
| 
 | |
| logger = get_logger(__file__)
 | |
| 
 | |
| boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
 | |
| 
 | |
| MB = 1024**2
 | |
| 
 | |
| storage_manager = None
 | |
| 
 | |
| 
 | |
| def check_folder(fp: str):
 | |
|     storage_manager.assert_fp_exists(fp)
 | |
| 
 | |
| 
 | |
| def get_fns(fp: str):
 | |
|     return storage_manager.get_fns(fp)
 | |
| 
 | |
| 
 | |
| def llm_load(fp: str, *args, **kwargs):
 | |
|     return storage_manager.load(fp, *args, **kwargs)
 | |
| 
 | |
| 
 | |
| def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
 | |
|     storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
 | |
| 
 | |
| 
 | |
| class CheckpointType(Enum):
 | |
|     NORMAL_CHECKPOINT = 1
 | |
| 
 | |
| 
 | |
| class StorageClient:
 | |
|     """
 | |
|     StorageClient as a client for s3 storage access.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, handler) -> None:
 | |
|         self.handler = handler
 | |
| 
 | |
|     @staticmethod
 | |
|     def load(client, load_path: str, map_location):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @staticmethod
 | |
|     def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @staticmethod
 | |
|     def assert_fp_exists(client):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_fns(client):
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class Boto3MetaInfo:
 | |
|     def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
 | |
|         self.client = client
 | |
|         self.bucket_name = bucket_name
 | |
|         self.endpoint = endpoint
 | |
|         self.file_path = file_path
 | |
| 
 | |
| 
 | |
| class LocalMetaInfo:
 | |
|     def __init__(self, client: StorageClient, dest_path: str) -> None:
 | |
|         self.client = client
 | |
|         self.dest_path = dest_path
 | |
| 
 | |
| 
 | |
| def unpack_meta(meta):
 | |
|     args = []
 | |
|     for k, v in meta.__dict__.items():
 | |
|         if k == "endpoint":
 | |
|             continue
 | |
|         args.append(v)
 | |
|     return args
 | |
| 
 | |
| 
 | |
| def compute_file_md5_by_chunk(file_name: str):
 | |
|     hash_md5 = hashlib.md5()
 | |
|     with open(file_name, "rb") as f:
 | |
|         for chunk in iter(lambda: f.read(4096), b""):
 | |
|             hash_md5.update(chunk)
 | |
|     return hash_md5.hexdigest()
 | |
| 
 | |
| 
 | |
| def get_boto3_meta(fp: str) -> Boto3MetaInfo:
 | |
|     assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
 | |
|     parts = fp.lstrip("s3://").split(os.path.sep)
 | |
|     match = boto3_url_re.match(parts[0])
 | |
|     assert match is not None, f"url '{fp}' is not a valid boto3 url"
 | |
|     bucket_name, endpoint = match.group(1), match.group(2)
 | |
|     endpoint = "http://" + endpoint + ":80"
 | |
|     return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
 | |
| 
 | |
| 
 | |
| def get_local_meta(fp: str) -> LocalMetaInfo:
 | |
|     assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
 | |
|     return LocalMetaInfo(None, fp)
 | |
| 
 | |
| 
 | |
| class Boto3Client(StorageClient):
 | |
|     """
 | |
|     Boto3Client
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         s3_endpoint_url: str,
 | |
|         use_threads: int = True,
 | |
|         multipart_chunksize=8 * MB,
 | |
|         max_concurrency: int = 10,
 | |
|         multipart_threshold=100 * MB,
 | |
|     ) -> None:
 | |
|         """S3 object/file storage management class
 | |
| 
 | |
|         Args:
 | |
|             s3_access_keys_id (str): S3 access key ID.
 | |
|             s3_secret_access_key (str): S3 secret access key.
 | |
|             use_threads (bool, optional): Whether to enable multipart. Defaults to True.
 | |
|             multipart_chunksize (_type_, optional): Defaults to 8*MB.
 | |
|             max_concurrency (int, optional): Defaults to 10.
 | |
| 
 | |
|         Raises:
 | |
|             RuntimeError: Connection failures caused by misconfiguration or network problems.
 | |
|         """
 | |
|         super().__init__(boto3)
 | |
|         self.botocore = botocore
 | |
|         try:
 | |
|             s3_access_key_id = os.environ["S3_ACCESS_KEY_ID"]
 | |
|             s3_secret_access_key = os.environ["S3_SECRET_ACCESS_KEY_ID"]
 | |
|         except KeyError as exc:
 | |
|             raise RuntimeError(
 | |
|                 "Please set boto3 bucket 'S3_ACCESS_KEY_ID' and 'S3_SECRET_ACCESS_KEY_ID' using environment variable!"
 | |
|             ) from exc
 | |
| 
 | |
|         self.client = self.handler.client(
 | |
|             "s3",
 | |
|             "",
 | |
|             use_ssl=False,
 | |
|             verify=False,
 | |
|             endpoint_url=s3_endpoint_url,
 | |
|             aws_access_key_id=s3_access_key_id,
 | |
|             aws_secret_access_key=s3_secret_access_key,
 | |
|         )
 | |
| 
 | |
|         self.config = self.handler.s3.transfer.TransferConfig(
 | |
|             multipart_threshold=multipart_threshold,
 | |
|             max_concurrency=max_concurrency,
 | |
|             multipart_chunksize=multipart_chunksize,
 | |
|             use_threads=use_threads,
 | |
|         )
 | |
| 
 | |
|     @staticmethod
 | |
|     def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
 | |
|         assert saved_obj is not None, "saved_obj is None!"
 | |
|         try:
 | |
|             with io.BytesIO() as f:
 | |
|                 torch.save(saved_obj, f, *args, **kwargs)
 | |
|                 f.seek(0)
 | |
|                 handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
 | |
|         except handler.botocore.exceptions.EndpointConnectionError as exc:
 | |
|             raise RuntimeError(
 | |
|                 f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
 | |
|             ) from exc
 | |
| 
 | |
|     @staticmethod
 | |
|     def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
 | |
|         """
 | |
|         Args:
 | |
|             fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
 | |
|         """
 | |
|         try:
 | |
|             with io.BytesIO() as f:
 | |
|                 handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
 | |
|                 f.seek(0)
 | |
|                 states = torch.load(f, *args, map_location=map_location, **kwargs)
 | |
|         except handler.botocore.exceptions.EndpointConnectionError as exc:
 | |
|             raise RuntimeError(
 | |
|                 f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
 | |
|             ) from exc
 | |
|         return states
 | |
| 
 | |
|     @staticmethod
 | |
|     def assert_fp_exists(
 | |
|         handler,
 | |
|         bucket_name: str,
 | |
|         fp: str,
 | |
|     ):
 | |
|         assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_fns(handler, bucket_name: str, fp: str):
 | |
|         """
 | |
|         Ref: https://stackoverflow.com/questions/54314563/
 | |
|         how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
 | |
|         """
 | |
|         paginator = handler.client.get_paginator("list_objects_v2")
 | |
|         pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
 | |
| 
 | |
|         folder_name_list = []
 | |
|         for page in pages:
 | |
|             for obj in page["Contents"]:
 | |
|                 fp: str = obj["Key"]
 | |
|                 folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
 | |
|         return folder_name_list
 | |
| 
 | |
| 
 | |
| class LocalClient(StorageClient):
 | |
|     """
 | |
|     Storage Client for local NFS.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, *args, **kwargs) -> None:  # pylint: disable=W0613
 | |
|         super().__init__(None)
 | |
| 
 | |
|     @staticmethod
 | |
|     def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
 | |
|         assert isinstance(handler, LocalClient)
 | |
|         assert saved_obj is not None
 | |
|         fp_dirname = os.path.dirname(fp)
 | |
|         if not os.path.exists(fp_dirname):
 | |
|             os.makedirs(fp_dirname, exist_ok=True)
 | |
|         torch.save(saved_obj, fp, *args, **kwargs)
 | |
| 
 | |
|     @staticmethod
 | |
|     def load(handler, fp: str, *args, map_location="cpu", **kwargs):
 | |
|         assert isinstance(handler, LocalClient)
 | |
|         assert os.path.exists(fp), f"{fp} is not found!"
 | |
|         with open(fp, "rb") as f:
 | |
|             states = torch.load(f, map_location=map_location, *args, **kwargs)
 | |
|         return states
 | |
| 
 | |
|     @staticmethod
 | |
|     def assert_fp_exists(handler, folder):
 | |
|         assert isinstance(handler, LocalClient)
 | |
|         assert os.path.exists(folder), folder
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_fns(handler, folder):
 | |
|         assert isinstance(handler, LocalClient)
 | |
|         assert os.path.exists(folder), f"folder '{folder}' not exists!"
 | |
|         fns = os.listdir(folder)
 | |
|         return fns
 | |
| 
 | |
|     @staticmethod
 | |
|     def delete_obj(handler, fp: str):
 | |
|         assert isinstance(handler, LocalClient)
 | |
|         if not os.path.isdir(fp):
 | |
|             os.remove(fp)
 | |
| 
 | |
| 
 | |
| class StorageManager(metaclass=SingletonMeta):
 | |
|     """
 | |
|     Storage Manager for saving or loading checkpoint.
 | |
|     """
 | |
| 
 | |
|     BACKEND_TYPE = {"boto3", "local"}
 | |
|     BACKEND_INIT_METHOD = {
 | |
|         "boto3": Boto3Client,
 | |
|         "local": LocalClient,
 | |
|     }
 | |
|     CLI_DICT = {}
 | |
| 
 | |
|     def __init__(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
 | |
|         """
 | |
|         example:
 | |
|         local:/path/to/checkpoint
 | |
|         boto3:s3://model_weights/0331/120bi
 | |
| 
 | |
|         Args:
 | |
|             path (str): _description_
 | |
|         """
 | |
|         try:
 | |
|             backend, path = path.split(":", maxsplit=1)
 | |
|         except Exception as exc:
 | |
|             raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
 | |
| 
 | |
|         init_args = (None,)
 | |
|         if backend == "local":
 | |
|             meta_info = get_local_meta(path)
 | |
|             backend_key = backend
 | |
|         elif backend == "boto3":
 | |
|             meta_info = get_boto3_meta(path)
 | |
|             backend_key = backend + ":" + meta_info.endpoint
 | |
|             init_args = (meta_info.endpoint,)
 | |
|             if (
 | |
|                 "http_proxy" in os.environ
 | |
|                 or "https_proxy" in os.environ
 | |
|                 or "HTTP_PROXY" in os.environ
 | |
|                 or "HTTPS_PROXY" in os.environ
 | |
|             ):
 | |
|                 raise RuntimeWarning(
 | |
|                     "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
 | |
| the proxy may make boto3 unavailable or affect performance."
 | |
|                 )
 | |
| 
 | |
|         assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
 | |
| 
 | |
|         # boto3 backend need special treatment.
 | |
|         if backend_key not in StorageManager.CLI_DICT:
 | |
|             StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})
 | |
| 
 | |
|         meta_info.client = StorageManager.CLI_DICT[backend_key]
 | |
| 
 | |
|         return meta_info
 | |
| 
 | |
|     def assert_fp_exists(self, folder) -> None:
 | |
|         meta = self._get_client(path=folder)
 | |
|         meta.client.assert_fp_exists(*unpack_meta(meta))
 | |
| 
 | |
|     def get_fns(self, folder) -> List[str]:
 | |
|         meta = self._get_client(path=folder)
 | |
|         return meta.client.get_fns(*unpack_meta(meta))
 | |
| 
 | |
|     def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
 | |
|         meta = self._get_client(path=save_path)
 | |
| 
 | |
|         meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
 | |
| 
 | |
|     def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
 | |
| 
 | |
|         meta = self._get_client(path=load_path)
 | |
|         return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
 | |
| 
 | |
|     def delete_obj(self, fp: str):
 | |
|         meta = self._get_client(path=fp)
 | |
|         meta.client.delete_obj(*unpack_meta(meta))
 | |
| 
 | |
| 
 | |
| storage_manager = StorageManager()
 |