TensorFlow에서 모델 학습 시 데이터를 CSV 파일로부터 읽어오기 위한 예제 코드로 numpy를 이용한 방법과 tensorflow를 이용한 방법으로 정리, 이 때 CSV 파일은 2개가 있는 상황을 가정
0. 데이터 파일: Iris 데이터
data-iris-1.csv data-iris-2.csv
1. numpy 사용
import numpy as np
# first CSV file read
data = np.loadtxt("data-iris-1.csv", delimiter=",", dtype=np.float32) # 75x5 matrix
# slicing data into x and y
x = data[:,0:-1] # from 1st to (n-1)th column, when data has n columns
y = data[:,[-1]] # nth column, when data han n columns
# print variable type
print(type(x))
# print variable shape
print("x is", x.shape, "and y is", y.shape)
# second CSV file read
data = np.loadtxt("data-iris-2.csv", delimiter=",", dtype=np.float32) # 75x5 matrix
# slicing data and appending x and y
x = np.append(x, data[:,0:-1], axis=0)
y = np.append(y, data[:,[-1]], axis=0)
print("x is", x.shape, "and y is", y.shape)
numpy를 사용하는 경우, CSV 파일이 다수가 있다면 읽어 오는 과정을 반복적으로 수행하여야 한다는 번거로움이 있으며, 이 경우에는 tensorflow 방식으로 사용하는 것이 더 편리함.
2. TensorFlow 사용
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
# set CSV file list
filename_queue = tf.train.string_input_producer(
["data-iris-1.csv", "data-iris-2.csv"], shuffle=False, name='filename_queue')
# set tensorflow reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# set record_defaults corresponding to data form
record_defaults = [[0.]]*5 # record_defaults = [[0.], [0.], [0.], [0.], [0.]]
data = tf.decode_csv(value, record_defaults=record_defaults)
# set collecting data and batch option
train_x_batch, train_y_batch = tf.train.batch([data[0:-1], data[-1:]], batch_size=4)
sess = tf.Session()
# start (mandatory)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for step in range(2):
x_batch, y_batch = sess.run([train_x_batch, train_y_batch])
print(x_batch.shape, y_batch.shape) # print shape of each variable
print(x_batch, y_batch) # print data of each variable
# end (mandatory)
coord.request_stop()
coord.join(threads)
1) CSV 파일이 다수인 경우, filename_queue에 있는 list에 이를 입력 하면 됨.
2) 첫 2행은 python에서 TensorFlow를 import 할 때 나오는 로그를 보지 않기 위한 코드로 삭제해도 됨.
참고) 김성훈 교수, 모두를 위한 머신러닝/딥러닝 강의
반응형
'PROGRAMMING > TensorFlow' 카테고리의 다른 글
PyCharm에서 TensorFlow를 위한 새 프로젝트 생성 (2) | 2017.03.29 |
---|---|
윈도우10, Anaconda, TensorFlow, PyCharm 환경설정 (29) | 2017.03.08 |
댓글