diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 355d5d9942..4135c08532 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -3144,6 +3144,94 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error( "'intermediate_events' or 'response' columns" ) in str(excinfo.value) + @mock.patch.object(_evals_utils, "EvalDatasetLoader") + @mock.patch("vertexai._genai._evals_common.vertexai.Client") + def test_run_inference_with_agent_engine_falls_back_to_managed_sessions_api( + self, + mock_vertexai_client, + mock_eval_dataset_loader, + ): + """Tests that run_inference falls back to the managed Sessions API + when the agent engine does not have create_session registered.""" + mock_df = pd.DataFrame( + { + "prompt": ["agent prompt"], + "session_inputs": [ + { + "user_id": "123", + "state": {"a": "1"}, + } + ], + } + ) + mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict( + orient="records" + ) + + # Create a mock agent engine WITHOUT create_session (simulates agents + # deployed via Console, gcloud, or source code deployment). + mock_agent_engine = mock.Mock( + spec=["api_client", "api_resource", "stream_query"], + ) + mock_agent_engine.api_resource.name = ( + "projects/test-project/locations/us-central1/reasoningEngines/123" + ) + + # Mock the managed Sessions API to return a session. + mock_session_operation = mock.Mock() + mock_session_operation.response.name = ( + "projects/test-project/locations/us-central1" + "/reasoningEngines/123/sessions/managed-session-1" + ) + mock_agent_engine.api_client.sessions.create.return_value = ( + mock_session_operation + ) + + stream_query_return_value = [ + { + "id": "1", + "content": {"parts": [{"text": "intermediate1"}]}, + "timestamp": 123, + "author": "model", + }, + { + "id": "2", + "content": {"parts": [{"text": "agent response"}]}, + "timestamp": 124, + "author": "model", + }, + ] + mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) + mock_vertexai_client.return_value.agent_engines.get.return_value = ( + mock_agent_engine + ) + + inference_result = self.client.evals.run_inference( + agent="projects/test-project/locations/us-central1/reasoningEngines/123", + src=mock_df, + ) + + # Verify the managed Sessions API was called as fallback. + mock_agent_engine.api_client.sessions.create.assert_called_once_with( + name="projects/test-project/locations/us-central1/reasoningEngines/123", + user_id="123", + config=vertexai_genai_types.CreateAgentEngineSessionConfig( + session_state={"a": "1"}, + ), + ) + + # Verify stream_query was called with the session ID extracted from + # the managed session's resource name. + mock_agent_engine.stream_query.assert_called_once_with( + user_id="123", + session_id="managed-session-1", + message="agent prompt", + ) + + # Verify the inference results are correct. + assert inference_result.eval_dataset_df["response"].iloc[0] == "agent response" + assert inference_result.candidate_name == "agent_engine_0" + @mock.patch.object(_evals_utils, "EvalDatasetLoader") @mock.patch("vertexai._genai._evals_common.InMemorySessionService") # fmt: skip @mock.patch("vertexai._genai._evals_common.Runner") diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index bfe5951fea..201135b731 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -1964,6 +1964,74 @@ def _run_agent( os.environ["GOOGLE_CLOUD_LOCATION"] = original_location +def _create_agent_engine_session( + *, + agent_engine: types.AgentEngine, + user_id: str, + session_state: Optional[dict[str, Any]] = None, +) -> Any: + """Creates a session for an agent engine and returns the session ID. + + First attempts to use the agent engine's own `create_session` operation + (available for agents deployed via AdkApp). If the agent engine does not + have `create_session` registered, falls back to the managed Vertex AI + Sessions API. + + Args: + agent_engine: The AgentEngine instance. + user_id: The user ID for the session. + session_state: Optional initial state for the session. + + Returns: + The session ID string. + + Raises: + RuntimeError: If the session could not be created via either path. + """ + try: + session = agent_engine.create_session( # type: ignore[attr-defined] + user_id=user_id, + state=session_state, + ) + return session["id"] + except AttributeError as exc: + # Agent engine does not have create_session registered (e.g. deployed + # via Console, gcloud, or source code deployment without AdkApp). + # Fall back to the managed Vertex AI Sessions API. + logger.info( + "Agent engine does not have 'create_session' operation registered." + " Falling back to managed Sessions API." + ) + if agent_engine.api_resource is None: + raise RuntimeError( + "Failed to create session: agent_engine.api_resource is None." + ) from exc + if agent_engine.api_client is None: + raise RuntimeError( + "Failed to create session: agent_engine.api_client is None." + ) from exc + operation = agent_engine.api_client.sessions.create( + name=agent_engine.api_resource.name, + user_id=user_id, + config=types.CreateAgentEngineSessionConfig( + session_state=session_state, + ), + ) + if operation.response and operation.response.name: + # Session name format: + # projects/{p}/locations/{l}/reasoningEngines/{re}/sessions/{id} + return operation.response.name.split("/")[-1] + elif operation.error: + raise RuntimeError( + f"Failed to create session via managed API: {operation.error}" + ) from exc + else: + raise RuntimeError( + "Failed to create session via managed API: " + "operation returned no response." + ) from exc + + def _execute_agent_run_with_retry( row: pd.Series, contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], @@ -1975,9 +2043,10 @@ def _execute_agent_run_with_retry( session_inputs = _get_session_inputs(row) user_id = session_inputs.user_id session_state = session_inputs.state - session = agent_engine.create_session( # type: ignore[attr-defined] + session_id = _create_agent_engine_session( + agent_engine=agent_engine, user_id=user_id, - state=session_state, + session_state=session_state, ) except KeyError as e: return {"error": f"Failed to get all required agent engine inputs: {e}"} @@ -1988,7 +2057,7 @@ def _execute_agent_run_with_retry( responses = [] for event in agent_engine.stream_query( # type: ignore[attr-defined] user_id=user_id, - session_id=session["id"], + session_id=session_id, message=contents, ): if event and CONTENT in event and PARTS in event[CONTENT]: