Client represents a higher level interface to training API.
abeja.train.
Client
(organization_id: str = None, job_definition_name: str = None, training_job_id: str = None, credential: typing.Dict[str, str] = None, timeout: typing.Union[int, NoneType] = None, max_retry_count: typing.Union[int, NoneType] = None) → None¶A High-Level client for Training API
from abeja.train import Client
client = Client(organization_id='1234567890123')
download_training_result
(training_job_id: str = None, path: str = None) → None¶download training artifact that includes model file or logs. Training artifact itself is a zip file, this function extracts it.
API reference: GET /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/result
training_job_id = '1600000000000'
client.download_training_result(training_job_id)
update_statistics
(statistics: abeja.training.statistics.Statistics) → None¶Notify a job statistics for ABEJA Platform.
API reference: POST /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/statistics
from abeja.train import Client
from abeja.train.statistics import Statistics as ABEJAStatistics
client = Client()
statistics = ABEJAStatistics(num_epochs=10, epoch=1)
statistics.add_stage(name=ABEJAStatistics.STAGE_TRAIN, accuracy=90.0, loss=0.10)
statistics.add_stage(name=ABEJAStatistics.STAGE_VALIDATION, accuracy=75.0, loss=0.07)
client.update_statistics(statistics)
abeja.train.statistics.Statistics
): job statistics to nofityabeja.train.statistics.
Statistics
(num_epochs: int = None, epoch: int = None, progress_percentage: float = None, **kwargs) → None¶STAGE_TRAIN
= 'train'¶STAGE_VALIDATION
= 'validation'¶add_stage
(name: str, accuracy: float = None, loss: float = None, **kwargs) → None¶add stage information
from_response
(response: typing.Union[typing.Dict[str, typing.Any], NoneType]) → typing.Union[_ForwardRef('Statistics'), NoneType]¶get_statistics
() → dict¶get stage information
Response Syntax:
{
'num_epochs': 10,
'epoch': 1,
'progress_percentage': 90,
}