mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			118 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			118 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
#!/usr/bin/env python
 | 
						|
# -*- encoding: utf-8 -*-
 | 
						|
 | 
						|
"""
 | 
						|
A .bin file corresponds to a Dataset instance here.
 | 
						|
"""
 | 
						|
 | 
						|
import json
 | 
						|
import mmap
 | 
						|
import os
 | 
						|
import threading
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
 | 
						|
 | 
						|
class JsonlDataset(torch.utils.data.Dataset):
 | 
						|
    """
 | 
						|
 | 
						|
    JSONL format is expected to roughly follow that of The Pile.
 | 
						|
    One-line-per-document of the form:
 | 
						|
    ```
 | 
						|
    {
 | 
						|
        "tokens": List[int],
 | 
						|
    }
 | 
						|
    ```
 | 
						|
 | 
						|
    Note that only the "tokens" key is used.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, path: str, dataset_type_id: int = 0, min_length=50):
 | 
						|
        self.path = path
 | 
						|
        self.threadlocal = threading.local()
 | 
						|
        resolved_path = Path(path).resolve()
 | 
						|
        self.resolved_path = resolved_path
 | 
						|
        self.meta = Path(f"{resolved_path}.meta")
 | 
						|
        self.type_id = dataset_type_id
 | 
						|
 | 
						|
        # only build the cache in on the primary worker to prevent overloading nfs
 | 
						|
        assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}"
 | 
						|
        try:
 | 
						|
            with open(self.meta, "rb") as f:
 | 
						|
                meta = np.load(f)
 | 
						|
        except Exception as e:
 | 
						|
            print(f"Cannot load file {self.meta}...")
 | 
						|
            raise e
 | 
						|
        self.offsets = meta[:, 0]
 | 
						|
        self.lengths = meta[:, -1]
 | 
						|
 | 
						|
        if min_length > 0:
 | 
						|
            mask = self.lengths >= min_length
 | 
						|
            self.old_lengths = self.lengths.copy()
 | 
						|
            self.old_length = len(self.offsets)
 | 
						|
            self.offsets = self.offsets[mask]
 | 
						|
            self.lengths = self.lengths[mask]
 | 
						|
 | 
						|
    def __getitem__(self, idx):
 | 
						|
        f = self._get_mmap()
 | 
						|
        position = self.offsets[idx]
 | 
						|
        f.seek(position)
 | 
						|
        item = f.readline().decode("utf-8")
 | 
						|
        try:
 | 
						|
            item = json.loads(item)
 | 
						|
            item["length"] = len(item["tokens"])  # add a length info
 | 
						|
            item["type_id"] = self.type_id
 | 
						|
        except Exception as err:
 | 
						|
            raise json.decoder.JSONDecodeError(
 | 
						|
                doc=self.path,
 | 
						|
                pos=position,
 | 
						|
                msg=(
 | 
						|
                    f"Error while loading JSONL line in file {self.path} at byte "
 | 
						|
                    f"{position}. Contents of line:\n{item}\n{err}"
 | 
						|
                ),
 | 
						|
            )
 | 
						|
        return item
 | 
						|
 | 
						|
    def get_dataset_name(self):
 | 
						|
        return str(self.resolved_path)
 | 
						|
 | 
						|
    def _get_mmap(self):
 | 
						|
        if not hasattr(self.threadlocal, "handles"):
 | 
						|
            with open(self.path, "rb") as f:
 | 
						|
                mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
 | 
						|
                self.threadlocal.handles = [f, mm]
 | 
						|
                if self.path.endswith(".gz") or self.path.endswith(".bz") or self.path.endswith(".bz2"):
 | 
						|
                    raise NotImplementedError(
 | 
						|
                        "Compressed files are not supported because .seek() would require "
 | 
						|
                        "rereading the entire file, making performance too slow."
 | 
						|
                    )
 | 
						|
        return self.threadlocal.handles[-1]
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        self.__dict__ = state
 | 
						|
        self.threadlocal = threading.local()
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        d = {}
 | 
						|
        for i, v in self.__dict__.items():
 | 
						|
            if i != "threadlocal":
 | 
						|
                d[i] = v
 | 
						|
        return d
 | 
						|
 | 
						|
    def __del__(self):
 | 
						|
        if hasattr(self.threadlocal, "handles"):
 | 
						|
            # cleanup files we opened on initialization
 | 
						|
            while self.threadlocal.handles:
 | 
						|
                self.threadlocal.handles.pop().close()
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def exists(path):
 | 
						|
        return os.path.exists(path)
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        # Virtual length of the dataset depends on the epoch number if the number of documents
 | 
						|
        # is not perfectly divisible by the data_subshard_count
 | 
						|
        return len(self.offsets)
 |