首页 > 其他分享 >使用人类棋手棋盘数据训练围棋机器人,实现数据预处理

使用人类棋手棋盘数据训练围棋机器人,实现数据预处理

时间:2023-06-14 11:37:46浏览次数:43  
标签:name zip 棋手 数据 self num file data 预处理


知己知彼,百战不殆。我们要打造一个能胜过人类的机器人,就必须要让机器人掌握人类的围棋思维模式,因此我们就需要使用人类棋手留下的棋盘数据训练机器人,让它从数据中掌握人类围棋思维存在的模式和套路。

幸运的是,我们能够通过围棋服务器拿到很多由人落子后产生的棋盘数据。很多围棋服务器公开了这些数据,这些围棋数据以一种叫Smart Game Format的方式存储,我们可以将其下载下来进行预处理后用于训练我们的神经网络,如此得到的网络,它的落子能力将远远超过上一节我们训练的网络机器人。

我们从当下最流行的围棋服务器下载棋盘数据,这个服务器叫KGS(Kiseido Go Server).在下载数据前,我们先了解具体的数据格式。它是一种文本格式数据,它通常用两个大写字母来表示棋盘属性,例如表示棋盘规格时使用的字母是SZ,然后在后面用大括号来容纳属性对应的数值,对于一个9*9的棋盘而言,对应的描述属性为SZ[9]。

它用W来表示白子,如果白子落在第三行第三列,对应的记录就是W[cc],也就是它使用字符次序来表示数字,因此c就表示数字3,同时B表示黑子,如果描述黑子落在第7行,第3列,对应的属性描述就是B[gc],字母g表示7。如果在某一步白棋或黑子pass,对应的描述就是B[],W[],也就是中括号内没有内容。

由此我们看下面一段数据对棋盘的描述:

(;FF[4] GM[1] SZ[9] HA[0] KM[6.5] RU[Japanese] RE[W+9.5] ;B[gc];W[cc];B[cg];W[gg];B[hf];W[gf];B[hg];W[hh];B[ge];W[df];B[dg] ;W[eh];B[cf];W[be];B[eg];W[fh];B[de];W[ec];B[fb];W[eb];B[ea];W[da] ;B[fa];W[cb];B[bf];W[fc];B[gb];W[fe];B[gd];W[ig];B[bd];W[he];B[ff] ;W[fg];B[ef];W[hd];B[fd];W[bi];B[bh];W[bc];B[cd];W[dc];B[ac];W[ab] ;B[ad];W[hc];B[ci];W[ed];B[ee];W[dh];B[ch];W[di];B[hb];W[ib];B[ha] ;W[ic];B[dd];W[ia];B[];
TW[aa][ba][bb][ca][db][ei][fi][gh][gi][hf][hg][hi][id][ie][if] [ih][ii]
TB[ae][af][ag][ah][ai][be][bg][bi][ce][df][fe][ga] W[])

其中FF[4]表示数据格式的版本号,有点类似于操作系统版本。GM[1]表示比赛第一盘,HA表示让子,HA[0]表示没有让子。RU[Japanese]表示围棋遵循日本规则,RE[W+9.5]表示白子以9.5分优势获胜,KM[6.5]表示第二落子的人获得6.5分补偿。接下来 以分好分割的就是双方落子方式。最后TW表示的是白子地盘,TB表示黑子占据的地盘。

理解了数据格式后,我们可以通过网址 https://www.u-go.net/gamerecords/ 下载棋盘数据:

这上面都存储了六段以上高手对弈的棋盘数据。我们接下来将会创建一个爬虫机器人,爬去网页,分析里面链接后自动将数据下载到本地并解压,在后面我们会具体给出爬虫的实现代码,当爬虫运行后,它会解析页面,找出下载链接,依次把文件下载到指定文件夹中,其运行信息如下:

>>>Downloading content/gdrive/My Drive/GO_RECORD/KGS-2006-19-10388-.tar.gz
worker is running
>>>Downloading content/gdrive/My Drive/GO_RECORD/KGS-2005-19-13941-.tar.gz
worker is running
>>>Downloading content/gdrive/My Drive/GO_RECORD/KGS-2004-19-12106-.tar.gz
worker is running
>>>Downloading content/gdrive/My Drive/GO_RECORD/KGS-2003-19-7582-.tar.gz
worker is running
>>>Downloading content/gdrive/My Drive/GO_RECORD/KGS-2002-19-3646-.tar.gz
worker is running
>>>Downloading content/gdrive/My Drive/GO_RECORD/KGS-2001-19-2298-.tar.gz

下载完数据后,我们会用代码解读棋盘数据,并将数据所表示的棋盘落子过程重放一遍,棋盘数据的解读烦琐耗时,为了将精力集中到网络训练上,我们将直接使用一个已经完成的数据解读类来帮我们解读棋盘数据。

首先我们先构造一段虚拟棋盘数据:

"(;GM[1] FF[4] SZ[9];B[ee];W[ef];B[ff])" + \
";W[df];B[fe];W[fc];B[ef];W[gd];B[fb]"

然后使用棋盘数据读取工具类Sgf_game读取上面信息,将其转换成白棋和黑棋的落子信息,然后启动一个虚拟棋盘,将上面的落子步骤显示出来,当我们正确读取上面棋盘信息后,我们可以输出以下模拟棋盘:

19  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
18  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
17  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
16  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
15  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
14  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
13  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
12  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
11  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
10  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 9  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 8  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 7  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 6  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 5  .  .  .  . x .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 4  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 3  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 2  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 1  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
   A B C D E F G H J K L M N O P Q R S T
19  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
18  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
17  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
16  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
15  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
14  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
13  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
12  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
11  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
10  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 9  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 8  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 7  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 6  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 5  .  .  .  . x .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 4  .  .  .  . o .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 3  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 2  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 1  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
   A B C D E F G H J K L M N O P Q R S T
19  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
18  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
17  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
16  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
15  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
14  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
13  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
12  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
11  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
10  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 9  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 8  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 7  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 6  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 5  .  .  .  . x .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 4  .  .  .  . ox .  .  .  .  .  .  .  .  .  .  .  .  . 
 3  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 2  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
 1  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
   A B C D E F G H J K L M N O P Q R S T

完成了数据的解析后,我们就得创建数据处理器,将下载的棋盘数据转换成网络可以识别的向量格式,然后喂给网络,滋养网络的发育。它首先从下载的棋盘描述文件中选取出一部分进行解压,然后读取解压后的数据文件,将它描述的棋盘转换为上一节对应的棋盘编码,同时将当前棋盘与下一步落子对应起来。

我们要把数据分割成两部分,其中时间在2014年前的数据作为测试数据,之后的数据作为训练数据。我们把数据读入内存,按照上面描描述解析数据后,将解析后的数据存储起来以便以后使用,因为数据解析是非常耗时耗力的”脏活累活“,我们尽量做一次即可。

首先我们完成下载数据的代码:

import os
import sys
import multiprocessing
import six
from urllib.request import urlopen, urlretrieve

#创建下载线程函数
def  worker(url_and_target):
  try:
    (url, target_path) = url_and_target
    print('>>>Downloading ' + target_path)
    urlretrieve(url, target_path)
  except (KeyboardInterrupt, SystemExit):
    print('Exiting download worker')
    
class KGSDownloader:
  def __init__(self, kgs_url = 'https://www.u-go.net/gamerecords/', 
               download_page = 'kgs_index.html',
              data_directory = '/content/gdrive/My Drive/GO_RECORD/'):
    self.kgs_url = kgs_url
    self.download_page = download_page
    self.data_directory = data_directory
    #下载文件信息
    self.file_info = []
    #下载数据对应的url
    self.urls = []
    #启动下载页面解析流程
    self.loading()
    
  def  download_files(self):
    print('begin download')
    '''
    根据CPU核数创建下载线程同时下载棋盘数据
    '''
    if not os.path.isdir(self.data_directory):
      os.makedirs(self.data_directory)
    
    urls_to_download = []
    print('file_info: ', self.file_info)
    
    for file_info in self.file_info:
      url = file_info['url']
      file_name = file_info['filename']
      #如果文件没有下载过就进行下载
      print('filename is : ', file_name)
      if not os.path.isfile(self.data_directory + '/' + file_name):
        urls_to_download.append((url, self.data_directory + '/' + file_name))
      
    cores = multiprocessing.cpu_count()
    #根据CPU核数创建线程池
    pool = multiprocessing.Pool(processes = cores)
    print('cores: ', cores)
    try:
      #将要下载的文件URL分发给每个下载线程
      print('pool imap: ', urls_to_download)
      it = pool.imap(worker, urls_to_download)
      print('it: ', it)
      for _ in it:
        pass
      #关闭线程池防止资源泄露
      pool.close()
      pool.join()
      print('pool imap end')
    except KeyboardInterrupt:
      print('>>>Caught KeyboardInterrupt, terminating works')
      pool.terminate()
      pool.join()
      sys.exit(-1)
      
  def  create_download_page(self):
    print('create_download_page: ', )
    
    if os.path.isfile(self.download_page):
      print('>>> Reading download page: ', self.download_page)
      download_file = open(self.download_page, 'r')
      download_contents = download_file.read()
      print('contents: ', download_contents)
      download_file.close()
    else:
      print('>>> Downloading download page')
      fp = urlopen(self.kgs_url)
      data = six.text_type(fp.read())
      fp.close()
      download_contents = data
      download_file = open(self.download_page, 'w')
      download_file.write(download_contents)
      download_file.close()
      
    return download_contents
  
  def  loading(self):
    '''
    从html页面中将下载数据的文件名以及对应的url抽取出来
    '''
    download_contents = self.create_download_page()
    print('download contents: ', download_contents)
    
    split_page = [item for item in download_contents.split('<a href="') if item.startswith("https://")]
    for item in split_page:
      #在html页面源码中,数据下载链接在"Download"字符串前面
      download_url = item.split('">Download')[0]
      if download_url.endswith('.tar.gz'):
        self.urls.append(download_url)
    
    '''
    下载文件名格式如下:
    KGS-2019_01-19-2095-.tar.gz
    2019是年份,2095是盘数
    2015年之前的文件名在年份之后跟着的是'-'而不是'_'这点要注意
    '''
    for url in self.urls:
      filename = os.path.basename(url)
      split_file_name = filename.split('-')
      num_games = int(split_file_name[len(split_file_name) - 2])
      print(filename + ' ' + str(num_games))
     
      self.file_info.append({'url': url, 'filename':filename,
                            'num_games' : num_games})
      
downloader = KGSDownloader()
downloader.download_files()

在上面代码中,我们启动一个线程池,你的电脑有几核,它就能生成几个线程同时下载数据。首先代码先从解析下载页面的html代码,从中解析出下载链接,最后再将下载链接依次分发给下载线程进行下载。

当把数据下载完毕后,我们需要从下载的数据中选取需要的数据。下载数据总共有17000盘棋局作用,我们使用下面代码从下载数据中选取需要的数据量:

import random
import os

'''
把数据分成两部分,一部分是测试数据,一部分是训练数据,为了保持数据集稳定,我们只采用不晚于2014年12月的数据。
下面代码先将下载数据中选定一定的棋盘数作为测试数据集,剩下的全部作为训练数据集
'''

class  Sampler:
  def  __init__(self, data_dir = '/content/gdrive/My Drive/GO_RECORD/',
                num_test_games = 100,
                cap_year = 2015, seed = 1337):
    self.data_dir = data_dir
    self.num_test_games = num_test_games
    self.test_games = []
    self.train_games = []
    self.test_folder = 'test_samples.py'
    self.cap_year = cap_year
    
    random.seed(seed)
    self.compute_test_samples()
    
  def  draw_data(self, data_type, num_samples):
    '''
    data_type 表明要抽取的数据是训练数据还是测试数据
    '''
    if  data_type == 'test':
      return  self.test_games
    elif  data_type == 'train' and num_samples is not None:
      return self.draw_training_sampels(num_samples)
    elif  data_type == 'train' and num_samples is None:
      return self.draw_all_training()
    
    raise  ValueError(data_type + ' is not a valid data type')
    
  def  draw_samples(self, num_sample_games):
    available_games = []
    loader = KGSDownloader(data_directory = self.data_dir)
    
    for fileinfo in loader.file_info:
      filename = fileinfo['filename']
      year = int(filename.split('-')[1].split('_')[0])
      if year > self.cap_year:
        continue
      
      num_games = fileinfo['num_games']
      for i in range(num_games):
        available_games.append((filename, i))
      
    print('>>>Total number of games used: ' + str(len(available_games)))
    
    sample_set = set()
    while len(sample_set) < num_sample_games:
      sample = random.choice(available_games)
      if sample not in sample_set:
        sample_set.add(sample)
        
    print('Drawn ' + str(num_sample_games) + ' samples:' )
    return list(sample_set)
  
  def  draw_training_games(self):
    '''
    从下载数据中抽取训练数据,这些数据对应于2014年之后的棋盘记录
    同时cap_year以后的数据暂时不考虑以便维持训练数据的稳定性
    '''
    loader = KGSDownloader(data_directory = self.data_dir)
    for file_info in loader.file_info:
      filename = file_info['filename']
      year = int(filename.split('-')[1].split('_')[0])
      if  year > self.cap_year:
        continue
      num_games = file_info['num_games']
      for i in range(num_games):
        sample = (filename, i)
        if sample not in self.test_games:
          self.train_games.append(sample)
          
    print('total num training games: ' + str(len(self.train_games)))
    
  def  compute_test_samples(self):
    '''
    将2014年之前的棋盘数据作为测试数据放置到test_folder文件夹中
    '''
    if not os.path.isfile(self.test_folder):
      test_games = self.draw_samples(self.num_test_games)
      test_sample_file = open(self.test_folder, 'w')
      for sample in test_games:
        test_sample_file.write(str(sample) + '\n')
      test_sample_file.close()
      
    test_sample_file = open(self.test_folder, 'r')
    sample_contents = test_sample_file.read()
    test_sample_file.close()
    
    for line in sample_contents.split('\n'):
      if line != "":
        (filename, index) = eval(line)
        self.test_games.append((filename, index))
        
  def  draw_training_sampels(self, num_sample_games):
    available_games = []
    loader = KGSDownloader(data_directory = self.data_dir)
    for fileinfo in loader.file_info:
      filename = fileinfo['filename']
      year = int(filename.split('-')[1].split('_')[0])
      if year > self.cap_year:
        continue
      num_games = fileinfo['num_games']
      for i in range(num_games):
        available_games.append((filename, i))
    print('total num games: ' + str(len(available_games)))
    
    sample_set = set()
    while len(sample_set) < num_sample_games:
      #从所有数据中随机选取一个
      sample = random.choice(available_games)
      #由于测试数据集都是在2014年之前,因此不属于测试数据集的数据都可以作为训练数据
      if sample not in self.test_games:
        sample_set.add(sample)
    print('Drawn ' + str(num_sample_games) + ' samples:')
    return list(sample_set)
  
  
  def  draw_all_training(self):
    available_games = []
    loader = KGSDownloader(data_directory = self.data_dir)
    
    for fileinfo in loader.file_info:
      filename = fileinfo['filename']
      year = int(filename.split('-')[1].split('_')[0])
      if year > self.cap_year:
        continue
        
      if 'num_games' in fileinfo.keys():
        num_games = fileinfo['num_games']
      else:
        continue
      
      for i in range(num_games):
        available_games.append((filename, i))
    
    print('total num games: ' + str(len(available_games)))
    
    sample_set = set()
    for sample in available_games:
      if sample not in self.test_games:
        sample_set.add(sample)
        
    print('Drawn all samples, ie ' + str(len(sample_set)) + ' samples:')
    return list(sample_set)

上面代码将数据分成两部分,一部分作为测试数据,一部分作为训练数据。同时提供了灵活性,例如支持我们从17000盘数据中抽样出100盘数据等等。最后我们把数据读入内存,然后转换成前面我们讲过的棋盘编码:

import tarfile
import gzip
import glob
import shutil
import os.path

class GoDataProcessor:
  def  __init__(self, encoder = 'oneplane', data_directory = '/content/gdrive/My Drive/GO_RECORD'):
    if encoder == 'oneplane':
      self.encoder = OnePlaneEncoder(19) 
      
    self.data_dir = data_directory
    
  def  load_go_data(self, data_type = 'train', num_samples = 1000):
    #从下载数据中抽取出给定数量的训练数据
    sampler = Sampler(data_dir = self.data_dir)
    data = sampler.draw_data(data_type, num_samples)
    
    #将选中的文件数据进行解压,index 表示第几盘
    zip_names = set()
    indices_by_zip_name = {}
    for filename, index in data:
      zip_names.add(filename)
      if filename not in indices_by_zip_name:
        indices_by_zip_name[filename] = []
        
      indices_by_zip_name[filename].append(index)
      
    for zip_name in zip_names:
      #创建解压后的文件名
      base_name = zip_name.replace('.tar.gz', '')
      data_file_name = base_name + data_type
      print('process zip file with name: ', self.data_dir + '/' + data_file_name)
      print('is file check: ', os.path.isfile(self.data_dir + '/' + data_file_name))
      #if not os.path.isfile(self.data_dir + '/' + data_file_name):
      
      self.process_zip(zip_name, data_file_name, indices_by_zip_name[zip_name])
    
    #将数据描述的棋盘转为为上一节描述的编码和对应的落子
    features_and_labels = self.consolidate_games(data_type, data)
    return  features_and_labels
  
  def  process_zip(self, zip_file_name, data_file_name, game_list):
    #先把数据文件解压出来
    tar_file = self.unzip_data(zip_file_name)
    zip_file = tarfile.open(self.data_dir + '/' + tar_file)
    #获得.tar.gz解压后文件夹中所有文件的名字集合
    name_list = zip_file.getnames()
    #这里得到当前棋盘总共下了多少步棋,game_list对应的是第几盘比赛
    total_examples = self.num_total_examples(zip_file, game_list, name_list)
    #shape = [19. 19]
    shape = self.encoder.shape()
    #feature_shape 是数组,每个元素是[19,19]二维向量
    feature_shape = np.insert(shape, 0, np.asarray([total_examples]))
    features = np.zeros(feature_shape)
    #lables是一维数组,每个元素对应落子位置
    labels = np.zeros((total_examples, ))
    print('process_zip with features len: ', len(features))
    features_len = len(features)
    
    counter = 0
    for index in game_list:
      name = name_list[index + 1]
      if not name.endswith('.sgf'):
        raise ValueError(name + ' is not a valid sgf')
        
      sgf_content = zip_file.extractfile(name).read()
      sgf = Sgf_game.from_string(sgf_content)
      
      '''
      水平高的一方可能会让子,于是另一方可直接连续落子,我们先处理这种情况
      '''
      game_state, first_move_done = self.get_handicap(sgf)
      
      for item in sgf.main_sequence_iter():
        #依次将落子步骤读取出来
        color, move_tuple = item.get_move()
        point = None
        if color is not None:
          if move_tuple is not None:
            row, col = move_tuple
            point = Point(row + 1, col + 1)
            move = Move.play(point)
          else:
            move = Move.pass_turn()
          if  first_move_done and point is not None:
            #如果有让子,那么把对方落子后的棋盘当做训练数据,然后另一方落子方式当做训练标签
              
            features[counter] = self.encoder.encode(game_state)
            labels[counter] = self.encoder.encode_point(point)
            counter += 1
          
          #先按照落子步骤形成棋盘,下一次读取落子时它就会变成训练数据
          game_state = game_state.apply_move(move)
          first_move_done = True
      
    feature_file_base = self.data_dir + '/' + data_file_name + '_features_%d'
    label_file_base = self.data_dir + '/' + data_file_name + '_label_%d'
      
    #我们将加工好的数据存储成文件
    chunk = 0
    chunksize = 1024
    #每1024条记录当做一个chunk,每一个chunk单独存储
    while features.shape[0] >= chunksize:
      feature_file = feature_file_base % chunk
      label_file = label_file_base % chunk
      chunk += 1
      current_features, features = features[:chunksize], features[chunksize:]
      current_labels, labels = labels[:chunksize], labels[chunksize:]
      np.save(feature_file, current_features)
      np.save(label_file, current_labels)
        
  def  unzip_data(self, zip_file_name):
    #.tar.gz文件经过了两层压缩,首先解压gz压缩
    this_gz = gzip.open(self.data_dir + '/' + zip_file_name)
    #去掉尾部的.gz后缀
    tar_file = zip_file_name[0:-3]
    #创建.tar文件,将解压后gz压缩后的内容拷贝到该文件
    this_tar = open(self.data_dir + '/' + tar_file, 'wb')
    shutil.copyfileobj(this_gz, this_tar)
    return tar_file
    
  
  def num_total_examples(self, zip_file, game_list, name_list):
    '''
    #根据棋盘描述文件中的落子次数推算出训练数据的长度,每一次落子前的棋盘会成为训练数据,
    落子则对应训练标签,一旦落子后形成的棋盘就会成为新的训练数据
    '''
    total_examples = 0
    for index in game_list:
      name = name_list[index + 1]
      if name.endswith('.sgf'):
        #zip_file对应解压后的tar文件,其中包含很多.sgf文件,这里把指定的sgf文件内容读取出来
        sfg_content = zip_file.extractfile(name).read()
        sgf = Sgf_game.from_string(sfg_content)
        game_state, first_move_done = self.get_handicap(sgf)
        
        num_moves = 0
        for item in sgf.main_sequence_iter():
          color, move = item.get_move()
          if color is not None:
            if first_move_done:
              num_moves += 1
            first_move_done = True
            
        total_examples = total_examples + num_moves
      else:
        raise ValueError(name + ' is not a valid sgf')
    
    return total_examples
  
  @staticmethod
  def  get_handicap(sgf):
    #将让子时对应的落子摆到棋盘上
    go_board = Board(19, 19)
    first_move_done = False
    move = None
    game_state = GameState.new_game(19)
    if sgf.get_handicap() is not None and sgf.get_handicap() != 0:
      for setup in sgf.get_root().get_setup_stones():
        for move in setup:
          row, col = move
          go_board.place_stone(Player.black, Point(row + 1, col + 1))
        
      first_move_done = True
      game_state = GameState(go_board, Player.white, None, move)
      
    return game_state, first_move_done
  
  #前面我们把数据存储成多个小段,这里我们把多个小段读入内存合作一个整体
  def  consolidate_games(self, data_type, samples):
    files_needed = set(file_name for file_name , index in samples)
    file_names = []
    for zip_file_name in files_needed:
      file_name = zip_file_name.replace('.tar.gz', '') + data_type
      file_names.append(file_name)
    
    feature_list = []
    label_list = []
    for file_name in file_names:
      file_prefix = file_name.replace('.tar.gz', '')
      base = self.data_dir + '/' + file_prefix + '_features_*.npy'
      print('consolidate with file: ', base)
      for feature_file in glob.glob(base):
        label_file = feature_file.replace('features', 'labels')
        x = np.load(feature_file)
        y = np.load(label_file)
        x = x.astype('float32')
        y = to_categorical(y.astype(int), 19 * 19)
        feature_list.append(x)
        label_list.append(y)
    
    features = np.concatenate(feature_list, axis = 0)
    labels = np.concatenate(label_list, axis = 0)
    np.save('{}/features_{}.npy'.format(self.data_dir, data_type), features)
    np.save('{}/labels_{}.npy'.format(self.data_dir, data_type), labels)
    
    return features, labels

上面的代码将下载后的棋盘数据解压,然后读取sfg格式文件,并将它们编码转换成前面我们说过的棋盘编码,由此我们就可以获得用于训练网络的数据。但运行上面的代码将非常缓慢耗时,因此我们要使用多进程机制加载数据以便提升速度和效率,首先我们将创建一个DataGenerator,它将像水泵一样将数据抽取出来传递给网络:

class DataGenerator:
  '''
  创建一个数据抽取水泵,按照网络需要每次从数据池中抽取一小部分数据用于网络训练
  '''
  def  __init__(self, data_directory, samples):
    self.data_directory = data_directory
    #samples表示要抽取的数据量
    self.samples = samples
    self.files = set(file_name for file_name, index in samples)
  
  def  get_num_samples(self, batch_size = 128, num_classes = 19 * 19):
    '''
    为了加快数据读取速度,我们’按需‘抽取数据而不是一下子读取大量数据
    '''
    if  self.num_samples is not None:
      return  self.num_samples
    else:
      self.num_samples = 0
      for X, y in self._generate(batch_size = batch_size, num_classes = num_classes):
        self.num_samples += X.shape[0]
        
      return  self.num_samples
    
  def  _generate(self, batch_size, num_classes):
    for zip_file_name in self.files:
      file_name = zip_file_name.replace('.tar.gz', '') + 'train'
      base = self.data_director + '/' + file_name + '_features_*.npy'
      for feature_file in glob.glob(base):
        label_file = feature_file.replace('features', 'labels')
        x = np.load(feature_file)
        y = np.load(label_file)
        x = x.astype('float32')
        y = to_categorical(y.astype(int), num_classes)
        while x.shape[0] >= batch_size:
          x_batch, x = x[:batch_size], x[batch_size:]
          y_batch, y = y[:batch_size], y[batch_size:]
          yield x_batch, y_batch
          
  def  generate(self, batch_size = 128, num_classes = 19 * 19):
    while  True:
      for item in self._generate(batch_size, num_classes):
        yield  item

接下来我们改进GoDataProcessor,使用多线程去实现文件的解压,读取并编码成训练数据:

#将前面的GoDataProcessor改进为多线程版本
import tarfile
import gzip
import glob
import shutil
import os.path
import numpy as np

def  worker(jobinfo):
  #工作线程
  try:
    '''
    实例化GoDataProcessor,调用它的process_zip解压给定压缩文件,同时解析sgf文件,将它们转换
    为棋盘编码,这个过程可以使用多线程加速
    '''
    clazz, encoder, zip_file, data_file_name, game_list = jobinfo
    clazz(encoder=encoder).process_zip(zip_file, data_file_name, game_list)
  except (KeyboardInterrupt, SystemExit):
    raise  Exception('>>> Exiting child process.')
    

class GoDataProcessor:
  def  __init__(self, encoder = 'oneplane', data_directory = '/content/gdrive/My Drive/GO_RECORD'):
    if encoder == 'oneplane':
      self.encoder = OnePlaneEncoder(19) 
    
    self.encoder_string = encoder
    self.data_dir = data_directory
    
  def  load_go_data(self, data_type = 'train', num_samples = 1000,
                   use_generator = False):
    #从下载数据中抽取出给定数量的训练数据
    sampler = Sampler(data_dir = self.data_dir)
    data = sampler.draw_data(data_type, num_samples)
    
    #启动线程池
    self.map_to_workers(data_type, data)
    if use_generator:
      #将解析后的数据分批次喂给网络
      generator = DataGenerator(self.data_dir, data)
      return generator
    else:
      #按照老方式一下子将所有数据推给网络
      features_and_labels = self.consolidate_games(data_type, data)
      return  features_and_labels
    
    
  
  def  process_zip(self, zip_file_name, data_file_name, game_list):
    #先把数据文件解压出来
    tar_file = self.unzip_data(zip_file_name)
    zip_file = tarfile.open(self.data_dir + '/' + tar_file)
    #获得.tar.gz解压后文件夹中所有文件的名字集合
    name_list = zip_file.getnames()
    #这里得到当前棋盘总共下了多少步棋,game_list对应的是第几盘比赛
    total_examples = self.num_total_examples(zip_file, game_list, name_list)
    #shape = [19. 19]
    shape = self.encoder.shape()
    #feature_shape 是数组,每个元素是[19,19]二维向量
    feature_shape = np.insert(shape, 0, np.asarray([total_examples]))
    features = np.zeros(feature_shape)
    #lables是一维数组,每个元素对应落子位置
    labels = np.zeros((total_examples, ))
    print('process_zip with features len: ', len(features))
    features_len = len(features)
    
    counter = 0
    for index in game_list:
      name = name_list[index + 1]
      if not name.endswith('.sgf'):
        raise ValueError(name + ' is not a valid sgf')
        
      sgf_content = zip_file.extractfile(name).read()
      sgf = Sgf_game.from_string(sgf_content)
      
      '''
      水平高的一方可能会让子,于是另一方可直接连续落子,我们先处理这种情况
      '''
      game_state, first_move_done = self.get_handicap(sgf)
      
      for item in sgf.main_sequence_iter():
        #依次将落子步骤读取出来
        color, move_tuple = item.get_move()
        point = None
        if color is not None:
          if move_tuple is not None:
            row, col = move_tuple
            point = Point(row + 1, col + 1)
            move = Move.play(point)
          else:
            move = Move.pass_turn()
          if  first_move_done and point is not None:
            #如果有让子,那么把对方落子后的棋盘当做训练数据,然后另一方落子方式当做训练标签
              
            features[counter] = self.encoder.encode(game_state)
            labels[counter] = self.encoder.encode_point(point)
            counter += 1
          
          #先按照落子步骤形成棋盘,下一次读取落子时它就会变成训练数据
          game_state = game_state.apply_move(move)
          first_move_done = True
      
    feature_file_base = self.data_dir + '/' + data_file_name + '_features_%d'
    label_file_base = self.data_dir + '/' + data_file_name + '_label_%d'
      
    #我们将加工好的数据存储成文件
    chunk = 0
    chunksize = 1024
    #每1024条记录当做一个chunk,每一个chunk单独存储
    while features.shape[0] >= chunksize:
      feature_file = feature_file_base % chunk
      label_file = label_file_base % chunk
      chunk += 1
      current_features, features = features[:chunksize], features[chunksize:]
      current_labels, labels = labels[:chunksize], labels[chunksize:]
      np.save(feature_file, current_features)
      np.save(label_file, current_labels)
        
  def  unzip_data(self, zip_file_name):
    #.tar.gz文件经过了两层压缩,首先解压gz压缩
    this_gz = gzip.open(self.data_dir + '/' + zip_file_name)
    #去掉尾部的.gz后缀
    tar_file = zip_file_name[0:-3]
    #创建.tar文件,将解压后gz压缩后的内容拷贝到该文件
    this_tar = open(self.data_dir + '/' + tar_file, 'wb')
    shutil.copyfileobj(this_gz, this_tar)
    return tar_file
    
  
  def num_total_examples(self, zip_file, game_list, name_list):
    '''
    #根据棋盘描述文件中的落子次数推算出训练数据的长度,每一次落子前的棋盘会成为训练数据,
    落子则对应训练标签,一旦落子后形成的棋盘就会成为新的训练数据
    '''
    total_examples = 0
    for index in game_list:
      name = name_list[index + 1]
      if name.endswith('.sgf'):
        #zip_file对应解压后的tar文件,其中包含很多.sgf文件,这里把指定的sgf文件内容读取出来
        sfg_content = zip_file.extractfile(name).read()
        sgf = Sgf_game.from_string(sfg_content)
        game_state, first_move_done = self.get_handicap(sgf)
        
        num_moves = 0
        for item in sgf.main_sequence_iter():
          color, move = item.get_move()
          if color is not None:
            if first_move_done:
              num_moves += 1
            first_move_done = True
            
        total_examples = total_examples + num_moves
      else:
        raise ValueError(name + ' is not a valid sgf')
    
    return total_examples
  
  @staticmethod
  def  get_handicap(sgf):
    #将让子时对应的落子摆到棋盘上
    go_board = Board(19, 19)
    first_move_done = False
    move = None
    game_state = GameState.new_game(19)
    if sgf.get_handicap() is not None and sgf.get_handicap() != 0:
      for setup in sgf.get_root().get_setup_stones():
        for move in setup:
          row, col = move
          go_board.place_stone(Player.black, Point(row + 1, col + 1))
        
      first_move_done = True
      game_state = GameState(go_board, Player.white, None, move)
      
    return game_state, first_move_done
  
  #前面我们把数据存储成多个小段,这里我们把多个小段读入内存合作一个整体
  def  consolidate_games(self, data_type, samples):
    files_needed = set(file_name for file_name , index in samples)
    file_names = []
    for zip_file_name in files_needed:
      file_name = zip_file_name.replace('.tar.gz', '') + data_type
      file_names.append(file_name)
    
    feature_list = []
    label_list = []
    for file_name in file_names:
      file_prefix = file_name.replace('.tar.gz', '')
      base = self.data_dir + '/' + file_prefix + '_features_*.npy'
      print('consolidate with file: ', base)
      for feature_file in glob.glob(base):
        label_file = feature_file.replace('features', 'labels')
        x = np.load(feature_file)
        y = np.load(label_file)
        x = x.astype('float32')
        y = to_categorical(y.astype(int), 19 * 19)
        feature_list.append(x)
        label_list.append(y)
    
    features = np.concatenate(feature_list, axis = 0)
    labels = np.concatenate(label_list, axis = 0)
    np.save('{}/features_{}.npy'.format(self.data_dir, data_type), features)
    np.save('{}/labels_{}.npy'.format(self.data_dir, data_type), labels)
    
    return features, labels
  
  def  map_to_workers(self, data_type, samples):
    #将选中的文件数据进行解压,index 表示第几盘
    zip_names = set()
    indices_by_zip_name = {}
    for filename, index in samples:
      zip_names.add(filename)
      if filename not in indices_by_zip_name:
        indices_by_zip_name[filename] = []
        
      indices_by_zip_name[filename].append(index)
    
    zips_to_process = []
    for zip_name in zip_names:
      #创建解压后的文件名
      base_name = zip_name.replace('.tar.gz', '')
      data_file_name = base_name + data_type
      zips_to_process.append((self.__class__, self.encoder_string, zip_name,
                             data_file_name, indices_by_zip_name[zip_name]))
      
      cores = multiprocessing.cpu_count()
      pool = multiprocessing.Pool(processes = cores)
      p = pool.map_async(worker, zips_to_process)
      try:
        _ = p.get()
      except KeyboardInterrupt:
        pool.terminate()
        pool.join()
        sys.exit(-1)

上面代码跟以前代码差别不大,唯一差别在于使用多线程执行process_zip,也就是将文件的解压,读取,以及编码成训练数据的过程线程化,从而依赖多线程成倍提升效率。

本节代码比较繁琐,请参考视频加深理解。

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:


标签:name,zip,棋手,数据,self,num,file,data,预处理
From: https://blog.51cto.com/u_16160261/6476512

相关文章

  • NodeJS研究笔记:利用Buffer类的二进制数据读取接口解析ELF文件格式
    javascript作为前端开发语言,自古来对二进制数据的读取解析方面的支持都很薄弱,一般来说,解析二进制数据时,往往是将数据转换成字符串,然后运用各种字符串操作技巧来实现二进制数据的读取。由于NodeJS作为后台服务器开发平台,数理逻辑的设计需求超越javascript作为前端语言时界面UI的设......
  • go实现高并发高可用分布式系统:设计类似kafka的高并发海量数据存储机制2
    上一节我们完成了数据的存储和索引,本节我们看如何写入数据和进行查询。我们将创建一个Segment对象,它一方面接收发送来的请求,也就是Record数据结构,然后将数据写入到store和index,基本架构如下:在前面章节中,我们使用代码定义了Record的数据结构,现在我们需要使用protobuf来重新定义它,一......
  • 在FreeSWITCH中使用Lua脚本来将电话记录存储到MySQL数据库中
    在FreeSWITCH中使用Lua脚本来将电话记录存储到MySQL数据库中,需要做以下几个步骤:安装MySQL客户端库首先需要通过包管理器(如apt-get或yum)安装MySQL客户端库,以便FreeSWITCH能够与MySQL数据库进行通信。例如,在Ubuntu系统中,可以运行以下命令进行安装:sudoapt-getinstalllibmysq......
  • go实现高并发高可用分布式系统:设计类似kafka的高并发海量数据存储机制1
    上一节我们实现了日志微服务,它以http服务器的模式运行,客户端通过json方式将日志数据post过来,然后通过httpget的方式读取日志。当时我们的实现是将所有日志信息添加到数组末尾,这意味着所有日志信息都会保存在内存中。但分布式系统的日志数量将非常巨大,例如推特一天的日志数量就达到......
  • jexcel_删除行并同步数据库
    写在*.aspx中1//删除行OK2varmyDeleteRow=function(){3varDBID=document.getElementById("my_textbox").value;4//vartempConfirm=confirm("DBID为:"+DBID);//弹出确认框5vartempConfirm=confir......
  • jexcel_增加行并同步数据库
    写在*.aspx中1//增加行OK2varaddRow=function(){3varfieldName="type";//字段名4varmodifyValue="请输入";//值5//vartempConfirm=confirm("modifyValue:"+modifyValue+"......
  • MySQL 表信息查询,便于补数据库结构设计文档
    MySQL表信息查询,便于补数据库结构设计文档selectc.table_name表名,t.TABLE_COMMENT表说明,c.COLUMN_NAME列名,c.COLUMN......
  • 如何成功实施一个数据治理项目?实施步骤有哪些?
    企业数字化转型以数据为中心,通过数据驱动业务发展、管理协同和运营。因此,数字化转型关键在于数据,数据治理则需先行。从而更好激发数据生产要素潜能,实现业务数据化、数据价值化,助力企业数字化转型。那么何为数据治理?国际数据管理协会(DAMA)在其《DAMA数据管理知识体系指南(第2版)》一......
  • windows使用navicate 导出导入MongoDB数据
    1.下载安装navicate以及mongodb-database-tools-windowsmongodb-database-tools-windows下载地址 https://www.mongodb.com/try/download/database-tools 2.navicate设置MongoDBdump、mongorestore可执行文件路径(mongodb-database-tools里的bin目录)3.选择要备份或恢复......
  • 拼多多接口|api接口数据采集获取商品详情数据源代码Java演示
    ​拼多多提供了商品API,可以通过该API获取拼多多所有商品的详细信息,具体步骤如下: 申请开放平台接入。注册获取apikey和apisecret,调用API时需提供。调用拼多多API,获取商品详情。请求参数:参数说明通用参数说明version:API版本key:调用key,测试key:test_api_......