Client

Client represents a higher level interface to training API.

class abeja.train.Client(organization_id: str = None, job_definition_name: str = None, training_job_id: str = None, credential: typing.Dict[str, str] = None, timeout: typing.Union[int, NoneType] = None, max_retry_count: typing.Union[int, NoneType] = None) → None

A High-Level client for Training API

from abeja.train import Client

client = Client(organization_id='1234567890123')
Params:
  • organization_id (str): The organization ID. Takes from os.environ[‘ABEJA_ORGANIZATION_ID’] if omitted.
  • job_definition_name (str): The job definition name. Takes from os.environ[‘TRAINING_JOB_DEFINITION_NAME’] if omitted.
  • training_job_id (str): The training job ID. Takes from os.environ[‘TRAINING_JOB_ID’] if omitted.
download_training_result(training_job_id: str = None, path: str = None) → None

download training artifact that includes model file or logs. Training artifact itself is a zip file, this function extracts it.

API reference: GET /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/result

Request Syntax:
training_job_id = '1600000000000'
client.download_training_result(training_job_id)
Params:
  • path (str): target path where result are extracted [optional]
Raises:
  • abeja.exceptions.BadRequest
  • abeja.exceptions.Unauthorized
  • abeja.exceptions.Forbidden
  • abeja.exceptions.NotFound
  • abeja.exceptions.MethodNotAllowed
  • abeja.exceptions.InternalServerError
  • json.JSONDecodeError
  • PermissionError
  • ValueError
  • IOError
update_statistics(statistics: abeja.training.statistics.Statistics) → None

Notify a job statistics for ABEJA Platform.

API reference: POST /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/statistics

Request Syntax:
from abeja.train import Client
from abeja.train.statistics import Statistics as ABEJAStatistics

client = Client()

statistics = ABEJAStatistics(num_epochs=10, epoch=1)
statistics.add_stage(name=ABEJAStatistics.STAGE_TRAIN, accuracy=90.0, loss=0.10)
statistics.add_stage(name=ABEJAStatistics.STAGE_VALIDATION, accuracy=75.0, loss=0.07)

client.update_statistics(statistics)
Params:
Returns:
None

Statistics

class abeja.train.statistics.Statistics(num_epochs: int = None, epoch: int = None, progress_percentage: float = None, **kwargs) → None
STAGE_TRAIN = 'train'
STAGE_VALIDATION = 'validation'
add_stage(name: str, accuracy: float = None, loss: float = None, **kwargs) → None

add stage information

Params:
  • name (str): name of stage. It have prepared STAGE_TRAIN and STAGE_VALIDATION as constants, but you can set arbitrary character strings.
  • accuracy (float): accuracy rate that value needs between 0 and 1.
  • loss (float): loss rate that value needs between 0 and 1.
Returns:
None
Raises:
  • ValueError
classmethod from_response(response: typing.Union[typing.Dict[str, typing.Any], NoneType]) → typing.Union[_ForwardRef('Statistics'), NoneType]
get_statistics() → dict

get stage information

Return type:
dict
Returns:

Response Syntax:

{
    'num_epochs': 10,
    'epoch': 1,
    'progress_percentage': 90,
}