본문 바로가기
PROGRAMMING/TensorFlow

TensorFlow 모델 학습을 위한 CSV 파일 읽기 예제 코드

by BLADEBONE 2017. 6. 15.

TensorFlow에서 모델 학습 시 데이터를 CSV 파일로부터 읽어오기 위한 예제 코드로 numpy를 이용한 방법과 tensorflow를 이용한 방법으로 정리, 이 때 CSV 파일은 2개가 있는 상황을 가정


0. 데이터 파일: Iris 데이터


data-iris-1.csv data-iris-2.csv



1. numpy 사용


import numpy as np

# first CSV file read
data = np.loadtxt("data-iris-1.csv", delimiter=",", dtype=np.float32) # 75x5 matrix
# slicing data into x and y
x = data[:,0:-1] # from 1st to (n-1)th column, when data has n columns
y = data[:,[-1]] # nth column, when data han n columns
# print variable type
print(type(x))
# print variable shape
print("x is", x.shape, "and y is", y.shape)

# second CSV file read
data = np.loadtxt("data-iris-2.csv", delimiter=",", dtype=np.float32) # 75x5 matrix
# slicing data and appending x and y
x = np.append(x, data[:,0:-1], axis=0)
y = np.append(y, data[:,[-1]], axis=0)
print("x is", x.shape, "and y is", y.shape)

ReadCSV_numpy.py


numpy를 사용하는 경우, CSV 파일이 다수가 있다면 읽어 오는 과정을 반복적으로 수행하여야 한다는 번거로움이 있으며, 이 경우에는 tensorflow 방식으로 사용하는 것이 더 편리함.




2. TensorFlow 사용


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

# set CSV file list
filename_queue = tf.train.string_input_producer(
["data-iris-1.csv", "data-iris-2.csv"], shuffle=False, name='filename_queue')

# set tensorflow reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# set record_defaults corresponding to data form
record_defaults = [[0.]]*5 # record_defaults = [[0.], [0.], [0.], [0.], [0.]]
data = tf.decode_csv(value, record_defaults=record_defaults)

# set collecting data and batch option
train_x_batch, train_y_batch = tf.train.batch([data[0:-1], data[-1:]], batch_size=4)

sess = tf.Session()

# start (mandatory)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for step in range(2):
x_batch, y_batch = sess.run([train_x_batch, train_y_batch])
print(x_batch.shape, y_batch.shape) # print shape of each variable
print(x_batch, y_batch) # print data of each variable

# end (mandatory)
coord.request_stop()
coord.join(threads)

ReadCSV_tensorflow.py


1) CSV 파일이 다수인 경우, filename_queue에 있는 list에 이를 입력 하면 됨.

2) 첫 2행은 python에서 TensorFlow를 import 할 때 나오는 로그를 보지 않기 위한 코드로 삭제해도 됨.



참고) 김성훈 교수, 모두를 위한 머신러닝/딥러닝 강의




반응형

댓글