mirror of https://github.com/hpcaitech/ColossalAI
[fix] multi graphs capture error
parent
b2c0d9ff2b
commit
9dec66fad6
|
@ -1,4 +1,3 @@
|
||||||
import copy
|
|
||||||
import time
|
import time
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
@ -110,7 +109,6 @@ class InferenceEngine:
|
||||||
|
|
||||||
t_capture_begin = time.perf_counter()
|
t_capture_begin = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
block_size = self.inference_config.block_size
|
block_size = self.inference_config.block_size
|
||||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||||
|
|
||||||
|
@ -133,7 +131,6 @@ class InferenceEngine:
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
for batch_size in reversed(batch_size_capture_list):
|
for batch_size in reversed(batch_size_capture_list):
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue