Client represents a higher level interface to training API.
abeja.train.
Client
(organization_id: Optional[str] = None, job_definition_name: Optional[str] = None, training_job_id: Optional[str] = None, credential: Optional[Dict[str, str]] = None, timeout: Optional[int] = None, max_retry_count: Optional[int] = None)¶A High-Level client for Training API
from abeja.train import Client
client = Client(organization_id='1234567890123')
organization_id (str): The organization ID. Takes from os.environ[‘ABEJA_ORGANIZATION_ID’] if omitted.
job_definition_name (str): The job definition name. Takes from os.environ[‘TRAINING_JOB_DEFINITION_NAME’] if omitted.
training_job_id (str): The training job ID. Takes from os.environ[‘TRAINING_JOB_ID’] if omitted.
download_training_result
(training_job_id: Optional[str] = None, path: Optional[str] = None) → None¶download training artifact that includes model file or logs. Training artifact itself is a zip file, this function extracts it.
API reference: GET /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/result
training_job_id = '1600000000000'
client.download_training_result(training_job_id)
path (str): target path where result are extracted [optional]
abeja.exceptions.BadRequest
abeja.exceptions.Unauthorized
abeja.exceptions.Forbidden
abeja.exceptions.NotFound
abeja.exceptions.MethodNotAllowed
abeja.exceptions.InternalServerError
json.JSONDecodeError
PermissionError
ValueError
IOError
update_statistics
(statistics: abeja.training.statistics.Statistics) → None¶Notify a job statistics for ABEJA Platform.
API reference: POST /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/statistics
from abeja.train import Client
from abeja.train.statistics import Statistics as ABEJAStatistics
client = Client()
statistics = ABEJAStatistics(num_epochs=10, epoch=1)
statistics.add_stage(name=ABEJAStatistics.STAGE_TRAIN, accuracy=90.0, loss=0.10)
statistics.add_stage(name=ABEJAStatistics.STAGE_VALIDATION, accuracy=75.0, loss=0.07)
client.update_statistics(statistics)
statistics (abeja.train.statistics.Statistics
): job statistics to nofity
None
abeja.train.statistics.
Statistics
(num_epochs: Optional[int] = None, epoch: Optional[int] = None, progress_percentage: Optional[float] = None, **kwargs)¶STAGE_TRAIN
= 'train'¶STAGE_VALIDATION
= 'validation'¶add_stage
(name: str, accuracy: Optional[float] = None, loss: Optional[float] = None, **kwargs) → None¶add stage information
name (str): name of stage. It have prepared STAGE_TRAIN and STAGE_VALIDATION as constants, but you can set arbitrary character strings.
accuracy (float): accuracy rate that value needs between 0 and 1.
loss (float): loss rate that value needs between 0 and 1.
None
ValueError
from_response
(response: Optional[Dict[str, Any]]) → Optional[abeja.training.statistics.Statistics]¶get_statistics
() → dict¶get stage information
dict
Response Syntax:
{
'num_epochs': 10,
'epoch': 1,
'progress_percentage': 90,
}