BaseSampler
BaseSampler
is a generic class designed for sampling environments and agents in a parallel simulation framework. It provides a method to sample an environment and a list of agents based on specified parameters.
Class Definition
class BaseSampler(Generic[ObsType, ActType]):
def __init__(
self,
env_candidates: Sequence[EnvironmentProfile | str] | None = None,
agent_candidates: Sequence[AgentProfile | str] | None = None,
) -> None:
self.env_candidates = env_candidates
self.agent_candidates = agent_candidates
Parameters
env_candidates
(Sequence[EnvironmentProfile | str] | None
, optional): A sequence of environment profiles or strings. Defaults toNone
.agent_candidates
(Sequence[AgentProfile | str] | None
, optional): A sequence of agent profiles or strings. Defaults toNone
.
Methods
sample
def sample(
self,
agent_classes: Type[BaseAgent[ObsType, ActType]]
| list[Type[BaseAgent[ObsType, ActType]]],
n_agent: int = 2,
replacement: bool = True,
size: int = 1,
env_params: dict[str, Any] = {},
agents_params: list[dict[str, Any]] = [{}, {}],
) -> Generator[EnvAgentCombo[ObsType, ActType], None, None]:
Description
Sample an environment and a list of agents.
Parameters
agent_classes
(Type[BaseAgent[ObsType, ActType]] | list[Type[BaseAgent[ObsType, ActType]]]
): A single agent class for all sampled agents or a list of agent classes.n_agent
(int
, optional): Number of agents. Defaults to2
.replacement
(bool
, optional): Whether to sample with replacement. Defaults toTrue
.size
(int
, optional): Number of samples. Defaults to1
.env_params
(dict[str, Any]
, optional): Parameters for the environment. Defaults to{}
.agents_params
(list[dict[str, Any]]
, optional): Parameters for the agents. Defaults to[{}, {}]
.
Returns
Generator[EnvAgentCombo[ObsType, ActType], None, None]
: A generator yielding tuples containing an environment and a list of agents.
Usage Example
from sotopia.agents.base_agent import BaseAgent
from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile
from sotopia.envs.parallel import ParallelSotopiaEnv
# Define a custom agent class inheriting from BaseAgent
class CustomAgent(BaseAgent):
pass
# Initialize the BaseSampler
sampler = BaseSampler()
# Sample an environment and agents
samples = sampler.sample(agent_classes=[CustomAgent], n_agent=3, size=5)
# Iterate over the generated samples
for env, agents in samples:
print(f"Environment: {env}")
print(f"Agents: {agents}")
Note: The sample
method raises NotImplementedError
and must be implemented in a subclass to function correctly.