@ -3,17 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
from functools import lru_cache
import os
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
import numpy as np
@ -88,16 +87,7 @@ def write_longs(f, a):
f . write ( np . array ( a , dtype = np . int64 ) )
dtypes = {
1 : np . uint8 ,
2 : np . int8 ,
3 : np . int16 ,
4 : np . int32 ,
5 : np . int64 ,
6 : np . float ,
7 : np . double ,
8 : np . uint16
}
dtypes = { 1 : np . uint8 , 2 : np . int8 , 3 : np . int16 , 4 : np . int32 , 5 : np . int64 , 6 : float , 7 : np . double , 8 : np . uint16 }
def code ( dtype ) :
@ -136,10 +126,8 @@ class IndexedDataset(torch.utils.data.Dataset):
def read_index ( self , path ) :
with open ( index_file_path ( path ) , ' rb ' ) as f :
magic = f . read ( 8 )
assert magic == self . _HDR_MAGIC , (
' Index file doesn \' t match expected format. '
' Make sure that --dataset-impl is configured properly. '
)
assert magic == self . _HDR_MAGIC , ( ' Index file doesn \' t match expected format. '
' Make sure that --dataset-impl is configured properly. ' )
version = f . read ( 8 )
assert struct . unpack ( ' <Q ' , version ) == ( 1 , )
code , self . element_size = struct . unpack ( ' <QQ ' , f . read ( 16 ) )
@ -198,13 +186,11 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists ( path ) :
return (
os . path . exists ( index_file_path ( path ) ) and os . path . exists ( data_file_path ( path ) )
)
return ( os . path . exists ( index_file_path ( path ) ) and os . path . exists ( data_file_path ( path ) ) )
@property
def supports_prefetch ( self ) :
return False # avoid prefetching to save memory
return False # avoid prefetching to save memory
class IndexedCachedDataset ( IndexedDataset ) :
@ -233,7 +219,7 @@ class IndexedCachedDataset(IndexedDataset):
for i in indices :
self . cache_index [ i ] = ptx
size = self . data_offsets [ i + 1 ] - self . data_offsets [ i ]
a = self . cache [ ptx : ptx + size ]
a = self . cache [ ptx : ptx + size ]
self . data_file . seek ( self . data_offsets [ i ] * self . element_size )
self . data_file . readinto ( a )
ptx + = size
@ -250,7 +236,7 @@ class IndexedCachedDataset(IndexedDataset):
tensor_size = self . sizes [ self . dim_offsets [ i ] : self . dim_offsets [ i + 1 ] ]
a = np . empty ( tensor_size , dtype = self . dtype )
ptx = self . cache_index [ i ]
np . copyto ( a , self . cache [ ptx : ptx + a . size ] )
np . copyto ( a , self . cache [ ptx : ptx + a . size ] )
return a
elif isinstance ( idx , slice ) :
# Hack just to make this work, can optimizer later if necessary
@ -261,15 +247,7 @@ class IndexedCachedDataset(IndexedDataset):
class IndexedDatasetBuilder ( object ) :
element_sizes = {
np . uint8 : 1 ,
np . int8 : 1 ,
np . int16 : 2 ,
np . int32 : 4 ,
np . int64 : 8 ,
np . float : 4 ,
np . double : 8
}
element_sizes = { np . uint8 : 1 , np . int8 : 1 , np . int16 : 2 , np . int32 : 4 , np . int64 : 8 , float : 4 , np . double : 8 }
def __init__ ( self , out_file , dtype = np . int32 ) :
self . out_file = open ( out_file , ' wb ' )
@ -332,12 +310,15 @@ def _warmup_mmap_file(path):
class MMapIndexedDataset ( torch . utils . data . Dataset ) :
class Index ( object ) :
_HDR_MAGIC = b ' MMIDIDX \x00 \x00 '
@classmethod
def writer ( cls , path , dtype ) :
class _Writer ( object ) :
def __enter__ ( self ) :
self . _file = open ( path , ' wb ' )
@ -384,10 +365,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __init__ ( self , path , skip_warmup = False ) :
with open ( path , ' rb ' ) as stream :
magic_test = stream . read ( 9 )
assert self . _HDR_MAGIC == magic_test , (
' Index file doesn \' t match expected format. '
' Make sure that --dataset-impl is configured properly. '
)
assert self . _HDR_MAGIC == magic_test , ( ' Index file doesn \' t match expected format. '
' Make sure that --dataset-impl is configured properly. ' )
version = struct . unpack ( ' <Q ' , stream . read ( 8 ) )
assert ( 1 , ) == version
@ -406,16 +385,16 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self . _bin_buffer_mmap = np . memmap ( path , mode = ' r ' , order = ' C ' )
self . _bin_buffer = memoryview ( self . _bin_buffer_mmap )
print ( " reading sizes... " )
self . _sizes = np . frombuffer (
self . _bin_buffer ,
dtype = np . int32 ,
count = self . _len ,
offset = offset )
self . _sizes = np . frombuffer ( self . _bin_buffer , dtype = np . int32 , count = self . _len , offset = offset )
print ( " reading pointers... " )
self . _pointers = np . frombuffer ( self . _bin_buffer , dtype = np . int64 , count = self . _len ,
self . _pointers = np . frombuffer ( self . _bin_buffer ,
dtype = np . int64 ,
count = self . _len ,
offset = offset + self . _sizes . nbytes )
print ( " reading document index... " )
self . _doc_idx = np . frombuffer ( self . _bin_buffer , dtype = np . int64 , count = self . _doc_count ,
self . _doc_idx = np . frombuffer ( self . _bin_buffer ,
dtype = np . int64 ,
count = self . _doc_count ,
offset = offset + self . _sizes . nbytes + self . _pointers . nbytes )
def __del__ ( self ) :
@ -480,8 +459,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __getitem__ ( self , idx ) :
if isinstance ( idx , int ) :
ptr , size = self . _index [ idx ]
np_array = np . frombuffer ( self . _bin_buffer , dtype = self . _index . dtype ,
count = size , offset = ptr )
np_array = np . frombuffer ( self . _bin_buffer , dtype = self . _index . dtype , count = size , offset = ptr )
return np_array
elif isinstance ( idx , slice ) :
start , stop , step = idx . indices ( len ( self ) )
@ -491,8 +469,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self . _index . _sizes [ idx ]
offsets = list ( accumulate ( sizes ) )
total_size = sum ( sizes )
np_array = np . frombuffer ( self . _bin_buffer , dtype = self . _index . dtype ,
count = total_size , offset = ptr )
np_array = np . frombuffer ( self . _bin_buffer , dtype = self . _index . dtype , count = total_size , offset = ptr )
sents = np . split ( np_array , offsets [ : - 1 ] )
return sents
@ -506,8 +483,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
if length is None :
length = size - offset
ptr + = offset * np . dtype ( self . _index . dtype ) . itemsize
np_array = np . frombuffer ( self . _bin_buffer , dtype = self . _index . dtype ,
count = length , offset = ptr )
np_array = np . frombuffer ( self . _bin_buffer , dtype = self . _index . dtype , count = length , offset = ptr )
return np_array
@property
@ -530,12 +506,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists ( path ) :
return (
os . path . exists ( index_file_path ( path ) ) and os . path . exists ( data_file_path ( path ) )
)
return ( os . path . exists ( index_file_path ( path ) ) and os . path . exists ( data_file_path ( path ) ) )
class MMapIndexedDatasetBuilder ( object ) :
def __init__ ( self , out_file , dtype = np . int64 ) :
self . _data_file = open ( out_file , ' wb ' )
self . _dtype = dtype