fix bugs in attention.py and request_handler.py

pull/5258/head
yuehuayingxueluo 2024-01-08 12:35:06 +08:00 committed by FrankLeeeee
parent bfd9b1b494
commit 47e53eaa1c
6 changed files with 208 additions and 60 deletions

View File

@ -214,9 +214,6 @@ class InferenceEngine:
List[str]: Decoded finished sequences generated by one step. List[str]: Decoded finished sequences generated by one step.
""" """
if self.verbose:
self.logger.info("Running generation step")
output_list = [] output_list = []
batch = self.request_handler.schedule() batch = self.request_handler.schedule()
@ -224,6 +221,7 @@ class InferenceEngine:
batch, batch,
self.k_cahce, self.k_cahce,
self.v_cache, self.v_cache,
padding_id=self.tokenizer.pad_token_id,
) )
logits = logits[:, -1, :] logits = logits[:, -1, :]

View File

@ -110,6 +110,10 @@ class RequestHandler:
self.prefill_batch.init_batch(self.running_list.prefill) self.prefill_batch.init_batch(self.running_list.prefill)
return self.prefill_batch return self.prefill_batch
if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
return self.running_batch return self.running_batch
def add_sequence(self, req: Sequence): def add_sequence(self, req: Sequence):

View File

@ -29,47 +29,50 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
for block_idx in range(block_num - 1): for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0) cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
token_id += block_size token_id += block_size
cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0) cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
1, 2, 0
)
elif type == "decoding": elif type == "decoding":
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding." assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1) source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz): for i in range(bsz):
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i].permute(0, 1) cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
return cache return cache
def convert_kvcache(source, cache, lengths, block_tables): def convert_kvcache(cache, lengths, block_tables):
""" """
Func: convert key/value cache for calculation Func: convert key/value cache for calculation
Args: key/value(source): shape [bsz, 1, num_heads, head_size] Args: cache: shape [num_blocks, num_heads, head_size, block_size]
cache: shape [num_blocks, num_heads, head_size, block_size]
lengths: key/value length lengths: key/value length
block_tables block_tables
""" """
num_blocks, num_heads, head_size, block_size = cache.shape num_blocks, num_heads, head_size, block_size = cache.shape
needed_blocks = (lengths + block_size - 1) // block_size needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = (lengths - 1) % block_size num_remaing_tokens = lengths % block_size
num_remaing_tokens[num_remaing_tokens == 0] += block_size
bsz = block_tables.shape[0] bsz = block_tables.shape[0]
seq_len = max(lengths) seq_len = max(lengths)
padded_cache = [] padded_cache = []
for i in range(bsz): for i in range(bsz):
cache1 = cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size)
cache2 = cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1)
_cache = torch.cat( _cache = torch.cat(
( (
cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size), cache1,
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0), cache2,
), ),
dim=0, dim=0,
) )
concat_cache = torch.cat((_cache, source[i]), dim=0) padding = seq_len - _cache.size(0)
padding = seq_len - concat_cache.size(0)
if padding > 0: if padding > 0:
concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1)) _cache = F.pad(_cache, (0, 0, 0, 0, 0, 1))
padded_cache.append(concat_cache) padded_cache.append(_cache)
return torch.stack(padded_cache, dim=0) return torch.stack(padded_cache, dim=0)

View File

@ -1,11 +1,22 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel import torch.nn as nn
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
repeat_kv,
)
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
from colossalai.inference.struct import BatchInfo from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import context_attention_unpadded
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
def rotate_half(x): def rotate_half(x):
@ -27,24 +38,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def llama_causal_lm_forward( def llama_causal_lm_forward(
self: LlamaForCausalLM, self: LlamaForCausalLM,
batch: BatchInfo = None, batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None, k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None,
padding_id: int = None,
): ):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward( hidden_states = llama_model_forward(
@ -52,6 +51,7 @@ def llama_causal_lm_forward(
batch=batch, batch=batch,
k_caches=k_caches, k_caches=k_caches,
v_caches=v_caches, v_caches=v_caches,
padding_id=padding_id,
) )
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits return logits
@ -62,13 +62,20 @@ def llama_model_forward(
batch: BatchInfo = None, batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None, k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None,
padding_id: int = None,
): ):
input_ids = batch.get_batch_inputs() input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor() block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths() sequence_lengths = batch.get_sequence_lengths()
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. attention_mask = batch.get_attn_mask(padding_id)
position_ids = generate_padding_position_id(input_ids)
if batch.is_prompts:
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
position_ids = generate_padding_position_id(attention_mask)
else:
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for layer_id, decoder_layer in enumerate(self.layers): for layer_id, decoder_layer in enumerate(self.layers):
@ -80,6 +87,7 @@ def llama_model_forward(
v_cache=v_caches[layer_id], v_cache=v_caches[layer_id],
is_prompts=batch.is_prompts, is_prompts=batch.is_prompts,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
attention_mask=attention_mask,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
@ -96,6 +104,7 @@ def llama_decoder_layer_forward(
v_cache: torch.Tensor = None, v_cache: torch.Tensor = None,
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: int = None, sequence_lengths: int = None,
attention_mask: torch.Tensor = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
@ -109,6 +118,7 @@ def llama_decoder_layer_forward(
v_cache=v_cache, v_cache=v_cache,
is_prompts=is_prompts, is_prompts=is_prompts,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
attention_mask=attention_mask,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -132,6 +142,7 @@ def llama_attn_forward(
v_cache: torch.Tensor = None, v_cache: torch.Tensor = None,
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: torch.Tensor = None, sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@ -139,9 +150,7 @@ def llama_attn_forward(
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = sequence_lengths[0].item()
if not is_prompts:
kv_seq_len = kv_seq_len + sequence_lengths[0].item()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@ -153,20 +162,26 @@ def llama_attn_forward(
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
_, _, _, block_size = k_cache.shape
# NOTE: context_attention_unpadded is used for testing accuracy and we can only use aligned inputs.
# The code below will be uncommented after the development of attention-related kernel is completed.
if is_prompts: if is_prompts:
attn_output = context_attention_unpadded( attn_output = pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = pad_decoding_forward(
query_states,
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
self.layer_idx,
self.attention_dropout,
self.training,
) )
# else:
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -175,13 +190,129 @@ def llama_attn_forward(
return attn_output return attn_output
def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor: def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
# Replace this code and use a more flexible method to obtain padding_id, avoiding directly setting padding_id like this.
padding_id = 2
attention_mask = input_ids.ne(padding_id).long()
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids return position_ids
# def unpad_inputs(input_ids: torch.Tensor):
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
seqlens = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices, seqlens)
def pad_decoding_forward(
query: torch.Tensor, # [bsz, 1, num_heads, head_size]
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None,
layer_id: int = 0,
attention_dropout: float = None,
training: bool = False,
):
bsz, query_length, num_heads, head_size = query.shape
seq_len = max(lengths)
copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(value, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
key = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
value = convert_kvcache(v_cache, lengths, block_tables)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, query.dtype, query_length)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, query_length), query.dtype, query.device, past_key_values_length=seq_len - query_length
)
if padding_mask is not None:
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
attn_output = torch.matmul(attn_weights, value)
if attn_output.size() != (bsz, num_heads, 1, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
return attn_output
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None,
):
# Firt, do shape verification
bsz, seq_len, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2]
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
q = q.transpose(1, 2)
k = repeat_kv(k.transpose(1, 2), num_kv_groups)
v = repeat_kv(v.transpose(1, 2), num_kv_groups)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len
)
if padding_mask is not None:
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min)
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (bsz, num_heads, seq_len, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
del attn_weights
return attn_output

View File

@ -321,5 +321,13 @@ class BatchInfo:
return torch.tensor(len_list, dtype=torch.int, device=self.device) return torch.tensor(len_list, dtype=torch.int, device=self.device)
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
past_values = []
for seq in self.sequences_set:
past_values.append(seq.input_token_id + seq.output_token_id)
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"

View File

@ -9,7 +9,7 @@ from transformers import AutoTokenizer, GenerationConfig
import colossalai import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed): def setup_seed(seed):
@ -24,21 +24,24 @@ def check_inference_engine(test_cai=False):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = transformers.LlamaForCausalLM( model = transformers.LlamaForCausalLM(
transformers.LlamaConfig( transformers.LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
) )
).cuda() ).cuda()
inputs = [ inputs = [
"介绍一下北京,", "介绍一下今天的北京,",
"介绍一下武汉,", "介绍一下武汉,",
] ]
output_len = 16
do_sample = True
if test_cai: if test_cai:
inference_config = InferenceConfig(max_output_len=1) inference_config = InferenceConfig(max_output_len=output_len)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=True, top_p=0.5, top_k=50) generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50)
outputs = inference_engine.generate(generation_config) outputs = inference_engine.generate(generation_config)
else: else:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@ -46,7 +49,7 @@ def check_inference_engine(test_cai=False):
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda() inputs = inputs.cuda()
generation_config = GenerationConfig( generation_config = GenerationConfig(
do_sample=True, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1 do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len
) )
outputs = model.generate(inputs, generation_config=generation_config) outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
@ -64,6 +67,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine(): def test_inference_engine():
spawn(run_dist, 1) spawn(run_dist, 1)