client.lua 4.97 KiB
local client = {}
local safer = require("safer")
local requests = require("requests")
local dkjson = require("dkjson")
local copas = require("copas")
requests.http_socket = require("copas.http")
requests.https_socket = require("copas.http")
local function util_sleep(sleeptime)
   local co, ismain = coroutine.running()
   if (not co) or ismain then
      local socket = require("socket")
      socket.sleep(sleeptime)
   else
      copas.sleep(sleeptime)
   end
end
local function debug_response(response)
   print(response.status_code)
   print(response.url)
   for k,v in pairs(response) do
      print(k,v)
   end
   for k,v in pairs(response.headers) do
      print(k,v)
   end
   local j, err = response.json()
   if j then
      for k,v in pairs(j or {}) do
         print(k,v)
      end
   else
      print(err)
   end
end
local function do_request(verb)
   return function(req)
      local ok, response = pcall(requests[verb], req)
      if ok then
         return response
      else
         return nil, response
      end
   end
end
local do_get = do_request("get")
local do_post = do_request("post")
local Client = safer.readonly {
   register = function(self, sga_type, nodes, persisted_jobs)
      assert(type(persisted_jobs) == "table")
      assert(type(persisted_jobs.lost) == "table")
      assert(type(persisted_jobs.retrieved) == "table")
      local sga_name = self.config.sga_name
      for node_name, _ in pairs(nodes) do
         if not node_name:match("^[a-zA-Z0-9_.]+$") then
            return nil, "Invalid node name "..node_name
         end
      end
      local response, err
      repeat
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
response, err = do_post({ url = self.config.csbase_server.."/v1/sga", headers = {["Content-Type"] = "application/json"}, data = dkjson.encode({ name = sga_name, type = sga_type, platform = self.config.platform, project_root_dir = self.config.project_root_dir, algorithm_root_dir = self.config.algorithm_root_dir, sandbox_root_dir = self.config.sandbox_root_dir, nodes = nodes, actions = { path = self.base_uri.."/v1/sga/"..sga_name.."/path", job = self.base_uri.."/v1/sga/"..sga_name.."/job", shutdown = self.base_uri.."/v1/sga/"..sga_name.."/shutdown", }, persistent_data = { lost = persisted_jobs.lost, retrieved = persisted_jobs.retrieved, }, resources = self.config.resources, extra_config = self.config.extra_config }), }) if not response then self.logger:error(err) self.logger:info("Will retry...") util_sleep(self.config.register_retry_s) end until response --debug_response(response) if response.status_code == 201 or response.status_code == 302 then --self.logger:debug(response.text) local data = response.json() self.actions = data.actions self.registered = true self:status() self.logger:info("sga "..self.config.sga_name.." registered to "..self.config.csbase_server) return true else return nil, "Received status code "..response.status_code.." " end end, heartbeat = function(self) if not self.registered then return nil, "not registered" end local response, err = do_get(self.config.csbase_server..self.actions["heartbeat"].uri) if not response then return nil, err end return response.status_code == 200 end, status = function(self, nodes_status) local response, err = do_post({ url = self.config.csbase_server..self.actions["status"].uri, headers = {["Content-Type"] = "application/json"}, data = dkjson.encode({ nodes = nodes_status, }), }) if not response then return nil, err end return response.status_code == 200 end, completion = function(self, cmd_id, walltime_s, usertime_s, systime_s)