mirror of https://github.com/hpcaitech/ColossalAI
remote comments
parent
7416e4943b
commit
df63db7e63
|
@ -37,18 +37,16 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
|
||||
# transfer state
|
||||
for p in params:
|
||||
# TODO(haze188): check状态转换
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._gemini_manager.sample_overall_data()
|
||||
|
||||
# evit chunks, aware of async fetched
|
||||
# TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下
|
||||
# TODO: check if prefetched chunks will be evicted
|
||||
self._gemini_manager.adjust_layout(
|
||||
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
|
||||
)
|
||||
|
||||
# fetch the rest synchronously
|
||||
# TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步)
|
||||
for chunk in chunks_fetch_sync:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
|
|
|
@ -154,7 +154,6 @@ class GeminiManager:
|
|||
|
||||
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
self._compute_idx += 1
|
||||
# TODO(haze188): _compute_list 记录块的访问顺序
|
||||
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Linear(in_features=10, out_features=5, bias=False) 50\n",
|
||||
"Linear(in_features=5, out_features=10, bias=False) 50\n",
|
||||
"Linear(in_features=10, out_features=10, bias=False) 100\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class Toy(nn.Module):\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" super(Toy, self).__init__()\n",
|
||||
" self.fc1 = nn.Linear(10,5, bias=False)\n",
|
||||
" self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n",
|
||||
"\n",
|
||||
"t = Toy()\n",
|
||||
"for mod in t.modules():\n",
|
||||
" for p in mod.parameters(recurse=False):\n",
|
||||
" print(mod, p.numel())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([5, 10]) 50\n",
|
||||
"torch.Size([10, 5]) 50\n",
|
||||
"torch.Size([10, 10]) 100\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for p in t.parameters():\n",
|
||||
" print(p.shape, p.numel())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'224'"
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conf_str = torch.__config__.parallel_info()\n",
|
||||
"inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n",
|
||||
"max_concurrency = inter_str.split(\"\\n\")[0]\n",
|
||||
"max_concurrency"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 0\n",
|
||||
"0 1\n",
|
||||
"0 2\n",
|
||||
"1 0\n",
|
||||
"1 1\n",
|
||||
"1 2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i in range(3):\n",
|
||||
" for j in range(3):\n",
|
||||
" print(i, j)\n",
|
||||
" if i == 1 and j == 2:break\n",
|
||||
" else:\n",
|
||||
" continue\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "colossalai-py310",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
|||
export GPUNUM=${GPUNUM:-1}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-2}
|
||||
# export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
|
||||
|
||||
|
|
|
@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module):
|
|||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
return torch.cuda.memory_allocated() / 1024**2 # 转换成MB
|
||||
|
||||
|
||||
def get_mem_info(prefix=""):
|
||||
return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
def get_model_size(model: nn.Module): # 得到模型参数量
|
||||
total_numel = 0
|
||||
for module in model.modules():
|
||||
for p in module.parameters(recurse=False):
|
||||
|
|
|
@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [
|
|||
"offload_optim_frac": 1.0,
|
||||
"offload_param_frac": 1.0,
|
||||
}, # zero3-offload-all
|
||||
{"placement_policy": "auto"},
|
||||
# {"placement_policy": "auto"},
|
||||
]
|
||||
|
||||
# this model is large enough to slice to chunks
|
||||
|
|
Loading…
Reference in New Issue