Commit 964db1c4 authored by gijs's avatar gijs

Some catalog restructuring and start of charrnn

parent 41c6b250
# -*- coding: utf-8 -*-
from flask import Flask, request
import glob
import os.path
import json
import re
import subprocess
app = Flask(__name__)
# checkpoint_dir = '/home/algolit/torch-rnn/cvs'
sample_bin = '/home/algolit/torch-rnn/sample.lua'
checkpoint_dir = './test-data'
def get_datasets():
return [os.path.basename(p) for p in glob.glob(os.path.join(checkpoint_dir, '*'))]
def get_checkpoints(dataset):
return [int(os.path.splitext(os.path.basename(p))[0]) for p in glob.glob(os.path.join(checkpoint_dir, dataset, '*.t7'))]
def get_checkpoint_path(dataset, checkpoint):
return os.path.join(checkpoint_dir, dataset, '{}.t7'.format(checkpoint))
@app.after_request
def add_headers(response):
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers',
'Content-Type,Authorization')
return response
@app.route('/datasets', methods=["GET"])
def datasets():
datasets = get_datasets()
datasets.sort()
return json.dumps(datasets)
@app.route('/checkpoints/<dataset>', methods=["GET"])
def checkpoints(dataset):
dataset = re.sub(r'[^\w-]', '', dataset)
checkpoints = get_checkpoints(dataset)
checkpoints.sort()
return json.dumps(checkpoints)
@app.route('/generate', methods=["POST"])
def generate():
dataset = re.sub(r'[^\w-]', '', request.form['dataset'])
checkpoint = int(request.form['checkpoint'])
length = int(request.form['length'])
checkpoint_file = get_checkpoint_path(dataset, checkpoint)
args = ['th', sample_bin, '-checkpoint', checkpoint, '-length', length]
#generated_text = subprocess.check_output(args)
return json.dumps({'text': json.dumps(args)})
return json.dumps({'text': generated_text})
if __name__ == '__main__':
app.run(host="localhost", port=5556, debug=True)
# torch-rnn/checkpoint-*
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment