Client

Client represents a higher level interface to training API.

class abeja.train.Client(organization_id: str | None = None, job_definition_name: str | None = None, training_job_id: str | None = None, credential: Dict[str, str] | None = None, timeout: int | None = None, max_retry_count: int | None = None)

A High-Level client for Training API

from abeja.train import Client

client = Client(organization_id='1234567890123')
Params:
  • organization_id (str): The organization ID. Takes from os.environ[‘ABEJA_ORGANIZATION_ID’] if omitted.

  • job_definition_name (str): The job definition name. Takes from os.environ[‘TRAINING_JOB_DEFINITION_NAME’] if omitted.

  • training_job_id (str): The training job ID. Takes from os.environ[‘TRAINING_JOB_ID’] if omitted.

download_training_result(training_job_id: str | None = None, path: str | None = None) None

download training artifact that includes model file or logs. Training artifact itself is a zip file, this function extracts it.

API reference: GET /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/result

Request Syntax:
training_job_id = '1600000000000'
client.download_training_result(training_job_id)
Params:
  • path (str): target path where result are extracted [optional]

Raises:
  • abeja.exceptions.BadRequest

  • abeja.exceptions.Unauthorized

  • abeja.exceptions.Forbidden

  • abeja.exceptions.NotFound

  • abeja.exceptions.MethodNotAllowed

  • abeja.exceptions.InternalServerError

  • json.JSONDecodeError

  • PermissionError

  • ValueError

  • IOError

update_statistics(statistics: Statistics) None

Notify a job statistics for ABEJA Platform.

API reference: POST /organizations/<organization_id>/training/definitions/<job_definition_name>/jobs/<training_job_id>/statistics

Request Syntax:
from abeja.train import Client
from abeja.train.statistics import Statistics as ABEJAStatistics

client = Client()

statistics = ABEJAStatistics(num_epochs=10, epoch=1)
statistics.add_stage(name=ABEJAStatistics.STAGE_TRAIN, accuracy=90.0, loss=0.10)
statistics.add_stage(name=ABEJAStatistics.STAGE_VALIDATION, accuracy=75.0, loss=0.07)

client.update_statistics(statistics)
Params:
Returns:

None

Statistics

class abeja.train.statistics.Statistics(num_epochs: int | None = None, epoch: int | None = None, progress_percentage: float | None = None, **kwargs)
STAGE_TRAIN = 'train'
STAGE_VALIDATION = 'validation'
add_stage(name: str, accuracy: float | None = None, loss: float | None = None, **kwargs) None

add stage information

Params:
  • name (str): name of stage. It have prepared STAGE_TRAIN and STAGE_VALIDATION as constants, but you can set arbitrary character strings.

  • accuracy (float): accuracy rate that value needs between 0 and 1.

  • loss (float): loss rate that value needs between 0 and 1.

Returns:

None

Raises:
  • ValueError

classmethod from_response(response: Dict[str, Any] | None) Statistics | None
get_statistics() dict

get stage information

Return type:

dict

Returns:

Response Syntax:

{
    'num_epochs': 10,
    'epoch': 1,
    'progress_percentage': 90,
}